Spaces:
Runtime error
Runtime error
Upload 25 files
Browse files- __init__.py +0 -0
- app.py +43 -0
- experiments/__init__.py +0 -0
- experiments/__pycache__/qa_agent.cpython-312.pyc +0 -0
- llm/__init__.py +0 -0
- llm/__pycache__/__init__.cpython-312.pyc +0 -0
- llm/__pycache__/calculator_agent.cpython-312.pyc +0 -0
- llm/__pycache__/gemini_client.cpython-312.pyc +0 -0
- llm/__pycache__/orchestrator.cpython-312.pyc +0 -0
- llm/__pycache__/qa_agent.cpython-312.pyc +0 -0
- llm/__pycache__/qa_tool.cpython-312.pyc +0 -0
- llm/__pycache__/wiki_agent.cpython-312.pyc +0 -0
- llm/__pycache__/wiki_tool.cpython-312.pyc +0 -0
- llm/calculator_agent.py +18 -0
- llm/gemini_client.py +51 -0
- llm/orchestrator.py +35 -0
- llm/qa_agent.py +29 -0
- llm/utils/__init__.py +0 -0
- llm/utils/__pycache__/__init__.cpython-312.pyc +0 -0
- llm/utils/__pycache__/wiki_client.cpython-312.pyc +0 -0
- llm/utils/wiki_client.py +20 -0
- llm/wiki_agent.py +46 -0
- requirements.txt +95 -0
- triviaQA.py +62 -0
- utils.py +0 -0
__init__.py
ADDED
|
File without changes
|
app.py
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
from llm.qa_agent import QnAAgent
|
| 5 |
+
from llm.calculator_agent import CalculatorAgent
|
| 6 |
+
from llm.orchestrator import Orchestrator
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
if __name__ == "__main__":
|
| 11 |
+
orchestrator = Orchestrator()
|
| 12 |
+
qna_agent = QnAAgent()
|
| 13 |
+
calculator_agent = CalculatorAgent()
|
| 14 |
+
|
| 15 |
+
# question = input("Question - ")
|
| 16 |
+
def get_answer(question:str) -> [str, str]:
|
| 17 |
+
|
| 18 |
+
api_name, parameters = orchestrator.get_API_call(question)
|
| 19 |
+
|
| 20 |
+
print(f"Using the {api_name} Agent")
|
| 21 |
+
print(api_name, parameters)
|
| 22 |
+
if api_name == "QnA":
|
| 23 |
+
answer, wiki_page = qna_agent.get_answer(parameters)
|
| 24 |
+
|
| 25 |
+
# elif api_name == "calculator":
|
| 26 |
+
# operand, op1, op2 = parameters.split(",")
|
| 27 |
+
# answer = calculator_agent.calculate(operand, op1, op2)
|
| 28 |
+
|
| 29 |
+
print(answer)
|
| 30 |
+
return [answer, wiki_page]
|
| 31 |
+
|
| 32 |
+
demo = gr.Interface(
|
| 33 |
+
fn=get_answer,
|
| 34 |
+
inputs=gr.Textbox(placeholder="Enter your question...[Who won the Cricket World Cup in 2023?]")
|
| 35 |
+
,
|
| 36 |
+
# outputs=[gr.Textbox(label=f'Document {i+1}') for i in range(TOP_K)],
|
| 37 |
+
outputs=[gr.Textbox(label="Answer"), gr.Textbox(label="Wikipedia Page")],
|
| 38 |
+
title="Real time Question Answering",
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
demo.launch()
|
| 42 |
+
|
| 43 |
+
|
experiments/__init__.py
ADDED
|
File without changes
|
experiments/__pycache__/qa_agent.cpython-312.pyc
ADDED
|
Binary file (1.82 kB). View file
|
|
|
llm/__init__.py
ADDED
|
File without changes
|
llm/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (140 Bytes). View file
|
|
|
llm/__pycache__/calculator_agent.cpython-312.pyc
ADDED
|
Binary file (962 Bytes). View file
|
|
|
llm/__pycache__/gemini_client.cpython-312.pyc
ADDED
|
Binary file (2.19 kB). View file
|
|
|
llm/__pycache__/orchestrator.cpython-312.pyc
ADDED
|
Binary file (1.69 kB). View file
|
|
|
llm/__pycache__/qa_agent.cpython-312.pyc
ADDED
|
Binary file (1.9 kB). View file
|
|
|
llm/__pycache__/qa_tool.cpython-312.pyc
ADDED
|
Binary file (1.55 kB). View file
|
|
|
llm/__pycache__/wiki_agent.cpython-312.pyc
ADDED
|
Binary file (2.28 kB). View file
|
|
|
llm/__pycache__/wiki_tool.cpython-312.pyc
ADDED
|
Binary file (2.24 kB). View file
|
|
|
llm/calculator_agent.py
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
class CalculatorAgent:
|
| 2 |
+
def calculate(self, operation: str, x: str, y: str) -> float:
|
| 3 |
+
operation = operation.lower().strip()
|
| 4 |
+
|
| 5 |
+
x = float(x)
|
| 6 |
+
y = float(y)
|
| 7 |
+
if operation == "add":
|
| 8 |
+
return x + y
|
| 9 |
+
elif operation == "subtract":
|
| 10 |
+
return x - y
|
| 11 |
+
elif operation == "multiply":
|
| 12 |
+
return x * y
|
| 13 |
+
elif operation == "divide":
|
| 14 |
+
if y == 0:
|
| 15 |
+
return "Cannot divide by zero"
|
| 16 |
+
return x / y
|
| 17 |
+
else:
|
| 18 |
+
return "Unknown operation"
|
llm/gemini_client.py
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
import google.generativeai as genai
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class GeminiClient:
|
| 7 |
+
def __init__(self, system_message=None):
|
| 8 |
+
self._system_message = system_message
|
| 9 |
+
self._connect_client()
|
| 10 |
+
|
| 11 |
+
def _connect_client(self):
|
| 12 |
+
if not os.getenv("GOOGLE_PALM_KEY"):
|
| 13 |
+
raise Exception("Please set your Google MakerSuite API key")
|
| 14 |
+
|
| 15 |
+
api_key = os.getenv("GOOGLE_PALM_KEY")
|
| 16 |
+
genai.configure(api_key=api_key)
|
| 17 |
+
|
| 18 |
+
safety_settings = [
|
| 19 |
+
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_ONLY_HIGH"},
|
| 20 |
+
{"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_ONLY_HIGH"},
|
| 21 |
+
{
|
| 22 |
+
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
|
| 23 |
+
"threshold": "BLOCK_ONLY_HIGH",
|
| 24 |
+
},
|
| 25 |
+
{
|
| 26 |
+
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
|
| 27 |
+
"threshold": "BLOCK_ONLY_HIGH",
|
| 28 |
+
},
|
| 29 |
+
]
|
| 30 |
+
|
| 31 |
+
defaults = {
|
| 32 |
+
"temperature": 0.7,
|
| 33 |
+
"top_k": 40,
|
| 34 |
+
"top_p": 0.95,
|
| 35 |
+
"max_output_tokens": 1024,
|
| 36 |
+
}
|
| 37 |
+
|
| 38 |
+
self._model = genai.GenerativeModel(
|
| 39 |
+
model_name="gemini-pro",
|
| 40 |
+
generation_config=defaults,
|
| 41 |
+
safety_settings=safety_settings,
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
def generate_text(self, prompt: str) -> str:
|
| 45 |
+
full_prompt = self._system_message + prompt
|
| 46 |
+
try:
|
| 47 |
+
response = self._model.generate_content(full_prompt).text
|
| 48 |
+
except Exception as e:
|
| 49 |
+
print(f"Error: {e}")
|
| 50 |
+
response = ""
|
| 51 |
+
return response
|
llm/orchestrator.py
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from llm.gemini_client import GeminiClient
|
| 2 |
+
|
| 3 |
+
SYSTEM_MESSAGE = """You are an orchestrator that can knows what various tools
|
| 4 |
+
or agents can do and which is the right one to pick. Given a question, your
|
| 5 |
+
job is just to pick the right agent to use and the rest will be taken care of.
|
| 6 |
+
|
| 7 |
+
For now you can use a calculator agent that can help you do basic arithmetic
|
| 8 |
+
calculations. You can also use a question answering agent that can answer
|
| 9 |
+
questions about various topics.
|
| 10 |
+
|
| 11 |
+
The API's are:
|
| 12 |
+
calculator[operand 1, operand 2, operation]
|
| 13 |
+
QnA[question]
|
| 14 |
+
|
| 15 |
+
Here are some examples:
|
| 16 |
+
|
| 17 |
+
Example 1:
|
| 18 |
+
Question: What is 2 + 2?
|
| 19 |
+
Response: calculator$add, 2, 2
|
| 20 |
+
|
| 21 |
+
Example 2: Who designed the Eiffel Tower?
|
| 22 |
+
Respnse: QnA$Who designed the Eiffel Tower?
|
| 23 |
+
|
| 24 |
+
### Question:
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class Orchestrator:
|
| 29 |
+
def __init__(self):
|
| 30 |
+
self._client = GeminiClient(system_message=SYSTEM_MESSAGE)
|
| 31 |
+
|
| 32 |
+
def get_API_call(self, query: str) -> (str, str):
|
| 33 |
+
api_call = self._client.generate_text(query)
|
| 34 |
+
api_name, parameters = api_call.split("$")
|
| 35 |
+
return api_name, parameters
|
llm/qa_agent.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from llm.gemini_client import GeminiClient
|
| 2 |
+
from llm.wiki_agent import WikiSearchAgent
|
| 3 |
+
|
| 4 |
+
SYSTEM_MESSAGE = """You are a Question Answering tool that can answer various
|
| 5 |
+
trivia questions. However, you might be asked questions that is beyond your
|
| 6 |
+
knowledge or recent events that you might not be trained on
|
| 7 |
+
(beyond training cutoff). So, if there is Wikipedia page entry provided,
|
| 8 |
+
use that to answer the question. Just return the answer, don't make a
|
| 9 |
+
verbose response."""
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class QnAAgent:
|
| 13 |
+
def __init__(self):
|
| 14 |
+
self._client = GeminiClient(system_message=SYSTEM_MESSAGE)
|
| 15 |
+
self._wiki_tool = WikiSearchAgent()
|
| 16 |
+
|
| 17 |
+
@staticmethod
|
| 18 |
+
def _format_prompt(query: str, wiki_page: str) -> str:
|
| 19 |
+
return f"\n###Question:{query} \
|
| 20 |
+
\n###Wikipedia Page:{wiki_page}"
|
| 21 |
+
|
| 22 |
+
def get_answer(self, query: str, use_context: bool = True) -> [str, str]:
|
| 23 |
+
if use_context:
|
| 24 |
+
wiki_page = self._wiki_tool.get_wikipedia_entry(query)
|
| 25 |
+
prompt = self._format_prompt(query, wiki_page)
|
| 26 |
+
else:
|
| 27 |
+
wiki_page = ""
|
| 28 |
+
prompt = query
|
| 29 |
+
return self._client.generate_text(prompt), wiki_page
|
llm/utils/__init__.py
ADDED
|
File without changes
|
llm/utils/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (146 Bytes). View file
|
|
|
llm/utils/__pycache__/wiki_client.cpython-312.pyc
ADDED
|
Binary file (1.17 kB). View file
|
|
|
llm/utils/wiki_client.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import wikipediaapi
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
class WikiClient:
|
| 5 |
+
def __init__(self):
|
| 6 |
+
self.wiki = wikipediaapi.Wikipedia(
|
| 7 |
+
user_agent="WikiAgent/0.0 [email protected]",
|
| 8 |
+
language="en",
|
| 9 |
+
extract_format=wikipediaapi.ExtractFormat.WIKI,
|
| 10 |
+
)
|
| 11 |
+
|
| 12 |
+
def get_pages(self, query):
|
| 13 |
+
pages = self.wiki.page(query)
|
| 14 |
+
return pages
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
if __name__ == "__main__":
|
| 18 |
+
client = WikiClient()
|
| 19 |
+
pages = client.get_pages("Cricket World Cup")
|
| 20 |
+
print(pages.summary)
|
llm/wiki_agent.py
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from llm.gemini_client import GeminiClient
|
| 2 |
+
|
| 3 |
+
from llm.utils.wiki_client import WikiClient
|
| 4 |
+
|
| 5 |
+
SYSTEM_MESSAGE = """you have access to a wikipedia summarizer that can return a summary for a topic. \
|
| 6 |
+
Your job is to act as a question answering tool. Whenever you are asked about a question related to knowledge, \
|
| 7 |
+
instead of using your internal knowledge (which can be faulty or out of date), \
|
| 8 |
+
format a Wikipedia search query string that can help answer the question. \
|
| 9 |
+
|
| 10 |
+
Remember Wikipedia Entries are usually about a simple entity or event, so keep the \
|
| 11 |
+
query short, and about the entity being asked about. Also, don't use your knowledge \
|
| 12 |
+
to ask about the answer. Instead form queries about the entity in the question. This \
|
| 13 |
+
will help you get the right wikipedia entries for questions when you dont know the answer
|
| 14 |
+
|
| 15 |
+
### Example 1:
|
| 16 |
+
|
| 17 |
+
Question: Who won the ICC Cricket World Cup?
|
| 18 |
+
Correct Response: Cricket World Cup
|
| 19 |
+
Incorrect response: Australia
|
| 20 |
+
|
| 21 |
+
### Example 2:
|
| 22 |
+
|
| 23 |
+
Question: Who directed the classic 30s western Stagecoach?
|
| 24 |
+
Response: Stagecoach
|
| 25 |
+
Incorrect response: John Ford
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
Below is the question. Return the wikipedia search query you would use \n
|
| 29 |
+
|
| 30 |
+
### Question:
|
| 31 |
+
"""
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class WikiSearchAgent:
|
| 35 |
+
def __init__(self):
|
| 36 |
+
self._llm_client = GeminiClient(system_message=SYSTEM_MESSAGE)
|
| 37 |
+
self._wiki_client = WikiClient()
|
| 38 |
+
|
| 39 |
+
def get_wikipedia_entry(self, prompt: str) -> str:
|
| 40 |
+
|
| 41 |
+
wiki_search_query = self._llm_client.generate_text(prompt)
|
| 42 |
+
wikipedia_page = self._wiki_client.get_pages(wiki_search_query)
|
| 43 |
+
try:
|
| 44 |
+
return wikipedia_page.summary
|
| 45 |
+
except:
|
| 46 |
+
return ""
|
requirements.txt
ADDED
|
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
aiofiles==23.2.1
|
| 2 |
+
aiohttp==3.9.1
|
| 3 |
+
aiosignal==1.3.1
|
| 4 |
+
altair==5.2.0
|
| 5 |
+
annotated-types==0.6.0
|
| 6 |
+
anyio==4.2.0
|
| 7 |
+
attrs==23.1.0
|
| 8 |
+
beautifulsoup4==4.12.2
|
| 9 |
+
cachetools==5.3.2
|
| 10 |
+
certifi==2023.11.17
|
| 11 |
+
charset-normalizer==3.3.2
|
| 12 |
+
click==8.1.7
|
| 13 |
+
colorama==0.4.6
|
| 14 |
+
contourpy==1.2.0
|
| 15 |
+
cycler==0.12.1
|
| 16 |
+
datasets==2.15.0
|
| 17 |
+
dill==0.3.7
|
| 18 |
+
fastapi==0.109.0
|
| 19 |
+
ffmpy==0.3.1
|
| 20 |
+
filelock==3.13.1
|
| 21 |
+
fonttools==4.47.2
|
| 22 |
+
frozenlist==1.4.1
|
| 23 |
+
fsspec==2023.10.0
|
| 24 |
+
google-ai-generativelanguage==0.4.0
|
| 25 |
+
google-api-core==2.15.0
|
| 26 |
+
google-auth==2.25.2
|
| 27 |
+
google-generativeai==0.3.2
|
| 28 |
+
googleapis-common-protos==1.62.0
|
| 29 |
+
gradio==4.14.0
|
| 30 |
+
gradio_client==0.8.0
|
| 31 |
+
grpcio==1.60.0
|
| 32 |
+
grpcio-status==1.60.0
|
| 33 |
+
h11==0.14.0
|
| 34 |
+
httpcore==1.0.2
|
| 35 |
+
httpx==0.26.0
|
| 36 |
+
huggingface-hub==0.20.1
|
| 37 |
+
idna==3.6
|
| 38 |
+
importlib-resources==6.1.1
|
| 39 |
+
install==1.3.5
|
| 40 |
+
Jinja2==3.1.3
|
| 41 |
+
jsonschema==4.21.0
|
| 42 |
+
jsonschema-specifications==2023.12.1
|
| 43 |
+
kiwisolver==1.4.5
|
| 44 |
+
markdown-it-py==3.0.0
|
| 45 |
+
MarkupSafe==2.1.3
|
| 46 |
+
matplotlib==3.8.2
|
| 47 |
+
mdurl==0.1.2
|
| 48 |
+
multidict==6.0.4
|
| 49 |
+
multiprocess==0.70.15
|
| 50 |
+
numpy==1.26.2
|
| 51 |
+
orjson==3.9.10
|
| 52 |
+
packaging==23.2
|
| 53 |
+
pandas==2.1.4
|
| 54 |
+
pillow==10.2.0
|
| 55 |
+
proto-plus==1.23.0
|
| 56 |
+
protobuf==4.25.1
|
| 57 |
+
pyarrow==14.0.2
|
| 58 |
+
pyarrow-hotfix==0.6
|
| 59 |
+
pyasn1==0.5.1
|
| 60 |
+
pyasn1-modules==0.3.0
|
| 61 |
+
pydantic==2.5.3
|
| 62 |
+
pydantic_core==2.14.6
|
| 63 |
+
pydub==0.25.1
|
| 64 |
+
Pygments==2.17.2
|
| 65 |
+
pyparsing==3.1.1
|
| 66 |
+
python-dateutil==2.8.2
|
| 67 |
+
python-multipart==0.0.6
|
| 68 |
+
pytz==2023.3.post1
|
| 69 |
+
PyYAML==6.0.1
|
| 70 |
+
referencing==0.32.1
|
| 71 |
+
requests==2.31.0
|
| 72 |
+
rich==13.7.0
|
| 73 |
+
rpds-py==0.17.1
|
| 74 |
+
rsa==4.9
|
| 75 |
+
semantic-version==2.10.0
|
| 76 |
+
setuptools==68.2.2
|
| 77 |
+
shellingham==1.5.4
|
| 78 |
+
six==1.16.0
|
| 79 |
+
sniffio==1.3.0
|
| 80 |
+
soupsieve==2.5
|
| 81 |
+
starlette==0.35.1
|
| 82 |
+
tomlkit==0.12.0
|
| 83 |
+
toolz==0.12.0
|
| 84 |
+
tqdm==4.66.1
|
| 85 |
+
typer==0.9.0
|
| 86 |
+
typing_extensions==4.9.0
|
| 87 |
+
tzdata==2023.3
|
| 88 |
+
urllib3==2.1.0
|
| 89 |
+
uvicorn==0.26.0
|
| 90 |
+
websockets==11.0.3
|
| 91 |
+
wheel==0.41.2
|
| 92 |
+
wikipedia==1.4.0
|
| 93 |
+
Wikipedia-API==0.6.0
|
| 94 |
+
xxhash==3.4.1
|
| 95 |
+
yarl==1.9.4
|
triviaQA.py
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import datasets
|
| 2 |
+
|
| 3 |
+
from llm.qa_agent import QnAAgent
|
| 4 |
+
|
| 5 |
+
validation_dataset = datasets.load_dataset(
|
| 6 |
+
"trivia_qa", "rc", split="test"
|
| 7 |
+
) # remove [:5%] to run on full validation set
|
| 8 |
+
|
| 9 |
+
PUNCTUATION_SET_TO_EXCLUDE = set("".join(["‘", "’", "´", "`", ".", ",", "-", '"']))
|
| 10 |
+
|
| 11 |
+
qna_agent = QnAAgent()
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def get_sub_answers(answers, begin=0, end=None):
|
| 15 |
+
return [" ".join(x.split(" ")[begin:end]) for x in answers if len(x.split(" ")) > 1]
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def expand_to_aliases(given_answers, make_sub_answers=False):
|
| 19 |
+
if make_sub_answers:
|
| 20 |
+
# if answers are longer than one word, make sure a predictions is correct if it coresponds to the complete 1: or :-1 sub word
|
| 21 |
+
# *e.g.* if the correct answer contains a prefix such as "the", or "a"
|
| 22 |
+
given_answers = (
|
| 23 |
+
given_answers
|
| 24 |
+
+ get_sub_answers(given_answers, begin=1)
|
| 25 |
+
+ get_sub_answers(given_answers, end=-1)
|
| 26 |
+
)
|
| 27 |
+
answers = []
|
| 28 |
+
for answer in given_answers:
|
| 29 |
+
alias = answer.replace("_", " ").lower()
|
| 30 |
+
alias = "".join(
|
| 31 |
+
c if c not in PUNCTUATION_SET_TO_EXCLUDE else " " for c in alias
|
| 32 |
+
)
|
| 33 |
+
answers.append(" ".join(alias.split()).strip())
|
| 34 |
+
return set(answers)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def evaluate(example):
|
| 38 |
+
# get answer from QnA agent
|
| 39 |
+
answer_without_context = qna_agent.get_answer(example["question"], use_context=False)
|
| 40 |
+
answer_with_context = qna_agent.get_answer(example["question"], use_context=True)
|
| 41 |
+
|
| 42 |
+
example["output"] = answer_without_context
|
| 43 |
+
example["output_context"] = answer_with_context
|
| 44 |
+
|
| 45 |
+
example["targets"] = example["answer"]["aliases"]
|
| 46 |
+
answers = expand_to_aliases(example["targets"], make_sub_answers=True)
|
| 47 |
+
|
| 48 |
+
predictions = expand_to_aliases([example["output"]])
|
| 49 |
+
preditions_with_context = expand_to_aliases([example["output_context"]])
|
| 50 |
+
|
| 51 |
+
# if there is a common element, it's a match
|
| 52 |
+
example["match"] = len(list(answers & predictions)) > 0
|
| 53 |
+
example["match_context"] = len(list(answers & preditions_with_context)) > 0
|
| 54 |
+
|
| 55 |
+
return example
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
results = validation_dataset.map(evaluate)
|
| 59 |
+
|
| 60 |
+
print("Exact Match (EM) without context: {:.2f}".format(100 * sum(results['match'])/len(results)))
|
| 61 |
+
print("Exact Match (EM) with context: {:.2f}".format(100 * sum(results['match_context'])/len(results)))
|
| 62 |
+
|
utils.py
ADDED
|
File without changes
|