Spaces:
Runtime error
Runtime error
mohit-raghavendra
commited on
Upload 25 files
Browse files- __init__.py +0 -0
- app.py +43 -0
- experiments/__init__.py +0 -0
- experiments/__pycache__/qa_agent.cpython-312.pyc +0 -0
- llm/__init__.py +0 -0
- llm/__pycache__/__init__.cpython-312.pyc +0 -0
- llm/__pycache__/calculator_agent.cpython-312.pyc +0 -0
- llm/__pycache__/gemini_client.cpython-312.pyc +0 -0
- llm/__pycache__/orchestrator.cpython-312.pyc +0 -0
- llm/__pycache__/qa_agent.cpython-312.pyc +0 -0
- llm/__pycache__/qa_tool.cpython-312.pyc +0 -0
- llm/__pycache__/wiki_agent.cpython-312.pyc +0 -0
- llm/__pycache__/wiki_tool.cpython-312.pyc +0 -0
- llm/calculator_agent.py +18 -0
- llm/gemini_client.py +51 -0
- llm/orchestrator.py +35 -0
- llm/qa_agent.py +29 -0
- llm/utils/__init__.py +0 -0
- llm/utils/__pycache__/__init__.cpython-312.pyc +0 -0
- llm/utils/__pycache__/wiki_client.cpython-312.pyc +0 -0
- llm/utils/wiki_client.py +20 -0
- llm/wiki_agent.py +46 -0
- requirements.txt +95 -0
- triviaQA.py +62 -0
- utils.py +0 -0
__init__.py
ADDED
File without changes
|
app.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
|
3 |
+
|
4 |
+
from llm.qa_agent import QnAAgent
|
5 |
+
from llm.calculator_agent import CalculatorAgent
|
6 |
+
from llm.orchestrator import Orchestrator
|
7 |
+
|
8 |
+
|
9 |
+
|
10 |
+
if __name__ == "__main__":
|
11 |
+
orchestrator = Orchestrator()
|
12 |
+
qna_agent = QnAAgent()
|
13 |
+
calculator_agent = CalculatorAgent()
|
14 |
+
|
15 |
+
# question = input("Question - ")
|
16 |
+
def get_answer(question:str) -> [str, str]:
|
17 |
+
|
18 |
+
api_name, parameters = orchestrator.get_API_call(question)
|
19 |
+
|
20 |
+
print(f"Using the {api_name} Agent")
|
21 |
+
print(api_name, parameters)
|
22 |
+
if api_name == "QnA":
|
23 |
+
answer, wiki_page = qna_agent.get_answer(parameters)
|
24 |
+
|
25 |
+
# elif api_name == "calculator":
|
26 |
+
# operand, op1, op2 = parameters.split(",")
|
27 |
+
# answer = calculator_agent.calculate(operand, op1, op2)
|
28 |
+
|
29 |
+
print(answer)
|
30 |
+
return [answer, wiki_page]
|
31 |
+
|
32 |
+
demo = gr.Interface(
|
33 |
+
fn=get_answer,
|
34 |
+
inputs=gr.Textbox(placeholder="Enter your question...[Who won the Cricket World Cup in 2023?]")
|
35 |
+
,
|
36 |
+
# outputs=[gr.Textbox(label=f'Document {i+1}') for i in range(TOP_K)],
|
37 |
+
outputs=[gr.Textbox(label="Answer"), gr.Textbox(label="Wikipedia Page")],
|
38 |
+
title="Real time Question Answering",
|
39 |
+
)
|
40 |
+
|
41 |
+
demo.launch()
|
42 |
+
|
43 |
+
|
experiments/__init__.py
ADDED
File without changes
|
experiments/__pycache__/qa_agent.cpython-312.pyc
ADDED
Binary file (1.82 kB). View file
|
|
llm/__init__.py
ADDED
File without changes
|
llm/__pycache__/__init__.cpython-312.pyc
ADDED
Binary file (140 Bytes). View file
|
|
llm/__pycache__/calculator_agent.cpython-312.pyc
ADDED
Binary file (962 Bytes). View file
|
|
llm/__pycache__/gemini_client.cpython-312.pyc
ADDED
Binary file (2.19 kB). View file
|
|
llm/__pycache__/orchestrator.cpython-312.pyc
ADDED
Binary file (1.69 kB). View file
|
|
llm/__pycache__/qa_agent.cpython-312.pyc
ADDED
Binary file (1.9 kB). View file
|
|
llm/__pycache__/qa_tool.cpython-312.pyc
ADDED
Binary file (1.55 kB). View file
|
|
llm/__pycache__/wiki_agent.cpython-312.pyc
ADDED
Binary file (2.28 kB). View file
|
|
llm/__pycache__/wiki_tool.cpython-312.pyc
ADDED
Binary file (2.24 kB). View file
|
|
llm/calculator_agent.py
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
class CalculatorAgent:
|
2 |
+
def calculate(self, operation: str, x: str, y: str) -> float:
|
3 |
+
operation = operation.lower().strip()
|
4 |
+
|
5 |
+
x = float(x)
|
6 |
+
y = float(y)
|
7 |
+
if operation == "add":
|
8 |
+
return x + y
|
9 |
+
elif operation == "subtract":
|
10 |
+
return x - y
|
11 |
+
elif operation == "multiply":
|
12 |
+
return x * y
|
13 |
+
elif operation == "divide":
|
14 |
+
if y == 0:
|
15 |
+
return "Cannot divide by zero"
|
16 |
+
return x / y
|
17 |
+
else:
|
18 |
+
return "Unknown operation"
|
llm/gemini_client.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
import google.generativeai as genai
|
4 |
+
|
5 |
+
|
6 |
+
class GeminiClient:
|
7 |
+
def __init__(self, system_message=None):
|
8 |
+
self._system_message = system_message
|
9 |
+
self._connect_client()
|
10 |
+
|
11 |
+
def _connect_client(self):
|
12 |
+
if not os.getenv("GOOGLE_PALM_KEY"):
|
13 |
+
raise Exception("Please set your Google MakerSuite API key")
|
14 |
+
|
15 |
+
api_key = os.getenv("GOOGLE_PALM_KEY")
|
16 |
+
genai.configure(api_key=api_key)
|
17 |
+
|
18 |
+
safety_settings = [
|
19 |
+
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_ONLY_HIGH"},
|
20 |
+
{"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_ONLY_HIGH"},
|
21 |
+
{
|
22 |
+
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
|
23 |
+
"threshold": "BLOCK_ONLY_HIGH",
|
24 |
+
},
|
25 |
+
{
|
26 |
+
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
|
27 |
+
"threshold": "BLOCK_ONLY_HIGH",
|
28 |
+
},
|
29 |
+
]
|
30 |
+
|
31 |
+
defaults = {
|
32 |
+
"temperature": 0.7,
|
33 |
+
"top_k": 40,
|
34 |
+
"top_p": 0.95,
|
35 |
+
"max_output_tokens": 1024,
|
36 |
+
}
|
37 |
+
|
38 |
+
self._model = genai.GenerativeModel(
|
39 |
+
model_name="gemini-pro",
|
40 |
+
generation_config=defaults,
|
41 |
+
safety_settings=safety_settings,
|
42 |
+
)
|
43 |
+
|
44 |
+
def generate_text(self, prompt: str) -> str:
|
45 |
+
full_prompt = self._system_message + prompt
|
46 |
+
try:
|
47 |
+
response = self._model.generate_content(full_prompt).text
|
48 |
+
except Exception as e:
|
49 |
+
print(f"Error: {e}")
|
50 |
+
response = ""
|
51 |
+
return response
|
llm/orchestrator.py
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from llm.gemini_client import GeminiClient
|
2 |
+
|
3 |
+
SYSTEM_MESSAGE = """You are an orchestrator that can knows what various tools
|
4 |
+
or agents can do and which is the right one to pick. Given a question, your
|
5 |
+
job is just to pick the right agent to use and the rest will be taken care of.
|
6 |
+
|
7 |
+
For now you can use a calculator agent that can help you do basic arithmetic
|
8 |
+
calculations. You can also use a question answering agent that can answer
|
9 |
+
questions about various topics.
|
10 |
+
|
11 |
+
The API's are:
|
12 |
+
calculator[operand 1, operand 2, operation]
|
13 |
+
QnA[question]
|
14 |
+
|
15 |
+
Here are some examples:
|
16 |
+
|
17 |
+
Example 1:
|
18 |
+
Question: What is 2 + 2?
|
19 |
+
Response: calculator$add, 2, 2
|
20 |
+
|
21 |
+
Example 2: Who designed the Eiffel Tower?
|
22 |
+
Respnse: QnA$Who designed the Eiffel Tower?
|
23 |
+
|
24 |
+
### Question:
|
25 |
+
"""
|
26 |
+
|
27 |
+
|
28 |
+
class Orchestrator:
|
29 |
+
def __init__(self):
|
30 |
+
self._client = GeminiClient(system_message=SYSTEM_MESSAGE)
|
31 |
+
|
32 |
+
def get_API_call(self, query: str) -> (str, str):
|
33 |
+
api_call = self._client.generate_text(query)
|
34 |
+
api_name, parameters = api_call.split("$")
|
35 |
+
return api_name, parameters
|
llm/qa_agent.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from llm.gemini_client import GeminiClient
|
2 |
+
from llm.wiki_agent import WikiSearchAgent
|
3 |
+
|
4 |
+
SYSTEM_MESSAGE = """You are a Question Answering tool that can answer various
|
5 |
+
trivia questions. However, you might be asked questions that is beyond your
|
6 |
+
knowledge or recent events that you might not be trained on
|
7 |
+
(beyond training cutoff). So, if there is Wikipedia page entry provided,
|
8 |
+
use that to answer the question. Just return the answer, don't make a
|
9 |
+
verbose response."""
|
10 |
+
|
11 |
+
|
12 |
+
class QnAAgent:
|
13 |
+
def __init__(self):
|
14 |
+
self._client = GeminiClient(system_message=SYSTEM_MESSAGE)
|
15 |
+
self._wiki_tool = WikiSearchAgent()
|
16 |
+
|
17 |
+
@staticmethod
|
18 |
+
def _format_prompt(query: str, wiki_page: str) -> str:
|
19 |
+
return f"\n###Question:{query} \
|
20 |
+
\n###Wikipedia Page:{wiki_page}"
|
21 |
+
|
22 |
+
def get_answer(self, query: str, use_context: bool = True) -> [str, str]:
|
23 |
+
if use_context:
|
24 |
+
wiki_page = self._wiki_tool.get_wikipedia_entry(query)
|
25 |
+
prompt = self._format_prompt(query, wiki_page)
|
26 |
+
else:
|
27 |
+
wiki_page = ""
|
28 |
+
prompt = query
|
29 |
+
return self._client.generate_text(prompt), wiki_page
|
llm/utils/__init__.py
ADDED
File without changes
|
llm/utils/__pycache__/__init__.cpython-312.pyc
ADDED
Binary file (146 Bytes). View file
|
|
llm/utils/__pycache__/wiki_client.cpython-312.pyc
ADDED
Binary file (1.17 kB). View file
|
|
llm/utils/wiki_client.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import wikipediaapi
|
2 |
+
|
3 |
+
|
4 |
+
class WikiClient:
|
5 |
+
def __init__(self):
|
6 |
+
self.wiki = wikipediaapi.Wikipedia(
|
7 |
+
user_agent="WikiAgent/0.0 [email protected]",
|
8 |
+
language="en",
|
9 |
+
extract_format=wikipediaapi.ExtractFormat.WIKI,
|
10 |
+
)
|
11 |
+
|
12 |
+
def get_pages(self, query):
|
13 |
+
pages = self.wiki.page(query)
|
14 |
+
return pages
|
15 |
+
|
16 |
+
|
17 |
+
if __name__ == "__main__":
|
18 |
+
client = WikiClient()
|
19 |
+
pages = client.get_pages("Cricket World Cup")
|
20 |
+
print(pages.summary)
|
llm/wiki_agent.py
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from llm.gemini_client import GeminiClient
|
2 |
+
|
3 |
+
from llm.utils.wiki_client import WikiClient
|
4 |
+
|
5 |
+
SYSTEM_MESSAGE = """you have access to a wikipedia summarizer that can return a summary for a topic. \
|
6 |
+
Your job is to act as a question answering tool. Whenever you are asked about a question related to knowledge, \
|
7 |
+
instead of using your internal knowledge (which can be faulty or out of date), \
|
8 |
+
format a Wikipedia search query string that can help answer the question. \
|
9 |
+
|
10 |
+
Remember Wikipedia Entries are usually about a simple entity or event, so keep the \
|
11 |
+
query short, and about the entity being asked about. Also, don't use your knowledge \
|
12 |
+
to ask about the answer. Instead form queries about the entity in the question. This \
|
13 |
+
will help you get the right wikipedia entries for questions when you dont know the answer
|
14 |
+
|
15 |
+
### Example 1:
|
16 |
+
|
17 |
+
Question: Who won the ICC Cricket World Cup?
|
18 |
+
Correct Response: Cricket World Cup
|
19 |
+
Incorrect response: Australia
|
20 |
+
|
21 |
+
### Example 2:
|
22 |
+
|
23 |
+
Question: Who directed the classic 30s western Stagecoach?
|
24 |
+
Response: Stagecoach
|
25 |
+
Incorrect response: John Ford
|
26 |
+
|
27 |
+
|
28 |
+
Below is the question. Return the wikipedia search query you would use \n
|
29 |
+
|
30 |
+
### Question:
|
31 |
+
"""
|
32 |
+
|
33 |
+
|
34 |
+
class WikiSearchAgent:
|
35 |
+
def __init__(self):
|
36 |
+
self._llm_client = GeminiClient(system_message=SYSTEM_MESSAGE)
|
37 |
+
self._wiki_client = WikiClient()
|
38 |
+
|
39 |
+
def get_wikipedia_entry(self, prompt: str) -> str:
|
40 |
+
|
41 |
+
wiki_search_query = self._llm_client.generate_text(prompt)
|
42 |
+
wikipedia_page = self._wiki_client.get_pages(wiki_search_query)
|
43 |
+
try:
|
44 |
+
return wikipedia_page.summary
|
45 |
+
except:
|
46 |
+
return ""
|
requirements.txt
ADDED
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
aiofiles==23.2.1
|
2 |
+
aiohttp==3.9.1
|
3 |
+
aiosignal==1.3.1
|
4 |
+
altair==5.2.0
|
5 |
+
annotated-types==0.6.0
|
6 |
+
anyio==4.2.0
|
7 |
+
attrs==23.1.0
|
8 |
+
beautifulsoup4==4.12.2
|
9 |
+
cachetools==5.3.2
|
10 |
+
certifi==2023.11.17
|
11 |
+
charset-normalizer==3.3.2
|
12 |
+
click==8.1.7
|
13 |
+
colorama==0.4.6
|
14 |
+
contourpy==1.2.0
|
15 |
+
cycler==0.12.1
|
16 |
+
datasets==2.15.0
|
17 |
+
dill==0.3.7
|
18 |
+
fastapi==0.109.0
|
19 |
+
ffmpy==0.3.1
|
20 |
+
filelock==3.13.1
|
21 |
+
fonttools==4.47.2
|
22 |
+
frozenlist==1.4.1
|
23 |
+
fsspec==2023.10.0
|
24 |
+
google-ai-generativelanguage==0.4.0
|
25 |
+
google-api-core==2.15.0
|
26 |
+
google-auth==2.25.2
|
27 |
+
google-generativeai==0.3.2
|
28 |
+
googleapis-common-protos==1.62.0
|
29 |
+
gradio==4.14.0
|
30 |
+
gradio_client==0.8.0
|
31 |
+
grpcio==1.60.0
|
32 |
+
grpcio-status==1.60.0
|
33 |
+
h11==0.14.0
|
34 |
+
httpcore==1.0.2
|
35 |
+
httpx==0.26.0
|
36 |
+
huggingface-hub==0.20.1
|
37 |
+
idna==3.6
|
38 |
+
importlib-resources==6.1.1
|
39 |
+
install==1.3.5
|
40 |
+
Jinja2==3.1.3
|
41 |
+
jsonschema==4.21.0
|
42 |
+
jsonschema-specifications==2023.12.1
|
43 |
+
kiwisolver==1.4.5
|
44 |
+
markdown-it-py==3.0.0
|
45 |
+
MarkupSafe==2.1.3
|
46 |
+
matplotlib==3.8.2
|
47 |
+
mdurl==0.1.2
|
48 |
+
multidict==6.0.4
|
49 |
+
multiprocess==0.70.15
|
50 |
+
numpy==1.26.2
|
51 |
+
orjson==3.9.10
|
52 |
+
packaging==23.2
|
53 |
+
pandas==2.1.4
|
54 |
+
pillow==10.2.0
|
55 |
+
proto-plus==1.23.0
|
56 |
+
protobuf==4.25.1
|
57 |
+
pyarrow==14.0.2
|
58 |
+
pyarrow-hotfix==0.6
|
59 |
+
pyasn1==0.5.1
|
60 |
+
pyasn1-modules==0.3.0
|
61 |
+
pydantic==2.5.3
|
62 |
+
pydantic_core==2.14.6
|
63 |
+
pydub==0.25.1
|
64 |
+
Pygments==2.17.2
|
65 |
+
pyparsing==3.1.1
|
66 |
+
python-dateutil==2.8.2
|
67 |
+
python-multipart==0.0.6
|
68 |
+
pytz==2023.3.post1
|
69 |
+
PyYAML==6.0.1
|
70 |
+
referencing==0.32.1
|
71 |
+
requests==2.31.0
|
72 |
+
rich==13.7.0
|
73 |
+
rpds-py==0.17.1
|
74 |
+
rsa==4.9
|
75 |
+
semantic-version==2.10.0
|
76 |
+
setuptools==68.2.2
|
77 |
+
shellingham==1.5.4
|
78 |
+
six==1.16.0
|
79 |
+
sniffio==1.3.0
|
80 |
+
soupsieve==2.5
|
81 |
+
starlette==0.35.1
|
82 |
+
tomlkit==0.12.0
|
83 |
+
toolz==0.12.0
|
84 |
+
tqdm==4.66.1
|
85 |
+
typer==0.9.0
|
86 |
+
typing_extensions==4.9.0
|
87 |
+
tzdata==2023.3
|
88 |
+
urllib3==2.1.0
|
89 |
+
uvicorn==0.26.0
|
90 |
+
websockets==11.0.3
|
91 |
+
wheel==0.41.2
|
92 |
+
wikipedia==1.4.0
|
93 |
+
Wikipedia-API==0.6.0
|
94 |
+
xxhash==3.4.1
|
95 |
+
yarl==1.9.4
|
triviaQA.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import datasets
|
2 |
+
|
3 |
+
from llm.qa_agent import QnAAgent
|
4 |
+
|
5 |
+
validation_dataset = datasets.load_dataset(
|
6 |
+
"trivia_qa", "rc", split="test"
|
7 |
+
) # remove [:5%] to run on full validation set
|
8 |
+
|
9 |
+
PUNCTUATION_SET_TO_EXCLUDE = set("".join(["‘", "’", "´", "`", ".", ",", "-", '"']))
|
10 |
+
|
11 |
+
qna_agent = QnAAgent()
|
12 |
+
|
13 |
+
|
14 |
+
def get_sub_answers(answers, begin=0, end=None):
|
15 |
+
return [" ".join(x.split(" ")[begin:end]) for x in answers if len(x.split(" ")) > 1]
|
16 |
+
|
17 |
+
|
18 |
+
def expand_to_aliases(given_answers, make_sub_answers=False):
|
19 |
+
if make_sub_answers:
|
20 |
+
# if answers are longer than one word, make sure a predictions is correct if it coresponds to the complete 1: or :-1 sub word
|
21 |
+
# *e.g.* if the correct answer contains a prefix such as "the", or "a"
|
22 |
+
given_answers = (
|
23 |
+
given_answers
|
24 |
+
+ get_sub_answers(given_answers, begin=1)
|
25 |
+
+ get_sub_answers(given_answers, end=-1)
|
26 |
+
)
|
27 |
+
answers = []
|
28 |
+
for answer in given_answers:
|
29 |
+
alias = answer.replace("_", " ").lower()
|
30 |
+
alias = "".join(
|
31 |
+
c if c not in PUNCTUATION_SET_TO_EXCLUDE else " " for c in alias
|
32 |
+
)
|
33 |
+
answers.append(" ".join(alias.split()).strip())
|
34 |
+
return set(answers)
|
35 |
+
|
36 |
+
|
37 |
+
def evaluate(example):
|
38 |
+
# get answer from QnA agent
|
39 |
+
answer_without_context = qna_agent.get_answer(example["question"], use_context=False)
|
40 |
+
answer_with_context = qna_agent.get_answer(example["question"], use_context=True)
|
41 |
+
|
42 |
+
example["output"] = answer_without_context
|
43 |
+
example["output_context"] = answer_with_context
|
44 |
+
|
45 |
+
example["targets"] = example["answer"]["aliases"]
|
46 |
+
answers = expand_to_aliases(example["targets"], make_sub_answers=True)
|
47 |
+
|
48 |
+
predictions = expand_to_aliases([example["output"]])
|
49 |
+
preditions_with_context = expand_to_aliases([example["output_context"]])
|
50 |
+
|
51 |
+
# if there is a common element, it's a match
|
52 |
+
example["match"] = len(list(answers & predictions)) > 0
|
53 |
+
example["match_context"] = len(list(answers & preditions_with_context)) > 0
|
54 |
+
|
55 |
+
return example
|
56 |
+
|
57 |
+
|
58 |
+
results = validation_dataset.map(evaluate)
|
59 |
+
|
60 |
+
print("Exact Match (EM) without context: {:.2f}".format(100 * sum(results['match'])/len(results)))
|
61 |
+
print("Exact Match (EM) with context: {:.2f}".format(100 * sum(results['match_context'])/len(results)))
|
62 |
+
|
utils.py
ADDED
File without changes
|