viachat-v0.95 / app.py
towing's picture
Update app.py
1baafe4 verified
import gradio as gr
import torch
import gc
# Load model directly
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
device = "cuda:0" if torch.cuda.is_available() else "cpu"
tokenizer = AutoTokenizer.from_pretrained("towing/viachat-t5-large-v0.95")
model = AutoModelForSeq2SeqLM.from_pretrained("towing/viachat-t5-large-v0.95")
model = model.to(device)
input_text = 'Translate from english to german: How old are you'
input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to(device)
print('input_text', input_text)
print('input_ids', input_ids)
outputs = model.generate(input_ids, max_length=500)
print('outputs', outputs)
text = tokenizer.decode(outputs[0], skip_special_tokens=True)
print("outputs text:", text)
# def greet(human_inputs):
# system_message = "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions."
# input_text = f"{system_message}\n###Human: {human_inputs}\n### Assistant: "
# input_ids = tokenizer(input_text, return_tensors="pt").input_ids
# print('input_text', input_text)
# print('input_ids', input_ids)
# outputs = model.generate(input_ids, max_length=500)
# print('outputs', outputs)
# text = tokenizer.decode(outputs[0], skip_special_tokens=True)
# print("outputs text:", text)
# return text
# stream out
yield_tokens = []
def greet_stream(human_inputs):
system_message = "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions."
input_text = f"{system_message}\n### Human: {human_inputs}\n### Assistant: "
input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to(device)
max_length = 200
encoder_output = model.encoder(
input_ids=input_ids
)['last_hidden_state'] # (1, len, d_model)
gen_ids = torch.tensor([[0]]).to(device)
past_key_values = None
yield_tokens = []
for i in range(max_length):
out = model.decoder(
input_ids=gen_ids,
encoder_hidden_states=encoder_output,
use_cache=True,
past_key_values=past_key_values,
)
past_key_values = out.past_key_values
last_hidden_state = out.last_hidden_state # (1, 1, d_model)
lm_logits = model.lm_head(last_hidden_state) # (1, 1, len_dict)
values, indices = lm_logits[0].topk(2)
# print(tokenizer.convert_ids_to_tokens(indices[0]))
gen_ids = torch.index_select(indices, 1, torch.tensor([0]).to(device))
yield_tokens.append(tokenizer.decode(
gen_ids[0].cpu(), skip_special_tokens=True))
if torch.equal(gen_ids[0], torch.tensor([1])):
# eos token
break
gc.collect()
torch.cuda.empty_cache()
yield "".join(yield_tokens)
# stream
return "".join(yield_tokens)
iface = gr.Interface(fn=greet_stream, inputs=[gr.Textbox(
label="Text 1",
info="Initial text",
lines=3,
value="Who are you?",
)], outputs="text")
iface.queue()
iface.launch()