gfhayworth commited on
Commit
393331d
·
1 Parent(s): 37d88e4

Update greg_funcs.py

Browse files
Files changed (1) hide show
  1. greg_funcs.py +103 -9
greg_funcs.py CHANGED
@@ -1,15 +1,21 @@
1
-
2
-
3
  from sentence_transformers import SentenceTransformer, CrossEncoder, util
4
  from torch import tensor as torch_tensor
5
  from datasets import load_dataset
6
 
7
  from langchain.llms import OpenAI
8
  from langchain.docstore.document import Document
 
9
  from langchain.chains.question_answering import load_qa_chain
10
  from langchain.chains.qa_with_sources import load_qa_with_sources_chain
11
- from langchain.prompts import PromptTemplate
 
 
 
 
 
12
 
 
 
13
 
14
  """# import models"""
15
 
@@ -62,9 +68,11 @@ def get_text_fmt(qry, passages = mypassages, doc_embedding=mycorpus_embeddings):
62
  prediction_text.append(result)
63
  return prediction_text
64
 
 
 
65
  template = """You are a friendly AI assistant for the insurance company Humana. Given the following extracted parts of a long document and a question, create a succinct final answer.
66
  If you don't know the answer, just say that you don't know. Don't try to make up an answer.
67
- If the question is not about Humana, politely inform them that you are tuned to only answer questions about Humana.
68
  QUESTION: {question}
69
  =========
70
  {context}
@@ -74,18 +82,104 @@ PROMPT = PromptTemplate(template=template, input_variables=["context", "question
74
 
75
  chain_qa = load_qa_chain(OpenAI(temperature=0), chain_type="stuff", prompt=PROMPT)
76
 
 
 
 
 
 
 
 
 
 
77
 
78
  def get_llm_response(message):
79
- mydocs = get_text_fmt(message)
80
- response = chain_qa.run(input_documents=mydocs, question=message)
81
- return response
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
 
83
  def chat(message, history):
84
  history = history or []
85
  message = message.lower()
86
 
87
- response = get_llm_response(message)
88
- history.append((message, response))
89
  return history, history
90
 
91
 
 
 
 
1
  from sentence_transformers import SentenceTransformer, CrossEncoder, util
2
  from torch import tensor as torch_tensor
3
  from datasets import load_dataset
4
 
5
  from langchain.llms import OpenAI
6
  from langchain.docstore.document import Document
7
+ from langchain.prompts import PromptTemplate
8
  from langchain.chains.question_answering import load_qa_chain
9
  from langchain.chains.qa_with_sources import load_qa_with_sources_chain
10
+ from langchain import LLMMathChain, SQLDatabase, SQLDatabaseChain, LLMChain
11
+ from langchain.agents import initialize_agent, Tool
12
+
13
+ import sqlite3
14
+ #import pandas as pd
15
+ import json
16
 
17
+ # database
18
+ cxn = sqlite3.connect('./data/mbr.db')
19
 
20
  """# import models"""
21
 
 
68
  prediction_text.append(result)
69
  return prediction_text
70
 
71
+ """# LLM based qa functions"""
72
+
73
  template = """You are a friendly AI assistant for the insurance company Humana. Given the following extracted parts of a long document and a question, create a succinct final answer.
74
  If you don't know the answer, just say that you don't know. Don't try to make up an answer.
75
+ If the question is not about Humana, politely inform the user that you are tuned to only answer questions about Humana.
76
  QUESTION: {question}
77
  =========
78
  {context}
 
82
 
83
  chain_qa = load_qa_chain(OpenAI(temperature=0), chain_type="stuff", prompt=PROMPT)
84
 
85
+ def get_text_fmt(qry, passages = mypassages, doc_embedding=mycorpus_embeddings):
86
+ predictions = search(qry, passages = passages, doc_embedding = doc_embedding, top_n=5, )
87
+ prediction_text = []
88
+ for hit in predictions:
89
+ page_content = passages[hit['corpus_id']]
90
+ metadata = {"source": hit['corpus_id']}
91
+ result = Document(page_content=page_content, metadata=metadata)
92
+ prediction_text.append(result)
93
+ return prediction_text
94
 
95
  def get_llm_response(message):
96
+ mydocs = get_text_fmt(message)
97
+ responses = chain_qa.run(input_documents=mydocs, question=message)
98
+ return responses
99
+
100
+ # for x in xmpl_list:
101
+ # print(32*'=')
102
+ # print(x)
103
+ # print(32*'=')
104
+ # r = get_llm_response(x)
105
+ # print(r)
106
+
107
+ """# Database query"""
108
+
109
+ db = SQLDatabase.from_uri("sqlite:///./data/mbr.db")
110
+
111
+ llm = OpenAI(temperature=0)
112
+ # default model
113
+ # model_name: str = "text-davinci-003"
114
+ # instruction fine-tuned, sometimes referred to as GPT-3.5
115
+
116
+ db_chain = SQLDatabaseChain(llm=llm, database=db, verbose=True)
117
+
118
+ def db_qry(qry):
119
+ responses = db_chain.run(query='my mbr_id is 456 ;'+str(qry) ) ############### hardcode mbr id 456 for demo
120
+ return responses
121
+
122
+ #db_qry('how many footcare visits have I had?')
123
+
124
+ """## Math
125
+ - default version
126
+ """
127
+
128
+ llm_math_chain = LLMMathChain(llm=llm, verbose=True)
129
+
130
+ #llm_math_chain.run('what is the square root of 49?')
131
+
132
+ """# Greeting"""
133
+
134
+ template = """You are a friendly AI assistant for the insurance company Humana.
135
+ Your name is Bruce and you were created on February 13, 20203.
136
+ Offer polite greetings and brief small talk.
137
+ Respond to thanks with, 'Glad to help.'
138
+ If the question is not about Humana, politely guide the user to ask questions about Humana insurance benefits.
139
+ QUESTION: {question}
140
+ =========
141
+ FINAL ANSWER:"""
142
+ greet_prompt = PromptTemplate(template=template, input_variables=["question"])
143
+
144
+ greet_llm = LLMChain(prompt=greet_prompt, llm=llm, verbose=True)
145
+
146
+ """# MRKL Chain"""
147
+
148
+ tools = [
149
+ Tool(
150
+ name = "Benefit",
151
+ func=get_llm_response,
152
+ description="useful for when you need to answer questions about plan benefits, premiums and payments. You should ask targeted questions"
153
+ ),
154
+ Tool(
155
+ name="Calculator",
156
+ func=llm_math_chain.run,
157
+ description="useful for when you need to answer questions about math"
158
+ ),
159
+ Tool(
160
+ name="Member DB",
161
+ func=db_qry,
162
+ description="useful for when you need to answer questions about member details such their accumulated use of services. Input should be in the form of a question containing full context"
163
+ ),
164
+ Tool(
165
+ name="Greeting",
166
+ func=greet_llm.run,
167
+ description="useful for when you need to respond to greetings, thanks and make small talk"
168
+ ),
169
+ ]
170
+
171
+ mrkl = initialize_agent(tools, llm, agent="zero-shot-react-description", verbose=True, return_intermediate_steps=True, max_iterations=5, early_stopping_method="generate")
172
+
173
+ def mrkl_rspnd(qry):
174
+ response = mrkl({"input":str(qry) })
175
+ return response
176
 
177
  def chat(message, history):
178
  history = history or []
179
  message = message.lower()
180
 
181
+ response = mrkl_rspnd(message)
182
+ history.append((message, response['output']))
183
  return history, history
184
 
185