chat2 / app.py
ethanrom's picture
Update app.py
be17324
import gradio as gr
import torch
from transformers import pipeline
from transformers import PegasusForConditionalGeneration, PegasusTokenizer
classifier = pipeline(
"question-answering",
model="deepset/roberta-base-squad2",
tokenizer="deepset/roberta-base-squad2"
)
model_name = 'tuner007/pegasus_paraphrase'
torch_device = 'cuda' if torch.cuda.is_available() else 'cpu'
tokenizer3 = PegasusTokenizer.from_pretrained(model_name)
model3 = PegasusForConditionalGeneration.from_pretrained(model_name).to(torch_device)
def qa_paraphrase(text_input, question):
prediction = classifier(
context=text_input,
question=question,
truncation=True,
max_length=512,
padding=True,
)
answer = prediction['answer']
answer_start = prediction['start']
answer_end = prediction['end']
context = text_input.split(".")
for i in range(len(context)):
if answer in context[i]:
sentence = context[i].strip() + "."
break
batch = tokenizer3([sentence],truncation=True,padding='longest',max_length=60, return_tensors="pt").to(torch_device)
translated = model3.generate(**batch,max_length=60,num_beams=10, num_return_sequences=1, temperature=1.5)
paraphrase = tokenizer3.batch_decode(translated, skip_special_tokens=True)[0]
return f"Answer: {answer}\nLong Form Answer: {paraphrase}"
iface = gr.Interface(
fn=qa_paraphrase,
inputs=[
gr.inputs.Textbox(label="Text Input"),
gr.inputs.Textbox(label="Question")
],
outputs=gr.outputs.Textbox(label="Output"),
title="Long Form Question Answering",
description="mimics long form question answering by extracting the sentence containing the answer and paraphrasing it"
)
iface.launch()