Spaces:
Runtime error
Runtime error
from tqdm import tqdm | |
from util import float_array_to_base64, base64_to_float_array | |
from util import get_bge_embedding_zh | |
import json | |
import torch | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
# compute cosine similarity between two vector | |
def get_cosine_similarity( v1, v2): | |
v1 = torch.tensor(v1).to(device) | |
v2 = torch.tensor(v2).to(device) | |
return torch.cosine_similarity(v1, v2, dim=0).item() | |
class MemoryPool: | |
def __init__(self): | |
self.memories = {} | |
self.diff_threshold = 20 | |
self.top_k = 7 | |
self.set_embedding( get_bge_embedding_zh ) | |
def set_embedding( self, embedding ): | |
self.embedding = embedding | |
def load_from_events( self, events ): | |
for event in tqdm( events ): | |
if len(event["options"])>0: | |
text, emoji = event.most_neutral_output() | |
else: | |
text = event["prefix"] | |
emoji = event["prefix_emoji"] | |
embedding = self.embedding( text ) | |
condition = event["condition"] | |
if condition is None: | |
memory_attribute = ("Stress", 10 ) | |
else: | |
memory_attribute = (condition[0],(condition[1]+ condition[2])//2 ) | |
name = event["name"] | |
memory = { | |
"name": name, | |
"text": text, | |
"embedding": embedding, | |
"memory_attribute": memory_attribute, | |
"emoji": emoji # TODO | |
} | |
self.memories[ name ] = memory | |
# 我希望为这个类进一步实现save和load函数,save函数可以将memories中的每一个value对应的dict,存储到一个jsonl中,load函数可以读取回来。注意编码都要使用utf-8, ensure_ascii = False | |
# 我希望修改save和load函数 | |
# 其中memory中会有embedding字段 | |
# from util import float_array_to_base64 | |
# from util import base64_to_float_array | |
# 我希望在save的时候,把embedding字段用float_array_to_base64替换为base64字符串,并且字段改名为bge_zh_base64 | |
# 在load的时候再把bge_zh_base64字段用base64_to_float_array,解码为embedding | |
def save(self, file_name): | |
""" | |
Save the memories dictionary to a jsonl file, converting | |
'embedding' to a base64 string. | |
""" | |
with open(file_name, 'w', encoding='utf-8') as file: | |
for memory in tqdm(self.memories.values()): | |
# Convert embedding to base64 | |
if 'embedding' in memory: | |
memory['bge_zh_base64'] = float_array_to_base64(memory['embedding']) | |
del memory['embedding'] # Remove the original embedding field | |
json_record = json.dumps(memory, ensure_ascii=False) | |
file.write(json_record + '\n') | |
def load(self, file_name): | |
""" | |
Load memories from a jsonl file into the memories dictionary, | |
converting 'bge_zh_base64' back to an embedding. | |
""" | |
with open(file_name, 'r', encoding='utf-8') as file: | |
for line in tqdm(file): | |
memory = json.loads(line.strip()) | |
# Decode base64 to embedding | |
if 'bge_zh_base64' in memory: | |
memory['embedding'] = base64_to_float_array(memory['bge_zh_base64']) | |
del memory['bge_zh_base64'] # Remove the base64 field | |
self.memories[memory['name']] = memory | |
def change_memory( self, memory_name , new_text , new_emoji = None): | |
if memory_name in self.memories: | |
memory = self.memories[memory_name] | |
memory["text"] = new_text | |
memory["embedding"] = self.embedding( new_text ) | |
if new_emoji: | |
memory["emoji"] = new_emoji | |
def retrieve( self, agent, query_text ): | |
query_embedding = self.embedding( query_text ) | |
valid_events = [] | |
# filter valid memory | |
for key in self.memories: | |
memory = self.memories[key] | |
attribute, value = memory["memory_attribute"] | |
if abs(agent[attribute] - value) <= self.diff_threshold: | |
# valid memory | |
simlarity = get_cosine_similarity(query_embedding, memory["embedding"]) | |
valid_events.append((simlarity, key) ) | |
# 我希望进一步将valid_events根据similarity的值从大到小排序 | |
# Sort the valid events based on similarity in descending order | |
valid_events.sort(key=lambda x: x[0], reverse=True) | |
result = [] | |
for _,key in valid_events: | |
result.append(self.memories[key]) | |
if len(result)>=self.top_k: | |
break | |
return result |