Spaces:
Runtime error
Runtime error
from .ChromaDB import ChromaDB | |
import os | |
from .utils import luotuo_openai_embedding, tiktokenizer | |
from .utils import response_postprocess | |
class ChatHaruhi: | |
def __init__(self, system_prompt = None, \ | |
role_name = None, \ | |
story_db=None, story_text_folder = None, \ | |
llm = 'openai', \ | |
embedding = 'luotuo_openai', \ | |
max_len_story = None, max_len_history = None, | |
verbose = False): | |
super(ChatHaruhi, self).__init__() | |
self.verbose = verbose | |
# constants | |
self.story_prefix_prompt = "Classic scenes for the role are as follows:\n" | |
self.k_search = 19 | |
self.narrator = ['旁白', '', 'scene','Scene','narrator' , 'Narrator'] | |
self.dialogue_divide_token = '\n###\n' | |
self.dialogue_bra_token = '「' | |
self.dialogue_ket_token = '」' | |
if system_prompt: | |
self.system_prompt = self.check_system_prompt( system_prompt ) | |
# TODO: embedding should be the seperately defined, so refactor this part later | |
if llm == 'openai': | |
# self.llm = LangChainGPT() | |
self.llm, self.tokenizer = self.get_models('openai') | |
elif llm == 'debug': | |
self.llm, self.tokenizer = self.get_models( 'debug') | |
elif llm == 'spark': | |
self.llm, self.tokenizer = self.get_models( 'spark') | |
elif llm == 'GLMPro': | |
self.llm, self.tokenizer = self.get_models( 'GLMPro') | |
elif llm == 'ChatGLM2GPT': | |
self.llm, self.tokenizer = self.get_models( 'ChatGLM2GPT') | |
self.story_prefix_prompt = '\n' | |
else: | |
print(f'warning! undefined llm {llm}, use openai instead.') | |
self.llm, self.tokenizer = self.get_models('openai') | |
if embedding == 'luotuo_openai': | |
self.embedding = luotuo_openai_embedding | |
else: | |
print(f'warning! undefined embedding {embedding}, use luotuo_openai instead.') | |
self.embedding = luotuo_openai_embedding | |
if role_name: | |
from .role_name_to_file import get_folder_role_name | |
# correct role_name to folder_role_name | |
role_name, url = get_folder_role_name(role_name) | |
unzip_folder = f'./temp_character_folder/temp_{role_name}' | |
db_folder = os.path.join(unzip_folder, f'content/{role_name}') | |
system_prompt = os.path.join(unzip_folder, f'content/system_prompt.txt') | |
if not os.path.exists(unzip_folder): | |
# not yet downloaded | |
# url = f'https://github.com/LC1332/Haruhi-2-Dev/raw/main/data/character_in_zip/{role_name}.zip' | |
import requests, zipfile, io | |
r = requests.get(url) | |
z = zipfile.ZipFile(io.BytesIO(r.content)) | |
z.extractall(unzip_folder) | |
if self.verbose: | |
print(f'loading pre-defined character {role_name}...') | |
self.db = ChromaDB() | |
self.db.load(db_folder) | |
self.system_prompt = self.check_system_prompt(system_prompt) | |
elif story_db: | |
self.db = ChromaDB() | |
self.db.load(story_db) | |
elif story_text_folder: | |
# print("Building story database from texts...") | |
self.db = self.build_story_db(story_text_folder) | |
else: | |
self.db = None | |
print('warning! database not yet figured out, both story_db and story_text_folder are not inputted.') | |
# raise ValueError("Either story_db or story_text_folder must be provided") | |
self.max_len_story, self.max_len_history = self.get_tokenlen_setting('openai') | |
if max_len_history is not None: | |
self.max_len_history = max_len_history | |
# user setting will override default setting | |
if max_len_story is not None: | |
self.max_len_story = max_len_story | |
# user setting will override default setting | |
self.dialogue_history = [] | |
def check_system_prompt(self, system_prompt): | |
# if system_prompt end with .txt, read the file with utf-8 | |
# else, return the string directly | |
if system_prompt.endswith('.txt'): | |
with open(system_prompt, 'r', encoding='utf-8') as f: | |
return f.read() | |
else: | |
return system_prompt | |
def get_models(self, model_name): | |
# TODO: if output only require tokenizer model, no need to initialize llm | |
# return the combination of llm, embedding and tokenizer | |
if model_name == 'openai': | |
from .LangChainGPT import LangChainGPT | |
return (LangChainGPT(), tiktokenizer) | |
elif model_name == 'debug': | |
from .PrintLLM import PrintLLM | |
return (PrintLLM(), tiktokenizer) | |
elif model_name == 'spark': | |
from .SparkGPT import SparkGPT | |
return (SparkGPT(), tiktokenizer) | |
elif model_name == 'GLMPro': | |
from .GLMPro import GLMPro | |
return (GLMPro(), tiktokenizer) | |
elif model_name == "ChatGLM2GPT": | |
from .ChatGLM2GPT import ChatGLM2GPT, GLM_tokenizer | |
return (ChatGLM2GPT(), GLM_tokenizer) | |
else: | |
print(f'warning! undefined model {model_name}, use openai instead.') | |
from .LangChainGPT import LangChainGPT | |
return (LangChainGPT(), tiktokenizer) | |
def get_tokenlen_setting( self, model_name ): | |
# return the setting of story and history token length | |
if model_name == 'openai': | |
return (1500, 1200) | |
else: | |
print(f'warning! undefined model {model_name}, use openai instead.') | |
return (1500, 1200) | |
def build_story_db_from_vec( self, texts, vecs ): | |
self.db = ChromaDB() | |
self.db.init_from_docs( vecs, texts) | |
def build_story_db(self, text_folder): | |
# 实现读取文本文件夹,抽取向量的逻辑 | |
db = ChromaDB() | |
strs = [] | |
# scan all txt file from text_folder | |
for file in os.listdir(text_folder): | |
# if file name end with txt | |
if file.endswith(".txt"): | |
file_path = os.path.join(text_folder, file) | |
with open(file_path, 'r', encoding='utf-8') as f: | |
strs.append(f.read()) | |
if self.verbose: | |
print(f'starting extract embedding... for { len(strs) } files') | |
vecs = [] | |
## TODO: 建立一个新的embedding batch test的单元测试 | |
## 新的支持list batch test的embedding代码 | |
## 用新的代码替换下面的for循环 | |
## Luotuo-bert-en也发布了,所以可以避开使用openai | |
for mystr in strs: | |
vecs.append(self.embedding(mystr)) | |
db.init_from_docs(vecs, strs) | |
return db | |
def save_story_db(self, db_path): | |
self.db.save(db_path) | |
def chat(self, text, role): | |
# add system prompt | |
self.llm.initialize_message() | |
self.llm.system_message(self.system_prompt) | |
# add story | |
query = self.get_query_string(text, role) | |
self.add_story( query ) | |
# add history | |
self.add_history() | |
# add query | |
self.llm.user_message(query) | |
# get response | |
response_raw = self.llm.get_response() | |
response = response_postprocess(response_raw, self.dialogue_bra_token, self.dialogue_ket_token) | |
# record dialogue history | |
self.dialogue_history.append((query, response)) | |
return response | |
def get_query_string(self, text, role): | |
if role in self.narrator: | |
return role + ":" + text | |
else: | |
return f"{role}:{self.dialogue_bra_token}{text}{self.dialogue_ket_token}" | |
def add_story(self, query): | |
if self.db is None: | |
return | |
query_vec = self.embedding(query) | |
stories = self.db.search(query_vec, self.k_search) | |
story_string = self.story_prefix_prompt | |
sum_story_token = self.tokenizer(story_string) | |
for story in stories: | |
story_token = self.tokenizer(story) + self.tokenizer(self.dialogue_divide_token) | |
if sum_story_token + story_token > self.max_len_story: | |
break | |
else: | |
sum_story_token += story_token | |
story_string += story + self.dialogue_divide_token | |
self.llm.user_message(story_string) | |
def add_history(self): | |
if len(self.dialogue_history) == 0: | |
return | |
sum_history_token = 0 | |
flag = 0 | |
for query, response in reversed(self.dialogue_history): | |
current_count = 0 | |
if query is not None: | |
current_count += self.tokenizer(query) | |
if response is not None: | |
current_count += self.tokenizer(response) | |
sum_history_token += current_count | |
if sum_history_token > self.max_len_history: | |
break | |
else: | |
flag += 1 | |
if flag == 0: | |
print('warning! no history added. the last dialogue is too long.') | |
for (query, response) in self.dialogue_history[-flag:]: | |
if query is not None: | |
self.llm.user_message(query) | |
if response is not None: | |
self.llm.ai_message(response) | |