|
import torch |
|
import torch.nn as nn |
|
import re |
|
import pickle |
|
import gradio as gr |
|
import spaces |
|
|
|
|
|
MODEL_PATH = "spam_model.pth" |
|
VOCAB_PATH = "vocab.pkl" |
|
|
|
class TransformerEncoder(nn.Module): |
|
def __init__(self, d_model=256, num_heads=1, d_ff=512, num_layers=1, vocab_size=10000, max_seq_len=100, dropout=0.1): |
|
super(TransformerEncoder, self).__init__() |
|
|
|
|
|
self.embedding = nn.Embedding(vocab_size, d_model) |
|
self.positional_encoding = nn.Parameter(torch.zeros(1, max_seq_len, d_model)) |
|
|
|
|
|
encoder_layer = nn.TransformerEncoderLayer( |
|
d_model=d_model, |
|
nhead=num_heads, |
|
dim_feedforward=d_ff, |
|
dropout=dropout, |
|
activation='relu', |
|
batch_first=True |
|
) |
|
|
|
self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers) |
|
|
|
|
|
self.fc = nn.Linear(d_model, 1) |
|
self.sigmoid = nn.Sigmoid() |
|
|
|
def forward(self, x): |
|
x = self.embedding(x) + self.positional_encoding[:, :x.size(1), :] |
|
x = self.encoder(x) |
|
x = x[:, 0, :] |
|
x = self.fc(x) |
|
return self.sigmoid(x) |
|
|
|
with open(VOCAB_PATH, "rb") as f: |
|
vocab = pickle.load(f) |
|
|
|
|
|
device = torch.device("cuda") |
|
model = TransformerEncoder(d_model=256, num_heads=1, num_layers=1, vocab_size=len(vocab), max_seq_len=100).to(device) |
|
model.load_state_dict(torch.load(MODEL_PATH, map_location=device)) |
|
model.to(device) |
|
model.eval() |
|
|
|
print("β
Model and vocabulary loaded successfully!") |
|
|
|
def simple_tokenize(text): |
|
return re.findall(r"\b\w+\b", text.lower()) |
|
|
|
@spaces.GPU |
|
def predict(text): |
|
max_len=100 |
|
model.eval() |
|
tokens = simple_tokenize(text.lower()) |
|
token_ids = [vocab.get(word, vocab['<UNK>']) for word in tokens] |
|
token_ids += [vocab['<PAD>']] * (max_len - len(token_ids)) |
|
input_tensor = torch.tensor([token_ids], dtype=torch.long).to(device) |
|
|
|
with torch.no_grad(): |
|
output = model(input_tensor).squeeze().item() |
|
|
|
predicted_label = "Spam" if output > 0.5 else "Ham" |
|
return f"Predicted Class : {predicted_label} " |
|
|
|
gr.Interface( |
|
fn=predict, |
|
inputs="text", |
|
outputs="text", |
|
title="Encoder Spam Classifier" |
|
).launch() |