juancauma commited on
Commit
93126d2
·
1 Parent(s): 7be39b2

added the rank column

Browse files
Files changed (1) hide show
  1. app.py +249 -2
app.py CHANGED
@@ -1,3 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  def create_grouped_leaderboard(selected_mwoz, selected_tau_airline, selected_tau_retail, sort_state):
2
  """
3
  Create the aggregated leaderboard DataFrame.
@@ -66,16 +146,183 @@ def create_grouped_leaderboard(selected_mwoz, selected_tau_airline, selected_tau
66
  })
67
 
68
  df = pd.DataFrame(aggregated)
69
- # Define allowed sorting columns.
70
  allowed_sort_cols = ["Average Score", "Conversation Consistency", "Backend Consistency", "Policy Completeness"]
71
  sort_by = sort_state.get("sort_by") if sort_state else None
72
  ascending = sort_state.get("ascending") if sort_state else True
73
  if sort_by in allowed_sort_cols:
74
  df = df.sort_values(sort_by, ascending=ascending)
75
-
76
  # Reset the index so the new ranking will reflect the sorted order.
77
  df.reset_index(drop=True, inplace=True)
78
  # Insert the Rank column as the first column, numbering from 1.
79
  df.insert(0, "Rank", range(1, len(df) + 1))
80
 
81
  return df
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import pandas as pd
3
+ import json
4
+ import os
5
+
6
+ def strip_timestamp(name):
7
+ """Remove the timestamp portion from the model name."""
8
+ parts = name.split('-')
9
+ return '-'.join(parts[1:]) if len(parts) > 1 else name
10
+
11
+ # Static grouping mapping for the 10 general submissions.
12
+ GROUPS = [
13
+ {
14
+ "mwoz": "20250214_193236-o1",
15
+ "tau_airline": "20250215_115156-tau-o1-airline",
16
+ "tau_retail": "20250215_121147-tau-o1-retail"
17
+ },
18
+ {
19
+ "mwoz": "20250131_012338-llama405",
20
+ "tau_airline": "20250204_144222-tau-llama-405b-airline",
21
+ "tau_retail": "20250205_033820-tau-llama405b-retail"
22
+ },
23
+ {
24
+ "mwoz": "20250130_140218-4o",
25
+ "tau_airline": "20250131_152503-tau-4o-airline",
26
+ "tau_retail": "20250131_152422-tau-4o-retail"
27
+ },
28
+ {
29
+ "mwoz": "20250130_183030-claude",
30
+ "tau_airline": "20250205_030422-tau-sonnet-airline",
31
+ "tau_retail": "20250131_152807-tau-sonnet-retail"
32
+ },
33
+ {
34
+ "mwoz": "20250131_012449-llama70",
35
+ "tau_airline": "20250208_024344-tau-llama70b-airline",
36
+ "tau_retail": "20250208_030407-tau-llama70b-retail"
37
+ },
38
+ {
39
+ "mwoz": "20250131_013711-qwen72b",
40
+ "tau_airline": "20250202_112945-qwen72b-airline",
41
+ "tau_retail": "20250202_140527-qwen72b-retail"
42
+ },
43
+ {
44
+ "mwoz": "20250130_184905-mistrallarge",
45
+ "tau_airline": "20250205_024823-tau-mistrallarge-airline",
46
+ "tau_retail": "20250205_044403-tau-mistrallarge-retail"
47
+ },
48
+ {
49
+ "mwoz": "20250131_010143-o1mini",
50
+ "tau_airline": "20250214_180731-tau-o1-mini-airline",
51
+ "tau_retail": "20250214_142736-tau-o1-mini-retail"
52
+ },
53
+ {
54
+ "mwoz": "20250130_140439-4omini",
55
+ "tau_airline": "20250131_152226-tau-4o-mini-airline",
56
+ "tau_retail": "20250131_152338-tau-4o-mini-retail"
57
+ },
58
+ {
59
+ "mwoz": "20250130_145202-gpt35",
60
+ "tau_airline": "20250131_152708-tau-gpt35-airline",
61
+ "tau_retail": "20250131_152610-tau-gpt35-retail"
62
+ }
63
+ ]
64
+
65
+ def load_mwoz_results():
66
+ """Load mwoz results from data/mwoz_leaderboard_results.json."""
67
+ path = os.path.join("data", "mwoz_leaderboard_results.json")
68
+ if not os.path.exists(path):
69
+ return []
70
+ with open(path, "r") as f:
71
+ return json.load(f)
72
+
73
+ def load_tau_results():
74
+ """Load tau results from data/tau_leaderboard_results.json."""
75
+ path = os.path.join("data", "tau_leaderboard_results.json")
76
+ if not os.path.exists(path):
77
+ return []
78
+ with open(path, "r") as f:
79
+ return json.load(f)
80
+
81
  def create_grouped_leaderboard(selected_mwoz, selected_tau_airline, selected_tau_retail, sort_state):
82
  """
