mohit-raghavendra commited on
Commit
3060e5b
·
verified ·
1 Parent(s): e753af9

Upload 25 files

Browse files
__init__.py ADDED
File without changes
app.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+
4
+ from llm.qa_agent import QnAAgent
5
+ from llm.calculator_agent import CalculatorAgent
6
+ from llm.orchestrator import Orchestrator
7
+
8
+
9
+
10
+ if __name__ == "__main__":
11
+ orchestrator = Orchestrator()
12
+ qna_agent = QnAAgent()
13
+ calculator_agent = CalculatorAgent()
14
+
15
+ # question = input("Question - ")
16
+ def get_answer(question:str) -> [str, str]:
17
+
18
+ api_name, parameters = orchestrator.get_API_call(question)
19
+
20
+ print(f"Using the {api_name} Agent")
21
+ print(api_name, parameters)
22
+ if api_name == "QnA":
23
+ answer, wiki_page = qna_agent.get_answer(parameters)
24
+
25
+ # elif api_name == "calculator":
26
+ # operand, op1, op2 = parameters.split(",")
27
+ # answer = calculator_agent.calculate(operand, op1, op2)
28
+
29
+ print(answer)
30
+ return [answer, wiki_page]
31
+
32
+ demo = gr.Interface(
33
+ fn=get_answer,
34
+ inputs=gr.Textbox(placeholder="Enter your question...[Who won the Cricket World Cup in 2023?]")
35
+ ,
36
+ # outputs=[gr.Textbox(label=f'Document {i+1}') for i in range(TOP_K)],
37
+ outputs=[gr.Textbox(label="Answer"), gr.Textbox(label="Wikipedia Page")],
38
+ title="Real time Question Answering",
39
+ )
40
+
41
+ demo.launch()
42
+
43
+
experiments/__init__.py ADDED
File without changes
experiments/__pycache__/qa_agent.cpython-312.pyc ADDED
Binary file (1.82 kB). View file
 
llm/__init__.py ADDED
File without changes
llm/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (140 Bytes). View file
 
llm/__pycache__/calculator_agent.cpython-312.pyc ADDED
Binary file (962 Bytes). View file
 
llm/__pycache__/gemini_client.cpython-312.pyc ADDED
Binary file (2.19 kB). View file
 
llm/__pycache__/orchestrator.cpython-312.pyc ADDED
Binary file (1.69 kB). View file
 
llm/__pycache__/qa_agent.cpython-312.pyc ADDED
Binary file (1.9 kB). View file
 
llm/__pycache__/qa_tool.cpython-312.pyc ADDED
Binary file (1.55 kB). View file
 
llm/__pycache__/wiki_agent.cpython-312.pyc ADDED
Binary file (2.28 kB). View file
 
llm/__pycache__/wiki_tool.cpython-312.pyc ADDED
Binary file (2.24 kB). View file
 
