Alexvatti's picture
Update app.py
a7fdab0 verified
raw
history blame
2.64 kB
import torch
import torch.nn as nn
import torch.optim as optim
import pandas as pd
import numpy as np
import re
import torch
import pickle
import json
import gradio as gr
# Define paths
MODEL_PATH = "spam_model.pth"
VOCAB_PATH = "vocab.pkl"
class TransformerEncoder(nn.Module):
def __init__(self, d_model=256, num_heads=1, d_ff=512, num_layers=1, vocab_size=10000, max_seq_len=100, dropout=0.1):
super(TransformerEncoder, self).__init__()
# Embedding & Positional Encoding
self.embedding = nn.Embedding(vocab_size, d_model)
self.positional_encoding = nn.Parameter(torch.zeros(1, max_seq_len, d_model))
# Transformer Encoder Layers
encoder_layer = nn.TransformerEncoderLayer(
d_model=d_model,
nhead=num_heads,
dim_feedforward=d_ff,
dropout=dropout,
activation='relu',
batch_first=True
)
self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
# Classification Head
self.fc = nn.Linear(d_model, 1)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
x = self.embedding(x) + self.positional_encoding[:, :x.size(1), :]
x = self.encoder(x) # Pass through transformer
x = x[:, 0, :] # Take first token's output (CLS token equivalent)
x = self.fc(x)
return self.sigmoid(x) # Binary classification (spam or not)
with open(VOCAB_PATH, "rb") as f:
vocab = pickle.load(f)
# Load model
device = torch.device("cpu")
device = torch.device("cpu") # Change to "cuda" if using GPU
model = TransformerEncoder(d_model=256, num_heads=1, num_layers=1, vocab_size=len(vocab), max_seq_len=100).to(device)
model.load_state_dict(torch.load(MODEL_PATH, map_location=device))
model.to(device)
model.eval() # Set model to evaluation mode
print("βœ… Model and vocabulary loaded successfully!")
def simple_tokenize(text):
return re.findall(r"\b\w+\b", text.lower())
def predict(text):
max_len=100
model.eval()
tokens = simple_tokenize(text.lower())
token_ids = [vocab.get(word, vocab['<UNK>']) for word in tokens]
token_ids += [vocab['<PAD>']] * (max_len - len(token_ids)) # Pad if needed
input_tensor = torch.tensor([token_ids], dtype=torch.long).to(device)
with torch.no_grad():
output = model(input_tensor).squeeze().item()
predicted_label = "Spam" if output > 0.5 else "Ham"
return f"Predicted Class : {predicted_label} "
gr.Interface(
fn=predict,
inputs="text",
outputs="text",
title="Encoder Spam Classifier"
).launch(share=True)