Spaces:
Sleeping
Sleeping
# app.py ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
import torch | |
import gradio as gr | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
TITLE = "Talk To Me Morty" | |
DESCRIPTION = """ | |
<p style='text-align:center'> | |
The bot was trained on a Rick & Morty dialogues dataset with DialoGPT. | |
</p> | |
<center> | |
<img src="https://huggingface.co/spaces/kingabzpro/Rick_and_Morty_Bot/resolve/main/img/rick.png" | |
alt="Rick" | |
width="150"> | |
</center> | |
""" | |
ARTICLE = """ | |
<p style='text-align:center'> | |
<a href="https://medium.com/geekculture/discord-bot-using-dailogpt-and-huggingface-api-c71983422701" | |
target="_blank">Complete Tutorial</a> Β· | |
<a href="https://dagshub.com/kingabzpro/DailoGPT-RickBot" | |
target="_blank">Project on DAGsHub</a> | |
</p> | |
""" | |
# βββ Load model once at start ββββββββββββββββββββββββββββββββββββββββββββββββ | |
tokenizer = AutoTokenizer.from_pretrained("ericzhou/DialoGPT-Medium-Rick_v2") | |
model = AutoModelForCausalLM.from_pretrained("ericzhou/DialoGPT-Medium-Rick_v2") | |
# βββ Chat handler ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
def chat(user_msg: str, history_ids: list[int] | None): | |
if not user_msg: | |
return [], history_ids or [] | |
new_ids = tokenizer.encode(user_msg + tokenizer.eos_token, | |
return_tensors="pt") | |
bot_input = ( | |
torch.cat([torch.LongTensor(history_ids), new_ids], dim=-1) | |
if history_ids else new_ids | |
) | |
history_ids = model.generate( | |
bot_input, | |
max_length=min(4096, bot_input.shape[-1] + 200), | |
pad_token_id=tokenizer.eos_token_id, | |
do_sample=True, | |
temperature=0.7, | |
top_p=0.92, | |
top_k=50, | |
).tolist() | |
turns = tokenizer.decode(history_ids[0], skip_special_tokens=False) \ | |
.split("<|endoftext|>") | |
# pack into (user, bot) pairs for Chatbot component | |
pairs = [(turns[i], turns[i + 1]) for i in range(0, len(turns) - 1, 2)] | |
return pairs, history_ids | |
# βββ Gradio UI βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
with gr.Blocks(theme=gr.themes.Base()) as demo: | |
gr.Markdown(f"<h1 style='text-align:center'>{TITLE}</h1>") | |
gr.Markdown(DESCRIPTION) | |
chatbot = gr.Chatbot(height=450) | |
state = gr.State([]) | |
with gr.Row(equal_height=True): | |
prompt = gr.Textbox(placeholder="Ask Rick anythingβ¦", scale=9, show_label=False) | |
send = gr.Button("Send",scale=1, variant="primary") | |
# send on click or β΅ | |
send.click(chat, inputs=[prompt, state], outputs=[chatbot, state]) | |
prompt.submit(chat, inputs=[prompt, state], outputs=[chatbot, state]) | |
gr.Examples([["How are you, Rick?"], ["Tell me a joke!"]], inputs=prompt) | |
gr.Markdown(ARTICLE) | |
if __name__ == "__main__": | |
demo.launch() |