from datasets import load_dataset import pandas as pd import os from langchain.document_loaders import CSVLoader from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain_core.callbacks import BaseCallbackHandler from langchain_core.output_parsers import StrOutputParser from langchain_core.prompts import ChatPromptTemplate from langchain_core.runnables import RunnablePassthrough, RunnableConfig from langchain_openai import OpenAIEmbeddings, ChatOpenAI from langchain.embeddings import CacheBackedEmbeddings from langchain.storage import LocalFileStore from langchain_community.vectorstores import FAISS import chainlit as cl csv_path = "./imdb.csv" if not os.path.exists(csv_path): dataset = load_dataset("ShubhamChoksi/IMDB_Movies") # Convert to pandas DataFrame (assuming 'train' split) imdb_train = dataset['train'] df = pd.DataFrame(imdb_train) # Save to a CSV file csv_file_path = 'imdb_movies.csv' df.to_csv(csv_file_path, index=False) loader = CSVLoader(csv_file_path) data = loader.load() # Initialize the text splitter for 1000 char chunk size and 100 char overlap text_splitter = RecursiveCharacterTextSplitter( chunk_size=1000, chunk_overlap=100 ) documents = text_splitter.split_documents(data) # Initialize the OpenAIEmbeddings object #embedding_model = OpenAIEmbeddings(model="text-embedding-ada-002") embedding_model = OpenAIEmbeddings() store = LocalFileStore("./cache/") embedder = CacheBackedEmbeddings.from_bytes_store( embedding_model, store, namespace=embedding_model.model ) # Attempt to load the FAISS index vector_store = None faiss_index_path = "faiss_index" try: vector_store = FAISS.load_local(faiss_index_path, embedding_model, allow_dangerous_deserialization=True) except Exception as e: vector_store = FAISS.from_documents(documents, embedding_model) vector_store.save_local("faiss_index") #### APP ### @cl.on_chat_start async def on_chat_start(): # Define the Prompt Template prompt_text = """ You are an AI assistant that loves to suggest movies for people to watch. You have access to the following context: {documents} Answer the following question. If you cannot answer at all just say "Sorry! I couldn't come up with any movies for this question!" Question: {question} """ prompt_template = ChatPromptTemplate.from_template(prompt_text) retriever = vector_store.as_retriever() chat_model = ChatOpenAI(model="gpt-4o", temperature=0, streaming=True) parser = StrOutputParser() runnable_chain = ( {"documents": retriever, "question": RunnablePassthrough()} | prompt_template | chat_model | parser ) cl.user_session.set("runnable_chain", runnable_chain) # @cl.on_message # async def on_message(message: cl.Message): # runnable_chain = cl.user_session.get("runnable_chain") # # msg = cl.Message(content="") # # async for chunk in runnable_chain.astream( # message.content, # config=RunnableConfig(callbacks=[cl.LangchainCallbackHandler()]), # ): # await msg.stream_token(chunk) # # await msg.send() @cl.on_message async def on_message(message: cl.Message): runnable = cl.user_session.get("runnable_chain") msg = cl.Message(content="") # class PostMessageHandler(BaseCallbackHandler): # """ # Callback handler for handling the retriever and LLM processes. # Used to post the sources of the retrieved documents as a Chainlit element. # """ # # def __init__(self, msg: cl.Message): # BaseCallbackHandler.__init__(self) # self.msg = msg # self.sources = set() # To store unique pairs # # def on_retriever_end(self, documents, *, run_id, parent_run_id, **kwargs): # for d in documents: # source_page_pair = (d.metadata['source'], d.metadata['page']) # self.sources.add(source_page_pair) # Add unique pairs to the set # # def on_llm_end(self, response, *, run_id, parent_run_id, **kwargs): # if len(self.sources): # sources_text = "\n".join([f"{source}#page={page}" for source, page in self.sources]) # self.msg.elements.append( # cl.Text(name="Sources", content=sources_text, display="inline") # ) async with cl.Step(type="run", name="Movie Assistant"): async for chunk in runnable.astream( message.content, config=RunnableConfig(callbacks=[ cl.LangchainCallbackHandler(), # PostMessageHandler(msg) ]), ): await msg.stream_token(chunk) await msg.send()