Spaces:
Runtime error
Runtime error
import zipfile | |
import gradio as gr | |
from PIL import Image | |
from ChatHaruhi import ChatHaruhi | |
import wget | |
import os | |
import openai | |
import copy | |
import random | |
import string | |
NAME_DICT = {'汤师爷': 'tangshiye', '慕容复': 'murongfu', '李云龙': 'liyunlong', 'Luna': 'Luna', '王多鱼': 'wangduoyu', | |
'Ron': 'Ron', '鸠摩智': 'jiumozhi', 'Snape': 'Snape', | |
'凉宫春日': 'haruhi', 'Malfoy': 'Malfoy', '虚竹': 'xuzhu', '萧峰': 'xiaofeng', '段誉': 'duanyu', | |
'Hermione': 'Hermione', 'Dumbledore': 'Dumbledore', '王语嫣': 'wangyuyan', | |
'Harry': 'Harry', 'McGonagall': 'McGonagall', '白展堂': 'baizhantang', '佟湘玉': 'tongxiangyu', | |
'郭芙蓉': 'guofurong', '旅行者': 'wanderer', '钟离': 'zhongli', | |
'胡桃': 'hutao', 'Sheldon': 'Sheldon', 'Raj': 'Raj', 'Penny': 'Penny', '韦小宝': 'weixiaobao', | |
'乔峰': 'qiaofeng', '神里绫华': 'ayaka', '雷电将军': 'raidenShogun', '于谦': 'yuqian'} | |
try: | |
os.makedirs("characters_zip") | |
except: | |
pass | |
try: | |
os.makedirs("characters") | |
except: | |
pass | |
ai_roles_obj = {} | |
for ai_role_en in NAME_DICT.values(): | |
file_url = f"https://github.com/LC1332/Haruhi-2-Dev/raw/main/data/character_in_zip/{ai_role_en}.zip" | |
try: | |
os.makedirs(f"characters/{ai_role_en}") | |
except: | |
pass | |
if f"{ai_role_en}.zip" not in os.listdir(f"characters_zip"): | |
destination_file = f"characters_zip/{ai_role_en}.zip" | |
wget.download(file_url, destination_file) | |
destination_folder = f"characters/{ai_role_en}" | |
with zipfile.ZipFile(destination_file, 'r') as zip_ref: | |
zip_ref.extractall(destination_folder) | |
db_folder = f"./characters/{ai_role_en}/content/{ai_role_en}" | |
system_prompt = f"./characters/{ai_role_en}/content/system_prompt.txt" | |
ai_roles_obj[ai_role_en] = ChatHaruhi(system_prompt=system_prompt, | |
llm="openai", | |
story_db=db_folder, | |
verbose=True) | |
# break | |
def format_chat( role, text ): | |
narrator = ['旁白', '', 'scene','Scene','narrator' , 'Narrator'] | |
if role in narrator: | |
return role + ":" + text | |
else: | |
return f"{role}:「{text}」" | |
def deformat_chat(chat): | |
chat = chat.strip('\'"') | |
if ':' in chat: | |
colon_index = chat.index(':') | |
elif ':' in chat: | |
colon_index = chat.index(':') | |
else: | |
return '', chat | |
role = chat[:colon_index] | |
text = chat[colon_index+1:] | |
text = text.strip('「」"\'') | |
return role, text | |
def print_last_chat( chats ): | |
shorten_chat = chats[0] | |
if len(shorten_chat) > 30: | |
shorten_chat = shorten_chat[:30] | |
shorten_chat = shorten_chat.replace('/', '_') | |
shorten_chat = shorten_chat.replace('.', '_') | |
shorten_chat = shorten_chat.replace('"', '_') | |
shorten_chat = shorten_chat.replace('\n', '_') | |
final_chat = chats[-1] | |
print( final_chat , '____', shorten_chat ) | |
from gradio.components import clear_button | |
# import gradio as gr | |
import matplotlib.pyplot as plt | |
import numpy as np | |
Fs = 8000 | |
f = 5 | |
sample = 8000 | |
x = np.arange(sample) | |
y = np.sin(2 * np.pi * f * x / Fs) | |
plt.plot(x, y) | |
def user_response(user_role, user_text, chatbot): | |
user_msg = format_chat( user_role, user_text ) | |
chatbot.append((user_msg, None )) | |
reserved_chatbot = chatbot.copy() | |
return "", chatbot, reserved_chatbot | |
def extract_chats( chatbot ): | |
chats = [] | |
for q,a in chatbot: | |
if q is not None: | |
chats.append(q) | |
if a is not None: | |
chats.append(a) | |
return chats | |
def ai_response(ai_role, chatbot): | |
role_en = NAME_DICT[ai_role] | |
# 我们需要构造history | |
history = [] | |
chats = extract_chats(chatbot) | |
# 解析roles和texts | |
for chat in chats: | |
role, text = deformat_chat(chat) | |
if role in NAME_DICT.keys(): | |
current_en = NAME_DICT[role] | |
else: | |
current_en = role | |
if current_en == role_en: | |
history.append((None, chat)) | |
else: | |
history.append((chat, None)) | |
if len(history) >= 1: | |
ai_roles_obj[ role_en ].dialogue_history = history[:-1] | |
last_role, last_text = deformat_chat(chats[-1]) | |
response = ai_roles_obj[ role_en ].chat(role = last_role, text = last_text) | |
else: | |
ai_roles_obj[ role_en ].dialogue_history = [] | |
response = ai_roles_obj[ role_en ].chat(role = 'scene', text = '') | |
# ai_msg = format_chat(ai_role, response) | |
ai_msg = response | |
chatbot.append( (None, ai_msg ) ) | |
reserved_chatbot = chatbot.copy() | |
chats = extract_chats( chatbot ) | |
# save_dialogue( chats ) | |
print_last_chat( chats ) | |
return chatbot, reserved_chatbot | |
def callback_remove_one_chat(chatbot, reserved_chatbot): | |
if len(chatbot) > 1: | |
chatbot.pop() | |
return chatbot | |
def callback_recover_one_chat(chatbot, reserved_chatbot): | |
if len(chatbot) < len(reserved_chatbot): | |
chatbot.append( reserved_chatbot[len(chatbot)] ) | |
return chatbot | |
def callback_clean(): | |
return [], [] | |
with gr.Blocks() as demo: | |
gr.Markdown( | |
""" | |
# Story Teller Demo | |
implemented by [Cheng Li](https://github.com/LC1332) and [Weishi MI](https://github.com/hhhwmws0117) | |
本项目是ChatHaruhi的子项目,原项目链接 [https://github.com/LC1332/Chat-Haruhi-Suzumiya](https://github.com/LC1332/Chat-Haruhi-Suzumiya) | |
如果觉得好玩可以去点个star | |
这个Gradio是一个初步的尝试,之后考虑做一套更正式的story-teller的算法 | |
""" | |
) | |
with gr.Row(): | |
with gr.Column(): | |
with gr.Row(height = 800): | |
# 给米唯实一个艰巨的任务,把这东西弄高一点 | |
chatbot = gr.Chatbot(height = 800) | |
with gr.Row(): | |
user_role = gr.Textbox(label="user_role", scale=1) | |
user_text = gr.Textbox(label="user_text", scale=20) | |
with gr.Row(): | |
user_submit = gr.Button("User Submit") | |
with gr.Column(): | |
with gr.Row(): | |
ai_role = gr.Radio(['汤师爷', '慕容复', '李云龙', | |
'Luna', '王多鱼', 'Ron', '鸠摩智', | |
'Snape', '凉宫春日', 'Malfoy', '虚竹', | |
'萧峰', '段誉', 'Hermione', 'Dumbledore', | |
'王语嫣', | |
'Harry', 'McGonagall', | |
'白展堂', '佟湘玉', '郭芙蓉', | |
'旅行者', '钟离', '胡桃', | |
'Sheldon', 'Raj', 'Penny', | |
'韦小宝', '乔峰', '神里绫华', | |
'雷电将军', '于谦'], label="characters", value='凉宫春日') | |
with gr.Row(): | |
ai_submit = gr.Button("AI Submit") | |
with gr.Row(): | |
remove_one_chat = gr.Button("Remove One Chat") | |
recover_one_chat = gr.Button("Recover One Chat") | |
with gr.Row(): | |
clean = gr.Button("Clean") | |
reserved_chatbot = gr.Chatbot(visible = False) | |
user_submit.click(fn = user_response, inputs = [user_role, user_text, chatbot], outputs = [user_text, chatbot,reserved_chatbot] ) | |
ai_submit.click(fn = ai_response, inputs = [ai_role, chatbot], outputs = [chatbot,reserved_chatbot] ) | |
remove_one_chat.click(fn = callback_remove_one_chat, inputs = [chatbot, reserved_chatbot], outputs = [chatbot] ) | |
recover_one_chat.click(fn = callback_recover_one_chat, inputs = [chatbot, reserved_chatbot], outputs = [chatbot] ) | |
clean.click(fn = callback_clean, inputs = [], outputs = [chatbot,reserved_chatbot] ) | |
demo.launch(debug=True) |