llm/calculator_agent.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ class CalculatorAgent:
2
+ def calculate(self, operation: str, x: str, y: str) -> float:
3
+ operation = operation.lower().strip()
4
+
5
+ x = float(x)
6
+ y = float(y)
7
+ if operation == "add":
8
+ return x + y
9
+ elif operation == "subtract":
10
+ return x - y
11
+ elif operation == "multiply":
12
+ return x * y
13
+ elif operation == "divide":
14
+ if y == 0:
15
+ return "Cannot divide by zero"
16
+ return x / y
17
+ else:
18
+ return "Unknown operation"
llm/gemini_client.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import google.generativeai as genai
4
+
5
+
6
+ class GeminiClient:
7
+ def __init__(self, system_message=None):
8
+ self._system_message = system_message
9
+ self._connect_client()
10
+
11
+ def _connect_client(self):
12
+ if not os.getenv("GOOGLE_PALM_KEY"):
13
+ raise Exception("Please set your Google MakerSuite API key")
14
+
15
+ api_key = os.getenv("GOOGLE_PALM_KEY")
16
+ genai.configure(api_key=api_key)
17
+
18
+ safety_settings = [
19
+ {"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_ONLY_HIGH"},
20
+ {"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_ONLY_HIGH"},
21
+ {
22
+ "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
23
+ "threshold": "BLOCK_ONLY_HIGH",
24
+ },
25
+ {
26
+ "category": "HARM_CATEGORY_DANGEROUS_CONTENT",
27
+ "threshold": "BLOCK_ONLY_HIGH",
28
+ },
29
+ ]
30
+
31
+ defaults = {
32
+ "temperature": 0.7,
33
+ "top_k": 40,
34
+ "top_p": 0.95,
35
+ "max_output_tokens": 1024,
36
+ }
37
+
38
+ self._model = genai.GenerativeModel(
39
+ model_name="gemini-pro",
40
+ generation_config=defaults,
41
+ safety_settings=safety_settings,
42
+ )
43
+
44
+ def generate_text(self, prompt: str) -> str:
45
+ full_prompt = self._system_message + prompt
46
+ try:
47
+ response = self._model.generate_content(full_prompt).text
48
+ except Exception as e:
49
+ print(f"Error: {e}")
50
+ response = ""
51
+ return response
llm/orchestrator.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from llm.gemini_client import GeminiClient
2
+
3
+ SYSTEM_MESSAGE = """You are an orchestrator that can knows what various tools
4
+ or agents can do and which is the right one to pick. Given a question, your
5
+ job is just to pick the right agent to use and the rest will be taken care of.
6
+
7
+ For now you can use a calculator agent that can help you do basic arithmetic
8
+ calculations. You can also use a question answering agent that can answer
9
+ questions about various topics.
10
+
11
+ The API's are:
12
+ calculator[operand 1, operand 2, operation]
13
+ QnA[question]
14
+
15
+ Here are some examples:
16
+
17
+ Example 1:
18
+ Question: What is 2 + 2?
19
+ Response: calculator$add, 2, 2
20
+
21
+ Example 2: Who designed the Eiffel Tower?
22
+ Respnse: QnA$Who designed the Eiffel Tower?
23
+
24
+ ### Question:
25
+ """
26
+
27
+
28
+ class Orchestrator:
29
+ def __init__(self):
30
+ self._client = GeminiClient(system_message=SYSTEM_MESSAGE)
31
+
32
+ def get_API_call(self, query: str) -> (str, str):
33
+ api_call = self._client.generate_text(query)
34
+ api_name, parameters = api_call.split("$")
35
+ return api_name, parameters
llm/qa_agent.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from llm.gemini_client import GeminiClient
2
+ from llm.wiki_agent import WikiSearchAgent
3
+
4
+ SYSTEM_MESSAGE = """You are a Question Answering tool that can answer various
5
+ trivia questions. However, you might be asked questions that is beyond your
6
+ knowledge or recent events that you might not be trained on
7
+ (beyond training cutoff). So, if there is Wikipedia page entry provided,
8
+ use that to answer the question. Just return the answer, don't make a
9
+ verbose response."""
10
+
11
+
12
+ class QnAAgent:
13
+ def __init__(self):
14
+ self._client = GeminiClient(system_message=SYSTEM_MESSAGE)
15
+ self._wiki_tool = WikiSearchAgent()
16
+
17
+ @staticmethod
18
+ def _format_prompt(query: str, wiki_page: str) -> str:
19
+ return f"\n###Question:{query} \
20
+ \n###Wikipedia Page:{wiki_page}"
21
+
22
+ def get_answer(self, query: str, use_context: bool = True) -> [str, str]:
23
+ if use_context:
24
+ wiki_page = self._wiki_tool.get_wikipedia_entry(query)
25
+ prompt = self._format_prompt(query, wiki_page)
26
+ else:
27
+ wiki_page = ""
28
+ prompt = query
29
+ return self._client.generate_text(prompt), wiki_page
llm/utils/__init__.py ADDED
File without changes
llm/utils/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (146 Bytes). View file
 
llm/utils/__pycache__/wiki_client.cpython-312.pyc ADDED
Binary file (1.17 kB). View file
 
