import gradio as gr import torch import gc # Load model directly from transformers import AutoTokenizer, AutoModelForSeq2SeqLM device = "cuda:0" if torch.cuda.is_available() else "cpu" tokenizer = AutoTokenizer.from_pretrained("towing/viachat-t5-large-v3.0") model = AutoModelForSeq2SeqLM.from_pretrained("towing/viachat-t5-large-v3.0") model = model.to(device) input_text = 'Translate from english to german: How old are you' input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to(device) print('input_text', input_text) print('input_ids', input_ids) outputs = model.generate(input_ids, max_length=500) print('outputs', outputs) text = tokenizer.decode(outputs[0], skip_special_tokens=True) print("outputs text:", text) # def greet(human_inputs): # system_message = "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions." # input_text = f"{system_message}\n###Human: {human_inputs}\n### Assistant: " # input_ids = tokenizer(input_text, return_tensors="pt").input_ids # print('input_text', input_text) # print('input_ids', input_ids) # outputs = model.generate(input_ids, max_length=500) # print('outputs', outputs) # text = tokenizer.decode(outputs[0], skip_special_tokens=True) # print("outputs text:", text) # return text # stream out yield_tokens = [] def greet_stream(human_inputs): system_message = "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions." input_text = f"{system_message}\n### Human: {human_inputs}\n### Assistant: " input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to(device) max_length = 200 encoder_output = model.encoder( input_ids=input_ids )['last_hidden_state'] # (1, len, d_model) gen_ids = torch.tensor([[0]]).to(device) past_key_values = None yield_tokens = [] for i in range(max_length): out = model.decoder( input_ids=gen_ids, encoder_hidden_states=encoder_output, use_cache=True, past_key_values=past_key_values, ) past_key_values = out.past_key_values last_hidden_state = out.last_hidden_state # (1, 1, d_model) lm_logits = model.lm_head(last_hidden_state) # (1, 1, len_dict) values, indices = lm_logits[0].topk(2) # print(tokenizer.convert_ids_to_tokens(indices[0])) gen_ids = torch.index_select(indices, 1, torch.tensor([0]).to(device)) yield_tokens.append(tokenizer.decode( gen_ids[0].cpu(), skip_special_tokens=True)) if torch.equal(gen_ids[0], torch.tensor([1])): # eos token break gc.collect() torch.cuda.empty_cache() yield "".join(yield_tokens) # stream return "".join(yield_tokens) iface = gr.Interface(fn=greet_stream, inputs=[gr.Textbox( label="Text 1", info="Initial text", lines=3, value="Who are you?", )], outputs="text") iface.queue() iface.launch()