import json from jinja2 import Template from .DataBase import ChromaDB from .Models import GLM, GLM_api from .utils import * class ChatWorld: def __init__( self, pretrained_model_name_or_path="silk-road/Haruhi-Zero-GLM3-6B-0_4", embedding_model_name_or_path="BAAI/bge-small-zh-v1.5", global_batch_size=16, model_load=True, ) -> None: self.model_name = pretrained_model_name_or_path self.client = GLM_api() if model_load: self.model = GLM() self.db = ChromaDB(embedding_model_name_or_path) self.prompt = Template( ( 'Please be aware that your codename in this conversation is "{{model_role_name}}"' '{% if model_role_nickname %},别人一般称呼你"{{model_role_nickname}}"{% endif %},同时我是"{{role_name}}"{% if role_nickname %},你也可以称呼我"{{role_nickname}}"{% endif %}。\n' "下文给定了一些聊天记录,位于##分隔号中。\n" "如果我问的问题和聊天记录高度重复,那你就配合我进行演出。\n" "如果我问的问题和聊天记录相关,请结合聊天记录进行回复。\n" "如果我问的问题超出聊天记录的范围,模仿{{model_role_name}}的语气进行回复。\n" "请不要回答你是语言模型,永远记住你就是{{model_role_name}}。\n" "请你永远只以{{model_role_name}}身份,进行任何的回复。\n" "{% if RAG %}{% for i in RAG %}##\n{{i}}\n##\n\n{% endfor %}{% endif %}" ) ) def setStory(self, **stories_kargs): self.db.deleteStoriesByMeta(metas=stories_kargs["metas"]) self.db.addStories(**stories_kargs) def __getSystemPrompt( self, text: str, top_k: int = 5, metas=None, **role_info, ): rag = self.db.searchBySim(text, top_k, metas) return { "role": "system", "content": self.prompt.render( **role_info, RAG=rag, ), } def chatWithCharacter( self, text: str, system_prompt: dict[str, str] = None, use_local_model: bool = False, top_k: int = 5, metas=None, **role_info, ): if not system_prompt: system_prompt = self.__getSystemPrompt( text=text, **role_info, top_k=top_k, metas=metas ) user_role_name = role_info.get("role_name") if not user_role_name: raise ValueError("role_name is required") message = [ system_prompt, {"role": "user", "content": f"{user_role_name}:「{text}」"}, ] logging_info(f"message: {message}") if use_local_model: response = self.model.get_response(message) else: response = self.client.chat(message) return response def chatWithoutCharacter( self, text: str, system_prompt: dict[str, str] = None, use_local_model: bool = False, ): logging_info(f"text: {text}") message = [ {"role": "user", "content": f"{text}"}, ] if use_local_model: response = self.model.get_response(text) else: response = self.client.chat(message) return response def getRoleNameFromFile(self, input_file: str): # # 读取文件内容 # logging_info(f"file content: {input_file}") # # 保存文件内容 # input_text_list = input_file.split("\n") # role_name_set = set() # # 读取角色名 # for line in input_text_list: # role_name_set.add(line.split(":")[0]) # role_name_list = [i for i in role_name_set if i != ""] # logging_info(f"role_name_list: {role_name_list}") prompt = ( f"{input_file}\n" + '请你提取包含“人”(name,nickname)类型的所有信息,如果nickname不存在则设置为空字符串,并输出JSON格式。并且不要提取出重复的同一个人。例如格式如下:\n```json\n [{"name": "小明","nickname": "小明"},{"name": "小红","nickname": ""}]```' ) respense = self.chatWithoutCharacter(prompt, use_local_model=False) json_start_index = respense.find("```json") json_end_index = respense.find("```", json_start_index + 1) json_str = respense[json_start_index + 7 : json_end_index] print(json_str) try: json_str = json.loads(json_str) role_name_list = [i["name"] for i in json_str] role_name_dict = {i["name"]: i["nickname"] for i in json_str} except Exception as e: print(e) role_name_list = [] role_name_dict = {} return role_name_list, role_name_dict