Spaces:
Sleeping
Sleeping
import torch | |
import os | |
import sys | |
import gradio as gr | |
project_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) | |
sys.path.append(project_dir) | |
from model import MultiTaskBertModel | |
from data_loader import load_dataset | |
from utils import bert_config, tokenizer, intent_ids_to_labels, intent_labels_to_ids, ner_labels_to_ids, ner_ids_to_labels | |
config = bert_config() | |
dataset = load_dataset("training_dataset") | |
model = MultiTaskBertModel(config, dataset) | |
model.load_state_dict(torch.load("pytorch_model.bin")) | |
model.eval() | |
ner_label_to_id = ner_labels_to_ids() | |
ner_id_to_label = ner_ids_to_labels(ner_label_to_id) | |
def predict(input_data): | |
tok = tokenizer() | |
preprocessed_input = tok(input_data, | |
return_offsets_mapping=True, | |
padding='max_length', | |
truncation=True, | |
max_length=128) | |
input_ids = torch.tensor([preprocessed_input['input_ids']]) | |
attention_mask = torch.tensor([preprocessed_input['attention_mask']]) | |
offset_mapping = torch.tensor(preprocessed_input['offset_mapping']) | |
with torch.no_grad(): | |
ner_logits, intent_logits = model.forward(input_ids, attention_mask) | |
ner_logits = torch.argmax(ner_logits.view(-1, 9), dim=1) | |
intent_logits = torch.argmax(intent_logits) | |
aligned_predictions = [] | |
for prediction, (start, end) in zip(ner_logits, offset_mapping): | |
if start == end: | |
continue | |
word = input_data[start:end] | |
if not word.strip(): | |
continue | |
aligned_predictions.append((word, ner_id_to_label[int(prediction)])) | |
labels = intent_labels_to_ids() | |
intent_labels = intent_ids_to_labels(labels) | |
intent_labels = intent_labels[int(intent_logits)] | |
return f"Ner logits: {aligned_predictions}, Intent Label: {intent_labels}" | |
title = "Multi Task Model" | |
description = ''' | |
This model is designed for a scheduler application, capable of handling various tasks such as setting | |
timers, scheduling meetings, appointments, and alarms. It provides Named Entity Recognition (NER) labels | |
to identify specific entities within the input text, along with an Intent label to determine the | |
overall task intention. The model's outputs facilitate efficient task management and organization, | |
enabling seamless interaction with the scheduler application. | |
<img src="bart.jpg" width=300px> | |
''' | |
gr.Interface( | |
fn=predict, | |
inputs="text", | |
outputs="text", | |
title=title, | |
description=description, | |
examples=[["Remind me about the meeting at 3 PM tomorrow"], ["Set a timer for 10 minutes"]], | |
).launch(share=True) |