svijayanand's picture
Update app.py
b6d5233 verified
import asyncio
import os
import chainlit as cl
from langchain_community.vectorstores import FAISS
from langchain_openai import OpenAIEmbeddings
from langchain_core.runnables.passthrough import RunnablePassthrough
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
from langchain_openai import ChatOpenAI
from langchain.schema.runnable import Runnable, RunnablePassthrough, RunnableConfig
from langchain.prompts import PromptTemplate
openai_api_key = os.getenv("OPENAI_API_KEY")
underlying_embeddings = OpenAIEmbeddings(api_key=openai_api_key)
@cl.on_chat_start
async def on_chat_start():
print("Embeddings already done, use the saved index")
# Combine the retrieved data with the output of the LLM
vector_store = FAISS.load_local(
"faiss_index", underlying_embeddings, allow_dangerous_deserialization=True
)
# create a prompt template to send to our LLM that will incorporate the documents from our retriever with the
# question we ask the chat model
prompt_template = ChatPromptTemplate.from_template(
"Answer the {question} based on the following {context}."
)
# create a retriever for our documents
retriever = vector_store.as_retriever()
# create a chat model / LLM
chat_model = ChatOpenAI(
model="gpt-3.5-turbo", temperature=0, api_key=openai_api_key
)
# create a parser to parse the output of our LLM
parser = StrOutputParser()
# πŸ’» Create the sequence (recipe)
runnable_chain = (
# TODO: How do we chain the output of our retriever, prompt, model and model output parser so that we can get a good answer to our query?
{"context": retriever, "question": RunnablePassthrough()}
| prompt_template
| chat_model
| StrOutputParser()
)
cl.user_session.set("runnable", runnable)
@cl.on_message
async def on_message(message: cl.Message):
logger.info('Starting application')
# Your main application logic here
runnable = cl.user_session.get("runnable") # type: Runnable
msg = cl.Message(content="")
async with cl.Step(type="run", name="QA Assistant"):
await msg.stream_token("OAI says: ")
async for chunk in runn.astream(
message.content,
config=RunnableConfig(callbacks=[cl.LangchainCallbackHandler()]),
):
await msg.stream_token(chunk)
await msg.send()