|
|
import gradio as gr |
|
|
from transformers import BertTokenizer, BertForSequenceClassification |
|
|
import torch |
|
|
|
|
|
|
|
|
model_name = "bert-base-multilingual-cased" |
|
|
model = BertForSequenceClassification.from_pretrained(model_name, num_labels=4) |
|
|
tokenizer = BertTokenizer.from_pretrained(model_name) |
|
|
|
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
model.to(device) |
|
|
|
|
|
|
|
|
def predict_category(document): |
|
|
|
|
|
if not document.strip(): |
|
|
return "Please enter some text to classify." |
|
|
|
|
|
|
|
|
inputs = tokenizer(document, return_tensors="pt", truncation=True, padding=True, max_length=512) |
|
|
|
|
|
|
|
|
inputs = {key: value.to(device) for key, value in inputs.items()} |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = model(**inputs) |
|
|
logits = outputs.logits |
|
|
|
|
|
|
|
|
predicted_label = torch.argmax(logits, dim=-1).item() |
|
|
|
|
|
|
|
|
labels = ['Business', 'Finance', 'Marketing', 'HR'] |
|
|
return labels[predicted_label] |
|
|
|
|
|
|
|
|
iface = gr.Interface( |
|
|
fn=predict_category, |
|
|
inputs=gr.Textbox(lines=10, placeholder="Upload or paste a business document here..."), |
|
|
outputs="text", |
|
|
title="Multilingual Business Document Classifier", |
|
|
description="This tool classifies business documents into categories using a pre-trained multilingual BERT model.", |
|
|
live=True |
|
|
) |
|
|
|
|
|
|
|
|
iface.launch() |
|
|
|