import gradio as gr import pandas as pd import json import os from gradio.themes.utils import colors # Load external CSS from the file "styles.css" try: with open("styles.css", "r", encoding="utf-8") as f: custom_css = f.read() except UnicodeDecodeError: # Try with a different encoding if utf-8 fails with open("styles.css", "r", encoding="latin-1") as f: custom_css = f.read() # Add more specific selector for Gradio and add !important to improve the cascading additional_css = """ .gradio-container .checkbox-panel, div.gradio-container [class*="block"] .checkbox-panel { background-color: #27272A !important; } .gradio-container .search-panel, div.gradio-container [class*="block"] .search-panel { background-color: #27272A !important; } """ custom_css += additional_css # Create a custom theme with light colors for our panels class CustomTheme(gr.themes.Base): def __init__(self): super().__init__( primary_hue=colors.red, secondary_hue=colors.gray, neutral_hue=colors.gray, text_size=gr.themes.sizes.text_lg ) # Don't set any global background colors self.block_border_width = "0px" self.block_shadow = "none" # Add additional CSS for the new styles, being more specific custom_css += """ /* Only override specific panels by ID */ #checkbox-panel, #search-panel { background-color: #F0F0F0 !important; } /* Only affect immediate children of these specific panels */ #checkbox-panel > *, #search-panel > * { background-color: transparent !important; } /* Target checkbox inputs specifically */ #checkbox-panel input[type="checkbox"], #search-panel input[type="text"] { background-color: transparent !important; } """ def strip_timestamp(name): """Remove the timestamp portion from the model name.""" parts = name.split('-') return '-'.join(parts[1:]) if len(parts) > 1 else name # Static grouping mapping for the 10 general submissions. GROUPS = [ {"mwoz": "20250214_193236-o1", "tau_airline": "20250215_115156-tau-o1-airline", "tau_retail": "20250215_121147-tau-o1-retail"}, {"mwoz": "20250131_012338-llama405", "tau_airline": "20250204_144222-tau-llama-405b-airline", "tau_retail": "20250205_033820-tau-llama405b-retail"}, {"mwoz": "20250130_140218-4o", "tau_airline": "20250131_152503-tau-4o-airline", "tau_retail": "20250131_152422-tau-4o-retail"}, {"mwoz": "20250130_183030-claude", "tau_airline": "20250205_030422-tau-sonnet-airline", "tau_retail": "20250131_152807-tau-sonnet-retail"}, {"mwoz": "20250131_012449-llama70", "tau_airline": "20250208_024344-tau-llama70b-airline", "tau_retail": "20250208_030407-tau-llama70b-retail"}, {"mwoz": "20250131_013711-qwen72b", "tau_airline": "20250202_112945-qwen72b-airline", "tau_retail": "20250202_140527-qwen72b-retail"}, {"mwoz": "20250130_184905-mistrallarge", "tau_airline": "20250205_024823-tau-mistrallarge-airline", "tau_retail": "20250205_044403-tau-mistrallarge-retail"}, {"mwoz": "20250131_010143-o1mini", "tau_airline": "20250214_180731-tau-o1-mini-airline", "tau_retail": "20250214_142736-tau-o1-mini-retail"}, {"mwoz": "20250130_140439-4omini", "tau_airline": "20250131_152226-tau-4o-mini-airline", "tau_retail": "20250131_152338-tau-4o-mini-retail"}, {"mwoz": "20250130_145202-gpt35", "tau_airline": "20250131_152708-tau-gpt35-airline", "tau_retail": "20250131_152610-tau-gpt35-retail"} ] def load_mwoz_results(): path = os.path.join("data", "mwoz_leaderboard_results.json") if not os.path.exists(path): return [] with open(path, "r") as f: return json.load(f) def load_tau_results(): path = os.path.join("data", "tau_leaderboard_results.json") if not os.path.exists(path): return [] with open(path, "r") as f: return json.load(f) def create_grouped_leaderboard(selected_mwoz, selected_tau_airline, selected_tau_retail, sort_state, search_query=""): if not (selected_mwoz or selected_tau_airline or selected_tau_retail): selected_mwoz = True mwoz_data = load_mwoz_results() tau_data = load_tau_results() mwoz_lookup = {entry["model_name"]: entry for entry in mwoz_data} tau_lookup = {entry["model_name"]: entry for entry in tau_data} aggregated = [] for group in GROUPS: metrics = {"avg_conv_consistency": 0, "avg_backend_consistency": 0, "avg_policy_completeness": 0} count = 0 title_parts = [] judge_model = "" if selected_mwoz: key = group["mwoz"] if key in mwoz_lookup: record = mwoz_lookup[key] metrics["avg_conv_consistency"] += record.get("avg_conv_consistency", 0) metrics["avg_backend_consistency"] += record.get("avg_backend_consistency", 0) metrics["avg_policy_completeness"] += record.get("avg_policy_completeness", 0) count += 1 title_parts.append(strip_timestamp(key)) judge_model = record.get("judge_model", "") if selected_tau_airline: key = group["tau_airline"] if key in tau_lookup: record = tau_lookup[key] metrics["avg_conv_consistency"] += record.get("avg_conv_consistency", 0) metrics["avg_backend_consistency"] += record.get("avg_backend_consistency", 0) metrics["avg_policy_completeness"] += record.get("avg_policy_completeness", 0) count += 1 title_parts.append(strip_timestamp(key)) judge_model = record.get("judge_model", "") if selected_tau_retail: key = group["tau_retail"] if key in tau_lookup: record = tau_lookup[key] metrics["avg_conv_consistency"] += record.get("avg_conv_consistency", 0) metrics["avg_backend_consistency"] += record.get("avg_backend_consistency", 0) metrics["avg_policy_completeness"] += record.get("avg_policy_completeness", 0) count += 1 title_parts.append(strip_timestamp(key)) judge_model = record.get("judge_model", "") if count > 0: avg_conv = metrics["avg_conv_consistency"] / count avg_backend = metrics["avg_backend_consistency"] / count avg_policy = metrics["avg_policy_completeness"] / count overall_avg = (avg_conv + avg_backend + avg_policy) / 3 else: avg_conv = avg_backend = avg_policy = overall_avg = 0 model_name = " / ".join(title_parts) # Apply search filter if search_query and search_query.lower() not in model_name.lower(): continue aggregated.append({ "Model": model_name, "Average Score": round(overall_avg, 4), "Conversation Consistency": round(avg_conv, 4), "Backend Consistency": round(avg_backend, 4), "Policy Completeness": round(avg_policy, 4), "Judge Model": judge_model }) df = pd.DataFrame(aggregated) # If no results found after filtering if df.empty: return df df["Rank"] = df["Average Score"].rank(ascending=False, method="min").astype(int) allowed_sort_cols = ["Average Score", "Conversation Consistency", "Backend Consistency", "Policy Completeness"] # Handle sort_state safely if isinstance(sort_state, str): try: sort_state = json.loads(sort_state) except: sort_state = {"sort_by": "Average Score", "ascending": False} # Ensure sort_state is a dict if not isinstance(sort_state, dict): sort_state = {"sort_by": "Average Score", "ascending": False} sort_by = sort_state.get("sort_by", "Average Score") ascending = sort_state.get("ascending", False) if sort_by in allowed_sort_cols: df = df.sort_values(sort_by, ascending=ascending) else: # Default sort if column not found df = df.sort_values("Average Score", ascending=False) cols = df.columns.tolist() if "Rank" in cols: cols.insert(0, cols.pop(cols.index("Rank"))) df = df[cols] return df def update_sort_state(current_state, clicked_column): """ Update the sorting state based on the clicked column. Handles various input formats for current_state. """ # Default state if nothing valid is provided new_state = {"sort_by": clicked_column, "ascending": False} # Handle the case when current_state is a string (JSON) if isinstance(current_state, str): try: current_state = json.loads(current_state) except (json.JSONDecodeError, TypeError): # If we can't parse it, return the default state return new_state # If current_state is None or not a dict, return default if not isinstance(current_state, dict): return new_state # Now we're sure current_state is a dict # Check if it has the needed keys if "sort_by" in current_state: if current_state["sort_by"] == clicked_column: # Toggle direction for the same column return { "sort_by": clicked_column, "ascending": not current_state.get("ascending", False) } else: # New column, default to descending (false) return { "sort_by": clicked_column, "ascending": False } # If we got here, current_state doesn't have the right format return new_state def sort_by_avg(sort_state): return update_sort_state(sort_state, "Average Score") def sort_by_conv(sort_state): return update_sort_state(sort_state, "Conversation Consistency") def sort_by_backend(sort_state): return update_sort_state(sort_state, "Backend Consistency") def sort_by_policy(sort_state): return update_sort_state(sort_state, "Policy Completeness") def get_color_for_value(value, min_val, max_val): if max_val == min_val: norm = 0.5 else: norm = (value - min_val) / (max_val - min_val) if norm < 0.5: ratio = norm / 0.5 # Darker red for lower values r = 180 g = int(140 * ratio) b = 0 else: ratio = (norm - 0.5) / 0.5 # Darker green for higher values r = int(140 * (1 - ratio)) g = 140 b = 0 return f"#{r:02X}{g:02X}{b:02X}" def generate_html_table(df): if df.empty: return "
{col} | " html += "|
---|---|
{cell_value} | " else: html += f"{cell_value} | " html += "
Sorted by: {sort_col} {sort_dir}
Error in sorting. Using default sort: Average Score (descending)
Use the checkboxes below to select which variants to include. At least one variant must be active.