Spaces:
Runtime error
Runtime error
import chainlit as cl | |
from datasets import load_dataset | |
from langchain_community.document_loaders import CSVLoader | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain_openai import OpenAIEmbeddings | |
from langchain.embeddings import CacheBackedEmbeddings | |
from langchain.storage import LocalFileStore | |
from langchain_community.vectorstores import FAISS | |
from langchain_core.runnables.base import RunnableSequence | |
from langchain_core.runnables.passthrough import RunnablePassthrough | |
from langchain_core.output_parsers import StrOutputParser | |
from langchain_core.prompts import ChatPromptTemplate | |
from langchain_openai import ChatOpenAI | |
from langchain.schema.runnable import Runnable, RunnablePassthrough, RunnableConfig | |
from langchain.callbacks.base import BaseCallbackHandler | |
def setup_data(): | |
dataset = load_dataset("ShubhamChoksi/IMDB_Movies") | |
dataset_dict = dataset | |
dataset_dict["train"].to_csv("imdb.csv") | |
loader = CSVLoader(file_path="imdb.csv") | |
data = loader.load() | |
text_splitter = RecursiveCharacterTextSplitter( | |
chunk_size=1000, | |
chunk_overlap=100 | |
) | |
chunked_documents = text_splitter.split_documents(data) | |
embedding_model = OpenAIEmbeddings() | |
store = LocalFileStore("./cache/") | |
embedder = CacheBackedEmbeddings.from_bytes_store(embedding_model, store, namespace=embedding_model.model) | |
vector_store = FAISS.from_documents(chunked_documents, embedder) | |
vector_store.save_local("faiss_index") | |
return vector_store | |
doc_search = setup_data() | |
model = ChatOpenAI(model_name="gpt-4o", temperature=0, streaming=True) | |
async def on_chat_start(): | |
template = """Answer the question based only on the following context: | |
{context} | |
Question: {question} | |
""" | |
prompt = ChatPromptTemplate.from_template(template) | |
def format_docs(docs): | |
return "\n\n".join([d.page_content for d in docs]) | |
retriever = doc_search.as_retriever() | |
runnable = ( | |
{"context": retriever | format_docs, "question": RunnablePassthrough()} | |
| prompt | |
| model | |
| StrOutputParser() | |
) | |
cl.user_session.set("runnable", runnable) | |
async def on_message(message: cl.Message): | |
runnable = cl.user_session.get("runnable") # type: Runnable | |
msg = cl.Message(content="") | |
class PostMessageHandler(BaseCallbackHandler): | |
""" | |
Callback handler for handling the retriever and LLM processes. | |
Used to post the sources of the retrieved documents as a Chainlit element. | |
""" | |
def __init__(self, msg: cl.Message): | |
BaseCallbackHandler.__init__(self) | |
self.msg = msg | |
self.sources = set() # To store unique pairs | |
def on_retriever_end(self, documents, *, run_id, parent_run_id, **kwargs): | |
for d in documents: | |
source_page_pair = (d.metadata['source'], d.metadata['page']) | |
self.sources.add(source_page_pair) # Add unique pairs to the set | |
def on_llm_end(self, response, *, run_id, parent_run_id, **kwargs): | |
if len(self.sources): | |
sources_text = "\n".join([f"{source}#page={page}" for source, page in self.sources]) | |
self.msg.elements.append( | |
cl.Text(name="Sources", content=sources_text, display="inline") | |
) | |
async with cl.Step(type="run", name="QA Assistant"): | |
async for chunk in runnable.astream( | |
message.content, | |
config=RunnableConfig(callbacks=[ | |
cl.LangchainCallbackHandler(), | |
PostMessageHandler(msg) | |
]), | |
): | |
await msg.stream_token(chunk) | |
await msg.send() | |