QueFilme / app.py
thigobr's picture
Another try
2d55914
from datasets import load_dataset
import pandas as pd
import os
from langchain.document_loaders import CSVLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_core.callbacks import BaseCallbackHandler
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnablePassthrough, RunnableConfig
from langchain_openai import OpenAIEmbeddings, ChatOpenAI
from langchain.embeddings import CacheBackedEmbeddings
from langchain.storage import LocalFileStore
from langchain_community.vectorstores import FAISS
import chainlit as cl
csv_path = "./imdb.csv"
if not os.path.exists(csv_path):
dataset = load_dataset("ShubhamChoksi/IMDB_Movies")
# Convert to pandas DataFrame (assuming 'train' split)
imdb_train = dataset['train']
df = pd.DataFrame(imdb_train)
# Save to a CSV file
csv_file_path = 'imdb_movies.csv'
df.to_csv(csv_file_path, index=False)
loader = CSVLoader(csv_file_path)
data = loader.load()
# Initialize the text splitter for 1000 char chunk size and 100 char overlap
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=1000,
chunk_overlap=100
)
documents = text_splitter.split_documents(data)
# Initialize the OpenAIEmbeddings object
#embedding_model = OpenAIEmbeddings(model="text-embedding-ada-002")
embedding_model = OpenAIEmbeddings()
store = LocalFileStore("./cache/")
embedder = CacheBackedEmbeddings.from_bytes_store(
embedding_model, store, namespace=embedding_model.model
)
# Attempt to load the FAISS index
vector_store = None
faiss_index_path = "faiss_index"
try:
vector_store = FAISS.load_local(faiss_index_path, embedding_model, allow_dangerous_deserialization=True)
except Exception as e:
vector_store = FAISS.from_documents(documents, embedding_model)
vector_store.save_local("faiss_index")
#### APP ###
@cl.on_chat_start
async def on_chat_start():
# Define the Prompt Template
prompt_text = """
You are an AI assistant that loves to suggest movies for people to watch. You have access to the following context:
{documents}
Answer the following question. If you cannot answer at all just say "Sorry! I couldn't come up with any movies for this question!"
Question: {question}
"""
prompt_template = ChatPromptTemplate.from_template(prompt_text)
retriever = vector_store.as_retriever()
chat_model = ChatOpenAI(model="gpt-4o", temperature=0, streaming=True)
parser = StrOutputParser()
runnable_chain = (
{"documents": retriever, "question": RunnablePassthrough()}
| prompt_template
| chat_model
| parser
)
cl.user_session.set("runnable_chain", runnable_chain)
# @cl.on_message
# async def on_message(message: cl.Message):
# runnable_chain = cl.user_session.get("runnable_chain")
#
# msg = cl.Message(content="")
#
# async for chunk in runnable_chain.astream(
# message.content,
# config=RunnableConfig(callbacks=[cl.LangchainCallbackHandler()]),
# ):
# await msg.stream_token(chunk)
#
# await msg.send()
@cl.on_message
async def on_message(message: cl.Message):
runnable = cl.user_session.get("runnable_chain")
msg = cl.Message(content="")
# class PostMessageHandler(BaseCallbackHandler):
# """
# Callback handler for handling the retriever and LLM processes.
# Used to post the sources of the retrieved documents as a Chainlit element.
# """
#
# def __init__(self, msg: cl.Message):
# BaseCallbackHandler.__init__(self)
# self.msg = msg
# self.sources = set() # To store unique pairs
#
# def on_retriever_end(self, documents, *, run_id, parent_run_id, **kwargs):
# for d in documents:
# source_page_pair = (d.metadata['source'], d.metadata['page'])
# self.sources.add(source_page_pair) # Add unique pairs to the set
#
# def on_llm_end(self, response, *, run_id, parent_run_id, **kwargs):
# if len(self.sources):
# sources_text = "\n".join([f"{source}#page={page}" for source, page in self.sources])
# self.msg.elements.append(
# cl.Text(name="Sources", content=sources_text, display="inline")
# )
async with cl.Step(type="run", name="Movie Assistant"):
async for chunk in runnable.astream(
message.content,
config=RunnableConfig(callbacks=[
cl.LangchainCallbackHandler(),
# PostMessageHandler(msg)
]),
):
await msg.stream_token(chunk)
await msg.send()