import gradio as gr import pandas as pd import json import os def strip_timestamp(name): """Remove the timestamp portion from the model name.""" parts = name.split('-') return '-'.join(parts[1:]) if len(parts) > 1 else name # Static grouping mapping for the 10 general submissions. GROUPS = [ { "mwoz": "20250214_193236-o1", "tau_airline": "20250215_115156-tau-o1-airline", "tau_retail": "20250215_121147-tau-o1-retail" }, { "mwoz": "20250131_012338-llama405", "tau_airline": "20250204_144222-tau-llama-405b-airline", "tau_retail": "20250205_033820-tau-llama405b-retail" }, { "mwoz": "20250130_140218-4o", "tau_airline": "20250131_152503-tau-4o-airline", "tau_retail": "20250131_152422-tau-4o-retail" }, { "mwoz": "20250130_183030-claude", "tau_airline": "20250205_030422-tau-sonnet-airline", "tau_retail": "20250131_152807-tau-sonnet-retail" }, { "mwoz": "20250131_012449-llama70", "tau_airline": "20250208_024344-tau-llama70b-airline", "tau_retail": "20250208_030407-tau-llama70b-retail" }, { "mwoz": "20250131_013711-qwen72b", "tau_airline": "20250202_112945-qwen72b-airline", "tau_retail": "20250202_140527-qwen72b-retail" }, { "mwoz": "20250130_184905-mistrallarge", "tau_airline": "20250205_024823-tau-mistrallarge-airline", "tau_retail": "20250205_044403-tau-mistrallarge-retail" }, { "mwoz": "20250131_010143-o1mini", "tau_airline": "20250214_180731-tau-o1-mini-airline", "tau_retail": "20250214_142736-tau-o1-mini-retail" }, { "mwoz": "20250130_140439-4omini", "tau_airline": "20250131_152226-tau-4o-mini-airline", "tau_retail": "20250131_152338-tau-4o-mini-retail" }, { "mwoz": "20250130_145202-gpt35", "tau_airline": "20250131_152708-tau-gpt35-airline", "tau_retail": "20250131_152610-tau-gpt35-retail" } ] def load_mwoz_results(): """Load mwoz results from data/mwoz_leaderboard_results.json.""" path = os.path.join("data", "mwoz_leaderboard_results.json") if not os.path.exists(path): return [] with open(path, "r") as f: return json.load(f) def load_tau_results(): """Load tau results from data/tau_leaderboard_results.json.""" path = os.path.join("data", "tau_leaderboard_results.json") if not os.path.exists(path): return [] with open(path, "r") as f: return json.load(f) def create_grouped_leaderboard(selected_mwoz, selected_tau_airline, selected_tau_retail, sort_state): """ Create the aggregated leaderboard DataFrame. Aggregates metrics based on the selected variants and sorts the DataFrame using sort_state. """ # Ensure at least one variant is active. if not (selected_mwoz or selected_tau_airline or selected_tau_retail): selected_mwoz = True mwoz_data = load_mwoz_results() tau_data = load_tau_results() mwoz_lookup = {entry["model_name"]: entry for entry in mwoz_data} tau_lookup = {entry["model_name"]: entry for entry in tau_data} aggregated = [] for group in GROUPS: metrics = {"avg_conv_consistency": 0, "avg_backend_consistency": 0, "avg_policy_completeness": 0} count = 0 title_parts = [] judge_model = "" if selected_mwoz: key = group["mwoz"] if key in mwoz_lookup: record = mwoz_lookup[key] metrics["avg_conv_consistency"] += record.get("avg_conv_consistency", 0) metrics["avg_backend_consistency"] += record.get("avg_backend_consistency", 0) metrics["avg_policy_completeness"] += record.get("avg_policy_completeness", 0) count += 1 title_parts.append(strip_timestamp(key)) judge_model = record.get("judge_model", "") if selected_tau_airline: key = group["tau_airline"] if key in tau_lookup: record = tau_lookup[key] metrics["avg_conv_consistency"] += record.get("avg_conv_consistency", 0) metrics["avg_backend_consistency"] += record.get("avg_backend_consistency", 0) metrics["avg_policy_completeness"] += record.get("avg_policy_completeness", 0) count += 1 title_parts.append(strip_timestamp(key)) judge_model = record.get("judge_model", "") if selected_tau_retail: key = group["tau_retail"] if key in tau_lookup: record = tau_lookup[key] metrics["avg_conv_consistency"] += record.get("avg_conv_consistency", 0) metrics["avg_backend_consistency"] += record.get("avg_backend_consistency", 0) metrics["avg_policy_completeness"] += record.get("avg_policy_completeness", 0) count += 1 title_parts.append(strip_timestamp(key)) judge_model = record.get("judge_model", "") if count > 0: avg_conv = metrics["avg_conv_consistency"] / count avg_backend = metrics["avg_backend_consistency"] / count avg_policy = metrics["avg_policy_completeness"] / count overall_avg = (avg_conv + avg_backend + avg_policy) / 3 else: avg_conv = avg_backend = avg_policy = overall_avg = 0 aggregated.append({ "Model": " / ".join(title_parts), "Average Score": round(overall_avg, 4), "Conversation Consistency": round(avg_conv, 4), "Backend Consistency": round(avg_backend, 4), "Policy Completeness": round(avg_policy, 4), "Judge Model": judge_model }) df = pd.DataFrame(aggregated) # Sort if a valid column is provided. allowed_sort_cols = ["Average Score", "Conversation Consistency", "Backend Consistency", "Policy Completeness"] sort_by = sort_state.get("sort_by") if sort_state else None ascending = sort_state.get("ascending") if sort_state else True if sort_by in allowed_sort_cols: df = df.sort_values(sort_by, ascending=ascending) return df def update_sort_state(current_state, clicked_column): """ Update the sort state based on the clicked column. If the same column is clicked, toggle the sort order; otherwise, switch to the new column with ascending order. """ if current_state is None: current_state = {"sort_by": clicked_column, "ascending": True} else: if current_state.get("sort_by") == clicked_column: current_state["ascending"] = not current_state.get("ascending", True) else: current_state["sort_by"] = clicked_column current_state["ascending"] = True return current_state def sort_by_avg(sort_state): return update_sort_state(sort_state, "Average Score") def sort_by_conv(sort_state): return update_sort_state(sort_state, "Conversation Consistency") def sort_by_backend(sort_state): return update_sort_state(sort_state, "Backend Consistency") def sort_by_policy(sort_state): return update_sort_state(sort_state, "Policy Completeness") def get_color_for_value(value, min_val, max_val): """ Compute a color for a given value based on its normalized position. Interpolates from red (lowest) to yellow (mid) to green (highest). """ if max_val == min_val: norm = 0.5 else: norm = (value - min_val) / (max_val - min_val) if norm < 0.5: ratio = norm / 0.5 r = 255 g = int(255 * ratio) b = 0 else: ratio = (norm - 0.5) / 0.5 r = int(255 * (1 - ratio)) g = 255 b = 0 return f"#{r:02X}{g:02X}{b:02X}" def generate_html_table(df): """ Generate an HTML table from the DataFrame. For each numeric column, apply a text color based on its relative value. """ numeric_cols = ["Average Score", "Conversation Consistency", "Backend Consistency", "Policy Completeness"] col_min = {} col_max = {} for col in numeric_cols: col_min[col] = df[col].min() if not df.empty else 0 col_max[col] = df[col].max() if not df.empty else 0 html = "
{col} | " html += "|
---|---|
{cell_value} | " else: html += f"{cell_value} | " html += "