llm/utils/wiki_client.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import wikipediaapi
2
+
3
+
4
+ class WikiClient:
5
+ def __init__(self):
6
+ self.wiki = wikipediaapi.Wikipedia(
7
+ user_agent="WikiAgent/0.0 [email protected]",
8
+ language="en",
9
+ extract_format=wikipediaapi.ExtractFormat.WIKI,
10
+ )
11
+
12
+ def get_pages(self, query):
13
+ pages = self.wiki.page(query)
14
+ return pages
15
+
16
+
17
+ if __name__ == "__main__":
18
+ client = WikiClient()
19
+ pages = client.get_pages("Cricket World Cup")
20
+ print(pages.summary)
llm/wiki_agent.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from llm.gemini_client import GeminiClient
2
+
3
+ from llm.utils.wiki_client import WikiClient
4
+
5
+ SYSTEM_MESSAGE = """you have access to a wikipedia summarizer that can return a summary for a topic. \
6
+ Your job is to act as a question answering tool. Whenever you are asked about a question related to knowledge, \
7
+ instead of using your internal knowledge (which can be faulty or out of date), \
8
+ format a Wikipedia search query string that can help answer the question. \
9
+
10
+ Remember Wikipedia Entries are usually about a simple entity or event, so keep the \
11
+ query short, and about the entity being asked about. Also, don't use your knowledge \
12
+ to ask about the answer. Instead form queries about the entity in the question. This \
13
+ will help you get the right wikipedia entries for questions when you dont know the answer
14
+
15
+ ### Example 1:
16
+
17
+ Question: Who won the ICC Cricket World Cup?
18
+ Correct Response: Cricket World Cup
19
+ Incorrect response: Australia
20
+
21
+ ### Example 2:
22
+
23
+ Question: Who directed the classic 30s western Stagecoach?
24
+ Response: Stagecoach
25
+ Incorrect response: John Ford
26
+
27
+
28
+ Below is the question. Return the wikipedia search query you would use \n
29
+
30
+ ### Question:
31
+ """
32
+
33
+
34
+ class WikiSearchAgent:
35
+ def __init__(self):
36
+ self._llm_client = GeminiClient(system_message=SYSTEM_MESSAGE)
37
+ self._wiki_client = WikiClient()
38
+
39
+ def get_wikipedia_entry(self, prompt: str) -> str:
40
+
41
+ wiki_search_query = self._llm_client.generate_text(prompt)
42
+ wikipedia_page = self._wiki_client.get_pages(wiki_search_query)
43
+ try:
44
+ return wikipedia_page.summary
45
+ except:
46
+ return ""
requirements.txt ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ aiofiles==23.2.1
2
+ aiohttp==3.9.1
3
+ aiosignal==1.3.1
4
+ altair==5.2.0
5
+ annotated-types==0.6.0
6
+ anyio==4.2.0
7
+ attrs==23.1.0
8
+ beautifulsoup4==4.12.2
9
+ cachetools==5.3.2
10
+ certifi==2023.11.17
11
+ charset-normalizer==3.3.2
12
+ click==8.1.7
13
+ colorama==0.4.6
14
+ contourpy==1.2.0
15
+ cycler==0.12.1
16
+ datasets==2.15.0
17
+ dill==0.3.7
18
+ fastapi==0.109.0
19
+ ffmpy==0.3.1
20
+ filelock==3.13.1
21
+ fonttools==4.47.2
22
+ frozenlist==1.4.1
23
+ fsspec==2023.10.0
24
+ google-ai-generativelanguage==0.4.0
25
+ google-api-core==2.15.0
26
+ google-auth==2.25.2
27
+ google-generativeai==0.3.2
28
+ googleapis-common-protos==1.62.0
29
+ gradio==4.14.0
30
+ gradio_client==0.8.0
31
+ grpcio==1.60.0
32
+ grpcio-status==1.60.0
33
+ h11==0.14.0
34
+ httpcore==1.0.2
35
+ httpx==0.26.0
36
+ huggingface-hub==0.20.1
37
+ idna==3.6
38
+ importlib-resources==6.1.1
39
+ install==1.3.5
40
+ Jinja2==3.1.3
41
+ jsonschema==4.21.0
42
+ jsonschema-specifications==2023.12.1
43
+ kiwisolver==1.4.5
44
+ markdown-it-py==3.0.0
45
+ MarkupSafe==2.1.3
46
+ matplotlib==3.8.2
47
+ mdurl==0.1.2
48
+ multidict==6.0.4
49
+ multiprocess==0.70.15
50
+ numpy==1.26.2
51
+ orjson==3.9.10
52
+ packaging==23.2
53
+ pandas==2.1.4
54
+ pillow==10.2.0
55
+ proto-plus==1.23.0
56
+ protobuf==4.25.1
57
+ pyarrow==14.0.2
58
+ pyarrow-hotfix==0.6
59
+ pyasn1==0.5.1
60
+ pyasn1-modules==0.3.0
61
+ pydantic==2.5.3
62
+ pydantic_core==2.14.6
63
+ pydub==0.25.1
64
+ Pygments==2.17.2
65
+ pyparsing==3.1.1
66
+ python-dateutil==2.8.2
67
+ python-multipart==0.0.6
68
+ pytz==2023.3.post1
69
+ PyYAML==6.0.1
70
+ referencing==0.32.1
71
+ requests==2.31.0
72
+ rich==13.7.0
73
+ rpds-py==0.17.1
74
+ rsa==4.9
75
+ semantic-version==2.10.0
76
+ setuptools==68.2.2
77
+ shellingham==1.5.4
78
+ six==1.16.0
79
+ sniffio==1.3.0
80
+ soupsieve==2.5
81
+ starlette==0.35.1
82
+ tomlkit==0.12.0
83
+ toolz==0.12.0
84
+ tqdm==4.66.1
85
+ typer==0.9.0
86
+ typing_extensions==4.9.0
87
+ tzdata==2023.3
88
+ urllib3==2.1.0
89
+ uvicorn==0.26.0
90
+ websockets==11.0.3
91
+ wheel==0.41.2
92
+ wikipedia==1.4.0
93
+ Wikipedia-API==0.6.0
94
+ xxhash==3.4.1
95
+ yarl==1.9.4
triviaQA.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datasets
2
+
3
+ from llm.qa_agent import QnAAgent
4
+
5
+ validation_dataset = datasets.load_dataset(
6
+ "trivia_qa", "rc", split="test"
7
+ ) # remove [:5%] to run on full validation set
8
+
9
+ PUNCTUATION_SET_TO_EXCLUDE = set("".join(["‘", "’", "´", "`", ".", ",", "-", '"']))
10
+
11
+ qna_agent = QnAAgent()
12
+
13
+
14
+ def get_sub_answers(answers, begin=0, end=None):
15
+ return [" ".join(x.split(" ")[begin:end]) for x in answers if len(x.split(" ")) > 1]
16
+
17
+
18
+ def expand_to_aliases(given_answers, make_sub_answers=False):
19
+ if make_sub_answers:
20
+ # if answers are longer than one word, make sure a predictions is correct if it coresponds to the complete 1: or :-1 sub word
21
+ # *e.g.* if the correct answer contains a prefix such as "the", or "a"
22
+ given_answers = (
23
+ given_answers
24
+ + get_sub_answers(given_answers, begin=1)
25
+ + get_sub_answers(given_answers, end=-1)
26
+ )
27
+ answers = []
28
+ for answer in given_answers:
29
+ alias = answer.replace("_", " ").lower()
30
+ alias = "".join(
31
+ c if c not in PUNCTUATION_SET_TO_EXCLUDE else " " for c in alias
32
+ )
33
+ answers.append(" ".join(alias.split()).strip())
34
+ return set(answers)
35
+
36
+
37
+ def evaluate(example):
38
+ # get answer from QnA agent
39
+ answer_without_context = qna_agent.get_answer(example["question"], use_context=False)
40
+ answer_with_context = qna_agent.get_answer(example["question"], use_context=True)
41
+
42
+ example["output"] = answer_without_context
43
+ example["output_context"] = answer_with_context
44
+
45
+ example["targets"] = example["answer"]["aliases"]
46
+ answers = expand_to_aliases(example["targets"], make_sub_answers=True)
47
+
48
+ predictions = expand_to_aliases([example["output"]])
49
+ preditions_with_context = expand_to_aliases([example["output_context"]])
50
+
51
+ # if there is a common element, it's a match
52
+ example["match"] = len(list(answers & predictions)) > 0
53
+ example["match_context"] = len(list(answers & preditions_with_context)) > 0
54
+
55
+ return example
56
+
57
+
58
+ results = validation_dataset.map(evaluate)
59
+
60
+ print("Exact Match (EM) without context: {:.2f}".format(100 * sum(results['match'])/len(results)))
61
+ print("Exact Match (EM) with context: {:.2f}".format(100 * sum(results['match_context'])/len(results)))
62
+
utils.py ADDED
File without changes