File size: 4,757 Bytes
c226341 d25c837 c226341 2d55914 d25c837 2d55914 d25c837 2d55914 d25c837 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 |
from datasets import load_dataset
import pandas as pd
import os
from langchain.document_loaders import CSVLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_core.callbacks import BaseCallbackHandler
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnablePassthrough, RunnableConfig
from langchain_openai import OpenAIEmbeddings, ChatOpenAI
from langchain.embeddings import CacheBackedEmbeddings
from langchain.storage import LocalFileStore
from langchain_community.vectorstores import FAISS
import chainlit as cl
csv_path = "./imdb.csv"
if not os.path.exists(csv_path):
dataset = load_dataset("ShubhamChoksi/IMDB_Movies")
# Convert to pandas DataFrame (assuming 'train' split)
imdb_train = dataset['train']
df = pd.DataFrame(imdb_train)
# Save to a CSV file
csv_file_path = 'imdb_movies.csv'
df.to_csv(csv_file_path, index=False)
loader = CSVLoader(csv_file_path)
data = loader.load()
# Initialize the text splitter for 1000 char chunk size and 100 char overlap
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=1000,
chunk_overlap=100
)
documents = text_splitter.split_documents(data)
# Initialize the OpenAIEmbeddings object
#embedding_model = OpenAIEmbeddings(model="text-embedding-ada-002")
embedding_model = OpenAIEmbeddings()
store = LocalFileStore("./cache/")
embedder = CacheBackedEmbeddings.from_bytes_store(
embedding_model, store, namespace=embedding_model.model
)
# Attempt to load the FAISS index
vector_store = None
faiss_index_path = "faiss_index"
try:
vector_store = FAISS.load_local(faiss_index_path, embedding_model, allow_dangerous_deserialization=True)
except Exception as e:
vector_store = FAISS.from_documents(documents, embedding_model)
vector_store.save_local("faiss_index")
#### APP ###
@cl.on_chat_start
async def on_chat_start():
# Define the Prompt Template
prompt_text = """
You are an AI assistant that loves to suggest movies for people to watch. You have access to the following context:
{documents}
Answer the following question. If you cannot answer at all just say "Sorry! I couldn't come up with any movies for this question!"
Question: {question}
"""
prompt_template = ChatPromptTemplate.from_template(prompt_text)
retriever = vector_store.as_retriever()
chat_model = ChatOpenAI(model="gpt-4o", temperature=0, streaming=True)
parser = StrOutputParser()
runnable_chain = (
{"documents": retriever, "question": RunnablePassthrough()}
| prompt_template
| chat_model
| parser
)
cl.user_session.set("runnable_chain", runnable_chain)
# @cl.on_message
# async def on_message(message: cl.Message):
# runnable_chain = cl.user_session.get("runnable_chain")
#
# msg = cl.Message(content="")
#
# async for chunk in runnable_chain.astream(
# message.content,
# config=RunnableConfig(callbacks=[cl.LangchainCallbackHandler()]),
# ):
# await msg.stream_token(chunk)
#
# await msg.send()
@cl.on_message
async def on_message(message: cl.Message):
runnable = cl.user_session.get("runnable_chain")
msg = cl.Message(content="")
# class PostMessageHandler(BaseCallbackHandler):
# """
# Callback handler for handling the retriever and LLM processes.
# Used to post the sources of the retrieved documents as a Chainlit element.
# """
#
# def __init__(self, msg: cl.Message):
# BaseCallbackHandler.__init__(self)
# self.msg = msg
# self.sources = set() # To store unique pairs
#
# def on_retriever_end(self, documents, *, run_id, parent_run_id, **kwargs):
# for d in documents:
# source_page_pair = (d.metadata['source'], d.metadata['page'])
# self.sources.add(source_page_pair) # Add unique pairs to the set
#
# def on_llm_end(self, response, *, run_id, parent_run_id, **kwargs):
# if len(self.sources):
# sources_text = "\n".join([f"{source}#page={page}" for source, page in self.sources])
# self.msg.elements.append(
# cl.Text(name="Sources", content=sources_text, display="inline")
# )
async with cl.Step(type="run", name="Movie Assistant"):
async for chunk in runnable.astream(
message.content,
config=RunnableConfig(callbacks=[
cl.LangchainCallbackHandler(),
# PostMessageHandler(msg)
]),
):
await msg.stream_token(chunk)
await msg.send()
|