|
import os |
|
import chainlit as cl |
|
from datasets import load_dataset |
|
from langchain.embeddings import CacheBackedEmbeddings |
|
from langchain.storage import LocalFileStore |
|
from langchain.text_splitter import RecursiveCharacterTextSplitter |
|
from langchain_core.runnables import RunnableConfig |
|
from langchain_core.runnables.passthrough import RunnablePassthrough |
|
from langchain_core.output_parsers import StrOutputParser |
|
from langchain_core.prompts import ChatPromptTemplate |
|
from langchain_community.document_loaders.csv_loader import CSVLoader |
|
from langchain_community.vectorstores import FAISS |
|
from langchain_openai import OpenAIEmbeddings |
|
from langchain_openai import ChatOpenAI |
|
|
|
|
|
imdb_csv_file="./imdb.csv" |
|
|
|
if not os.path.exists(imdb_csv_file): |
|
dataset = load_dataset('ShubhamChoksi/IMDB_Movies') |
|
dataset_dict = dataset |
|
dataset_dict["train"].to_csv('imdb.csv') |
|
|
|
loader = CSVLoader(file_path=imdb_csv_file) |
|
data = loader.load() |
|
|
|
|
|
text_splitter = RecursiveCharacterTextSplitter( |
|
chunk_size=1000, chunk_overlap=100 |
|
) |
|
chunked_documents = text_splitter.split_documents(data) |
|
|
|
openai_api_key = os.getenv("openai_api_key") |
|
embedding_model = OpenAIEmbeddings(model="text-embedding-3-small", openai_api_key=openai_api_key) |
|
|
|
store = LocalFileStore("./cache/") |
|
cached_embedder = CacheBackedEmbeddings.from_bytes_store(embedding_model, store, namespace=embedding_model.model) |
|
|
|
vector_file = "local_vector" |
|
|
|
if not os.path.exists(vector_file): |
|
vector_store = FAISS.from_documents(chunked_documents, cached_embedder) |
|
vector_store.save_local(vector_file) |
|
else: |
|
vector_store = FAISS.load_local(vector_file, cached_embedder, allow_dangerous_deserialization=True) |
|
|
|
@cl.on_chat_start |
|
async def on_chat_start(): |
|
prompt_template = ChatPromptTemplate.from_template( |
|
"You are a movie recommendation system, for a given {query} find recommendations from {content}." |
|
) |
|
retriever = vector_store.as_retriever() |
|
chat_model = ChatOpenAI(model="gpt-4o", temperature=0.2, openai_api_key=openai_api_key, streaming=True) |
|
parser = StrOutputParser() |
|
|
|
runnable_chain = ( |
|
{"query": RunnablePassthrough(), "content": retriever} |
|
| prompt_template |
|
| chat_model |
|
| parser |
|
) |
|
cl.user_session.set("chain", runnable_chain) |
|
|
|
|
|
@cl.on_message |
|
async def main(message): |
|
chain = cl.user_session.get("chain") |
|
user_input = cl.Message(content="") |
|
|
|
async for stream in chain.astream(message.content, config=RunnableConfig(callbacks=[cl.LangchainCallbackHandler()])): |
|
await user_input.stream_token(stream) |