Spaces:
Runtime error
Runtime error
File size: 18,176 Bytes
aef3deb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 |
from .utils import base64_to_float_array, base64_to_string
def get_text_from_data( data ):
if "text" in data:
return data['text']
elif "enc_text" in data:
# from .utils import base64_to_string
return base64_to_string( data['enc_text'] )
else:
print("warning! failed to get text from data ", data)
return ""
def parse_rag(text):
lines = text.split("\n")
ans = []
for i, line in enumerate(lines):
if "{{RAG对话}}" in line:
ans.append({"n": 1, "max_token": -1, "query": "default", "lid": i})
elif "{{RAG对话|" in line:
query_info = line.split("|")[1].rstrip("}}")
ans.append({"n": 1, "max_token": -1, "query": query_info, "lid": i})
elif "{{RAG多对话|" in line:
parts = line.split("|")
max_token = int(parts[1].split("<=")[1])
max_n = int(parts[2].split("<=")[1].rstrip("}}"))
ans.append({"n": max_n, "max_token": max_token, "query": "default", "lid": i})
return ans
class ChatHaruhi:
def __init__(self,
role_name = None,
user_name = None,
persona = None,
stories = None,
story_vecs = None,
role_from_hf = None,
role_from_jsonl = None,
llm = None, # 默认的message2response的函数
llm_async = None, # 默认的message2response的async函数
user_name_in_message = "default",
verbose = None,
embed_name = None,
embedding = None,
db = None,
token_counter = "default",
max_input_token = 1800,
max_len_story_haruhi = 1000,
max_story_n_haruhi = 5
):
self.verbose = True if verbose is None or verbose else False
self.db = db
self.embed_name = embed_name
self.max_len_story_haruhi = max_len_story_haruhi # 这个设置只对过往Haruhi的sugar角色有效
self.max_story_n_haruhi = max_story_n_haruhi # 这个设置只对过往Haruhi的sugar角色有效
self.last_query_msg = None
if embedding is None:
self.embedding = self.set_embedding_with_name( embed_name )
if persona and role_name and stories and story_vecs and len(stories) == len(story_vecs):
# 完全从外部设置,这个时候要求story_vecs和embedding的返回长度一致
self.persona, self.role_name, self.user_name = persona, role_name, user_name
self.build_db(stories, story_vecs)
elif persona and role_name and stories:
# 从stories中提取story_vecs,重新用self.embedding进行embedding
story_vecs = self.extract_story_vecs(stories)
self.persona, self.role_name, self.user_name = persona, role_name, user_name
self.build_db(stories, story_vecs)
elif role_from_hf:
# 从hf加载role
self.persona, new_role_name, self.stories, self.story_vecs = self.load_role_from_hf(role_from_hf)
if new_role_name:
self.role_name = new_role_name
else:
self.role_name = role_name
self.user_name = user_name
self.build_db(self.stories, self.story_vecs)
elif role_from_jsonl:
# 从jsonl加载role
self.persona, new_role_name, self.stories, self.story_vecs = self.load_role_from_jsonl(role_from_jsonl)
if new_role_name:
self.role_name = new_role_name
else:
self.role_name = role_name
self.user_name = user_name
self.build_db(self.stories, self.story_vecs)
elif persona and role_name:
# 这个时候也就是说没有任何的RAG,
self.persona, self.role_name, self.user_name = persona, role_name, user_name
self.db = None
elif role_name and self.check_sugar( role_name ):
# 这个时候是sugar的role
self.persona, self.role_name, self.stories, self.story_vecs = self.load_role_from_sugar( role_name )
self.build_db(self.stories, self.story_vecs)
# 与 江YH讨论 所有的载入方式都要在外部使用 add_rag_prompt_after_persona() 防止混淆
# self.add_rag_prompt_after_persona()
else:
raise ValueError("persona和role_name必须同时设置,或者role_name是ChatHaruhi的预设人物")
self.llm, self.llm_async = llm, llm_async
if not self.llm and self.verbose:
print("warning, llm没有设置,仅get_message起作用,调用chat将回复idle message")
self.user_name_in_message = user_name_in_message
self.previous_user_pool = set([user_name]) if user_name else set()
self.current_user_name_in_message = user_name_in_message.lower() == "add"
self.idle_message = "idel message, you see this because self.llm has not been set."
if token_counter.lower() == "default":
# TODO change load from util
from .utils import tiktoken_counter
self.token_counter = tiktoken_counter
elif token_counter == None:
self.token_counter = lambda x: 0
else:
self.token_counter = token_counter
if self.verbose:
print("user set costomized token_counter")
self.max_input_token = max_input_token
self.history = []
def check_sugar(self, role_name):
from .sugar_map import sugar_role_names, enname2zhname
return role_name in sugar_role_names
def load_role_from_sugar(self, role_name):
from .sugar_map import sugar_role_names, enname2zhname
en_role_name = sugar_role_names[role_name]
new_role_name = enname2zhname[en_role_name]
role_from_hf = "silk-road/ChatHaruhi-RolePlaying/" + en_role_name
persona, _, stories, story_vecs = self.load_role_from_hf(role_from_hf)
return persona, new_role_name, stories, story_vecs
def add_rag_prompt_after_persona( self ):
rag_sentence = "{{RAG多对话|token<=" + str(self.max_len_story_haruhi) + "|n<=" + str(self.max_story_n_haruhi) + "}}"
self.persona += "Classic scenes for the role are as follows:\n" + rag_sentence + "\n"
def set_embedding_with_name(self, embed_name):
if embed_name is None or embed_name == "bge_zh":
from .embeddings import get_bge_zh_embedding
self.embed_name = "bge_zh"
return get_bge_zh_embedding
elif embed_name == "foo":
from .embeddings import foo_embedding
return foo_embedding
elif embed_name == "bce":
from .embeddings import foo_bce
return foo_bce
elif embed_name == "openai" or embed_name == "luotuo_openai":
from .embeddings import foo_openai
return foo_openai
def set_new_user(self, user):
if len(self.previous_user_pool) > 0 and user not in self.previous_user_pool:
if self.user_name_in_message.lower() == "default":
if self.verbose:
print(f'new user {user} included in conversation')
self.current_user_name_in_message = True
self.user_name = user
self.previous_user_pool.add(user)
def chat(self, user, text):
self.set_new_user(user)
message = self.get_message(user, text)
if self.llm:
response = self.llm(message)
self.append_message(response)
return response
return None
async def async_chat(self, user, text):
self.set_new_user(user)
message = self.get_message(user, text)
if self.llm_async:
response = await self.llm_async(message)
self.append_message(response)
return response
def parse_rag_from_persona(self, persona, text = None):
#每个query_rag需要饱含
# "n" 需要几个story
# "max_token" 最多允许多少个token,如果-1则不限制
# "query" 需要查询的内容,如果等同于"default"则替换为text
# "lid" 需要替换的行,这里直接进行行替换,忽视行的其他内容
query_rags = parse_rag( persona )
if text is not None:
for rag in query_rags:
if rag['query'] == "default":
rag['query'] = text
return query_rags, self.token_counter(persona)
def append_message( self, response , speaker = None ):
if self.last_query_msg is not None:
self.history.append(self.last_query_msg)
self.last_query_msg = None
if speaker is None:
# 如果role是none,则认为是本角色{{role}}输出的句子
self.history.append({"speaker":"{{role}}","content":response})
# 叫speaker是为了和role进行区分
else:
self.history.append({"speaker":speaker,"content":response})
def check_recompute_stories_token(self):
return len(self.db.metas) == len(self.db.stories)
def recompute_stories_token(self):
self.db.metas = [self.token_counter(story) for story in self.db.stories]
def rag_retrieve( self, query, n, max_token, avoid_ids = [] ):
# 返回一个rag_id的列表
query_vec = self.embedding(query)
self.db.clean_flag()
self.db.disable_story_with_ids( avoid_ids )
retrieved_ids = self.db.search( query_vec, n )
if self.check_recompute_stories_token():
self.recompute_stories_token()
sum_token = 0
ans = []
for i in range(0, len(retrieved_ids)):
if i == 0:
sum_token += self.db.metas[retrieved_ids[i]]
ans.append(retrieved_ids[i])
continue
else:
sum_token += self.db.metas[retrieved_ids[i]]
if sum_token <= max_token:
ans.append(retrieved_ids[i])
else:
break
return ans
def rag_retrieve_all( self, query_rags, rest_limit ):
# 返回一个rag_ids的列表
retrieved_ids = []
rag_ids = []
for query_rag in query_rags:
query = query_rag['query']
n = query_rag['n']
max_token = rest_limit
if rest_limit > query_rag['max_token'] and query_rag['max_token'] > 0:
max_token = query_rag['max_token']
rag_id = self.rag_retrieve( query, n, max_token, avoid_ids = retrieved_ids )
rag_ids.append( rag_id )
retrieved_ids += rag_id
return rag_ids
def append_history_under_limit(self, message, rest_limit):
# 返回一个messages的列表
# print("call append history_under_limit")
# 从后往前计算token,不超过rest limit,
# 如果speaker是{{role}J,则message的role是assistant
current_limit = rest_limit
history_list = []
for item in reversed(self.history):
current_token = self.token_counter(item['content'])
current_limit -= current_token
if current_limit < 0:
break
else:
history_list.append(item)
history_list = list(reversed(history_list))
# TODO: 之后为了解决多人对话,这了content还会额外增加speaker: content这样的信息
for item in history_list:
if item['speaker'] == "{{role}}":
message.append({"role":"assistant","content":item['content']})
else:
message.append({"role":"user","content":item['content']})
return message
def get_message(self, user, text):
query_token = self.token_counter(text)
# 首先获取需要多少个rag story
query_rags, persona_token = self.parse_rag_from_persona( self.persona, text )
#每个query_rag需要饱含
# "n" 需要几个story
# "max_token" 最多允许多少个token,如果-1则不限制
# "query" 需要查询的内容,如果等同于"default"则替换为text
# "lid" 需要替换的行,这里直接进行行替换,忽视行的其他内容
rest_limit = self.max_input_token - persona_token - query_token
if self.verbose:
print(f"query_rags: {query_rags} rest_limit = { rest_limit }")
rag_ids = self.rag_retrieve_all( query_rags, rest_limit )
# 将rag_ids对应的故事 替换到persona中
augmented_persona = self.augment_persona( self.persona, rag_ids, query_rags )
system_prompt = self.package_system_prompt( self.role_name, augmented_persona )
token_for_system = self.token_counter( system_prompt )
rest_limit = self.max_input_token - token_for_system - query_token
message = [{"role":"system","content":system_prompt}]
message = self.append_history_under_limit( message, rest_limit )
# TODO: 之后为了解决多人对话,这了content还会额外增加speaker: content这样的信息
message.append({"role":"user","content":text})
self.last_query_msg = {"speaker":user,"content":text}
return message
def package_system_prompt(self, role_name, augmented_persona):
bot_name = role_name
return f"""You are now in roleplay conversation mode. Pretend to be {bot_name} whose persona follows:
{augmented_persona}
You will stay in-character whenever possible, and generate responses as if you were {bot_name}"""
def augment_persona(self, persona, rag_ids, query_rags):
lines = persona.split("\n")
for rag_id, query_rag in zip(rag_ids, query_rags):
lid = query_rag['lid']
new_text = ""
for id in rag_id:
new_text += "###\n" + self.db.stories[id].strip() + "\n"
new_text = new_text.strip()
lines[lid] = new_text
return "\n".join(lines)
def load_role_from_jsonl( self, role_from_jsonl ):
import json
datas = []
with open(role_from_jsonl, 'r') as f:
for line in f:
try:
datas.append(json.loads(line))
except:
continue
column_name = ""
from .embeddings import embedname2columnname
if self.embed_name in embedname2columnname:
column_name = embedname2columnname[self.embed_name]
else:
print('warning! unkown embedding name ', self.embed_name ,' while loading role')
column_name = 'luotuo_openai'
stories, story_vecs, persona = self.extract_text_vec_from_datas(datas, column_name)
return persona, None, stories, story_vecs
def load_role_from_hf(self, role_from_hf):
# 从hf加载role
# self.persona, new_role_name, self.stories, self.story_vecs = self.load_role_from_hf(role_from_hf)
from datasets import load_dataset
if role_from_hf.count("/") == 1:
dataset = load_dataset(role_from_hf)
datas = dataset["train"]
elif role_from_hf.count("/") >= 2:
split_index = role_from_hf.index('/')
second_split_index = role_from_hf.index('/', split_index+1)
dataset_name = role_from_hf[:second_split_index]
split_name = role_from_hf[second_split_index+1:]
fname = split_name + '.jsonl'
dataset = load_dataset(dataset_name,data_files={'train':fname})
datas = dataset["train"]
column_name = ""
from .embeddings import embedname2columnname
if self.embed_name in embedname2columnname:
column_name = embedname2columnname[self.embed_name]
else:
print('warning! unkown embedding name ', self.embed_name ,' while loading role')
column_name = 'luotuo_openai'
stories, story_vecs, persona = self.extract_text_vec_from_datas(datas, column_name)
return persona, None, stories, story_vecs
def extract_text_vec_from_datas(self, datas, column_name):
# 从datas中提取text和vec
# extract text and vec from huggingface dataset
# return texts, vecs
# from .utils import base64_to_float_array
texts = []
vecs = []
for data in datas:
if data[column_name] == 'system_prompt':
system_prompt = get_text_from_data( data )
elif data[column_name] == 'config':
pass
else:
vec = base64_to_float_array( data[column_name] )
text = get_text_from_data( data )
vecs.append( vec )
texts.append( text )
return texts, vecs, system_prompt
def extract_story_vecs(self, stories):
# 从stories中提取story_vecs
if self.verbose:
print(f"re-extract vector for {len(stories)} stories")
story_vecs = []
from .embeddings import embedshortname2model_name
from .embeddings import device
if device.type != "cpu" and self.embed_name in embedshortname2model_name:
# model_name = "BAAI/bge-small-zh-v1.5"
model_name = embedshortname2model_name[self.embed_name]
from .utils import get_general_embeddings_safe
story_vecs = get_general_embeddings_safe( stories, model_name = model_name )
# 使用batch的方式进行embedding,非常快
else:
from tqdm import tqdm
for story in tqdm(stories):
story_vecs.append(self.embedding(story))
return story_vecs
def build_db(self, stories, story_vecs):
# db的构造函数
if self.db is None:
from .NaiveDB import NaiveDB
self.db = NaiveDB()
self.db.build_db(stories, story_vecs)
|