silk-road commited on
Commit
e74a2eb
1 Parent(s): 3bf1705

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +230 -0
app.py ADDED
@@ -0,0 +1,230 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import zipfile
2
+ import gradio as gr
3
+ from PIL import Image
4
+ from chatharuhi import ChatHaruhi
5
+ import wget
6
+ import os
7
+ import openai
8
+ import copy
9
+ import random
10
+ import string
11
+
12
+
13
+ NAME_DICT = {'汤师爷': 'tangshiye', '慕容复': 'murongfu', '李云龙': 'liyunlong', 'Luna': 'Luna', '王多鱼': 'wangduoyu',
14
+ 'Ron': 'Ron', '鸠摩智': 'jiumozhi', 'Snape': 'Snape',
15
+ '凉宫春日': 'haruhi', 'Malfoy': 'Malfoy', '虚竹': 'xuzhu', '萧峰': 'xiaofeng', '段誉': 'duanyu',
16
+ 'Hermione': 'Hermione', 'Dumbledore': 'Dumbledore', '王语嫣': 'wangyuyan',
17
+ 'Harry': 'Harry', 'McGonagall': 'McGonagall', '白展堂': 'baizhantang', '佟湘玉': 'tongxiangyu',
18
+ '郭芙蓉': 'guofurong', '旅行者': 'wanderer', '钟离': 'zhongli',
19
+ '胡桃': 'hutao', 'Sheldon': 'Sheldon', 'Raj': 'Raj', 'Penny': 'Penny', '韦小宝': 'weixiaobao',
20
+ '乔峰': 'qiaofeng', '神里绫华': 'ayaka', '雷电将军': 'raidenShogun', '于谦': 'yuqian'}
21
+
22
+
23
+ ai_roles_obj = {}
24
+
25
+ for ai_role_en in NAME_DICT.values():
26
+ zip_file_path = f"/content/Haruhi-2-Dev/data/character_in_zip/{ai_role_en}.zip"
27
+ if not os.path.exists(zip_file_path):
28
+ # os.remove(zip_file_path)
29
+ print('unfound zip file ', zip_file_path)
30
+ continue
31
+
32
+ destination_folder = f"characters/{ai_role_en}"
33
+
34
+ with zipfile.ZipFile(zip_file_path, 'r') as zip_ref:
35
+ zip_ref.extractall(destination_folder)
36
+
37
+ db_folder = f"./characters/{ai_role_en}/content/{ai_role_en}"
38
+ system_prompt = f"./characters/{ai_role_en}/content/system_prompt.txt"
39
+ ai_roles_obj[ai_role_en] = ChatHaruhi(system_prompt=system_prompt,
40
+ llm="openai",
41
+ story_db=db_folder,
42
+ verbose=True)
43
+ # break
44
+ def format_chat( role, text ):
45
+ narrator = ['旁白', '', 'scene','Scene','narrator' , 'Narrator']
46
+ if role in narrator:
47
+ return role + ":" + text
48
+ else:
49
+ return f"{role}:「{text}」"
50
+
51
+ def deformat_chat(chat):
52
+
53
+ chat = chat.strip('\'"')
54
+ if ':' in chat:
55
+ colon_index = chat.index(':')
56
+ elif ':' in chat:
57
+ colon_index = chat.index(':')
58
+ else:
59
+ return '', chat
60
+
61
+ role = chat[:colon_index]
62
+ text = chat[colon_index+1:]
63
+
64
+ text = text.strip('「」"\'')
65
+
66
+ return role, text
67
+
68
+ def print_last_chat( chats ):
69
+ shorten_chat = chats[0]
70
+ if len(shorten_chat) > 30:
71
+ shorten_chat = shorten_chat[:30]
72
+ shorten_chat = shorten_chat.replace('/', '_')
73
+ shorten_chat = shorten_chat.replace('.', '_')
74
+ shorten_chat = shorten_chat.replace('"', '_')
75
+ shorten_chat = shorten_chat.replace('\n', '_')
76
+
77
+ final_chat = chats[-1]
78
+
79
+ print( final_chat , '____', shorten_chat )
80
+
81
+
82
+ from gradio.components import clear_button
83
+ # import gradio as gr
84
+
85
+ import matplotlib.pyplot as plt
86
+ import numpy as np
87
+
88
+ Fs = 8000
89
+ f = 5
90
+ sample = 8000
91
+ x = np.arange(sample)
92
+ y = np.sin(2 * np.pi * f * x / Fs)
93
+ plt.plot(x, y)
94
+
95
+
96
+ def user_response(user_role, user_text, chatbot):
97
+
98
+ user_msg = format_chat( user_role, user_text )
99
+ chatbot.append((user_msg, None ))
100
+
101
+ reserved_chatbot = chatbot.copy()
102
+
103
+ return "", chatbot, reserved_chatbot
104
+
105
+ def extract_chats( chatbot ):
106
+ chats = []
107
+ for q,a in chatbot:
108
+ if q is not None:
109
+ chats.append(q)
110
+ if a is not None:
111
+ chats.append(a)
112
+ return chats
113
+
114
+
115
+ def ai_response(ai_role, chatbot):
116
+ role_en = NAME_DICT[ai_role]
117
+
118
+ # 我们需要构造history
119
+ history = []
120
+
121
+ chats = extract_chats(chatbot)
122
+
123
+ # 解析roles和texts
124
+ for chat in chats:
125
+ role, text = deformat_chat(chat)
126
+ if role in NAME_DICT.keys():
127
+ current_en = NAME_DICT[role]
128
+ else:
129
+ current_en = role
130
+
131
+ if current_en == role_en:
132
+ history.append((None, chat))
133
+ else:
134
+ history.append((chat, None))
135
+
136
+ if len(history) >= 1:
137
+ ai_roles_obj[ role_en ].dialogue_history = history[:-1]
138
+ last_role, last_text = deformat_chat(chats[-1])
139
+ response = ai_roles_obj[ role_en ].chat(role = last_role, text = last_text)
140
+ else:
141
+ ai_roles_obj[ role_en ].dialogue_history = []
142
+ response = ai_roles_obj[ role_en ].chat(role = 'scene', text = '')
143
+
144
+ # ai_msg = format_chat(ai_role, response)
145
+ ai_msg = response
146
+
147
+ chatbot.append( (None, ai_msg ) )
148
+
149
+ reserved_chatbot = chatbot.copy()
150
+
151
+ chats = extract_chats( chatbot )
152
+ # save_dialogue( chats )
153
+ print_last_chat( chats )
154
+ return chatbot, reserved_chatbot
155
+
156
+ def callback_remove_one_chat(chatbot, reserved_chatbot):
157
+ if len(chatbot) > 1:
158
+ chatbot.pop()
159
+ return chatbot
160
+
161
+ def callback_recover_one_chat(chatbot, reserved_chatbot):
162
+ if len(chatbot) < len(reserved_chatbot):
163
+ chatbot.append( reserved_chatbot[len(chatbot)] )
164
+ return chatbot
165
+
166
+ def callback_clean():
167
+ return [], []
168
+
169
+
170
+ with gr.Blocks() as demo:
171
+ gr.Markdown(
172
+ """
173
+ # Story Teller Demo
174
+
175
+ implemented by [Cheng Li](https://github.com/LC1332) and [Weishi MI](https://github.com/hhhwmws0117)
176
+
177
+ 本项目是ChatHaruhi的子项目,原项目链接 [https://github.com/LC1332/Chat-Haruhi-Suzumiya](https://github.com/LC1332/Chat-Haruhi-Suzumiya)
178
+
179
+ 如果觉得好玩可以去点个star
180
+
181
+ 这个Gradio是一个初步的尝试,之后考虑做一套更正式的story-teller的算法
182
+ """
183
+ )
184
+ with gr.Row():
185
+ with gr.Column():
186
+ with gr.Row(height = 800):
187
+ # 给米唯实一个艰巨的任务,把这东西弄高一点
188
+ chatbot = gr.Chatbot(height = 800)
189
+ with gr.Row():
190
+ user_role = gr.Textbox(label="user_role", scale=1)
191
+ user_text = gr.Textbox(label="user_text", scale=20)
192
+ with gr.Row():
193
+ user_submit = gr.Button("User Submit")
194
+
195
+
196
+
197
+ with gr.Column():
198
+ with gr.Row():
199
+ ai_role = gr.Radio(['汤师爷', '慕容复', '李云龙',
200
+ 'Luna', '王多鱼', 'Ron', '鸠摩智',
201
+ 'Snape', '凉宫春日', 'Malfoy', '虚竹',
202
+ '萧峰', '段誉', 'Hermione', 'Dumbledore',
203
+ '王语嫣',
204
+ 'Harry', 'McGonagall',
205
+ '白展堂', '佟湘玉', '郭芙蓉',
206
+ '旅行者', '钟离', '胡桃',
207
+ 'Sheldon', 'Raj', 'Penny',
208
+ '韦小宝', '乔峰', '神里绫华',
209
+ '雷电将军', '于谦'], label="characters", value='凉宫春日')
210
+ with gr.Row():
211
+ ai_submit = gr.Button("AI Submit")
212
+
213
+ with gr.Row():
214
+ remove_one_chat = gr.Button("Remove One Chat")
215
+ recover_one_chat = gr.Button("Recover One Chat")
216
+ with gr.Row():
217
+ clean = gr.Button("Clean")
218
+
219
+ reserved_chatbot = gr.Chatbot(visible = False)
220
+
221
+ user_submit.click(fn = user_response, inputs = [user_role, user_text, chatbot], outputs = [user_text, chatbot,reserved_chatbot] )
222
+ ai_submit.click(fn = ai_response, inputs = [ai_role, chatbot], outputs = [chatbot,reserved_chatbot] )
223
+
224
+ remove_one_chat.click(fn = callback_remove_one_chat, inputs = [chatbot, reserved_chatbot], outputs = [chatbot] )
225
+ recover_one_chat.click(fn = callback_recover_one_chat, inputs = [chatbot, reserved_chatbot], outputs = [chatbot] )
226
+
227
+ clean.click(fn = callback_clean, inputs = [], outputs = [chatbot,reserved_chatbot] )
228
+
229
+
230
+ demo.launch(debug=True, share=True)