File size: 8,067 Bytes
bfb6e70
 
 
 
46d0705
bfb6e70
 
 
5a12fca
0b084fa
bfb6e70
46d0705
bfb6e70
e5a24c6
 
 
 
 
 
 
 
 
 
 
 
 
 
bfb6e70
c415e05
 
 
 
 
 
 
 
0b084fa
c415e05
 
66127a4
 
 
 
 
fc0b4d8
66127a4
 
 
c415e05
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46d0705
bfb6e70
e5a24c6
46d0705
 
e5a24c6
 
 
 
46d0705
 
c415e05
 
 
 
 
 
 
46d0705
 
 
 
 
 
 
 
 
c415e05
 
 
 
 
 
 
46d0705
 
 
 
 
 
bfb6e70
46d0705
 
 
e5a24c6
 
 
 
 
 
 
 
 
 
 
 
 
c415e05
e5a24c6
 
 
 
ef61dcf
e5a24c6
 
 
 
 
 
 
 
 
 
fc0b4d8
 
e5a24c6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c415e05
5a12fca
c415e05
 
5a12fca
 
c415e05
e5a24c6
 
 
5a12fca
e5a24c6
 
 
bfb6e70
 
 
 
 
 
 
e5a24c6
bfb6e70
 
 
 
 
b922a7a
 
e5a24c6
 
 
bfb6e70
 
 
 
 
 
e5a24c6
bfb6e70
 
 
 
e5a24c6
bfb6e70
 
6244aa1
5a12fca
 
 
d15fe5b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
import os
from typing import Dict, List, Any

import uuid
from copy import deepcopy
from langchain.embeddings import OpenAIEmbeddings

from chromadb import Client as ChromaClient
from aiflows.messages import FlowMessage
from aiflows.base_flows import AtomicFlow

import hydra

import os
from typing import Dict, List, Any

import uuid
from copy import deepcopy
from langchain.embeddings import OpenAIEmbeddings

from aiflows.messages import FlowMessage
from aiflows.base_flows import AtomicFlow
from langchain.text_splitter import CharacterTextSplitter
from langchain.document_loaders import TextLoader
from langchain.vectorstores import Chroma
import hydra

