VectorStoreFlowModule / ChromaDBFlow.py
martinjosifoski's picture
First commit.
bfb6e70
raw
history blame
2.34 kB
import os
from typing import Dict, List, Any
import uuid
from langchain.embeddings import OpenAIEmbeddings
from chromadb import Client as ChromaClient
from flows.base_flows import AtomicFlow
class ChromaDBFlow(AtomicFlow):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.client = ChromaClient()
self.collection = self.client.get_or_create_collection(name=self.flow_config["name"])
def get_input_keys(self) -> List[str]:
return self.flow_config["input_keys"]
def get_output_keys(self) -> List[str]:
return self.flow_config["output_keys"]
def run(self, input_data: Dict[str, Any]) -> Dict[str, Any]:
api_information = self._get_from_state("api_information")
if api_information.backend_used == "openai":
embeddings = OpenAIEmbeddings(openai_api_key=api_information.api_key)
else:
# ToDo: Add support for Azure
embeddings = OpenAIEmbeddings(openai_api_key=os.getenv("OPENAI_API_KEY"))
response = {}
operation = input_data["operation"]
if operation not in ["write", "read"]:
raise ValueError(f"Operation '{operation}' not supported")
content = input_data["content"]
if operation == "read":
if not isinstance(content, str):
raise ValueError(f"content(query) must be a string during read, got {type(content)}: {content}")
if content == "":
response["retrieved"] = [[""]]
return response
query = content
query_result = self.collection.query(
query_embeddings=embeddings.embed_query(query),
n_results=self.flow_config["n_results"]
)
response["retrieved"] = [doc for doc in query_result["documents"]]
elif operation == "write":
if content != "":
if not isinstance(content, list):
content = [content]
documents = content
self.collection.add(
ids=[str(uuid.uuid4()) for _ in range(len(documents))],
embeddings=embeddings.embed_documents(documents),
documents=documents
)
response["retrieved"] = ""
return response