File size: 3,078 Bytes
c226341
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
from datasets import load_dataset
import pandas as pd
import os
from langchain.document_loaders import CSVLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnablePassthrough, RunnableConfig
from langchain_openai import OpenAIEmbeddings, ChatOpenAI
from langchain.embeddings import CacheBackedEmbeddings
from langchain.storage import LocalFileStore
from langchain_community.vectorstores import FAISS
import chainlit as cl

csv_path = "./imdb.csv"
if not os.path.exists(csv_path):
    dataset = load_dataset("ShubhamChoksi/IMDB_Movies")

# Convert to pandas DataFrame (assuming 'train' split)
imdb_train = dataset['train']
df = pd.DataFrame(imdb_train)

# Save to a CSV file
csv_file_path = 'imdb_movies.csv'
df.to_csv(csv_file_path, index=False)

loader = CSVLoader(csv_file_path)
data = loader.load()

# Initialize the text splitter for 1000 char chunk size and 100 char overlap
text_splitter = RecursiveCharacterTextSplitter(
    chunk_size=1000,
    chunk_overlap=100
)

documents = text_splitter.split_documents(data)

# Initialize the OpenAIEmbeddings object
#embedding_model = OpenAIEmbeddings(model="text-embedding-ada-002")
embedding_model = OpenAIEmbeddings()

store = LocalFileStore("./cache/")
embedder = CacheBackedEmbeddings.from_bytes_store(
    embedding_model, store, namespace=embedding_model.model
)

# Attempt to load the FAISS index
vector_store = None
faiss_index_path = "faiss_index"

try:
    vector_store = FAISS.load_local(faiss_index_path, embedding_model, allow_dangerous_deserialization=True)
except Exception as e:
    vector_store = FAISS.from_documents(documents, embedding_model)

vector_store.save_local("faiss_index")

#### APP ###
@cl.on_chat_start
async def on_chat_start():
    # Define the Prompt Template
    prompt_text = """
    You are an AI assistant that loves to suggest movies for people to watch. You have access to the following context:

    {documents}

    Answer the following question. If you cannot answer at all just say "Sorry! I couldn't come up with any movies for this question!"

    Question: {question}
    """

    prompt_template = ChatPromptTemplate.from_template(prompt_text)
    retriever = vector_store.as_retriever()
    chat_model = ChatOpenAI(model="gpt-4o", temperature=0, streaming=True)
    parser = StrOutputParser()

    runnable_chain = (
            {"documents": retriever, "question": RunnablePassthrough()}
            | prompt_template
            | chat_model
            | parser
    )

    cl.user_session.set("runnable_chain", runnable_chain)


@cl.on_message
async def on_message(message: cl.Message):
    runnable_chain = cl.user_session.get("runnable_chain")

    msg = cl.Message(content="")

    async for chunk in runnable_chain.astream(
            message.content,
            config=RunnableConfig(callbacks=[cl.LangchainCallbackHandler()]),
    ):
        await msg.stream_token(chunk)

    await msg.send()