vonliechti commited on
Commit
dd5fe55
·
verified ·
1 Parent(s): c8e3129

Upload folder using huggingface_hub

Browse files
Files changed (3) hide show
  1. README.md +5 -2
  2. benchmarking.ipynb +404 -0
  3. semscore.py +167 -0
README.md CHANGED
@@ -63,6 +63,9 @@ TBD
63
 
64
  ## Acknowledgments
65
 
66
- * [MemGPT](https://github.com/cpacker/MemGPT)
 
 
67
  * [Stanford SQuAD](https://rajpurkar.github.io/SQuAD-explorer/)
68
- * [GPT-4](https://openai.com/gpt-4/)
 
 
63
 
64
  ## Acknowledgments
65
 
66
+ * [Agents 2.0](https://github.com/huggingface/transformers/tree/main/src/transformers/agents)
67
+ * [SemScore: Automated Evaluation of Instruction-Tuned LLMs based on Semantic Textual Similarity](https://arxiv.org/abs/2401.17072)
68
+ * [SemScore](https://huggingface.co/blog/g-ronimo/semscore)
69
  * [Stanford SQuAD](https://rajpurkar.github.io/SQuAD-explorer/)
70
+ * [llama 3.1](https://github.com/meta-llama/Meta-Llama)
71
+ * [Gradio](https://www.gradio.app/)
benchmarking.ipynb ADDED
@@ -0,0 +1,404 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": [
7
+ "## Load SQuAD data"
8
+ ]
9
+ },
10
+ {
11
+ "cell_type": "code",
12
+ "execution_count": 34,
13
+ "metadata": {},
14
+ "outputs": [],
15
+ "source": [
16
+ "import numpy as np\n",
17
+ "import json\n",
18
+ "import pandas as pd\n",
19
+ "\n",
20
+ "def display_text_df(df):\n",
21
+ " display(df.style.set_properties(**{'white-space': 'pre-wrap'}).set_table_styles(\n",
22
+ " [{'selector': 'th', 'props': [('text-align', 'left')]},\n",
23
+ " {'selector': 'td', 'props': [('text-align', 'left')]}\n",
24
+ " ]\n",
25
+ " ).hide())\n"
26
+ ]
27
+ },
28
+ {
29
+ "cell_type": "code",
30
+ "execution_count": 5,
31
+ "metadata": {},
32
+ "outputs": [],
33
+ "source": [
34
+ "from data import get_data\n",
35
+ "data = get_data(download=False)\n"
36
+ ]
37
+ },
38
+ {
39
+ "cell_type": "code",
40
+ "execution_count": 6,
41
+ "metadata": {},
42
+ "outputs": [
43
+ {
44
+ "data": {
45
+ "text/plain": [
46
+ "('To whom did the Virgin Mary allegedly appear in 1858 in Lourdes France?',\n",
47
+ " 'Saint Bernadette Soubirous')"
48
+ ]
49
+ },
50
+ "execution_count": 6,
51
+ "metadata": {},
52
+ "output_type": "execute_result"
53
+ }
54
+ ],
55
+ "source": [
56
+ "data.question_answer_pairs[0]"
57
+ ]
58
+ },
59
+ {
60
+ "cell_type": "code",
61
+ "execution_count": 35,
62
+ "metadata": {},
63
+ "outputs": [
64
+ {
65
+ "data": {
66
+ "text/html": [
67
+ "<style type=\"text/css\">\n",
68
+ "#T_fc111 th {\n",
69
+ " text-align: left;\n",
70
+ "}\n",
71
+ "#T_fc111 td {\n",
72
+ " text-align: left;\n",
73
+ "}\n",
74
+ "#T_fc111_row0_col0, #T_fc111_row0_col1, #T_fc111_row1_col0, #T_fc111_row1_col1, #T_fc111_row2_col0, #T_fc111_row2_col1, #T_fc111_row3_col0, #T_fc111_row3_col1, #T_fc111_row4_col0, #T_fc111_row4_col1, #T_fc111_row5_col0, #T_fc111_row5_col1, #T_fc111_row6_col0, #T_fc111_row6_col1, #T_fc111_row7_col0, #T_fc111_row7_col1, #T_fc111_row8_col0, #T_fc111_row8_col1, #T_fc111_row9_col0, #T_fc111_row9_col1 {\n",
75
+ " white-space: pre-wrap;\n",
76
+ "}\n",
77
+ "</style>\n",
78
+ "<table id=\"T_fc111\">\n",
79
+ " <thead>\n",
80
+ " <tr>\n",
81
+ " <th id=\"T_fc111_level0_col0\" class=\"col_heading level0 col0\" >Question</th>\n",
82
+ " <th id=\"T_fc111_level0_col1\" class=\"col_heading level0 col1\" >Answer</th>\n",
83
+ " </tr>\n",
84
+ " </thead>\n",
85
+ " <tbody>\n",
86
+ " <tr>\n",
87
+ " <td id=\"T_fc111_row0_col0\" class=\"data row0 col0\" >What year was the Banská Akadémia founded?</td>\n",
88
+ " <td id=\"T_fc111_row0_col1\" class=\"data row0 col1\" >1735</td>\n",
89
+ " </tr>\n",
90
+ " <tr>\n",
91
+ " <td id=\"T_fc111_row1_col0\" class=\"data row1 col0\" >What is another speed that can also be reported by the camera?</td>\n",
92
+ " <td id=\"T_fc111_row1_col1\" class=\"data row1 col1\" >SOS-based speed</td>\n",
93
+ " </tr>\n",
94
+ " <tr>\n",
95
+ " <td id=\"T_fc111_row2_col0\" class=\"data row2 col0\" >Where were the use of advanced materials and techniques on display in Sumer?</td>\n",
96
+ " <td id=\"T_fc111_row2_col1\" class=\"data row2 col1\" >Sumerian temples and palaces</td>\n",
97
+ " </tr>\n",
98
+ " <tr>\n",
99
+ " <td id=\"T_fc111_row3_col0\" class=\"data row3 col0\" >Who is elected every even numbered year?</td>\n",
100
+ " <td id=\"T_fc111_row3_col1\" class=\"data row3 col1\" >mayor</td>\n",
101
+ " </tr>\n",
102
+ " <tr>\n",
103
+ " <td id=\"T_fc111_row4_col0\" class=\"data row4 col0\" >What was the purpose of top secret ICBM committee?</td>\n",
104
+ " <td id=\"T_fc111_row4_col1\" class=\"data row4 col1\" >decide on the feasibility of building an ICBM large enough to carry a thermonuclear weapon</td>\n",
105
+ " </tr>\n",
106
+ " <tr>\n",
107
+ " <td id=\"T_fc111_row5_col0\" class=\"data row5 col0\" >What conferences became a requirement after Vatican II?</td>\n",
108
+ " <td id=\"T_fc111_row5_col1\" class=\"data row5 col1\" >National Bishop Conferences</td>\n",
109
+ " </tr>\n",
110
+ " <tr>\n",
111
+ " <td id=\"T_fc111_row6_col0\" class=\"data row6 col0\" >Who does M fight with?</td>\n",
112
+ " <td id=\"T_fc111_row6_col1\" class=\"data row6 col1\" >C</td>\n",
113
+ " </tr>\n",
114
+ " <tr>\n",
115
+ " <td id=\"T_fc111_row7_col0\" class=\"data row7 col0\" >How many species of fungi have been found on Antarctica?</td>\n",
116
+ " <td id=\"T_fc111_row7_col1\" class=\"data row7 col1\" >1150</td>\n",
117
+ " </tr>\n",
118
+ " <tr>\n",
119
+ " <td id=\"T_fc111_row8_col0\" class=\"data row8 col0\" >After losing the battle of Guilford Courthouse, Cornawallis moved his troops where?</td>\n",
120
+ " <td id=\"T_fc111_row8_col1\" class=\"data row8 col1\" >Virginia coastline</td>\n",
121
+ " </tr>\n",
122
+ " <tr>\n",
123
+ " <td id=\"T_fc111_row9_col0\" class=\"data row9 col0\" >What is the Olympic Torch made from?</td>\n",
124
+ " <td id=\"T_fc111_row9_col1\" class=\"data row9 col1\" >aluminum.</td>\n",
125
+ " </tr>\n",
126
+ " </tbody>\n",
127
+ "</table>\n"
128
+ ],
129
+ "text/plain": [
130
+ "<pandas.io.formats.style.Styler at 0x3afc43c80>"
131
+ ]
132
+ },
133
+ "metadata": {},
134
+ "output_type": "display_data"
135
+ }
136
+ ],
137
+ "source": [
138
+ "np.random.seed(42)\n",
139
+ "arr =np.array(data.question_answer_pairs)\n",
140
+ "n_samples = 10\n",
141
+ "indices = np.random.choice(len(arr), n_samples, replace=False)\n",
142
+ "random_sample = arr[indices]\n",
143
+ "# Display the questions and answers in the random sample as a dataframe\n",
144
+ "dfSample = pd.DataFrame(random_sample, columns=[\"Question\", \"Answer\"])\n",
145
+ "display_text_df(dfSample)"
146
+ ]
147
+ },
148
+ {
149
+ "cell_type": "markdown",
150
+ "metadata": {},
151
+ "source": [
152
+ "### Create the agent to be evaluated"
153
+ ]
154
+ },
155
+ {
156
+ "cell_type": "code",
157
+ "execution_count": 8,
158
+ "metadata": {},
159
+ "outputs": [],
160
+ "source": [
161
+ "from agent import get_agent\n",
162
+ "agent = get_agent()"
163
+ ]
164
+ },
165
+ {
166
+ "cell_type": "markdown",
167
+ "metadata": {},
168
+ "source": [
169
+ "### Run the agent on the random sample of questions"
170
+ ]
171
+ },
172
+ {
173
+ "cell_type": "code",
174
+ "execution_count": 36,
175
+ "metadata": {},
176
+ "outputs": [
177
+ {
178
+ "data": {
179
+ "application/vnd.jupyter.widget-view+json": {
180
+ "model_id": "4bce5a5c2449435dbd058ed938db2a91",
181
+ "version_major": 2,
182
+ "version_minor": 0
183
+ },
184
+ "text/plain": [
185
+ " 0%| | 0/10 [00:00<?, ?it/s]"
186
+ ]
187
+ },
188
+ "metadata": {},
189
+ "output_type": "display_data"
190
+ }
191
+ ],
192
+ "source": [
193
+ "from gradio import ChatMessage\n",
194
+ "from transformers.agents import agent_types\n",
195
+ "from tqdm.notebook import tqdm\n",
196
+ "import logging\n",
197
+ "\n",
198
+ "answers_ref, answers_pred = [], [] \n",
199
+ "\n",
200
+ "# Suppress logging from the agent, which can be quite verbose\n",
201
+ "agent.logger.setLevel(logging.CRITICAL)\n",
202
+ "\n",
203
+ "for question, answer in tqdm(random_sample):\n",
204
+ " class Output:\n",
205
+ " output: agent_types.AgentType | str = None\n",
206
+ "\n",
207
+ " prompt = question\n",
208
+ " answers_ref.append(answer)\n",
209
+ " final_answer = agent.run(prompt, stream=False, reset=True)\n",
210
+ " answers_pred.append(final_answer)"
211
+ ]
212
+ },
213
+ {
214
+ "cell_type": "markdown",
215
+ "metadata": {},
216
+ "source": [
217
+ "### Use semantic similarity to evaluate the agent's answers against the reference answers\n",
218
+ "\n",
219
+ "* One flaw of this approach is that it does not take into account the existence of multiple acceptable answers.\n",
220
+ "* It also does not benefit from having the context of the question. "
221
+ ]
222
+ },
223
+ {
224
+ "cell_type": "code",
225
+ "execution_count": 37,
226
+ "metadata": {},
227
+ "outputs": [],
228
+ "source": [
229
+ "from semscore import EmbeddingModelWrapper\n",
230
+ "from statistics import mean\n",
231
+ "\n",
232
+ "answers_ref = [str(answer) for answer in answers_ref]\n",
233
+ "answers_pred = [str(answer) for answer in answers_pred]\n",
234
+ "\n",
235
+ "em = EmbeddingModelWrapper()\n",
236
+ "similarities = em.get_similarities(\n",
237
+ " em.get_embeddings( answers_pred ),\n",
238
+ " em.get_embeddings( answers_ref ),\n",
239
+ ")"
240
+ ]
241
+ },
242
+ {
243
+ "cell_type": "code",
244
+ "execution_count": 39,
245
+ "metadata": {},
246
+ "outputs": [
247
+ {
248
+ "data": {
249
+ "text/html": [
250
+ "<style type=\"text/css\">\n",
251
+ "#T_67704 th {\n",
252
+ " text-align: left;\n",
253
+ "}\n",
254
+ "#T_67704 td {\n",
255
+ " text-align: left;\n",
256
+ "}\n",
257
+ "#T_67704_row0_col0, #T_67704_row0_col1, #T_67704_row0_col2, #T_67704_row0_col3, #T_67704_row1_col0, #T_67704_row1_col1, #T_67704_row1_col2, #T_67704_row1_col3, #T_67704_row2_col0, #T_67704_row2_col1, #T_67704_row2_col2, #T_67704_row2_col3, #T_67704_row3_col0, #T_67704_row3_col1, #T_67704_row3_col2, #T_67704_row3_col3, #T_67704_row4_col0, #T_67704_row4_col1, #T_67704_row4_col2, #T_67704_row4_col3, #T_67704_row5_col0, #T_67704_row5_col1, #T_67704_row5_col2, #T_67704_row5_col3, #T_67704_row6_col0, #T_67704_row6_col1, #T_67704_row6_col2, #T_67704_row6_col3, #T_67704_row7_col0, #T_67704_row7_col1, #T_67704_row7_col2, #T_67704_row7_col3, #T_67704_row8_col0, #T_67704_row8_col1, #T_67704_row8_col2, #T_67704_row8_col3, #T_67704_row9_col0, #T_67704_row9_col1, #T_67704_row9_col2, #T_67704_row9_col3 {\n",
258
+ " white-space: pre-wrap;\n",
259
+ "}\n",
260
+ "</style>\n",
261
+ "<table id=\"T_67704\">\n",
262
+ " <thead>\n",
263
+ " <tr>\n",
264
+ " <th id=\"T_67704_level0_col0\" class=\"col_heading level0 col0\" >Question</th>\n",
265
+ " <th id=\"T_67704_level0_col1\" class=\"col_heading level0 col1\" >Reference Answer</th>\n",
266
+ " <th id=\"T_67704_level0_col2\" class=\"col_heading level0 col2\" >Predicted Answer</th>\n",
267
+ " <th id=\"T_67704_level0_col3\" class=\"col_heading level0 col3\" >Similarity</th>\n",
268
+ " </tr>\n",
269
+ " </thead>\n",
270
+ " <tbody>\n",
271
+ " <tr>\n",
272
+ " <td id=\"T_67704_row0_col0\" class=\"data row0 col0\" >What year was the Banská Akadémia founded?</td>\n",
273
+ " <td id=\"T_67704_row0_col1\" class=\"data row0 col1\" >1735</td>\n",
274
+ " <td id=\"T_67704_row0_col2\" class=\"data row0 col2\" >1735</td>\n",
275
+ " <td id=\"T_67704_row0_col3\" class=\"data row0 col3\" >1.000000</td>\n",
276
+ " </tr>\n",
277
+ " <tr>\n",
278
+ " <td id=\"T_67704_row1_col0\" class=\"data row1 col0\" >What is another speed that can also be reported by the camera?</td>\n",
279
+ " <td id=\"T_67704_row1_col1\" class=\"data row1 col1\" >SOS-based speed</td>\n",
280
+ " <td id=\"T_67704_row1_col2\" class=\"data row1 col2\" >Average speed</td>\n",
281
+ " <td id=\"T_67704_row1_col3\" class=\"data row1 col3\" >0.433297</td>\n",
282
+ " </tr>\n",
283
+ " <tr>\n",
284
+ " <td id=\"T_67704_row2_col0\" class=\"data row2 col0\" >Where were the use of advanced materials and techniques on display in Sumer?</td>\n",
285
+ " <td id=\"T_67704_row2_col1\" class=\"data row2 col1\" >Sumerian temples and palaces</td>\n",
286
+ " <td id=\"T_67704_row2_col2\" class=\"data row2 col2\" >Based on the information provided, it appears that the Sumerians developed and displayed advanced materials and techniques such as metrology, writing, and astronomy throughout their city-states. The specific locations where these advanced materials and techniques were on display are not explicitly mentioned.\n",
287
+ "\n",
288
+ "However, considering the context of the question, I would argue that the city-states of Sumer itself is the most relevant answer. The city-states of Sumer were the hub of Sumerian civilization, culture, and innovation, and it was likely there that these advanced materials and techniques were developed, displayed, and showcased.\n",
289
+ "\n",
290
+ "Therefore, my final answer to the user request is:\n",
291
+ "\n",
292
+ "The city-states of Sumer</td>\n",
293
+ " <td id=\"T_67704_row2_col3\" class=\"data row2 col3\" >0.545807</td>\n",
294
+ " </tr>\n",
295
+ " <tr>\n",
296
+ " <td id=\"T_67704_row3_col0\" class=\"data row3 col0\" >Who is elected every even numbered year?</td>\n",
297
+ " <td id=\"T_67704_row3_col1\" class=\"data row3 col1\" >mayor</td>\n",
298
+ " <td id=\"T_67704_row3_col2\" class=\"data row3 col2\" >mayor</td>\n",
299
+ " <td id=\"T_67704_row3_col3\" class=\"data row3 col3\" >1.000000</td>\n",
300
+ " </tr>\n",
301
+ " <tr>\n",
302
+ " <td id=\"T_67704_row4_col0\" class=\"data row4 col0\" >What was the purpose of top secret ICBM committee?</td>\n",
303
+ " <td id=\"T_67704_row4_col1\" class=\"data row4 col1\" >decide on the feasibility of building an ICBM large enough to carry a thermonuclear weapon</td>\n",
304
+ " <td id=\"T_67704_row4_col2\" class=\"data row4 col2\" >decide on the feasibility of building an ICBM large enough to carry a thermonuclear weapon</td>\n",
305
+ " <td id=\"T_67704_row4_col3\" class=\"data row4 col3\" >1.000000</td>\n",
306
+ " </tr>\n",
307
+ " <tr>\n",
308
+ " <td id=\"T_67704_row5_col0\" class=\"data row5 col0\" >What conferences became a requirement after Vatican II?</td>\n",
309
+ " <td id=\"T_67704_row5_col1\" class=\"data row5 col1\" >National Bishop Conferences</td>\n",
310
+ " <td id=\"T_67704_row5_col2\" class=\"data row5 col2\" >['National Bishop Conferences']</td>\n",
311
+ " <td id=\"T_67704_row5_col3\" class=\"data row5 col3\" >0.937632</td>\n",
312
+ " </tr>\n",
313
+ " <tr>\n",
314
+ " <td id=\"T_67704_row6_col0\" class=\"data row6 col0\" >Who does M fight with?</td>\n",
315
+ " <td id=\"T_67704_row6_col1\" class=\"data row6 col1\" >C</td>\n",
316
+ " <td id=\"T_67704_row6_col2\" class=\"data row6 col2\" >C</td>\n",
317
+ " <td id=\"T_67704_row6_col3\" class=\"data row6 col3\" >1.000000</td>\n",
318
+ " </tr>\n",
319
+ " <tr>\n",
320
+ " <td id=\"T_67704_row7_col0\" class=\"data row7 col0\" >How many species of fungi have been found on Antarctica?</td>\n",
321
+ " <td id=\"T_67704_row7_col1\" class=\"data row7 col1\" >1150</td>\n",
322
+ " <td id=\"T_67704_row7_col2\" class=\"data row7 col2\" >Based on the output from the `squad_retriever` tool, I can see that there are two documents in the SQuAD dataset that answer the question \"How many species of fungi have been found on Antarctica?\".\n",
323
+ "\n",
324
+ "The first document states that about 1150 species of fungi have been recorded from Antarctica. The second document does not provide a different answer to this question.\n",
325
+ "\n",
326
+ "Therefore, my final answer is:\n",
327
+ "\n",
328
+ "There are approximately 1150 species of fungi that have been found on Antarctica.</td>\n",
329
+ " <td id=\"T_67704_row7_col3\" class=\"data row7 col3\" >-0.020657</td>\n",
330
+ " </tr>\n",
331
+ " <tr>\n",
332
+ " <td id=\"T_67704_row8_col0\" class=\"data row8 col0\" >After losing the battle of Guilford Courthouse, Cornawallis moved his troops where?</td>\n",
333
+ " <td id=\"T_67704_row8_col1\" class=\"data row8 col1\" >Virginia coastline</td>\n",
334
+ " <td id=\"T_67704_row8_col2\" class=\"data row8 col2\" >The Virginia coastline</td>\n",
335
+ " <td id=\"T_67704_row8_col3\" class=\"data row8 col3\" >0.948570</td>\n",
336
+ " </tr>\n",
337
+ " <tr>\n",
338
+ " <td id=\"T_67704_row9_col0\" class=\"data row9 col0\" >What is the Olympic Torch made from?</td>\n",
339
+ " <td id=\"T_67704_row9_col1\" class=\"data row9 col1\" >aluminum.</td>\n",
340
+ " <td id=\"T_67704_row9_col2\" class=\"data row9 col2\" >aluminum</td>\n",
341
+ " <td id=\"T_67704_row9_col3\" class=\"data row9 col3\" >0.973508</td>\n",
342
+ " </tr>\n",
343
+ " </tbody>\n",
344
+ "</table>\n"
345
+ ],
346
+ "text/plain": [
347
+ "<pandas.io.formats.style.Styler at 0x3b0db7320>"
348
+ ]
349
+ },
350
+ "metadata": {},
351
+ "output_type": "display_data"
352
+ },
353
+ {
354
+ "name": "stdout",
355
+ "output_type": "stream",
356
+ "text": [
357
+ "Mean similarity: 0.78\n"
358
+ ]
359
+ }
360
+ ],
361
+ "source": [
362
+ "import pandas as pd\n",
363
+ "questions = [question for question, _ in random_sample]\n",
364
+ "dfAnswers = pd.DataFrame(list(zip(questions, answers_ref, answers_pred)), columns=[\"Question\", \"Reference Answer\", \"Predicted Answer\"])\n",
365
+ "dfAnswers[\"Similarity\"] = similarities\n",
366
+ "display(dfAnswers.style.set_properties(**{'white-space': 'pre-wrap'}).set_table_styles(\n",
367
+ " [{'selector': 'th', 'props': [('text-align', 'left')]},\n",
368
+ " {'selector': 'td', 'props': [('text-align', 'left')]}\n",
369
+ " ]\n",
370
+ ").hide())\n",
371
+ "print(f\"Mean similarity: {round(mean(similarities), 2)}\")\n",
372
+ "\n"
373
+ ]
374
+ },
375
+ {
376
+ "cell_type": "code",
377
+ "execution_count": null,
378
+ "metadata": {},
379
+ "outputs": [],
380
+ "source": []
381
+ }
382
+ ],
383
+ "metadata": {
384
+ "kernelspec": {
385
+ "display_name": "aai520",
386
+ "language": "python",
387
+ "name": "python3"
388
+ },
389
+ "language_info": {
390
+ "codemirror_mode": {
391
+ "name": "ipython",
392
+ "version": 3
393
+ },
394
+ "file_extension": ".py",
395
+ "mimetype": "text/x-python",
396
+ "name": "python",
397
+ "nbconvert_exporter": "python",
398
+ "pygments_lexer": "ipython3",
399
+ "version": "3.12.5"
400
+ }
401
+ },
402
+ "nbformat": 4,
403
+ "nbformat_minor": 2
404
+ }
semscore.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoTokenizer, AutoModel
2
+ from accelerate import Accelerator
3
+ from accelerate.utils import gather_object
4
+ from tqdm import tqdm
5
+ import torch, gc
6
+ import torch.nn as nn
7
+
8
+ class EmbeddingModelWrapper():
9
+ DEFAULT_MODEL="sentence-transformers/all-mpnet-base-v2"
10
+
11
+ def __init__(self, model_path=None, bs=8):
12
+ if model_path is None: model_path = self.DEFAULT_MODEL
13
+ self.model, self.tokenizer = self.load_model(model_path)
14
+ self.bs = bs
15
+ self.cos = nn.CosineSimilarity(dim=1, eps=1e-6)
16
+
17
+ def load_model(self, model_path):
18
+ model = AutoModel.from_pretrained(
19
+ model_path,
20
+ ).to("mps")
21
+ model.eval()
22
+ tokenizer = AutoTokenizer.from_pretrained(
23
+ model_path,
24
+ )
25
+ return model, tokenizer
26
+
27
+ def emb_mean_pooling(self, model_output, attention_mask):
28
+ token_embeddings = model_output[0]
29
+ input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
30
+ return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
31
+
32
+ def get_embeddings(self, sentences):
33
+ embeddings=torch.tensor([],device="mps")
34
+
35
+ if self.bs is None:
36
+ batches=[sentences]
37
+ else:
38
+ batches = [sentences[i:i + self.bs] for i in range(0, len(sentences), self.bs)]
39
+
40
+ for sentences in batches:
41
+ encoded_input = self.tokenizer(sentences, padding=True, truncation=True, return_tensors='pt').to("mps")
42
+ with torch.no_grad():
43
+ model_output = self.model(**encoded_input)
44
+ batch_embeddings=self.emb_mean_pooling(model_output, encoded_input['attention_mask'])
45
+
46
+ embeddings=torch.cat( (embeddings, batch_embeddings), dim=0 )
47
+
48
+ return embeddings
49
+
50
+ def get_similarities(self, x, y=None):
51
+ if y is None:
52
+ num_samples=x.shape[0]
53
+ similarities = [[0 for i in range(num_samples)] for f in range(num_samples)]
54
+ for row in tqdm(range(num_samples)):
55
+ similarities[row][0:row+1]=self.cos(x[row].repeat(row+1,1), x[0:row+1]).tolist()
56
+ return similarities
57
+ else:
58
+ return self.cos(x,y).tolist()
59
+
60
+ class ModelPredictionGenerator:
61
+ def __init__(self, model, tokenizer, eval_dataset, use_accelerate=False, bs=8, generation_config=None):
62
+ self.model=model
63
+ self.tokenizer=tokenizer
64
+ self.bs=bs
65
+ self.eval_prompts=self.messages_to_prompts( eval_dataset )
66
+ self.use_accelerate=use_accelerate
67
+ self.accelerator = Accelerator()
68
+
69
+ assert tokenizer.eos_token_id is not None
70
+ assert tokenizer.chat_template is not None
71
+ if tokenizer.pad_token_id is None:
72
+ tokenizer.pad_token_id = tokenizer.eos_token_id
73
+
74
+ # llama-precise
75
+ if generation_config is None:
76
+ self.generation_config = {
77
+ "temperature": 0.7,
78
+ "top_p": 0.1,
79
+ "repetition_penalty": 1.18,
80
+ "top_k": 40,
81
+ "do_sample": True,
82
+ "max_new_tokens": 100,
83
+ "pad_token_id": tokenizer.pad_token_id
84
+ }
85
+ else:
86
+ self.generation_config = generation_config
87
+
88
+ def clear_cache(self):
89
+ torch.mps.empty_cache()
90
+ gc.collect()
91
+
92
+ def messages_to_prompts(self, ds):
93
+ prompts=[]
94
+ for conversation in ds["messages"]:
95
+ for i,msg in enumerate(conversation):
96
+ if msg["role"]=="user":
97
+ prompts.append(
98
+ dict (
99
+ # prompt: format current messages up to the current user message and add a generation prompt
100
+ prompt=self.tokenizer.apply_chat_template(conversation[:i+1], add_generation_prompt=True, tokenize=False),
101
+ answer_ref=conversation[i+1]["content"]
102
+ )
103
+ )
104
+ return prompts
105
+
106
+ def get_batches(self, dataset, batch_size):
107
+ return [dataset[i:i + batch_size] for i in range(0, len(dataset), batch_size)]
108
+
109
+ def tokenize_batch(self, batch):
110
+ pad_side=self.tokenizer.padding_side
111
+ self.tokenizer.padding_side="left" # left pad for inference
112
+
113
+ prompts=[ item["prompt"] for item in batch ]
114
+ prompts_tok=self.tokenizer(
115
+ prompts,
116
+ return_tensors="pt",
117
+ padding='longest',
118
+ truncation=True,
119
+ max_length=self.tokenizer.model_max_length,
120
+ return_length=True,
121
+ pad_to_multiple_of=8,
122
+ add_special_tokens=False
123
+ ).to(self.model.device)
124
+ self.tokenizer.padding_side=pad_side # restore orig. padding side
125
+
126
+ return prompts_tok
127
+
128
+ def generate_batch(self, batch_tok):
129
+ with torch.no_grad():
130
+ outputs_tok=self.model.generate(
131
+ input_ids=batch_tok["input_ids"],
132
+ attention_mask=batch_tok["attention_mask"],
133
+ **self.generation_config
134
+ ).to("cpu")
135
+ outputs=[
136
+ # cut prompt from output
137
+ self.tokenizer.decode(
138
+ outputs_tok[i][outputs_tok[i] != self.tokenizer.pad_token_id][batch_tok["length"][i]:],
139
+ spaces_between_special_tokens=False,
140
+ skip_special_tokens=True
141
+ ).strip()
142
+ for i,t in enumerate(outputs_tok) ]
143
+
144
+ return outputs
145
+
146
+ def run(self):
147
+ self.model.eval()
148
+ self.clear_cache()
149
+
150
+ if self.use_accelerate:
151
+ with self.accelerator.split_between_processes(list(range(len(self.eval_prompts)))) as eval_prompts_local_idcs:
152
+ eval_prompts_local = [self.eval_prompts[i] for i in eval_prompts_local_idcs]
153
+ else:
154
+ eval_prompts_local = self.eval_prompts
155
+
156
+ for batch in tqdm( self.get_batches(eval_prompts_local, self.bs) ):
157
+ batch_tok = self.tokenize_batch( batch )
158
+ answers = self.generate_batch( batch_tok )
159
+
160
+ for i in range(len(batch)):
161
+ batch[i]["answer_pred"]=answers[i]
162
+ batch[i]["GPU"]=self.accelerator.process_index
163
+
164
+ if self.use_accelerate:
165
+ return gather_object(eval_prompts_local)
166
+ else:
167
+ return eval_prompts_local