|
import zipfile |
|
import gradio as gr |
|
from PIL import Image |
|
from chatharuhi import ChatHaruhi |
|
import wget |
|
import os |
|
import openai |
|
import copy |
|
import random |
|
import string |
|
|
|
|
|
NAME_DICT = {'汤师爷': 'tangshiye', '慕容复': 'murongfu', '李云龙': 'liyunlong', 'Luna': 'Luna', '王多鱼': 'wangduoyu', |
|
'Ron': 'Ron', '鸠摩智': 'jiumozhi', 'Snape': 'Snape', |
|
'凉宫春日': 'haruhi', 'Malfoy': 'Malfoy', '虚竹': 'xuzhu', '萧峰': 'xiaofeng', '段誉': 'duanyu', |
|
'Hermione': 'Hermione', 'Dumbledore': 'Dumbledore', '王语嫣': 'wangyuyan', |
|
'Harry': 'Harry', 'McGonagall': 'McGonagall', '白展堂': 'baizhantang', '佟湘玉': 'tongxiangyu', |
|
'郭芙蓉': 'guofurong', '旅行者': 'wanderer', '钟离': 'zhongli', |
|
'胡桃': 'hutao', 'Sheldon': 'Sheldon', 'Raj': 'Raj', 'Penny': 'Penny', '韦小宝': 'weixiaobao', |
|
'乔峰': 'qiaofeng', '神里绫华': 'ayaka', '雷电将军': 'raidenShogun', '于谦': 'yuqian'} |
|
|
|
|
|
ai_roles_obj = {} |
|
|
|
for ai_role_en in NAME_DICT.values(): |
|
zip_file_path = f"/content/Haruhi-2-Dev/data/character_in_zip/{ai_role_en}.zip" |
|
if not os.path.exists(zip_file_path): |
|
|
|
print('unfound zip file ', zip_file_path) |
|
continue |
|
|
|
destination_folder = f"characters/{ai_role_en}" |
|
|
|
with zipfile.ZipFile(zip_file_path, 'r') as zip_ref: |
|
zip_ref.extractall(destination_folder) |
|
|
|
db_folder = f"./characters/{ai_role_en}/content/{ai_role_en}" |
|
system_prompt = f"./characters/{ai_role_en}/content/system_prompt.txt" |
|
ai_roles_obj[ai_role_en] = ChatHaruhi(system_prompt=system_prompt, |
|
llm="openai", |
|
story_db=db_folder, |
|
verbose=True) |
|
|
|
def format_chat( role, text ): |
|
narrator = ['旁白', '', 'scene','Scene','narrator' , 'Narrator'] |
|
if role in narrator: |
|
return role + ":" + text |
|
else: |
|
return f"{role}:「{text}」" |
|
|
|
def deformat_chat(chat): |
|
|
|
chat = chat.strip('\'"') |
|
if ':' in chat: |
|
colon_index = chat.index(':') |
|
elif ':' in chat: |
|
colon_index = chat.index(':') |
|
else: |
|
return '', chat |
|
|
|
role = chat[:colon_index] |
|
text = chat[colon_index+1:] |
|
|
|
text = text.strip('「」"\'') |
|
|
|
return role, text |
|
|
|
def print_last_chat( chats ): |
|
shorten_chat = chats[0] |
|
if len(shorten_chat) > 30: |
|
shorten_chat = shorten_chat[:30] |
|
shorten_chat = shorten_chat.replace('/', '_') |
|
shorten_chat = shorten_chat.replace('.', '_') |
|
shorten_chat = shorten_chat.replace('"', '_') |
|
shorten_chat = shorten_chat.replace('\n', '_') |
|
|
|
final_chat = chats[-1] |
|
|
|
print( final_chat , '____', shorten_chat ) |
|
|
|
|
|
from gradio.components import clear_button |
|
|
|
|
|
import matplotlib.pyplot as plt |
|
import numpy as np |
|
|
|
Fs = 8000 |
|
f = 5 |
|
sample = 8000 |
|
x = np.arange(sample) |
|
y = np.sin(2 * np.pi * f * x / Fs) |
|
plt.plot(x, y) |
|
|
|
|
|
def user_response(user_role, user_text, chatbot): |
|
|
|
user_msg = format_chat( user_role, user_text ) |
|
chatbot.append((user_msg, None )) |
|
|
|
reserved_chatbot = chatbot.copy() |
|
|
|
return "", chatbot, reserved_chatbot |
|
|
|
def extract_chats( chatbot ): |
|
chats = [] |
|
for q,a in chatbot: |
|
if q is not None: |
|
chats.append(q) |
|
if a is not None: |
|
chats.append(a) |
|
return chats |
|
|
|
|
|
def ai_response(ai_role, chatbot): |
|
role_en = NAME_DICT[ai_role] |
|
|
|
|
|
history = [] |
|
|
|
chats = extract_chats(chatbot) |
|
|
|
|
|
for chat in chats: |
|
role, text = deformat_chat(chat) |
|
if role in NAME_DICT.keys(): |
|
current_en = NAME_DICT[role] |
|
else: |
|
current_en = role |
|
|
|
if current_en == role_en: |
|
history.append((None, chat)) |
|
else: |
|
history.append((chat, None)) |
|
|
|
if len(history) >= 1: |
|
ai_roles_obj[ role_en ].dialogue_history = history[:-1] |
|
last_role, last_text = deformat_chat(chats[-1]) |
|
response = ai_roles_obj[ role_en ].chat(role = last_role, text = last_text) |
|
else: |
|
ai_roles_obj[ role_en ].dialogue_history = [] |
|
response = ai_roles_obj[ role_en ].chat(role = 'scene', text = '') |
|
|
|
|
|
ai_msg = response |
|
|
|
chatbot.append( (None, ai_msg ) ) |
|
|
|
reserved_chatbot = chatbot.copy() |
|
|
|
chats = extract_chats( chatbot ) |
|
|
|
print_last_chat( chats ) |
|
return chatbot, reserved_chatbot |
|
|
|
def callback_remove_one_chat(chatbot, reserved_chatbot): |
|
if len(chatbot) > 1: |
|
chatbot.pop() |
|
return chatbot |
|
|
|
def callback_recover_one_chat(chatbot, reserved_chatbot): |
|
if len(chatbot) < len(reserved_chatbot): |
|
chatbot.append( reserved_chatbot[len(chatbot)] ) |
|
return chatbot |
|
|
|
def callback_clean(): |
|
return [], [] |
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown( |
|
""" |
|
# Story Teller Demo |
|
|
|
implemented by [Cheng Li](https://github.com/LC1332) and [Weishi MI](https://github.com/hhhwmws0117) |
|
|
|
本项目是ChatHaruhi的子项目,原项目链接 [https://github.com/LC1332/Chat-Haruhi-Suzumiya](https://github.com/LC1332/Chat-Haruhi-Suzumiya) |
|
|
|
如果觉得好玩可以去点个star |
|
|
|
这个Gradio是一个初步的尝试,之后考虑做一套更正式的story-teller的算法 |
|
""" |
|
) |
|
with gr.Row(): |
|
with gr.Column(): |
|
with gr.Row(height = 800): |
|
|
|
chatbot = gr.Chatbot(height = 800) |
|
with gr.Row(): |
|
user_role = gr.Textbox(label="user_role", scale=1) |
|
user_text = gr.Textbox(label="user_text", scale=20) |
|
with gr.Row(): |
|
user_submit = gr.Button("User Submit") |
|
|
|
|
|
|
|
with gr.Column(): |
|
with gr.Row(): |
|
ai_role = gr.Radio(['汤师爷', '慕容复', '李云龙', |
|
'Luna', '王多鱼', 'Ron', '鸠摩智', |
|
'Snape', '凉宫春日', 'Malfoy', '虚竹', |
|
'萧峰', '段誉', 'Hermione', 'Dumbledore', |
|
'王语嫣', |
|
'Harry', 'McGonagall', |
|
'白展堂', '佟湘玉', '郭芙蓉', |
|
'旅行者', '钟离', '胡桃', |
|
'Sheldon', 'Raj', 'Penny', |
|
'韦小宝', '乔峰', '神里绫华', |
|
'雷电将军', '于谦'], label="characters", value='凉宫春日') |
|
with gr.Row(): |
|
ai_submit = gr.Button("AI Submit") |
|
|
|
with gr.Row(): |
|
remove_one_chat = gr.Button("Remove One Chat") |
|
recover_one_chat = gr.Button("Recover One Chat") |
|
with gr.Row(): |
|
clean = gr.Button("Clean") |
|
|
|
reserved_chatbot = gr.Chatbot(visible = False) |
|
|
|
user_submit.click(fn = user_response, inputs = [user_role, user_text, chatbot], outputs = [user_text, chatbot,reserved_chatbot] ) |
|
ai_submit.click(fn = ai_response, inputs = [ai_role, chatbot], outputs = [chatbot,reserved_chatbot] ) |
|
|
|
remove_one_chat.click(fn = callback_remove_one_chat, inputs = [chatbot, reserved_chatbot], outputs = [chatbot] ) |
|
recover_one_chat.click(fn = callback_recover_one_chat, inputs = [chatbot, reserved_chatbot], outputs = [chatbot] ) |
|
|
|
clean.click(fn = callback_clean, inputs = [], outputs = [chatbot,reserved_chatbot] ) |
|
|
|
|
|
demo.launch(debug=True, share=True) |