import os import chainlit as cl from datasets import load_dataset from langchain.embeddings import CacheBackedEmbeddings from langchain.storage import LocalFileStore from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain_core.runnables import RunnableConfig from langchain_core.runnables.passthrough import RunnablePassthrough from langchain_core.output_parsers import StrOutputParser from langchain_core.prompts import ChatPromptTemplate from langchain_community.document_loaders.csv_loader import CSVLoader from langchain_community.vectorstores import FAISS from langchain_openai import OpenAIEmbeddings from langchain_openai import ChatOpenAI imdb_csv_file="./imdb.csv" if not os.path.exists(imdb_csv_file): dataset = load_dataset('ShubhamChoksi/IMDB_Movies') dataset_dict = dataset dataset_dict["train"].to_csv('imdb.csv') loader = CSVLoader(file_path=imdb_csv_file) data = loader.load() text_splitter = RecursiveCharacterTextSplitter( chunk_size=1000, chunk_overlap=100 ) chunked_documents = text_splitter.split_documents(data) openai_api_key = os.getenv("openai_api_key") embedding_model = OpenAIEmbeddings(model="text-embedding-3-small", openai_api_key=openai_api_key) store = LocalFileStore("./cache/") cached_embedder = CacheBackedEmbeddings.from_bytes_store(embedding_model, store, namespace=embedding_model.model) vector_file = "local_vector" if not os.path.exists(vector_file): vector_store = FAISS.from_documents(chunked_documents, cached_embedder) vector_store.save_local(vector_file) else: vector_store = FAISS.load_local(vector_file, cached_embedder, allow_dangerous_deserialization=True) @cl.on_chat_start async def on_chat_start(): prompt_template = ChatPromptTemplate.from_template( "You are a movie recommendation system, for a given {query} find recommendations from {content}." ) retriever = vector_store.as_retriever() chat_model = ChatOpenAI(model="gpt-4o", temperature=0.2, openai_api_key=openai_api_key, streaming=True) parser = StrOutputParser() runnable_chain = ( {"query": RunnablePassthrough(), "content": retriever} | prompt_template | chat_model | parser ) cl.user_session.set("chain", runnable_chain) @cl.on_message async def main(message): chain = cl.user_session.get("chain") user_input = cl.Message(content="") async for stream in chain.astream(message.content, config=RunnableConfig(callbacks=[cl.LangchainCallbackHandler()])): await user_input.stream_token(stream)