File size: 2,485 Bytes
57191f4
0e78abf
0e0157a
57191f4
 
0e0157a
57191f4
 
 
 
0e0157a
 
29239cb
9243247
 
 
0e0157a
 
57191f4
 
 
 
 
 
0e0157a
 
 
 
 
 
 
 
 
 
 
 
b6d5233
0e0157a
 
 
 
 
 
 
 
 
 
 
 
57191f4
0e0157a
57191f4
 
 
a85bd19
 
0e0157a
 
 
 
 
 
 
 
 
 
 
 
47f2c23
 
 
a85bd19
0e0157a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import asyncio
import os
import chainlit as cl

from langchain_community.vectorstores import FAISS
from langchain_openai import OpenAIEmbeddings
from langchain_core.runnables.passthrough import RunnablePassthrough
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
from langchain_openai import ChatOpenAI
from langchain.schema.runnable import Runnable, RunnablePassthrough, RunnableConfig
from langchain.prompts import PromptTemplate

openai_api_key = os.getenv("OPENAI_API_KEY")
underlying_embeddings = OpenAIEmbeddings(api_key=openai_api_key)

@cl.on_chat_start
async def on_chat_start():

    print("Embeddings already done, use the saved index")
    # Combine the retrieved data with the output of the LLM
    vector_store = FAISS.load_local(
        "faiss_index", underlying_embeddings, allow_dangerous_deserialization=True
    )
    
    # create a prompt template to send to our LLM that will incorporate the documents from our retriever with the
    # question we ask the chat model
    prompt_template = ChatPromptTemplate.from_template(
        "Answer the {question} based on the following {context}."
    )
    
    # create a retriever for our documents
    retriever = vector_store.as_retriever()
    
    # create a chat model / LLM
    chat_model = ChatOpenAI(
        model="gpt-3.5-turbo", temperature=0, api_key=openai_api_key
    )
    
    # create a parser to parse the output of our LLM
    parser = StrOutputParser()
    
    # 💻 Create the sequence (recipe)
    runnable_chain = (
        # TODO: How do we chain the output of our retriever, prompt, model and model output parser so that we can get a good answer to our query?
        {"context": retriever, "question": RunnablePassthrough()}
        | prompt_template
        | chat_model
        | StrOutputParser()
    )

    cl.user_session.set("runnable", runnable)


@cl.on_message
async def on_message(message: cl.Message):
    logger.info('Starting application')
    # Your main application logic here
    runnable = cl.user_session.get("runnable")  # type: Runnable

    msg = cl.Message(content="")

    async with cl.Step(type="run", name="QA Assistant"):
        
        await msg.stream_token("OAI says: ")

        async for chunk in runn.astream(
            message.content,
            config=RunnableConfig(callbacks=[cl.LangchainCallbackHandler()]),
        ):
            await msg.stream_token(chunk)

    await msg.send()