File size: 8,067 Bytes
bfb6e70 46d0705 bfb6e70 5a12fca 0b084fa bfb6e70 46d0705 bfb6e70 e5a24c6 bfb6e70 c415e05 0b084fa c415e05 66127a4 fc0b4d8 66127a4 c415e05 46d0705 bfb6e70 e5a24c6 46d0705 e5a24c6 46d0705 c415e05 46d0705 c415e05 46d0705 bfb6e70 46d0705 e5a24c6 c415e05 e5a24c6 ef61dcf e5a24c6 fc0b4d8 e5a24c6 c415e05 5a12fca c415e05 5a12fca c415e05 e5a24c6 5a12fca e5a24c6 bfb6e70 e5a24c6 bfb6e70 b922a7a e5a24c6 bfb6e70 e5a24c6 bfb6e70 e5a24c6 bfb6e70 6244aa1 5a12fca d15fe5b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 |
import os
from typing import Dict, List, Any
import uuid
from copy import deepcopy
from langchain.embeddings import OpenAIEmbeddings
from chromadb import Client as ChromaClient
from aiflows.messages import FlowMessage
from aiflows.base_flows import AtomicFlow
import hydra
import os
from typing import Dict, List, Any
import uuid
from copy import deepcopy
from langchain.embeddings import OpenAIEmbeddings
from aiflows.messages import FlowMessage
from aiflows.base_flows import AtomicFlow
from langchain.text_splitter import CharacterTextSplitter
from langchain.document_loaders import TextLoader
from langchain.vectorstores import Chroma
import hydra
class ChromaDBFlow(AtomicFlow):
""" A flow that uses the ChromaDB model to write and read memories stored in a database
*Configuration Parameters*:
- `name` (str): The name of the flow. Default: "chroma_db"
- `description` (str): A description of the flow. This description is used to generate the help message of the flow.
Default: "ChromaDB is a document store that uses vector embeddings to store and retrieve documents."
- `backend` (Dict[str, Any]): The configuration of the backend which is used to fetch api keys. Default: LiteLLMBackend with the
default parameters of LiteLLMBackend (see aiflows.backends.LiteLLMBackend). Except for the following parameter whose default value is overwritten:
- `api_infos` (List[Dict[str, Any]]): The list of api infos. Default: No default value, this parameter is required.
- `model_name` (str): The name of the model. Default: "". In the current implementation, this parameter is not used.
- `similarity_search_kwargs` (Dict[str, Any]): The parameters to pass to the similarity search method of the ChromaDB. Default:
- `k` (int): The number of documents to retrieve. Default: 2
- `filter` (str): The filter to apply to the documents. Default: null
- `paths_to_data` (List[str]): The paths to the data to store in the database at instantiation. Default: []
- `chunk_size` (int): The size of the chunks to split the documents into. Default: 700
- `seperator` (str): The separator to use to split the documents. Default: "\n"
- `chunk_overlap` (int): The overlap between the chunks. Default: 0
- `persist_directory` (str): The directory to persist the database. Default: "./demo_db_dir"
- Other parameters are inherited from the default configuration of AtomicFlow (see AtomicFlow)
*Input Interface*:
- `operation` (str): The operation to perform. It can be "write" or "read".
- `content` (str or List[str]): The content to write or read. If operation is "write", it must be a string or a list of strings. If operation is "read", it must be a string.
*Output Interface*:
- `retrieved` (str or List[str]): The retrieved content. If operation is "write", it is an empty string. If operation is "read", it is a string or a list of strings.
:param backend: The backend of the flow (used to retrieve the API key)
:type backend: LiteLLMBackend
:param \**kwargs: Additional arguments to pass to the flow.
"""
def __init__(self, backend,**kwargs):
super().__init__(**kwargs)
self.backend = backend
def set_up_flow_state(self):
super().set_up_flow_state()
self.flow_state["db_created"] =False
@classmethod
def _set_up_backend(cls, config):
""" This instantiates the backend of the flow from a configuration file.
:param config: The configuration of the backend.
:type config: Dict[str, Any]
:return: The backend of the flow.
:rtype: Dict[str, LiteLLMBackend]
"""
kwargs = {}
kwargs["backend"] = \
hydra.utils.instantiate(config['backend'], _convert_="partial")
return kwargs
@classmethod
def instantiate_from_config(cls, config):
""" This method instantiates the flow from a configuration file
:param config: The configuration of the flow.
:type config: Dict[str, Any]
:return: The instantiated flow.
:rtype: ChromaDBFlow
"""
flow_config = deepcopy(config)
kwargs = {"flow_config": flow_config}
# ~~~ Set up backend ~~~
kwargs.update(cls._set_up_backend(flow_config))
# ~~~ Instantiate flow ~~~
return cls(**kwargs)
def get_embeddings_model(self):
api_information = self.backend.get_key()
if api_information.backend_used == "openai":
embeddings = OpenAIEmbeddings(openai_api_key=api_information.api_key)
else:
# ToDo: Add support for Azure
embeddings = OpenAIEmbeddings(openai_api_key=os.getenv("OPENAI_API_KEY"))
return embeddings
def get_db(self):
db_created = self.flow_state["db_created"]
if hasattr(self, 'db'):
#do nothing
db = self.db
elif db_created or len(self.flow_config["paths_to_data"]) == 0:
# load from disk
db = Chroma(
persist_directory=self.flow_config["persist_directory"],
embedding_function=self.get_embeddings_model()
)
else:
# create db and save to disk
full_docs = []
text_splitter = CharacterTextSplitter(
chunk_size=self.flow_config["chunk_size"],
chunk_overlap=self.flow_config["chunk_overlap"],
separator=self.flow_config["separator"]
)
for path in self.flow_config["paths_to_data"]:
loader = TextLoader(path)
documents = loader.load()
docs = text_splitter.split_documents(documents)
full_docs.extend(docs)
db = Chroma.from_documents(
full_docs,
self.get_embeddings_model(),
persist_directory=self.flow_config["persist_directory"]
)
self.flow_state["db_created"] = True
return db
def run(self, input_message: FlowMessage):
""" This method runs the flow. It runs the ChromaDBFlow. It either writes or reads memories from the database.
:param input_message: The input message of the flow.
:type input_message: FlowMessage
"""
self.db = self.get_db()
input_data = input_message.data
embeddings = self.get_embeddings_model()
response = {}
operation = input_data["operation"]
if operation not in ["write", "read"]:
raise ValueError(f"Operation '{operation}' not supported")
content = input_data["content"]
if operation == "read":
if not isinstance(content, str):
raise ValueError(f"content(query) must be a string during read, got {type(content)}: {content}")
if content == "":
response["retrieved"] = [[""]]
else:
query = content
query_result = self.db.similarity_search(query, **self.flow_config["similarity_search_kwargs"])
response["retrieved"] = [doc.page_content for doc in query_result]
elif operation == "write":
if content != "":
if not isinstance(content, list):
content = [content]
documents = content
self.db._collection.add(
ids=[str(uuid.uuid4()) for _ in range(len(documents))],
embeddings=embeddings.embed_documents(documents),
documents=documents
)
response["retrieved"] = ""
reply = self.package_output_message(
input_message = input_message,
response = response
)
self.send_message(reply)
|