Spaces:
Sleeping
Sleeping
import gradio as gr | |
import pandas as pd | |
import json | |
import os | |
from gradio.themes.utils import colors | |
# Load external CSS from the file "styles.css" | |
try: | |
with open("styles.css", "r", encoding="utf-8") as f: | |
custom_css = f.read() | |
except UnicodeDecodeError: | |
# Try with a different encoding if utf-8 fails | |
with open("styles.css", "r", encoding="latin-1") as f: | |
custom_css = f.read() | |
# Add more specific selector for Gradio and add !important to improve the cascading | |
additional_css = """ | |
.gradio-container .checkbox-panel, | |
div.gradio-container [class*="block"] .checkbox-panel { | |
background-color: #27272A !important; | |
} | |
.gradio-container .search-panel, | |
div.gradio-container [class*="block"] .search-panel { | |
background-color: #27272A !important; | |
} | |
""" | |
custom_css += additional_css | |
# Create a custom theme with light colors for our panels | |
class CustomTheme(gr.themes.Base): | |
def __init__(self): | |
super().__init__( | |
primary_hue=colors.red, | |
secondary_hue=colors.gray, | |
neutral_hue=colors.gray, | |
text_size=gr.themes.sizes.text_lg | |
) | |
# Don't set any global background colors | |
self.block_border_width = "0px" | |
self.block_shadow = "none" | |
# Add additional CSS for the new styles, being more specific | |
custom_css += """ | |
/* Only override specific panels by ID */ | |
#checkbox-panel, | |
#search-panel { | |
background-color: #F0F0F0 !important; | |
} | |
/* Only affect immediate children of these specific panels */ | |
#checkbox-panel > *, | |
#search-panel > * { | |
background-color: transparent !important; | |
} | |
/* Target checkbox inputs specifically */ | |
#checkbox-panel input[type="checkbox"], | |
#search-panel input[type="text"] { | |
background-color: transparent !important; | |
} | |
""" | |
def strip_timestamp(name): | |
"""Remove the timestamp portion from the model name.""" | |
parts = name.split('-') | |
return '-'.join(parts[1:]) if len(parts) > 1 else name | |
# Static grouping mapping for the 10 general submissions. | |
GROUPS = [ | |
{"mwoz": "20250214_193236-o1", "tau_airline": "20250215_115156-tau-o1-airline", "tau_retail": "20250215_121147-tau-o1-retail"}, | |
{"mwoz": "20250131_012338-llama405", "tau_airline": "20250204_144222-tau-llama-405b-airline", "tau_retail": "20250205_033820-tau-llama405b-retail"}, | |
{"mwoz": "20250130_140218-4o", "tau_airline": "20250131_152503-tau-4o-airline", "tau_retail": "20250131_152422-tau-4o-retail"}, | |
{"mwoz": "20250130_183030-claude", "tau_airline": "20250205_030422-tau-sonnet-airline", "tau_retail": "20250131_152807-tau-sonnet-retail"}, | |
{"mwoz": "20250131_012449-llama70", "tau_airline": "20250208_024344-tau-llama70b-airline", "tau_retail": "20250208_030407-tau-llama70b-retail"}, | |
{"mwoz": "20250131_013711-qwen72b", "tau_airline": "20250202_112945-qwen72b-airline", "tau_retail": "20250202_140527-qwen72b-retail"}, | |
{"mwoz": "20250130_184905-mistrallarge", "tau_airline": "20250205_024823-tau-mistrallarge-airline", "tau_retail": "20250205_044403-tau-mistrallarge-retail"}, | |
{"mwoz": "20250131_010143-o1mini", "tau_airline": "20250214_180731-tau-o1-mini-airline", "tau_retail": "20250214_142736-tau-o1-mini-retail"}, | |
{"mwoz": "20250130_140439-4omini", "tau_airline": "20250131_152226-tau-4o-mini-airline", "tau_retail": "20250131_152338-tau-4o-mini-retail"}, | |
{"mwoz": "20250130_145202-gpt35", "tau_airline": "20250131_152708-tau-gpt35-airline", "tau_retail": "20250131_152610-tau-gpt35-retail"} | |
] | |
def load_mwoz_results(): | |
path = os.path.join("data", "mwoz_leaderboard_results.json") | |
if not os.path.exists(path): | |
return [] | |
with open(path, "r") as f: | |
return json.load(f) | |
def load_tau_results(): | |
path = os.path.join("data", "tau_leaderboard_results.json") | |
if not os.path.exists(path): | |
return [] | |
with open(path, "r") as f: | |
return json.load(f) | |
def create_grouped_leaderboard(selected_mwoz, selected_tau_airline, selected_tau_retail, sort_state, search_query=""): | |
if not (selected_mwoz or selected_tau_airline or selected_tau_retail): | |
selected_mwoz = True | |
mwoz_data = load_mwoz_results() | |
tau_data = load_tau_results() | |
mwoz_lookup = {entry["model_name"]: entry for entry in mwoz_data} | |
tau_lookup = {entry["model_name"]: entry for entry in tau_data} | |
aggregated = [] | |
for group in GROUPS: | |
metrics = {"avg_conv_consistency": 0, "avg_backend_consistency": 0, "avg_policy_completeness": 0} | |
count = 0 | |
title_parts = [] | |
judge_model = "" | |
if selected_mwoz: | |
key = group["mwoz"] | |
if key in mwoz_lookup: | |
record = mwoz_lookup[key] | |
metrics["avg_conv_consistency"] += record.get("avg_conv_consistency", 0) | |
metrics["avg_backend_consistency"] += record.get("avg_backend_consistency", 0) | |
metrics["avg_policy_completeness"] += record.get("avg_policy_completeness", 0) | |
count += 1 | |
title_parts.append(strip_timestamp(key)) | |
judge_model = record.get("judge_model", "") | |
if selected_tau_airline: | |
key = group["tau_airline"] | |
if key in tau_lookup: | |
record = tau_lookup[key] | |
metrics["avg_conv_consistency"] += record.get("avg_conv_consistency", 0) | |
metrics["avg_backend_consistency"] += record.get("avg_backend_consistency", 0) | |
metrics["avg_policy_completeness"] += record.get("avg_policy_completeness", 0) | |
count += 1 | |
title_parts.append(strip_timestamp(key)) | |
judge_model = record.get("judge_model", "") | |
if selected_tau_retail: | |
key = group["tau_retail"] | |
if key in tau_lookup: | |
record = tau_lookup[key] | |
metrics["avg_conv_consistency"] += record.get("avg_conv_consistency", 0) | |
metrics["avg_backend_consistency"] += record.get("avg_backend_consistency", 0) | |
metrics["avg_policy_completeness"] += record.get("avg_policy_completeness", 0) | |
count += 1 | |
title_parts.append(strip_timestamp(key)) | |
judge_model = record.get("judge_model", "") | |
if count > 0: | |
avg_conv = metrics["avg_conv_consistency"] / count | |
avg_backend = metrics["avg_backend_consistency"] / count | |
avg_policy = metrics["avg_policy_completeness"] / count | |
overall_avg = (avg_conv + avg_backend + avg_policy) / 3 | |
else: | |
avg_conv = avg_backend = avg_policy = overall_avg = 0 | |
model_name = " / ".join(title_parts) | |
# Apply search filter | |
if search_query and search_query.lower() not in model_name.lower(): | |
continue | |
aggregated.append({ | |
"Model": model_name, | |
"Average Score": round(overall_avg, 4), | |
"Conversation Consistency": round(avg_conv, 4), | |
"Backend Consistency": round(avg_backend, 4), | |
"Policy Completeness": round(avg_policy, 4), | |
"Judge Model": judge_model | |
}) | |
df = pd.DataFrame(aggregated) | |
# If no results found after filtering | |
if df.empty: | |
return df | |
df["Rank"] = df["Average Score"].rank(ascending=False, method="min").astype(int) | |
allowed_sort_cols = ["Average Score", "Conversation Consistency", "Backend Consistency", "Policy Completeness"] | |
# Handle sort_state safely | |
if isinstance(sort_state, str): | |
try: | |
sort_state = json.loads(sort_state) | |
except: | |
sort_state = {"sort_by": "Average Score", "ascending": False} | |
# Ensure sort_state is a dict | |
if not isinstance(sort_state, dict): | |
sort_state = {"sort_by": "Average Score", "ascending": False} | |
sort_by = sort_state.get("sort_by", "Average Score") | |
ascending = sort_state.get("ascending", False) | |
if sort_by in allowed_sort_cols: | |
df = df.sort_values(sort_by, ascending=ascending) | |
else: | |
# Default sort if column not found | |
df = df.sort_values("Average Score", ascending=False) | |
cols = df.columns.tolist() | |
if "Rank" in cols: | |
cols.insert(0, cols.pop(cols.index("Rank"))) | |
df = df[cols] | |
return df | |
def update_sort_state(current_state, clicked_column): | |
""" | |
Update the sorting state based on the clicked column. | |
Handles various input formats for current_state. | |
""" | |
# Default state if nothing valid is provided | |
new_state = {"sort_by": clicked_column, "ascending": False} | |
# Handle the case when current_state is a string (JSON) | |
if isinstance(current_state, str): | |
try: | |
current_state = json.loads(current_state) | |
except (json.JSONDecodeError, TypeError): | |
# If we can't parse it, return the default state | |
return new_state | |
# If current_state is None or not a dict, return default | |
if not isinstance(current_state, dict): | |
return new_state | |
# Now we're sure current_state is a dict | |
# Check if it has the needed keys | |
if "sort_by" in current_state: | |
if current_state["sort_by"] == clicked_column: | |
# Toggle direction for the same column | |
return { | |
"sort_by": clicked_column, | |
"ascending": not current_state.get("ascending", False) | |
} | |
else: | |
# New column, default to descending (false) | |
return { | |
"sort_by": clicked_column, | |
"ascending": False | |
} | |
# If we got here, current_state doesn't have the right format | |
return new_state | |
def sort_by_avg(sort_state): | |
return update_sort_state(sort_state, "Average Score") | |
def sort_by_conv(sort_state): | |
return update_sort_state(sort_state, "Conversation Consistency") | |
def sort_by_backend(sort_state): | |
return update_sort_state(sort_state, "Backend Consistency") | |
def sort_by_policy(sort_state): | |
return update_sort_state(sort_state, "Policy Completeness") | |
def get_color_for_value(value, min_val, max_val): | |
if max_val == min_val: | |
norm = 0.5 | |
else: | |
norm = (value - min_val) / (max_val - min_val) | |
if norm < 0.5: | |
ratio = norm / 0.5 | |
# Darker red for lower values | |
r = 180 | |
g = int(140 * ratio) | |
b = 0 | |
else: | |
ratio = (norm - 0.5) / 0.5 | |
# Darker green for higher values | |
r = int(140 * (1 - ratio)) | |
g = 140 | |
b = 0 | |
return f"#{r:02X}{g:02X}{b:02X}" | |
def generate_html_table(df): | |
if df.empty: | |
return "<div class='no-results'>No matching results found.</div>" | |
numeric_cols = ["Average Score", "Conversation Consistency", "Backend Consistency", "Policy Completeness"] | |
col_min = {} | |
col_max = {} | |
for col in numeric_cols: | |
col_min[col] = df[col].min() if not df.empty else 0 | |
col_max[col] = df[col].max() if not df.empty else 0 | |
# Build a simple HTML table without borders or JavaScript sorting | |
html = "<table style='border: none; border-collapse: collapse;'>" | |
# Header row | |
html += "<tr>" | |
for col in df.columns: | |
html += f"<th style='padding:8px; border: none;'>{col}</th>" | |
html += "</tr>" | |
# Table rows | |
for _, row in df.iterrows(): | |
html += "<tr style='border: none;'>" | |
for col in df.columns: | |
cell_value = row[col] | |
if col in numeric_cols: | |
color = get_color_for_value(cell_value, col_min[col], col_max[col]) | |
html += f"<td style='padding: 8px; border: none; color: {color}; font-weight: bold;'>{cell_value}</td>" | |
else: | |
html += f"<td style='padding: 8px; border: none;'>{cell_value}</td>" | |
html += "</tr>" | |
html += "</table>" | |
return html | |
def update_leaderboard(selected_mwoz, selected_tau_airline, selected_tau_retail, sort_state, search_query=""): | |
""" | |
Update the leaderboard based on selection and sort state. | |
""" | |
try: | |
# Convert sort_state to dict if it's a string | |
if isinstance(sort_state, str): | |
try: | |
sort_state = json.loads(sort_state) | |
except: | |
# If JSON parsing fails, create a default state | |
sort_state = {"sort_by": "Average Score", "ascending": False} | |
# Ensure sort_state is a dict | |
if not isinstance(sort_state, dict): | |
sort_state = {"sort_by": "Average Score", "ascending": False} | |
# Generate the data and table | |
df = create_grouped_leaderboard(selected_mwoz, selected_tau_airline, selected_tau_retail, sort_state, search_query) | |
html_table = generate_html_table(df) | |
# Get sort info with fallbacks | |
sort_col = sort_state.get("sort_by", "Average Score") | |
sort_dir = "▼" if not sort_state.get("ascending", False) else "▲" | |
html_output = f""" | |
<div class="sort-info"> | |
<p>Sorted by: {sort_col} {sort_dir}</p> | |
</div> | |
{html_table} | |
""" | |
return html_output | |
except Exception as e: | |
# If anything goes wrong, return a basic table with an error message | |
print(f"Error in update_leaderboard: {str(e)}") | |
df = create_grouped_leaderboard(selected_mwoz, selected_tau_airline, selected_tau_retail, | |
{"sort_by": "Average Score", "ascending": False}) | |
html_table = generate_html_table(df) | |
return f""" | |
<div class="sort-info" style="color: #ff6b6b;"> | |
<p>Error in sorting. Using default sort: Average Score (descending)</p> | |
</div> | |
{html_table} | |
""" | |
# Create our custom theme instance | |
custom_theme = CustomTheme() | |
with gr.Blocks(css=custom_css, title="TD-EVAL Leaderboard", theme=custom_theme) as demo: | |
gr.Markdown("# 🏆 TD-EVAL Model Evaluation Leaderboard") | |
gr.HTML('<div class="subtitle">This leaderboard displays aggregated model performance across multiple evaluation metrics.</div>') | |
# Add JavaScript to ensure backgrounds are properly set | |
gr.HTML(""" | |
<script> | |
// Function to fix background colors | |
function fixBackgrounds() { | |
// Add a style tag to force all block-info spans to be black | |
var styleEl = document.createElement('style'); | |
styleEl.textContent = ` | |
span[data-testid="block-info"] { color: #000000 !important; } | |
.svelte-1gfkn6j { color: #000000 !important; } | |
.search-panel label, | |
.search-panel .label-wrap, | |
.search-panel span, | |
#search-panel span, | |
div[id="search-panel"] span { color: #000000 !important; } | |
`; | |
document.head.appendChild(styleEl); | |
// Only fix specific panels by ID | |
var checkboxPanel = document.getElementById('checkbox-panel'); | |
if (checkboxPanel) { | |
checkboxPanel.style.backgroundColor = '#F0F0F0'; | |
// Only make checkboxes and their direct containers transparent | |
var checkboxes = checkboxPanel.querySelectorAll('input[type="checkbox"]'); | |
checkboxes.forEach(function(checkbox) { | |
var parent = checkbox.parentElement; | |
if (parent) parent.style.backgroundColor = 'transparent'; | |
checkbox.style.backgroundColor = 'transparent'; | |
// Find and style the associated label to be black | |
var label = checkbox.nextElementSibling; | |
if (label && label.tagName === 'LABEL') { | |
label.style.color = '#000000'; | |
} | |
// Also find any span elements that might contain the label text | |
var spans = parent.querySelectorAll('span'); | |
spans.forEach(function(span) { | |
span.style.color = '#000000'; | |
}); | |
// Find label elements in the parent container | |
var labels = parent.querySelectorAll('label'); | |
labels.forEach(function(label) { | |
label.style.color = '#000000'; | |
}); | |
// Apply custom styling for the checkbox to show orange checkmark | |
if (checkbox.checked) { | |
checkbox.style.position = 'relative'; | |
checkbox.style.appearance = 'none'; | |
checkbox.style.backgroundColor = '#F0F0F0'; | |
checkbox.style.border = '1px solid #CCCCCC'; | |
checkbox.style.borderRadius = '3px'; | |
// Create or update the checkmark element | |
var checkmark = checkbox.querySelector('.orange-checkmark'); | |
if (!checkmark) { | |
checkmark = document.createElement('span'); | |
checkmark.className = 'orange-checkmark'; | |
checkmark.style.position = 'absolute'; | |
checkmark.style.left = '50%'; | |
checkmark.style.top = '50%'; | |
checkmark.style.transform = 'translate(-50%, -50%)'; | |
checkmark.style.color = '#c34700'; | |
checkmark.style.fontSize = '14px'; | |
checkmark.style.fontWeight = 'bold'; | |
checkmark.innerText = '✓'; | |
checkbox.appendChild(checkmark); | |
} | |
} | |
}); | |
} | |
var searchPanel = document.getElementById('search-panel'); | |
if (searchPanel) { | |
searchPanel.style.backgroundColor = '#F0F0F0'; | |
// Only make search input and its direct container transparent | |
var searchInput = searchPanel.querySelector('input[type="text"]'); | |
if (searchInput) { | |
var parent = searchInput.parentElement; | |
if (parent) parent.style.backgroundColor = 'transparent'; | |
searchInput.style.backgroundColor = '#FFFFFF'; | |
// Ensure the border is visible and matches text color | |
searchInput.style.border = '2px solid #000000'; | |
searchInput.style.color = '#000000'; | |
} | |
// Make sure the label is black | |
var searchLabels = searchPanel.querySelectorAll('label, .label-wrap, .label-wrap span'); | |
searchLabels.forEach(function(label) { | |
label.style.color = '#000000'; | |
}); | |
// Target the specific span element that contains the label text | |
var blockInfoSpans = document.querySelectorAll('span[data-testid="block-info"]'); | |
blockInfoSpans.forEach(function(span) { | |
span.style.color = '#000000'; | |
}); | |
// Also target elements with the svelte class | |
var svelteElements = document.querySelectorAll('.svelte-1gfkn6j'); | |
svelteElements.forEach(function(element) { | |
if (element.textContent.includes('Search models')) { | |
element.style.color = '#000000'; | |
} | |
}); | |
} | |
} | |
// Run on page load and every second for 3 seconds to catch any delayed rendering | |
setTimeout(fixBackgrounds, 500); | |
setTimeout(fixBackgrounds, 1000); | |
setTimeout(fixBackgrounds, 2000); | |
</script> | |
""") | |
gr.HTML(''' | |
<div class="variants_container"> | |
<div class="variants_title">Variants:</div> | |
<ul style="list-style: none; padding: 0; margin: 8px 0;"> | |
<li>mwoz: Baseline variant.</li> | |
<li>tau-airline: Airline specialty variant.</li> | |
<li>tau-retail: Retail specialty variant.</li> | |
</ul> | |
<p>Use the checkboxes below to select which variants to include. At least one variant must be active.</p> | |
</div> | |
''') | |
with gr.Row(elem_classes="checkbox-panel", elem_id="checkbox-panel"): | |
cb_mwoz = gr.Checkbox(label="mwoz", value=True) | |
cb_tau_airline = gr.Checkbox(label="tau-airline", value=True) | |
cb_tau_retail = gr.Checkbox(label="tau-retail", value=True) | |
with gr.Row(elem_classes="search-panel", elem_id="search-panel"): | |
search_input = gr.Textbox( | |
label="Search models", | |
placeholder="Type to filter…", | |
elem_classes="search-input", | |
elem_id="search-input" | |
) | |
hidden_sort_state = gr.State(value={"sort_by": "Average Score", "ascending": False}) | |
# Add sorting buttons | |
gr.Markdown("### Sort by:") | |
with gr.Row(): | |
btn_avg = gr.Button("Average Score ▼") | |
btn_conv = gr.Button("Conversation Consistency") | |
btn_backend = gr.Button("Backend Consistency") | |
btn_policy = gr.Button("Policy Completeness") | |
leaderboard_display = gr.HTML(label="Aggregated Model Rankings") | |
# Function to toggle sort state and update button labels | |
def toggle_sort(column, current_state, btn_avg, btn_conv, btn_backend, btn_policy): | |
# Default new state - flip direction if same column, otherwise default to descending | |
if isinstance(current_state, dict) and current_state.get("sort_by") == column: | |
new_ascending = not current_state.get("ascending", False) | |
else: | |
new_ascending = False | |
new_state = {"sort_by": column, "ascending": new_ascending} | |
# Update button labels | |
direction = "▲" if new_ascending else "▼" | |
avg_label = f"Average Score {direction}" if column == "Average Score" else "Average Score" | |
conv_label = f"Conversation Consistency {direction}" if column == "Conversation Consistency" else "Conversation Consistency" | |
backend_label = f"Backend Consistency {direction}" if column == "Backend Consistency" else "Backend Consistency" | |
policy_label = f"Policy Completeness {direction}" if column == "Policy Completeness" else "Policy Completeness" | |
return new_state, avg_label, conv_label, backend_label, policy_label | |
# Connect sort buttons with the toggle function | |
btn_avg.click( | |
fn=toggle_sort, | |
inputs=[gr.Textbox(value="Average Score", visible=False), hidden_sort_state, btn_avg, btn_conv, btn_backend, btn_policy], | |
outputs=[hidden_sort_state, btn_avg, btn_conv, btn_backend, btn_policy] | |
).then( | |
fn=update_leaderboard, | |
inputs=[cb_mwoz, cb_tau_airline, cb_tau_retail, hidden_sort_state, search_input], | |
outputs=leaderboard_display | |
) | |
btn_conv.click( | |
fn=toggle_sort, | |
inputs=[gr.Textbox(value="Conversation Consistency", visible=False), hidden_sort_state, btn_avg, btn_conv, btn_backend, btn_policy], | |
outputs=[hidden_sort_state, btn_avg, btn_conv, btn_backend, btn_policy] | |
).then( | |
fn=update_leaderboard, | |
inputs=[cb_mwoz, cb_tau_airline, cb_tau_retail, hidden_sort_state, search_input], | |
outputs=leaderboard_display | |
) | |
btn_backend.click( | |
fn=toggle_sort, | |
inputs=[gr.Textbox(value="Backend Consistency", visible=False), hidden_sort_state, btn_avg, btn_conv, btn_backend, btn_policy], | |
outputs=[hidden_sort_state, btn_avg, btn_conv, btn_backend, btn_policy] | |
).then( | |
fn=update_leaderboard, | |
inputs=[cb_mwoz, cb_tau_airline, cb_tau_retail, hidden_sort_state, search_input], | |
outputs=leaderboard_display | |
) | |
btn_policy.click( | |
fn=toggle_sort, | |
inputs=[gr.Textbox(value="Policy Completeness", visible=False), hidden_sort_state, btn_avg, btn_conv, btn_backend, btn_policy], | |
outputs=[hidden_sort_state, btn_avg, btn_conv, btn_backend, btn_policy] | |
).then( | |
fn=update_leaderboard, | |
inputs=[cb_mwoz, cb_tau_airline, cb_tau_retail, hidden_sort_state, search_input], | |
outputs=leaderboard_display | |
) | |
# Connect dataflow for variant checkboxes and search | |
cb_mwoz.change(fn=update_leaderboard, inputs=[cb_mwoz, cb_tau_airline, cb_tau_retail, hidden_sort_state, search_input], outputs=leaderboard_display) | |
cb_tau_airline.change(fn=update_leaderboard, inputs=[cb_mwoz, cb_tau_airline, cb_tau_retail, hidden_sort_state, search_input], outputs=leaderboard_display) | |
cb_tau_retail.change(fn=update_leaderboard, inputs=[cb_mwoz, cb_tau_airline, cb_tau_retail, hidden_sort_state, search_input], outputs=leaderboard_display) | |
search_input.change(fn=update_leaderboard, inputs=[cb_mwoz, cb_tau_airline, cb_tau_retail, hidden_sort_state, search_input], outputs=leaderboard_display) | |
demo.load(fn=update_leaderboard, inputs=[cb_mwoz, cb_tau_airline, cb_tau_retail, hidden_sort_state, search_input], outputs=leaderboard_display) | |
if __name__ == "__main__": | |
demo.launch() |