import gradio as gr
import torch
import gc
# Load model directly
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM


device = "cuda:0" if torch.cuda.is_available() else "cpu"

tokenizer = AutoTokenizer.from_pretrained("towing/viachat-t5-large-v0.95")
model = AutoModelForSeq2SeqLM.from_pretrained("towing/viachat-t5-large-v0.95")
model = model.to(device)

input_text = 'Translate from english to german: How old are you'
input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to(device)
print('input_text', input_text)
print('input_ids', input_ids)
outputs = model.generate(input_ids, max_length=500)
print('outputs', outputs)
text = tokenizer.decode(outputs[0], skip_special_tokens=True)
print("outputs text:", text)

# def greet(human_inputs):

#     system_message = "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions."

#     input_text = f"{system_message}\n###Human: {human_inputs}\n### Assistant: "
#     input_ids = tokenizer(input_text, return_tensors="pt").input_ids
#     print('input_text', input_text)
#     print('input_ids', input_ids)
#     outputs = model.generate(input_ids, max_length=500)
#     print('outputs', outputs)
#     text = tokenizer.decode(outputs[0], skip_special_tokens=True)
#     print("outputs text:", text)
#     return text


# stream out

yield_tokens = []


def greet_stream(human_inputs):
    system_message = "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions."
    input_text = f"{system_message}\n### Human: {human_inputs}\n### Assistant: "
    input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to(device)

    max_length = 200
    encoder_output = model.encoder(
        input_ids=input_ids
    )['last_hidden_state']  # (1, len, d_model)

    gen_ids = torch.tensor([[0]]).to(device)
    past_key_values = None
    yield_tokens = []
    for i in range(max_length):
        out = model.decoder(
            input_ids=gen_ids,
            encoder_hidden_states=encoder_output,
            use_cache=True,
            past_key_values=past_key_values,
        )
        past_key_values = out.past_key_values
        last_hidden_state = out.last_hidden_state  # (1, 1, d_model)
        lm_logits = model.lm_head(last_hidden_state)  # (1, 1, len_dict)
        values, indices = lm_logits[0].topk(2)
        # print(tokenizer.convert_ids_to_tokens(indices[0]))
        gen_ids = torch.index_select(indices, 1, torch.tensor([0]).to(device))
        yield_tokens.append(tokenizer.decode(
            gen_ids[0].cpu(), skip_special_tokens=True))
        if torch.equal(gen_ids[0], torch.tensor([1])):
          # eos token
          break
        gc.collect()
        torch.cuda.empty_cache()
        yield "".join(yield_tokens)

    # stream
    return "".join(yield_tokens)


iface = gr.Interface(fn=greet_stream, inputs=[gr.Textbox(
    label="Text 1",
    info="Initial text",
    lines=3,
    value="Who are you?",
)], outputs="text")

iface.queue()
iface.launch()