83
  Create the aggregated leaderboard DataFrame.
 
146
  })
147
 
148
  df = pd.DataFrame(aggregated)
149
+ # Sort if a valid column is provided.
150
  allowed_sort_cols = ["Average Score", "Conversation Consistency", "Backend Consistency", "Policy Completeness"]
151
  sort_by = sort_state.get("sort_by") if sort_state else None
152
  ascending = sort_state.get("ascending") if sort_state else True
153
  if sort_by in allowed_sort_cols:
154
  df = df.sort_values(sort_by, ascending=ascending)
155
+
156
  # Reset the index so the new ranking will reflect the sorted order.
157
  df.reset_index(drop=True, inplace=True)
158
  # Insert the Rank column as the first column, numbering from 1.
159
  df.insert(0, "Rank", range(1, len(df) + 1))
160
 
161
  return df
162
+
163
+ def update_sort_state(current_state, clicked_column):
164
+ """
165
+ Update the sort state based on the clicked column.
166
+ If the same column is clicked, toggle the sort order;
167
+ otherwise, switch to the new column with ascending order.
168
+ """
169
+ if current_state is None:
170
+ current_state = {"sort_by": clicked_column, "ascending": True}
171
+ else:
172
+ if current_state.get("sort_by") == clicked_column:
173
+ current_state["ascending"] = not current_state.get("ascending", True)
174
+ else:
175
+ current_state["sort_by"] = clicked_column
176
+ current_state["ascending"] = True
177
+ return current_state
178
+
179
+ def sort_by_avg(sort_state):
180
+ return update_sort_state(sort_state, "Average Score")
181
+
182
+ def sort_by_conv(sort_state):
183
+ return update_sort_state(sort_state, "Conversation Consistency")
184
+
185
+ def sort_by_backend(sort_state):
186
+ return update_sort_state(sort_state, "Backend Consistency")
187
+
188
+ def sort_by_policy(sort_state):
189
+ return update_sort_state(sort_state, "Policy Completeness")
190
+
191
+ def get_color_for_value(value, min_val, max_val):
192
+ """
193
+ Compute a color for a given value based on its normalized position.
194
+ Interpolates from red (lowest) to yellow (mid) to green (highest).
195
+ """
196
+ if max_val == min_val:
197
+ norm = 0.5
198
+ else:
199
+ norm = (value - min_val) / (max_val - min_val)
200
+ if norm < 0.5:
201
+ ratio = norm / 0.5
202
+ r = 255
203
+ g = int(255 * ratio)
204
+ b = 0
205
+ else:
206
+ ratio = (norm - 0.5) / 0.5
207
+ r = int(255 * (1 - ratio))
208
+ g = 255
209
+ b = 0
210
+ return f"#{r:02X}{g:02X}{b:02X}"
211
+
212
+ def generate_html_table(df):
213
+ """
214
+ Generate an HTML table from the DataFrame.
215
+ For each numeric column, apply a text color based on its relative value.
216
+ """
217
+ numeric_cols = ["Average Score", "Conversation Consistency", "Backend Consistency", "Policy Completeness"]
218
+ col_min = {}
219
+ col_max = {}
220
+ for col in numeric_cols:
221
+ col_min[col] = df[col].min() if not df.empty else 0
222
+ col_max[col] = df[col].max() if not df.empty else 0
223
+
224
+ html = "<table border='1' style='border-collapse: collapse; text-align: center; width: 100%;'>"
225
+ # Header row
226
+ html += "<tr>"
227
+ for col in df.columns:
228
+ html += f"<th style='padding: 8px;'>{col}</th>"
229
+ html += "</tr>"
230
+
231
+ # Data rows
232
+ for _, row in df.iterrows():
233
+ html += "<tr>"
234
+ for col in df.columns:
235
+ cell_value = row[col]
236
+ if col in numeric_cols:
237
+ color = get_color_for_value(cell_value, col_min[col], col_max[col])
238
+ # Now applying the color to the text (color property) instead of background.
239
+ html += f"<td style='padding: 8px; color: {color};'>{cell_value}</td>"
240
+ else:
241
+ html += f"<td style='padding: 8px;'>{cell_value}</td>"
242
+ html += "</tr>"
243
+ html += "</table>"
244
+ return html
245
+
246
+ def update_leaderboard(selected_mwoz, selected_tau_airline, selected_tau_retail, sort_state):
247
+ """
248
+ Update the leaderboard by creating the aggregated DataFrame and converting it to HTML.
249
+ """
250
+ df = create_grouped_leaderboard(selected_mwoz, selected_tau_airline, selected_tau_retail, sort_state)
251
+ html_table = generate_html_table(df)
252
+ return html_table
253
+
254
+ with gr.Blocks(title="TD-EVAL Leaderboard") as demo:
255
+ gr.Markdown("# 🏆 TD-EVAL Model Evaluation Leaderboard")
256
+ gr.Markdown("""
257
+ This leaderboard displays aggregated model performance across multiple evaluation metrics.
258
+
259
+ **Variants:**
260
+ - **mwoz:** Baseline variant.
261
+ - **tau-airline:** Airline specialty variant.
262
+ - **tau-retail:** Retail specialty variant.
263
+
264
+ Use the checkboxes below to select which variants to include. At least one variant must be active.
265
+ """)
266
+
267
+ with gr.Row():
268
+ cb_mwoz = gr.Checkbox(label="mwoz", value=True)
269
+ cb_tau_airline = gr.Checkbox(label="tau-airline", value=True)
270
+ cb_tau_retail = gr.Checkbox(label="tau-retail", value=True)
271
+
272
+ gr.Markdown("### Sort by (click a button to toggle ascending/descending):")
273
+ with gr.Row():
274
+ btn_avg = gr.Button("Average Score")
275
+ btn_conv = gr.Button("Conversation Consistency")
276
+ btn_backend = gr.Button("Backend Consistency")
277
+ btn_policy = gr.Button("Policy Completeness")
278
+
279
+ # Initialize sort state: default sort by Average Score descending.
280
+ sort_state = gr.State({"sort_by": "Average Score", "ascending": False})
281
+
282
+ leaderboard_display = gr.HTML(label="Aggregated Model Rankings")
283
+
284
+ refresh_btn = gr.Button("🔄 Refresh Leaderboard")
285
+
286
+ # Sort button events.
287
+ btn_avg.click(fn=sort_by_avg, inputs=[sort_state], outputs=[sort_state]).then(
288
+ fn=update_leaderboard,
289
+ inputs=[cb_mwoz, cb_tau_airline, cb_tau_retail, sort_state],
290
+ outputs=leaderboard_display
291
+ )
292
+ btn_conv.click(fn=sort_by_conv, inputs=[sort_state], outputs=[sort_state]).then(
293
+ fn=update_leaderboard,
294
+ inputs=[cb_mwoz, cb_tau_airline, cb_tau_retail, sort_state],
295
+ outputs=leaderboard_display
296
+ )
297
+ btn_backend.click(fn=sort_by_backend, inputs=[sort_state], outputs=[sort_state]).then(
298
+ fn=update_leaderboard,
299
+ inputs=[cb_mwoz, cb_tau_airline, cb_tau_retail, sort_state],
300
+ outputs=leaderboard_display
301
+ )
302
+ btn_policy.click(fn=sort_by_policy, inputs=[sort_state], outputs=[sort_state]).then(
303
+ fn=update_leaderboard,
304
+ inputs=[cb_mwoz, cb_tau_airline, cb_tau_retail, sort_state],
305
+ outputs=leaderboard_display
306
+ )
307
+
308
+ # Refresh button event.
309
+ refresh_btn.click(
310
+ fn=update_leaderboard,
311
+ inputs=[cb_mwoz, cb_tau_airline, cb_tau_retail, sort_state],
312
+ outputs=leaderboard_display
313
+ )
314
+
315
+ # Update leaderboard immediately when any checkbox changes.
316
+ cb_mwoz.change(fn=update_leaderboard, inputs=[cb_mwoz, cb_tau_airline, cb_tau_retail, sort_state], outputs=leaderboard_display)
317
+ cb_tau_airline.change(fn=update_leaderboard, inputs=[cb_mwoz, cb_tau_airline, cb_tau_retail, sort_state], outputs=leaderboard_display)
318
+ cb_tau_retail.change(fn=update_leaderboard, inputs=[cb_mwoz, cb_tau_airline, cb_tau_retail, sort_state], outputs=leaderboard_display)
319
+
320
+ # Load initial leaderboard on app start.
321
+ demo.load(
322
+ fn=update_leaderboard,
323
+ inputs=[cb_mwoz, cb_tau_airline, cb_tau_retail, sort_state],
324
+ outputs=leaderboard_display
325
+ )
326
+
327
+ if __name__ == "__main__":
328
+ demo.launch()