File size: 2,564 Bytes
65ff7bb
 
51eae94
 
 
1f0895f
6741d3b
1f0895f
 
 
 
51eae94
1f0895f
 
51eae94
65ff7bb
 
 
 
 
 
 
51eae94
b2de00c
51eae94
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65ff7bb
8c9c81b
65ff7bb
 
 
e1101db
51eae94
65ff7bb
 
d9bb31c
 
 
 
8858192
d9bb31c
 
 
 
 
 
8858192
d9bb31c
65ff7bb
 
 
 
 
 
 
 
65acde3
7d401ad
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import os
import chainlit as cl
from datasets import load_dataset
from langchain.embeddings import CacheBackedEmbeddings
from langchain.storage import LocalFileStore
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_core.runnables import RunnableConfig
from langchain_core.runnables.passthrough import RunnablePassthrough
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
from langchain_community.document_loaders.csv_loader import CSVLoader
from langchain_community.vectorstores import FAISS
from langchain_openai import OpenAIEmbeddings
from langchain_openai import ChatOpenAI


imdb_csv_file="./imdb.csv"

if not os.path.exists(imdb_csv_file):
    dataset = load_dataset('ShubhamChoksi/IMDB_Movies')
    dataset_dict = dataset
    dataset_dict["train"].to_csv('imdb.csv')

loader = CSVLoader(file_path=imdb_csv_file)
data = loader.load()


text_splitter = RecursiveCharacterTextSplitter(
    chunk_size=1000, chunk_overlap=100
)
chunked_documents = text_splitter.split_documents(data)

openai_api_key = os.getenv("openai_api_key")
embedding_model = OpenAIEmbeddings(model="text-embedding-3-small", openai_api_key=openai_api_key)

store = LocalFileStore("./cache/")
cached_embedder = CacheBackedEmbeddings.from_bytes_store(embedding_model, store, namespace=embedding_model.model)

vector_file = "local_vector"

if not os.path.exists(vector_file):
    vector_store = FAISS.from_documents(chunked_documents, cached_embedder)
    vector_store.save_local(vector_file)
else:
    vector_store = FAISS.load_local(vector_file, cached_embedder, allow_dangerous_deserialization=True)

@cl.on_chat_start
async def on_chat_start():
    prompt_template = ChatPromptTemplate.from_template(
        "You are a movie recommendation system, for a given {query} find recommendations from {content}."
    )
    retriever = vector_store.as_retriever()
    chat_model = ChatOpenAI(model="gpt-4o", temperature=0.2, openai_api_key=openai_api_key, streaming=True)
    parser = StrOutputParser()

    runnable_chain = (
        {"query": RunnablePassthrough(), "content": retriever}
        | prompt_template
        | chat_model
        | parser
    )
    cl.user_session.set("chain", runnable_chain)


@cl.on_message
async def main(message):
    chain = cl.user_session.get("chain")
    user_input = cl.Message(content="")

    async for stream in chain.astream(message.content, config=RunnableConfig(callbacks=[cl.LangchainCallbackHandler()])):
        await user_input.stream_token(stream)