File size: 3,699 Bytes
61ac6e5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
from datasets import load_dataset
from huggingface_hub import list_datasets
from google.colab import userdata
from langchain import OpenAI, LLMMathChain, SerpAPIWrapper
from langchain.agents import initialize_agent, Tool, AgentExecutor
from langchain_community.chat_models import ChatOpenAI
import os
import chainlit as cl
import openai
from google.colab import userdata
from dotenv import load_dotenv
from langchain_community.document_loaders import TextLoader
from langchain_community.document_loaders.csv_loader import CSVLoader
from langchain_community.vectorstores import FAISS
from langchain.storage import LocalFileStore
from langchain.prompts import ChatPromptTemplate
from langchain_openai import ChatOpenAI
from langchain.schema.runnable import RunnableMap
from langchain.schema.output_parser import StrOutputParser
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.output_parsers import ResponseSchema, StructuredOutputParser
import pandas as pd
from langchain_openai import OpenAIEmbeddings
import openai
import asyncio
from dotenv import dotenv_values

# get keys
my_secrets = dotenv_values("key.env")

#load the csv
loader = TextLoader('data.csv')
documents = loader.load()

#split using recursive text splitter
text_splitter = RecursiveCharacterTextSplitter(
        chunk_size=1000,
        chunk_overlap=100,
        length_function=len,
        is_separator_regex=False,
)

docs = text_splitter.split_documents(documents)

# create embeddings
underlying_embeddings = OpenAIEmbeddings(model="text-embedding-ada-002",api_key=my_secrets["OPEN_API_KEY"])
db = FAISS.from_documents(docs, underlying_embeddings)

# Get the retriever for the Chat Model
retriever = db.as_retriever(
        search_kwargs={"k": 10}
)


@cl.on_chat_start
def start():

  # Create the prompt template make sure it doesn't return data not in rag
  template = """
      You're a helpful AI assistent tasked to answer the user's questions about movies.
      You can only make conversations based on the provided context about movies. If a response cannot be formed strictly using the context, politely say you don’t have knowledge about that topic under new line character 'ANSWER:' tag which is prefixed with new line character.

      Remember, you must return both an answer under 'ANSWER:' tag which is prefixed with new line character and citations in line separated format of answer and bulleted list of citiations under 'CITATIONS:' tag. A citation consists of a VERBATIM quote that \
      justifies the answer and the ID of the quoted article.  Return a citation for every quote across all articles \
      that justify the answer. Add a new line character after all citations. Use the following format for your final output:

      new line character
      ANSWER:

      CITATIONS:
      new line character

      CONTEXT:
      {context}

      QUESTION: {question}

      YOUR ANSWER:
  """

  prompt = ChatPromptTemplate.from_messages([("system", template)])

  llm = ChatOpenAI(model="gpt-3.5-turbo-0125", temperature=0, api_key=my_secrets["OPEN_API_KEY"])

  # Define the chain
  inputs = RunnableMap({
      'context': lambda x: retriever.get_relevant_documents(x['question']),
      'question': lambda x: x['question']
  })

  #create runnable chain
  runnable_chain = (
    inputs |
    prompt |
    llm |
    StrOutputParser()
  )
  cl.user_session.set("runnable_chain", runnable_chain)


@cl.on_message
async def on_message(message: cl.Message):
    runnable_chain = cl.user_session.get("runnable_chain")
    msg = message.content

    result = runnable_chain.invoke({"question": msg})

    #print(str(result))
    await cl.Message(content=result).send()