Spaces:
Runtime error
Runtime error
| #!/usr/bin/env python | |
| # -*- coding: utf-8 -*- | |
| # Modified from https://github.com/geekan/MetaGPT/blob/main/metagpt/memory/memory.py | |
| from collections import defaultdict | |
| from typing import Iterable, Type | |
| from autoagents.actions import Action | |
| from autoagents.system.schema import Message | |
| class Memory: | |
| """The most basic memory: super-memory""" | |
| def __init__(self): | |
| """Initialize an empty storage list and an empty index dictionary""" | |
| self.storage: list[Message] = [] | |
| self.index: dict[Type[Action], list[Message]] = defaultdict(list) | |
| def add(self, message: Message): | |
| """Add a new message to storage, while updating the index""" | |
| if message in self.storage: | |
| return | |
| self.storage.append(message) | |
| if message.cause_by: | |
| self.index[message.cause_by].append(message) | |
| def add_batch(self, messages: Iterable[Message]): | |
| for message in messages: | |
| self.add(message) | |
| def get_by_role(self, role: str) -> list[Message]: | |
| """Return all messages of a specified role""" | |
| return [message for message in self.storage if message.role == role] | |
| def get_by_content(self, content: str) -> list[Message]: | |
| """Return all messages containing a specified content""" | |
| return [message for message in self.storage if content in message.content] | |
| def delete(self, message: Message): | |
| """Delete the specified message from storage, while updating the index""" | |
| self.storage.remove(message) | |
| if message.cause_by and message in self.index[message.cause_by]: | |
| self.index[message.cause_by].remove(message) | |
| def clear(self): | |
| """Clear storage and index""" | |
| self.storage = [] | |
| self.index = defaultdict(list) | |
| def count(self) -> int: | |
| """Return the number of messages in storage""" | |
| return len(self.storage) | |
| def try_remember(self, keyword: str) -> list[Message]: | |
| """Try to recall all messages containing a specified keyword""" | |
| return [message for message in self.storage if keyword in message.content] | |
| def get(self, k=0) -> list[Message]: | |
| """Return the most recent k memories, return all when k=0""" | |
| return self.storage[-k:] | |
| def remember(self, observed: list[Message], k=10) -> list[Message]: | |
| """remember the most recent k memories from observed Messages, return all when k=0""" | |
| already_observed = self.get(k) | |
| news: list[Message] = [] | |
| for i in observed: | |
| if i in already_observed: | |
| continue | |
| news.append(i) | |
| return news | |
| def get_by_action(self, action: Type[Action]) -> list[Message]: | |
| """Return all messages triggered by a specified Action""" | |
| return self.index[action] | |
| def get_by_actions(self, actions: Iterable[Type[Action]]) -> list[Message]: | |
| """Return all messages triggered by specified Actions""" | |
| rsp = [] | |
| for action in actions: | |
| if action not in self.index: | |
| continue # return [] | |
| rsp += self.index[action] | |
| return rsp | |
| def get_by_and_actions(self, actions: Iterable[Type[Action]]) -> list[Message]: | |
| """Return all messages triggered by specified Actions""" | |
| rsp = [] | |
| for action in actions: | |
| if action not in self.index: | |
| return [] | |
| rsp += self.index[action] | |
| return rsp | |