import chainlit as cl from datasets import load_dataset from langchain_community.document_loaders import CSVLoader from langchain_community.vectorstores.chroma import Chroma from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain_openai import OpenAIEmbeddings #from langchain.embeddings import CacheBackedEmbeddings #from langchain.storage import LocalFileStore #from langchain_community.vectorstores import FAISS #from langchain_core.runnables.base import RunnableSequence from langchain_core.runnables.passthrough import RunnablePassthrough from langchain_core.output_parsers import StrOutputParser from langchain_core.prompts import ChatPromptTemplate from langchain_openai import ChatOpenAI from langchain.schema.runnable import Runnable, RunnablePassthrough, RunnableConfig from langchain.callbacks.base import BaseCallbackHandler from langchain.indexes import SQLRecordManager, index # def setup_data(): # dataset = load_dataset("ShubhamChoksi/IMDB_Movies") # dataset_dict = dataset # dataset_dict["train"].to_csv("imdb.csv") # loader = CSVLoader(file_path="imdb.csv") # data = loader.load() # text_splitter = RecursiveCharacterTextSplitter( # chunk_size=1000, # chunk_overlap=100 # ) # chunked_documents = text_splitter.split_documents(data) # embedding_model = OpenAIEmbeddings() # store = LocalFileStore("./cache/") # embedder = CacheBackedEmbeddings.from_bytes_store(embedding_model, store, namespace=embedding_model.model) # vector_store = FAISS.from_documents(chunked_documents, embedder) # vector_store.save_local("faiss_index") # return vector_store def setup_data(): dataset = load_dataset("ShubhamChoksi/IMDB_Movies") dataset_dict = dataset dataset_dict["train"].to_csv("imdb.csv") loader = CSVLoader(file_path="imdb.csv") data = loader.load() text_splitter = RecursiveCharacterTextSplitter( chunk_size=1000, chunk_overlap=100 ) docs = text_splitter.split_documents(data) # chunked documents embeddings_model = OpenAIEmbeddings() doc_search = Chroma.from_documents(docs, embeddings_model, persist_directory="cache") namespace = "chromadb/my_documents" record_manager = SQLRecordManager( namespace, db_url="sqlite:///record_manager_cache.sql" ) record_manager.create_schema() index_result = index( docs, record_manager, doc_search, cleanup="incremental", source_id_key="source", ) print(f"Indexing stats: {index_result}") return doc_search doc_search = setup_data() model = ChatOpenAI(model_name="gpt-4o", temperature=0, streaming=True) @cl.on_chat_start async def on_chat_start(): template = """Answer the question based only on the following context: {context} Question: {question} """ prompt = ChatPromptTemplate.from_template(template) def format_docs(docs): return "\n\n".join([d.page_content for d in docs]) retriever = doc_search.as_retriever() runnable = ( {"context": retriever | format_docs, "question": RunnablePassthrough()} | prompt | model | StrOutputParser() ) cl.user_session.set("runnable", runnable) @cl.on_message async def on_message(message: cl.Message): runnable = cl.user_session.get("runnable") # type: Runnable msg = cl.Message(content="") class PostMessageHandler(BaseCallbackHandler): """ Callback handler for handling the retriever and LLM processes. Used to post the sources of the retrieved documents as a Chainlit element. """ def __init__(self, msg: cl.Message): BaseCallbackHandler.__init__(self) self.msg = msg self.sources = set() # To store unique pairs def on_retriever_end(self, documents, *, run_id, parent_run_id, **kwargs): for d in documents: source_page_pair = (d.metadata['source'], d.metadata['page']) self.sources.add(source_page_pair) # Add unique pairs to the set def on_llm_end(self, response, *, run_id, parent_run_id, **kwargs): if len(self.sources): sources_text = "\n".join([f"{source}#page={page}" for source, page in self.sources]) self.msg.elements.append( cl.Text(name="Sources", content=sources_text, display="inline") ) async with cl.Step(type="run", name="QA Assistant"): async for chunk in runnable.astream( message.content, config=RunnableConfig(callbacks=[ cl.LangchainCallbackHandler(), PostMessageHandler(msg) ]), ): await msg.stream_token(chunk) await msg.send()