import torch import torch.nn as nn import re import pickle import gradio as gr import spaces # Define paths MODEL_PATH = "spam_model.pth" VOCAB_PATH = "vocab.pkl" class TransformerEncoder(nn.Module): def __init__(self, d_model=256, num_heads=1, d_ff=512, num_layers=1, vocab_size=10000, max_seq_len=100, dropout=0.1): super(TransformerEncoder, self).__init__() # Embedding & Positional Encoding self.embedding = nn.Embedding(vocab_size, d_model) self.positional_encoding = nn.Parameter(torch.zeros(1, max_seq_len, d_model)) # Transformer Encoder Layers encoder_layer = nn.TransformerEncoderLayer( d_model=d_model, nhead=num_heads, dim_feedforward=d_ff, dropout=dropout, activation='relu', batch_first=True ) self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers) # Classification Head self.fc = nn.Linear(d_model, 1) self.sigmoid = nn.Sigmoid() def forward(self, x): x = self.embedding(x) + self.positional_encoding[:, :x.size(1), :] x = self.encoder(x) # Pass through transformer x = x[:, 0, :] # Take first token's output (CLS token equivalent) x = self.fc(x) return self.sigmoid(x) # Binary classification (spam or not) with open(VOCAB_PATH, "rb") as f: vocab = pickle.load(f) # Load model device = torch.device("cuda") model = TransformerEncoder(d_model=256, num_heads=1, num_layers=1, vocab_size=len(vocab), max_seq_len=100).to(device) model.load_state_dict(torch.load(MODEL_PATH, map_location=device)) model.to(device) model.eval() # Set model to evaluation mode print("✅ Model and vocabulary loaded successfully!") def simple_tokenize(text): return re.findall(r"\b\w+\b", text.lower()) @spaces.GPU def predict(text): max_len=100 model.eval() tokens = simple_tokenize(text.lower()) token_ids = [vocab.get(word, vocab['']) for word in tokens] token_ids += [vocab['']] * (max_len - len(token_ids)) # Pad if needed input_tensor = torch.tensor([token_ids], dtype=torch.long).to(device) with torch.no_grad(): output = model(input_tensor).squeeze().item() predicted_label = "Spam" if output > 0.5 else "Ham" return f"Predicted Class : {predicted_label} " gr.Interface( fn=predict, inputs="text", outputs="text", title="Encoder Spam Classifier" ).launch()