xke
try Chromadb version
780f5aa
raw
history blame
4.8 kB
import chainlit as cl
from datasets import load_dataset
from langchain_community.document_loaders import CSVLoader
from langchain_community.vectorstores.chroma import Chroma
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_openai import OpenAIEmbeddings
#from langchain.embeddings import CacheBackedEmbeddings
#from langchain.storage import LocalFileStore
#from langchain_community.vectorstores import FAISS
#from langchain_core.runnables.base import RunnableSequence
from langchain_core.runnables.passthrough import RunnablePassthrough
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
from langchain_openai import ChatOpenAI
from langchain.schema.runnable import Runnable, RunnablePassthrough, RunnableConfig
from langchain.callbacks.base import BaseCallbackHandler
from langchain.indexes import SQLRecordManager, index
# def setup_data():
# dataset = load_dataset("ShubhamChoksi/IMDB_Movies")
# dataset_dict = dataset
# dataset_dict["train"].to_csv("imdb.csv")
# loader = CSVLoader(file_path="imdb.csv")
# data = loader.load()
# text_splitter = RecursiveCharacterTextSplitter(
# chunk_size=1000,
# chunk_overlap=100
# )
# chunked_documents = text_splitter.split_documents(data)
# embedding_model = OpenAIEmbeddings()
# store = LocalFileStore("./cache/")
# embedder = CacheBackedEmbeddings.from_bytes_store(embedding_model, store, namespace=embedding_model.model)
# vector_store = FAISS.from_documents(chunked_documents, embedder)
# vector_store.save_local("faiss_index")
# return vector_store
def setup_data():
dataset = load_dataset("ShubhamChoksi/IMDB_Movies")
dataset_dict = dataset
dataset_dict["train"].to_csv("imdb.csv")
loader = CSVLoader(file_path="imdb.csv")
data = loader.load()
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=1000,
chunk_overlap=100
)
docs = text_splitter.split_documents(data) # chunked documents
embeddings_model = OpenAIEmbeddings()
doc_search = Chroma.from_documents(docs, embeddings_model)
namespace = "chromadb/my_documents"
record_manager = SQLRecordManager(
namespace, db_url="sqlite:///record_manager_cache.sql"
)
record_manager.create_schema()
index_result = index(
docs,
record_manager,
doc_search,
cleanup="incremental",
source_id_key="source",
)
print(f"Indexing stats: {index_result}")
return doc_search
doc_search = setup_data()
model = ChatOpenAI(model_name="gpt-4o", temperature=0, streaming=True)
@cl.on_chat_start
async def on_chat_start():
template = """Answer the question based only on the following context:
{context}
Question: {question}
"""
prompt = ChatPromptTemplate.from_template(template)
def format_docs(docs):
return "\n\n".join([d.page_content for d in docs])
retriever = doc_search.as_retriever()
runnable = (
{"context": retriever | format_docs, "question": RunnablePassthrough()}
| prompt
| model
| StrOutputParser()
)
cl.user_session.set("runnable", runnable)
@cl.on_message
async def on_message(message: cl.Message):
runnable = cl.user_session.get("runnable") # type: Runnable
msg = cl.Message(content="")
class PostMessageHandler(BaseCallbackHandler):
"""
Callback handler for handling the retriever and LLM processes.
Used to post the sources of the retrieved documents as a Chainlit element.
"""
def __init__(self, msg: cl.Message):
BaseCallbackHandler.__init__(self)
self.msg = msg
self.sources = set() # To store unique pairs
def on_retriever_end(self, documents, *, run_id, parent_run_id, **kwargs):
for d in documents:
source_page_pair = (d.metadata['source'], d.metadata['page'])
self.sources.add(source_page_pair) # Add unique pairs to the set
def on_llm_end(self, response, *, run_id, parent_run_id, **kwargs):
if len(self.sources):
sources_text = "\n".join([f"{source}#page={page}" for source, page in self.sources])
self.msg.elements.append(
cl.Text(name="Sources", content=sources_text, display="inline")
)
async with cl.Step(type="run", name="QA Assistant"):
async for chunk in runnable.astream(
message.content,
config=RunnableConfig(callbacks=[
cl.LangchainCallbackHandler(),
PostMessageHandler(msg)
]),
):
await msg.stream_token(chunk)
await msg.send()