Alexvatti's picture
Update app.py
a70aa3f verified
import torch
import torch.nn as nn
import re
import pickle
import gradio as gr
import spaces
# Define paths
MODEL_PATH = "spam_model.pth"
VOCAB_PATH = "vocab.pkl"
class TransformerEncoder(nn.Module):
def __init__(self, d_model=256, num_heads=1, d_ff=512, num_layers=1, vocab_size=10000, max_seq_len=100, dropout=0.1):
super(TransformerEncoder, self).__init__()
# Embedding & Positional Encoding
self.embedding = nn.Embedding(vocab_size, d_model)
self.positional_encoding = nn.Parameter(torch.zeros(1, max_seq_len, d_model))
# Transformer Encoder Layers
encoder_layer = nn.TransformerEncoderLayer(
d_model=d_model,
nhead=num_heads,
dim_feedforward=d_ff,
dropout=dropout,
activation='relu',
batch_first=True
)
self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
# Classification Head
self.fc = nn.Linear(d_model, 1)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
x = self.embedding(x) + self.positional_encoding[:, :x.size(1), :]
x = self.encoder(x) # Pass through transformer
x = x[:, 0, :] # Take first token's output (CLS token equivalent)
x = self.fc(x)
return self.sigmoid(x) # Binary classification (spam or not)
with open(VOCAB_PATH, "rb") as f:
vocab = pickle.load(f)
# Load model
device = torch.device("cuda")
model = TransformerEncoder(d_model=256, num_heads=1, num_layers=1, vocab_size=len(vocab), max_seq_len=100).to(device)
model.load_state_dict(torch.load(MODEL_PATH, map_location=device))
model.to(device)
model.eval() # Set model to evaluation mode
print("βœ… Model and vocabulary loaded successfully!")
def simple_tokenize(text):
return re.findall(r"\b\w+\b", text.lower())
@spaces.GPU
def predict(text):
max_len=100
model.eval()
tokens = simple_tokenize(text.lower())
token_ids = [vocab.get(word, vocab['<UNK>']) for word in tokens]
token_ids += [vocab['<PAD>']] * (max_len - len(token_ids)) # Pad if needed
input_tensor = torch.tensor([token_ids], dtype=torch.long).to(device)
with torch.no_grad():
output = model(input_tensor).squeeze().item()
predicted_label = "Spam" if output > 0.5 else "Ham"
return f"Predicted Class : {predicted_label} "
gr.Interface(
fn=predict,
inputs="text",
outputs="text",
title="Encoder Spam Classifier"
).launch()