|
import gradio as gr |
|
import os |
|
import httpx |
|
import openai |
|
from openai import OpenAI |
|
from openai import AsyncOpenAI |
|
|
|
from datasets import load_dataset |
|
|
|
dataset = load_dataset("silk-road/50-Chinese-Novel-Characters") |
|
|
|
|
|
novel_list = [] |
|
|
|
novel2roles = {} |
|
|
|
role2datas = {} |
|
|
|
from tqdm import tqdm |
|
for data in tqdm(dataset['train']): |
|
novel = data['book'] |
|
role = data['role'] |
|
if novel not in novel_list: |
|
novel_list.append(novel) |
|
|
|
if novel not in novel2roles: |
|
novel2roles[novel] = [] |
|
|
|
if role not in novel2roles[novel]: |
|
novel2roles[novel].append(role) |
|
|
|
role_tuple = (novel, role) |
|
|
|
if role_tuple not in role2datas: |
|
role2datas[role_tuple] = [] |
|
|
|
role2datas[role_tuple].append(data) |
|
|
|
|
|
from ChatHaruhi.utils import base64_to_float_array |
|
|
|
from tqdm import tqdm |
|
|
|
for novel in tqdm(novel_list): |
|
for role in novel2roles[novel]: |
|
for data in role2datas[(novel, role)]: |
|
data["vec"] = base64_to_float_array(data["bge_zh_s15"]) |
|
|
|
def conv2story( role, conversations ): |
|
lines = [conv["value"] if conv["from"] == "human" else role + ": " + conv["value"] for conv in conversations] |
|
return "\n".join(lines) |
|
|
|
for novel in tqdm(novel_list): |
|
for role in novel2roles[novel]: |
|
for data in role2datas[(novel, role)]: |
|
data["story"] = conv2story( role, data["conversations"] ) |
|
|
|
|
|
from ChatHaruhi import ChatHaruhi |
|
from ChatHaruhi.response_openai import get_response as get_response_openai |
|
from ChatHaruhi.response_zhipu import get_response as get_response_zhipu |
|
from ChatHaruhi.response_erniebot import get_response as get_response_erniebot |
|
from ChatHaruhi.response_spark import get_response as get_response_spark |
|
|
|
|
|
get_response = get_response_zhipu |
|
|
|
narrators = ["叙述者", "旁白","文章作者","作者","Narrator","narrator"] |
|
|
|
|
|
def package_persona( role_name, world_name ): |
|
if role_name in narrators: |
|
return package_persona_for_narrator( role_name, world_name ) |
|
|
|
return f"""I want you to act like {role_name} from {world_name}. |
|
If others‘ questions are related with the novel, please try to reuse the original lines from the novel. |
|
I want you to respond and answer like {role_name} using the tone, manner and vocabulary {role_name} would use.""" |
|
|
|
def package_persona_for_narrator( role_name, world_name ): |
|
return f"""I want you to act like narrator {role_name} from {world_name}. |
|
当角色行动之后,继续交代和推进新的剧情.""" |
|
|
|
role_tuple2chatbot = {} |
|
|
|
|
|
def initialize_chatbot( novel, role ): |
|
global role_tuple2chatbot |
|
if (novel, role) not in role_tuple2chatbot: |
|
persona = package_persona( role, novel ) |
|
persona += "\n{{RAG对话}}\n{{RAG对话}}\n{{RAG对话}}\n" |
|
stories = [data["story"] for data in role2datas[(novel, role)] ] |
|
vecs = [data["vec"] for data in role2datas[(novel, role)] ] |
|
chatbot = ChatHaruhi( role_name = role, persona = persona , stories = stories, story_vecs= vecs,\ |
|
llm = get_response) |
|
chatbot.verbose = False |
|
|
|
role_tuple2chatbot[(novel, role)] = chatbot |
|
|
|
from tqdm import tqdm |
|
for novel in tqdm(novel_list): |
|
for role in novel2roles[novel]: |
|
initialize_chatbot( novel, role ) |
|
|
|
readme_text = """# 使用说明 |
|
|
|
选择小说角色 |
|
|
|
如果你有什么附加信息,添加到附加信息里面就可以 |
|
|
|
比如"韩立会炫耀自己刚刚学会了Python" |
|
|
|
然后就可以开始聊天了 |
|
|
|
因为这些角色还没有增加Greeting信息,所以之后再开发个随机乱聊功能 |
|
|
|
# 开发细节 |
|
|
|
- 采用ChatHaruhi3.0的接口进行prompting |
|
- 这里的数据是用一个7B的tuned qwen模型进行抽取的 |
|
- 想看数据可以去看第三个tab |
|
- 抽取模型用了40k左右的GLM蒸馏数据 |
|
- 抽取模型是腾讯大哥BPSK训练的 |
|
|
|
# 总结人物性格 |
|
|
|
第三个Tab里面,可以显示一个prompt总结人物的性格 |
|
|
|
复制到openai或者GLM或者Claude进行人物总结 |
|
|
|
|
|
# 这些小说数据从HaruhiZero 0.4模型开始,被加入训练 |
|
|
|
openai太慢了 今天试试GLM的 |
|
|
|
不过当前demo是openai的 |
|
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from ChatHaruhi.response_openai import get_response, async_get_response |
|
import gradio as gr |
|
|
|
def get_role_list( novel ): |
|
new_list = novel2roles[novel] |
|
new_value = new_list[0] |
|
return gr.update(choices = new_list, value = new_value) |
|
|
|
save_log = "/content/output.txt" |
|
|
|
def get_chatbot( novel, role ): |
|
if (novel, role) not in role_tuple2chatbot: |
|
initialize_chatbot( novel, role ) |
|
|
|
return role_tuple2chatbot[(novel, role)] |
|
|
|
import json |
|
|
|
def random_chat_callback( novel, role, chat_history): |
|
datas = role2datas[(novel, role)] |
|
|
|
reesponse_set = set() |
|
|
|
for chat_tuple in chat_history: |
|
if chat_tuple[1] is not None: |
|
reesponse_set.add(chat_tuple[1]) |
|
|
|
for _ in range(5): |
|
random_data = random.choice(datas) |
|
convs = random_data["conversations"] |
|
n = len(convs) |
|
index = [x for x in range(0,n,2)] |
|
|
|
for i in index: |
|
query = convs[i]['value'] |
|
response = convs[i+1]['value'] |
|
if response not in reesponse_set: |
|
chat_history.append( (query, response) ) |
|
return chat_history |
|
|
|
return chat_history |
|
|
|
|
|
|
|
async def submit_chat( novel, role, user_name, user_text, chat_history, persona_addition_info,model_sel): |
|
|
|
if len(user_text) > 400: |
|
user_text = user_text[:400] |
|
|
|
if_user_in_text = True |
|
|
|
chatbot = get_chatbot( novel, role ) |
|
chatbot.persona = initialize_persona( novel, role, persona_addition_info) |
|
|
|
|
|
if model_sel == "openai": |
|
chatbot.llm = get_response_openai |
|
elif model_sel == "Zhipu": |
|
chatbot.llm = get_response_zhipu |
|
elif model_sel == "spark": |
|
chatbot.llm = get_response_spark |
|
else: |
|
chatbot.llm = get_response_erniebot |
|
|
|
|
|
history = [] |
|
|
|
for chat_tuple in chat_history: |
|
if chat_tuple[0] is not None: |
|
history.append( {"speaker":"{{user}}","content":chat_tuple[0]} ) |
|
if chat_tuple[1] is not None: |
|
history.append( {"speaker":"{{role}}","content":chat_tuple[1]} ) |
|
|
|
chatbot.history = history |
|
|
|
input_text = user_text |
|
|
|
if if_user_in_text: |
|
input_text = user_name + " : " + user_text |
|
response = chatbot.chat(user = "", text = input_text ) |
|
|
|
else: |
|
response = chatbot.chat(user = user_name, text = input_text) |
|
|
|
chat_history.append( (input_text, response) ) |
|
|
|
print_data = {"novel":novel, "role":role, "user_text":input_text, "response":response} |
|
|
|
print(json.dumps(print_data, ensure_ascii=False)) |
|
|
|
with open(save_log, "a",encoding = "utf-8") as f: |
|
f.write(json.dumps(print_data, ensure_ascii=False) + "\n") |
|
|
|
return chat_history |
|
|
|
|
|
def initialize_persona( novel, role, persona_addition_info): |
|
whole_persona = package_persona( role, novel ) |
|
whole_persona += "\n" + persona_addition_info |
|
whole_persona += "\n{{RAG对话}}\n{{RAG对话}}\n{{RAG对话}}\n" |
|
|
|
return whole_persona |
|
|
|
def clean_history( ): |
|
return [] |
|
|
|
def clean_input(): |
|
return "" |
|
|
|
import random |
|
|
|
def generate_summarize_prompt( novel, role_name ): |
|
whole_prompt = f''' |
|
你在分析小说{novel}中的角色{role_name} |
|
结合小说{novel}中的内容,以及下文中角色{role_name}的对话 |
|
判断{role_name}的人物设定、人物特点以及语言风格 |
|
|
|
{role_name}的对话: |
|
''' |
|
stories = [data["story"] for data in role2datas[(novel, role_name)] ] |
|
|
|
sample_n = 5 |
|
|
|
sample_stories = random.sample(stories, sample_n) |
|
|
|
for story in sample_stories: |
|
whole_prompt += story + "\n\n" |
|
|
|
return whole_prompt.strip() |
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("""# 50本小说的人物测试 |
|
|
|
这个interface由李鲁鲁实现,主要是用来看语料的 |
|
|
|
增加了随机聊天,支持GLM,openai切换 |
|
|
|
米唯实接入了qwen1.8B并布置于huggingface上""") |
|
|
|
with gr.Tab("聊天"): |
|
with gr.Row(): |
|
novel_sel = gr.Dropdown( novel_list, label = "小说", value = "悟空传" , interactive = True) |
|
role_sel = gr.Dropdown( novel2roles[novel_sel.value], label = "角色", value = "孙悟空", interactive = True ) |
|
|
|
with gr.Row(): |
|
chat_history = gr.Chatbot(height = 600) |
|
|
|
with gr.Row(): |
|
user_name = gr.Textbox(label="user_name", scale = 1, value = "鲁鲁", interactive = True) |
|
user_text = gr.Textbox(label="user_text", scale = 20) |
|
submit = gr.Button("submit", scale = 1) |
|
|
|
with gr.Row(): |
|
random_chat = gr.Button("随机聊天", scale = 1) |
|
clean_message = gr.Button("清空聊天", scale = 1) |
|
|
|
with gr.Row(): |
|
persona_addition_info = gr.TextArea( label = "额外人物设定", value = "", interactive = True ) |
|
|
|
with gr.Row(): |
|
update_persona = gr.Button("补充人物设定到prompt", scale = 1) |
|
model_sel = gr.Radio(["Zhipu","openai","spark","erniebot"], interactive = True, scale = 5, value = "Zhipu", label = "模型选择") |
|
|
|
with gr.Row(): |
|
whole_persona = gr.TextArea( label = "完整的system prompt", value = "", interactive = False ) |
|
|
|
novel_sel.change(fn = get_role_list, inputs = [novel_sel], outputs = [role_sel]).then(fn = initialize_persona, inputs = [novel_sel, role_sel, persona_addition_info], outputs = [whole_persona]) |
|
|
|
role_sel.change(fn = initialize_persona, inputs = [novel_sel, role_sel, persona_addition_info], outputs = [whole_persona]) |
|
|
|
update_persona.click(fn = initialize_persona, inputs = [novel_sel, role_sel, persona_addition_info], outputs = [whole_persona]) |
|
|
|
random_chat.click(fn = random_chat_callback, inputs = [novel_sel, role_sel, chat_history], outputs = [chat_history]) |
|
|
|
user_text.submit(fn = submit_chat, inputs = [novel_sel, role_sel, user_name, user_text, chat_history, persona_addition_info,model_sel], outputs = [chat_history]).then(fn = clean_input, inputs = [], outputs = [user_text]) |
|
submit.click(fn = submit_chat, inputs = [novel_sel, role_sel, user_name, user_text, chat_history, persona_addition_info,model_sel], outputs = [chat_history]).then(fn = clean_input, inputs = [], outputs = [user_text]) |
|
|
|
clean_message.click(fn = clean_history, inputs = [], outputs = [chat_history]) |
|
|
|
with gr.Tab("README"): |
|
gr.Markdown(readme_text) |
|
|
|
with gr.Tab("辅助人物总结"): |
|
with gr.Row(): |
|
generate_prompt = gr.Button("生成人物总结prompt", scale = 1) |
|
|
|
with gr.Row(): |
|
whole_prompt = gr.TextArea( label = "复制这个prompt到Openai或者GLM或者Claude进行总结", value = "", interactive = False ) |
|
|
|
generate_prompt.click(fn = generate_summarize_prompt, inputs = [novel_sel, role_sel], outputs = [whole_prompt]) |
|
|
|
|
|
|
|
|
|
|
|
demo.launch(share=True, debug = True) |