kingabzpro's picture
Update app.py
c2056b1 verified
raw
history blame
3.3 kB
# app.py ──────────────────────────────────────────────────────────────────────
import torch
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
TITLE = "Talk To Me Morty"
DESCRIPTION = """
<p style='text-align:center'>
The bot was trained on a Rick & Morty dialogues dataset with DialoGPT.
</p>
<center>
<img src="https://huggingface.co/spaces/kingabzpro/Rick_and_Morty_Bot/resolve/main/img/rick.png"
alt="Rick"
width="150">
</center>
"""
ARTICLE = """
<p style='text-align:center'>
<a href="https://medium.com/geekculture/discord-bot-using-dailogpt-and-huggingface-api-c71983422701"
target="_blank">Complete Tutorial</a> Β·
<a href="https://dagshub.com/kingabzpro/DailoGPT-RickBot"
target="_blank">Project on DAGsHub</a>
</p>
"""
# ─── Load model once at start ────────────────────────────────────────────────
tokenizer = AutoTokenizer.from_pretrained("ericzhou/DialoGPT-Medium-Rick_v2")
model = AutoModelForCausalLM.from_pretrained("ericzhou/DialoGPT-Medium-Rick_v2")
# ─── Chat handler ────────────────────────────────────────────────────────────
def chat(user_msg: str, history_ids: list[int] | None):
if not user_msg:
return [], history_ids or []
new_ids = tokenizer.encode(user_msg + tokenizer.eos_token,
return_tensors="pt")
bot_input = (
torch.cat([torch.LongTensor(history_ids), new_ids], dim=-1)
if history_ids else new_ids
)
history_ids = model.generate(
bot_input,
max_length=min(4096, bot_input.shape[-1] + 200),
pad_token_id=tokenizer.eos_token_id,
do_sample=True,
temperature=0.7,
top_p=0.92,
top_k=50,
).tolist()
turns = tokenizer.decode(history_ids[0], skip_special_tokens=False) \
.split("<|endoftext|>")
# pack into (user, bot) pairs for Chatbot component
pairs = [(turns[i], turns[i + 1]) for i in range(0, len(turns) - 1, 2)]
return pairs, history_ids
# ─── Gradio UI ───────────────────────────────────────────────────────────────
with gr.Blocks(theme=gr.themes.Base()) as demo:
gr.Markdown(f"<h1 style='text-align:center'>{TITLE}</h1>")
gr.Markdown(DESCRIPTION)
chatbot = gr.Chatbot(height=450)
state = gr.State([])
with gr.Row(equal_height=True):
prompt = gr.Textbox(placeholder="Ask Rick anything…", scale=9, show_label=False)
send = gr.Button("Send",scale=1, variant="primary")
# send on click or ↡
send.click(chat, inputs=[prompt, state], outputs=[chatbot, state])
prompt.submit(chat, inputs=[prompt, state], outputs=[chatbot, state])
gr.Examples([["How are you, Rick?"], ["Tell me a joke!"]], inputs=prompt)
gr.Markdown(ARTICLE)
if __name__ == "__main__":
demo.launch()