santhoshs's picture
Update app.py
b2de00c verified
import os
import chainlit as cl
from datasets import load_dataset
from langchain.embeddings import CacheBackedEmbeddings
from langchain.storage import LocalFileStore
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_core.runnables import RunnableConfig
from langchain_core.runnables.passthrough import RunnablePassthrough
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
from langchain_community.document_loaders.csv_loader import CSVLoader
from langchain_community.vectorstores import FAISS
from langchain_openai import OpenAIEmbeddings
from langchain_openai import ChatOpenAI
imdb_csv_file="./imdb.csv"
if not os.path.exists(imdb_csv_file):
dataset = load_dataset('ShubhamChoksi/IMDB_Movies')
dataset_dict = dataset
dataset_dict["train"].to_csv('imdb.csv')
loader = CSVLoader(file_path=imdb_csv_file)
data = loader.load()
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=1000, chunk_overlap=100
)
chunked_documents = text_splitter.split_documents(data)
openai_api_key = os.getenv("openai_api_key")
embedding_model = OpenAIEmbeddings(model="text-embedding-3-small", openai_api_key=openai_api_key)
store = LocalFileStore("./cache/")
cached_embedder = CacheBackedEmbeddings.from_bytes_store(embedding_model, store, namespace=embedding_model.model)
vector_file = "local_vector"
if not os.path.exists(vector_file):
vector_store = FAISS.from_documents(chunked_documents, cached_embedder)
vector_store.save_local(vector_file)
else:
vector_store = FAISS.load_local(vector_file, cached_embedder, allow_dangerous_deserialization=True)
@cl.on_chat_start
async def on_chat_start():
prompt_template = ChatPromptTemplate.from_template(
"You are a movie recommendation system, for a given {query} find recommendations from {content}."
)
retriever = vector_store.as_retriever()
chat_model = ChatOpenAI(model="gpt-4o", temperature=0.2, openai_api_key=openai_api_key, streaming=True)
parser = StrOutputParser()
runnable_chain = (
{"query": RunnablePassthrough(), "content": retriever}
| prompt_template
| chat_model
| parser
)
cl.user_session.set("chain", runnable_chain)
@cl.on_message
async def main(message):
chain = cl.user_session.get("chain")
user_input = cl.Message(content="")
async for stream in chain.astream(message.content, config=RunnableConfig(callbacks=[cl.LangchainCallbackHandler()])):
await user_input.stream_token(stream)