Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
import gc | |
# Load model directly | |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
device = "cuda:0" if torch.cuda.is_available() else "cpu" | |
tokenizer = AutoTokenizer.from_pretrained("towing/viachat-t5-large-v0.95") | |
model = AutoModelForSeq2SeqLM.from_pretrained("towing/viachat-t5-large-v0.95") | |
model = model.to(device) | |
input_text = 'Translate from english to german: How old are you' | |
input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to(device) | |
print('input_text', input_text) | |
print('input_ids', input_ids) | |
outputs = model.generate(input_ids, max_length=500) | |
print('outputs', outputs) | |
text = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
print("outputs text:", text) | |
# def greet(human_inputs): | |
# system_message = "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions." | |
# input_text = f"{system_message}\n###Human: {human_inputs}\n### Assistant: " | |
# input_ids = tokenizer(input_text, return_tensors="pt").input_ids | |
# print('input_text', input_text) | |
# print('input_ids', input_ids) | |
# outputs = model.generate(input_ids, max_length=500) | |
# print('outputs', outputs) | |
# text = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
# print("outputs text:", text) | |
# return text | |
# stream out | |
yield_tokens = [] | |
def greet_stream(human_inputs): | |
system_message = "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions." | |
input_text = f"{system_message}\n### Human: {human_inputs}\n### Assistant: " | |
input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to(device) | |
max_length = 200 | |
encoder_output = model.encoder( | |
input_ids=input_ids | |
)['last_hidden_state'] # (1, len, d_model) | |
gen_ids = torch.tensor([[0]]).to(device) | |
past_key_values = None | |
yield_tokens = [] | |
for i in range(max_length): | |
out = model.decoder( | |
input_ids=gen_ids, | |
encoder_hidden_states=encoder_output, | |
use_cache=True, | |
past_key_values=past_key_values, | |
) | |
past_key_values = out.past_key_values | |
last_hidden_state = out.last_hidden_state # (1, 1, d_model) | |
lm_logits = model.lm_head(last_hidden_state) # (1, 1, len_dict) | |
values, indices = lm_logits[0].topk(2) | |
# print(tokenizer.convert_ids_to_tokens(indices[0])) | |
gen_ids = torch.index_select(indices, 1, torch.tensor([0]).to(device)) | |
yield_tokens.append(tokenizer.decode( | |
gen_ids[0].cpu(), skip_special_tokens=True)) | |
if torch.equal(gen_ids[0], torch.tensor([1])): | |
# eos token | |
break | |
gc.collect() | |
torch.cuda.empty_cache() | |
yield "".join(yield_tokens) | |
# stream | |
return "".join(yield_tokens) | |
iface = gr.Interface(fn=greet_stream, inputs=[gr.Textbox( | |
label="Text 1", | |
info="Initial text", | |
lines=3, | |
value="Who are you?", | |
)], outputs="text") | |
iface.queue() | |
iface.launch() | |