kingabzpro commited on
Commit
573fbce
Β·
verified Β·
1 Parent(s): a431c7b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +77 -44
app.py CHANGED
@@ -1,50 +1,83 @@
1
- import os
 
2
  import gradio as gr
 
 
 
3
 
4
- title = "Talk To Me Morty"
5
- description = """
6
- <p>
7
- <center>
8
- The bot was trained on Rick and Morty dialogues Kaggle Dataset using DialoGPT.
9
- <img src="https://huggingface.co/spaces/kingabzpro/Rick_and_Morty_Bot/resolve/main/img/rick.png" alt="rick" width="200"/>
10
- </center>
11
  </p>
12
  """
13
- article = "<p style='text-align: center'><a href='https://medium.com/geekculture/discord-bot-using-dailogpt-and-huggingface-api-c71983422701' target='_blank'>Complete Tutorial</a></p><p style='text-align: center'><a href='https://dagshub.com/kingabzpro/DailoGPT-RickBot' target='_blank'>Project is Available at DAGsHub</a></p></center><center><img src='https://visitor-badge.glitch.me/badge?page_id=kingabzpro/Rick_and_Morty_Bot' alt='visitor badge'></center></p>"
14
- examples = [["How are you Rick?"]]
15
- from transformers import AutoModelForCausalLM, AutoTokenizer
16
- import torch
17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  tokenizer = AutoTokenizer.from_pretrained("ericzhou/DialoGPT-Medium-Rick_v2")
19
- model = AutoModelForCausalLM.from_pretrained("ericzhou/DialoGPT-Medium-Rick_v2")
20
-
21
- def predict(input, history=[]):
22
- # tokenize the new input sentence
23
- new_user_input_ids = tokenizer.encode(input + tokenizer.eos_token, return_tensors='pt')
24
-
25
- # append the new user input tokens to the chat history
26
- bot_input_ids = torch.cat([torch.LongTensor(history), new_user_input_ids], dim=-1)
27
-
28
- # generate a response
29
- history = model.generate(bot_input_ids, max_length=4000, pad_token_id=tokenizer.eos_token_id).tolist()
30
-
31
- # convert the tokens to text, and then split the responses into lines
32
- response = tokenizer.decode(history[0]).split("<|endoftext|>")
33
- #print('decoded_response-->>'+str(response))
34
- response = [(response[i], response[i+1]) for i in range(0, len(response)-1, 2)] # convert to tuples of list
35
- #print('response-->>'+str(response))
36
- return response, history
37
-
38
- gr.Interface(fn=predict,
39
- title=title,
40
- description=description,
41
- examples=examples,
42
- inputs=["text", "state"],
43
- outputs=["chatbot", "state"],
44
- theme='gradio/seafoam').launch()
45
-
46
- #theme ="grass",
47
- #title = title,
48
- #flagging_callback=hf_writer,
49
- #description = description,
50
- #article = article
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py ──────────────────────────────────────────────────────────────────────
2
+ import torch
3
  import gradio as gr
4
+ from transformers import AutoTokenizer, AutoModelForCausalLM
5
+
6
+ TITLE = "Talk To Me Morty"
7
 
8
+ DESCRIPTION = """
9
+ <p style='text-align:center'>
10
+ The bot was trained on a Rick & Morty dialogues dataset with DialoGPT.<br>
11
+ <img src="https://huggingface.co/spaces/kingabzpro/Rick_and_Morty_Bot/resolve/main/img/rick.png"
12
+ alt="Rick"
13
+ width="200">
 
14
  </p>
15
  """
 
 
 
 
16
 
17
+ ARTICLE = """
18
+ <p style='text-align:center'>
19
+ <a href="https://medium.com/geekculture/discord-bot-using-dailogpt-and-huggingface-api-c71983422701"
20
+ target="_blank">Complete Tutorial</a> Β·
21
+ <a href="https://dagshub.com/kingabzpro/DailoGPT-RickBot"
22
+ target="_blank">Project on DAGsHub</a>
23
+ </p>
24
+ <p style='text-align:center'>
25
+ <img src="https://visitor-badge.glitch.me/badge?page_id=kingabzpro/Rick_and_Morty_Bot"
26
+ alt="visitor badge">
27
+ </p>
28
+ """
29
+
30
+ # ─── Load model once at start ────────────────────────────────────────────────
31
  tokenizer = AutoTokenizer.from_pretrained("ericzhou/DialoGPT-Medium-Rick_v2")
32
+ model = AutoModelForCausalLM.from_pretrained("ericzhou/DialoGPT-Medium-Rick_v2")
33
+
34
+ # ─── Chat handler ────────────────────────────────────────────────────────────
35
+ def chat(user_msg: str, history_ids: list[int] | None):
36
+ if not user_msg:
37
+ return [], history_ids or []
38
+
39
+ new_ids = tokenizer.encode(user_msg + tokenizer.eos_token,
40
+ return_tensors="pt")
41
+
42
+ bot_input = (
43
+ torch.cat([torch.LongTensor(history_ids), new_ids], dim=-1)
44
+ if history_ids else new_ids
45
+ )
46
+
47
+ history_ids = model.generate(
48
+ bot_input,
49
+ max_length=min(4096, bot_input.shape[-1] + 200),
50
+ pad_token_id=tokenizer.eos_token_id,
51
+ do_sample=True,
52
+ temperature=0.7,
53
+ top_p=0.92,
54
+ top_k=50,
55
+ ).tolist()
56
+
57
+ turns = tokenizer.decode(history_ids[0], skip_special_tokens=False) \
58
+ .split("<|endoftext|>")
59
+ # pack into (user, bot) pairs for Chatbot component
60
+ pairs = [(turns[i], turns[i + 1]) for i in range(0, len(turns) - 1, 2)]
61
+ return pairs, history_ids
62
+
63
+ # ─── Gradio UI ───────────────────────────────────────────────────────────────
64
+ with gr.Blocks(theme=gr.themes.Seafoam()) as demo:
65
+ gr.Markdown(f"<h1 style='text-align:center'>{TITLE}</h1>")
66
+ gr.Markdown(DESCRIPTION)
67
+
68
+ chatbot = gr.Chatbot(height=450)
69
+ state = gr.State([])
70
+
71
+ with gr.Row():
72
+ prompt = gr.Textbox(placeholder="Ask Rick anything…", scale=4, show_label=False)
73
+ send = gr.Button("Send", variant="primary")
74
+
75
+ # send on click or ↡
76
+ send.click(chat, inputs=[prompt, state], outputs=[chatbot, state])
77
+ prompt.submit(chat, inputs=[prompt, state], outputs=[chatbot, state])
78
+
79
+ gr.Examples([["How are you, Rick?"], ["Tell me a joke!"]], inputs=prompt)
80
+ gr.Markdown(ARTICLE)
81
+
82
+ if __name__ == "__main__":
83
+ demo.launch()