MCBD / app.py
Anushree1's picture
Update app.py
782dbf8 verified
import gradio as gr
from transformers import BertTokenizer, BertForSequenceClassification
import torch
# Step 1: Load the pre-trained model and tokenizer
model_name = "bert-base-multilingual-cased" # Multilingual BERT model
model = BertForSequenceClassification.from_pretrained(model_name, num_labels=4) # Adjust num_labels based on your dataset
tokenizer = BertTokenizer.from_pretrained(model_name)
# Step 2: Load the model onto the device (GPU if available, otherwise CPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
# Step 3: Define a function to predict the class of the document
def predict_category(document):
# Check for empty input
if not document.strip():
return "Please enter some text to classify."
# Step 3a: Tokenize the input document
inputs = tokenizer(document, return_tensors="pt", truncation=True, padding=True, max_length=512)
# Step 3b: Move inputs to the same device as the model
inputs = {key: value.to(device) for key, value in inputs.items()}
# Step 3c: Get model predictions
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
# Step 3d: Get the predicted label
predicted_label = torch.argmax(logits, dim=-1).item()
# Mapping predicted labels to corresponding categories (update this with your own labels)
labels = ['Business', 'Finance', 'Marketing', 'HR'] # Replace with your actual label classes
return labels[predicted_label]
# Step 4: Create the Gradio interface
iface = gr.Interface(
fn=predict_category,
inputs=gr.Textbox(lines=10, placeholder="Upload or paste a business document here..."),
outputs="text",
title="Multilingual Business Document Classifier",
description="This tool classifies business documents into categories using a pre-trained multilingual BERT model.",
live=True
)
# Step 5: Launch the app
iface.launch()