class ChromaDBFlow(AtomicFlow):
    """ A flow that uses the ChromaDB model to write and read memories stored in a database
    
    *Configuration Parameters*:
    
    - `name` (str): The name of the flow. Default: "chroma_db"
    - `description` (str): A description of the flow. This description is used to generate the help message of the flow. 
    Default: "ChromaDB is a document store that uses vector embeddings to store and retrieve documents."
    - `backend` (Dict[str, Any]): The configuration of the backend which is used to fetch api keys. Default: LiteLLMBackend with the
    default parameters of LiteLLMBackend (see aiflows.backends.LiteLLMBackend). Except for the following parameter whose default value is overwritten:
        - `api_infos` (List[Dict[str, Any]]): The list of api infos. Default: No default value, this parameter is required.
        - `model_name` (str): The name of the model. Default: "". In the current implementation, this parameter is not used.
    - `similarity_search_kwargs` (Dict[str, Any]): The parameters to pass to the similarity search method of the ChromaDB. Default:
        - `k` (int): The number of documents to retrieve. Default: 2
        - `filter` (str): The filter to apply to the documents. Default: null
    - `paths_to_data` (List[str]): The paths to the data to store in the database at instantiation. Default: []
    - `chunk_size` (int): The size of the chunks to split the documents into. Default: 700
    - `seperator` (str): The separator to use to split the documents. Default: "\n"
    - `chunk_overlap` (int): The overlap between the chunks. Default: 0
    - `persist_directory` (str): The directory to persist the database. Default: "./demo_db_dir"
    
    - Other parameters are inherited from the default configuration of AtomicFlow (see AtomicFlow)
    
    *Input Interface*:
    
    - `operation` (str): The operation to perform. It can be "write" or "read".
    - `content` (str or List[str]): The content to write or read. If operation is "write", it must be a string or a list of strings. If operation is "read", it must be a string.
        
    *Output Interface*:
    
    - `retrieved` (str or List[str]): The retrieved content. If operation is "write", it is an empty string. If operation is "read", it is a string or a list of strings.
    
    :param backend: The backend of the flow (used to retrieve the API key)
    :type backend: LiteLLMBackend
    :param \**kwargs: Additional arguments to pass to the flow.
    """
    def __init__(self, backend,**kwargs):
        super().__init__(**kwargs)
        
        self.backend = backend

    def set_up_flow_state(self):
        super().set_up_flow_state()
        self.flow_state["db_created"] =False
    
    @classmethod
    def _set_up_backend(cls, config):
        """ This instantiates the backend of the flow from a configuration file.
        
        :param config: The configuration of the backend.
        :type config: Dict[str, Any]
        :return: The backend of the flow.
        :rtype: Dict[str, LiteLLMBackend]
        """
        kwargs = {}

        kwargs["backend"] = \
            hydra.utils.instantiate(config['backend'], _convert_="partial")
        
        return kwargs
    
    @classmethod
    def instantiate_from_config(cls, config):
        """ This method instantiates the flow from a configuration file
        
        :param config: The configuration of the flow.
        :type config: Dict[str, Any]
        :return: The instantiated flow.
        :rtype: ChromaDBFlow
        """
        flow_config = deepcopy(config)

        kwargs = {"flow_config": flow_config}

        # ~~~ Set up backend ~~~
        kwargs.update(cls._set_up_backend(flow_config))

        # ~~~ Instantiate flow ~~~
        return cls(**kwargs)
    
    
    def get_embeddings_model(self):
        api_information = self.backend.get_key()
        if api_information.backend_used == "openai":
            embeddings = OpenAIEmbeddings(openai_api_key=api_information.api_key)
        else:
            # ToDo: Add support for Azure
            embeddings = OpenAIEmbeddings(openai_api_key=os.getenv("OPENAI_API_KEY"))
        return embeddings
    
    
    def get_db(self):
        db_created = self.flow_state["db_created"]
        
        if hasattr(self, 'db'):
            #do nothing
            db = self.db
        
        elif db_created or len(self.flow_config["paths_to_data"]) == 0:
             # load from disk
            db = Chroma(
                persist_directory=self.flow_config["persist_directory"],
                embedding_function=self.get_embeddings_model()
            )
        else:
            # create db and save to disk
            full_docs = []
            text_splitter = CharacterTextSplitter(
                chunk_size=self.flow_config["chunk_size"],
                chunk_overlap=self.flow_config["chunk_overlap"],
                separator=self.flow_config["separator"]
            )
            
            for path in self.flow_config["paths_to_data"]:
                loader =  TextLoader(path)
                documents = loader.load()
                docs = text_splitter.split_documents(documents)
                full_docs.extend(docs)
                
            db = Chroma.from_documents(
                full_docs,
                self.get_embeddings_model(),
                persist_directory=self.flow_config["persist_directory"]
            )
            
        self.flow_state["db_created"] = True
        return db
        
    def run(self, input_message: FlowMessage):
        """ This method runs the flow. It runs the ChromaDBFlow. It either writes or reads memories from the database.
        
        :param input_message: The input message of the flow.
        :type input_message: FlowMessage
        """
        
        self.db = self.get_db()
        
        input_data = input_message.data
        
        embeddings = self.get_embeddings_model()
        
        response = {}

        operation = input_data["operation"]
        if operation not in ["write", "read"]:
            raise ValueError(f"Operation '{operation}' not supported")

        content = input_data["content"]
        
        if operation == "read":
            if not isinstance(content, str):
                raise ValueError(f"content(query) must be a string during read, got {type(content)}: {content}")
            if content == "":
                response["retrieved"] = [[""]]
            else:
                query = content
                query_result = self.db.similarity_search(query, **self.flow_config["similarity_search_kwargs"])
        
                response["retrieved"] = [doc.page_content for doc in query_result]

        elif operation == "write":
            if content != "":
                if not isinstance(content, list):
                    content = [content]
                documents = content
                self.db._collection.add(
                    ids=[str(uuid.uuid4()) for _ in range(len(documents))],
                    embeddings=embeddings.embed_documents(documents),
                    documents=documents
                )
                
            response["retrieved"] = ""

        reply = self.package_output_message(
            input_message = input_message,
            response = response
        )
        self.send_message(reply)