juancauma's picture
changes to font color
a732253
import gradio as gr
import pandas as pd
import json
import os
from gradio.themes.utils import colors
# Load external CSS from the file "styles.css"
try:
with open("styles.css", "r", encoding="utf-8") as f:
custom_css = f.read()
except UnicodeDecodeError:
# Try with a different encoding if utf-8 fails
with open("styles.css", "r", encoding="latin-1") as f:
custom_css = f.read()
# Add more specific selector for Gradio and add !important to improve the cascading
additional_css = """
.gradio-container .checkbox-panel,
div.gradio-container [class*="block"] .checkbox-panel {
background-color: #27272A !important;
}
.gradio-container .search-panel,
div.gradio-container [class*="block"] .search-panel {
background-color: #27272A !important;
}
"""
custom_css += additional_css
# Create a custom theme with light colors for our panels
class CustomTheme(gr.themes.Base):
def __init__(self):
super().__init__(
primary_hue=colors.red,
secondary_hue=colors.gray,
neutral_hue=colors.gray,
text_size=gr.themes.sizes.text_lg
)
# Don't set any global background colors
self.block_border_width = "0px"
self.block_shadow = "none"
# Add additional CSS for the new styles, being more specific
custom_css += """
/* Only override specific panels by ID */
#checkbox-panel,
#search-panel {
background-color: #F0F0F0 !important;
}
/* Only affect immediate children of these specific panels */
#checkbox-panel > *,
#search-panel > * {
background-color: transparent !important;
}
/* Target checkbox inputs specifically */
#checkbox-panel input[type="checkbox"],
#search-panel input[type="text"] {
background-color: transparent !important;
}
"""
def strip_timestamp(name):
"""Remove the timestamp portion from the model name."""
parts = name.split('-')
return '-'.join(parts[1:]) if len(parts) > 1 else name
# Static grouping mapping for the 10 general submissions.
GROUPS = [
{"mwoz": "20250214_193236-o1", "tau_airline": "20250215_115156-tau-o1-airline", "tau_retail": "20250215_121147-tau-o1-retail"},
{"mwoz": "20250131_012338-llama405", "tau_airline": "20250204_144222-tau-llama-405b-airline", "tau_retail": "20250205_033820-tau-llama405b-retail"},
{"mwoz": "20250130_140218-4o", "tau_airline": "20250131_152503-tau-4o-airline", "tau_retail": "20250131_152422-tau-4o-retail"},
{"mwoz": "20250130_183030-claude", "tau_airline": "20250205_030422-tau-sonnet-airline", "tau_retail": "20250131_152807-tau-sonnet-retail"},
{"mwoz": "20250131_012449-llama70", "tau_airline": "20250208_024344-tau-llama70b-airline", "tau_retail": "20250208_030407-tau-llama70b-retail"},
{"mwoz": "20250131_013711-qwen72b", "tau_airline": "20250202_112945-qwen72b-airline", "tau_retail": "20250202_140527-qwen72b-retail"},
{"mwoz": "20250130_184905-mistrallarge", "tau_airline": "20250205_024823-tau-mistrallarge-airline", "tau_retail": "20250205_044403-tau-mistrallarge-retail"},
{"mwoz": "20250131_010143-o1mini", "tau_airline": "20250214_180731-tau-o1-mini-airline", "tau_retail": "20250214_142736-tau-o1-mini-retail"},
{"mwoz": "20250130_140439-4omini", "tau_airline": "20250131_152226-tau-4o-mini-airline", "tau_retail": "20250131_152338-tau-4o-mini-retail"},
{"mwoz": "20250130_145202-gpt35", "tau_airline": "20250131_152708-tau-gpt35-airline", "tau_retail": "20250131_152610-tau-gpt35-retail"}
]
def load_mwoz_results():
path = os.path.join("data", "mwoz_leaderboard_results.json")
if not os.path.exists(path):
return []
with open(path, "r") as f:
return json.load(f)
def load_tau_results():
path = os.path.join("data", "tau_leaderboard_results.json")
if not os.path.exists(path):
return []
with open(path, "r") as f:
return json.load(f)
def create_grouped_leaderboard(selected_mwoz, selected_tau_airline, selected_tau_retail, sort_state, search_query=""):
if not (selected_mwoz or selected_tau_airline or selected_tau_retail):
selected_mwoz = True
mwoz_data = load_mwoz_results()
tau_data = load_tau_results()
mwoz_lookup = {entry["model_name"]: entry for entry in mwoz_data}
tau_lookup = {entry["model_name"]: entry for entry in tau_data}
aggregated = []
for group in GROUPS:
metrics = {"avg_conv_consistency": 0, "avg_backend_consistency": 0, "avg_policy_completeness": 0}
count = 0
title_parts = []
judge_model = ""
if selected_mwoz:
key = group["mwoz"]
if key in mwoz_lookup:
record = mwoz_lookup[key]
metrics["avg_conv_consistency"] += record.get("avg_conv_consistency", 0)
metrics["avg_backend_consistency"] += record.get("avg_backend_consistency", 0)
metrics["avg_policy_completeness"] += record.get("avg_policy_completeness", 0)
count += 1
title_parts.append(strip_timestamp(key))
judge_model = record.get("judge_model", "")
if selected_tau_airline:
key = group["tau_airline"]
if key in tau_lookup:
record = tau_lookup[key]
metrics["avg_conv_consistency"] += record.get("avg_conv_consistency", 0)
metrics["avg_backend_consistency"] += record.get("avg_backend_consistency", 0)
metrics["avg_policy_completeness"] += record.get("avg_policy_completeness", 0)
count += 1
title_parts.append(strip_timestamp(key))
judge_model = record.get("judge_model", "")
if selected_tau_retail:
key = group["tau_retail"]
if key in tau_lookup:
record = tau_lookup[key]
metrics["avg_conv_consistency"] += record.get("avg_conv_consistency", 0)
metrics["avg_backend_consistency"] += record.get("avg_backend_consistency", 0)
metrics["avg_policy_completeness"] += record.get("avg_policy_completeness", 0)
count += 1
title_parts.append(strip_timestamp(key))
judge_model = record.get("judge_model", "")
if count > 0:
avg_conv = metrics["avg_conv_consistency"] / count
avg_backend = metrics["avg_backend_consistency"] / count
avg_policy = metrics["avg_policy_completeness"] / count
overall_avg = (avg_conv + avg_backend + avg_policy) / 3
else:
avg_conv = avg_backend = avg_policy = overall_avg = 0
model_name = " / ".join(title_parts)
# Apply search filter
if search_query and search_query.lower() not in model_name.lower():
continue
aggregated.append({
"Model": model_name,
"Average Score": round(overall_avg, 4),
"Conversation Consistency": round(avg_conv, 4),
"Backend Consistency": round(avg_backend, 4),
"Policy Completeness": round(avg_policy, 4),
"Judge Model": judge_model
})
df = pd.DataFrame(aggregated)
# If no results found after filtering
if df.empty:
return df
df["Rank"] = df["Average Score"].rank(ascending=False, method="min").astype(int)
allowed_sort_cols = ["Average Score", "Conversation Consistency", "Backend Consistency", "Policy Completeness"]
# Handle sort_state safely
if isinstance(sort_state, str):
try:
sort_state = json.loads(sort_state)
except:
sort_state = {"sort_by": "Average Score", "ascending": False}
# Ensure sort_state is a dict
if not isinstance(sort_state, dict):
sort_state = {"sort_by": "Average Score", "ascending": False}
sort_by = sort_state.get("sort_by", "Average Score")
ascending = sort_state.get("ascending", False)
if sort_by in allowed_sort_cols:
df = df.sort_values(sort_by, ascending=ascending)
else:
# Default sort if column not found
df = df.sort_values("Average Score", ascending=False)
cols = df.columns.tolist()
if "Rank" in cols:
cols.insert(0, cols.pop(cols.index("Rank")))
df = df[cols]
return df
def update_sort_state(current_state, clicked_column):
"""
Update the sorting state based on the clicked column.
Handles various input formats for current_state.
"""
# Default state if nothing valid is provided
new_state = {"sort_by": clicked_column, "ascending": False}
# Handle the case when current_state is a string (JSON)
if isinstance(current_state, str):
try:
current_state = json.loads(current_state)
except (json.JSONDecodeError, TypeError):
# If we can't parse it, return the default state
return new_state
# If current_state is None or not a dict, return default
if not isinstance(current_state, dict):
return new_state
# Now we're sure current_state is a dict
# Check if it has the needed keys
if "sort_by" in current_state:
if current_state["sort_by"] == clicked_column:
# Toggle direction for the same column
return {
"sort_by": clicked_column,
"ascending": not current_state.get("ascending", False)
}
else:
# New column, default to descending (false)
return {
"sort_by": clicked_column,
"ascending": False
}
# If we got here, current_state doesn't have the right format
return new_state
def sort_by_avg(sort_state):
return update_sort_state(sort_state, "Average Score")
def sort_by_conv(sort_state):
return update_sort_state(sort_state, "Conversation Consistency")
def sort_by_backend(sort_state):
return update_sort_state(sort_state, "Backend Consistency")
def sort_by_policy(sort_state):
return update_sort_state(sort_state, "Policy Completeness")
def get_color_for_value(value, min_val, max_val):
if max_val == min_val:
norm = 0.5
else:
norm = (value - min_val) / (max_val - min_val)
if norm < 0.5:
ratio = norm / 0.5
# Darker red for lower values
r = 180
g = int(140 * ratio)
b = 0
else:
ratio = (norm - 0.5) / 0.5
# Darker green for higher values
r = int(140 * (1 - ratio))
g = 140
b = 0
return f"#{r:02X}{g:02X}{b:02X}"
def generate_html_table(df):
if df.empty:
return "<div class='no-results'>No matching results found.</div>"
numeric_cols = ["Average Score", "Conversation Consistency", "Backend Consistency", "Policy Completeness"]
col_min = {}
col_max = {}
for col in numeric_cols:
col_min[col] = df[col].min() if not df.empty else 0
col_max[col] = df[col].max() if not df.empty else 0
# Build a simple HTML table without borders or JavaScript sorting
html = "<table style='border: none; border-collapse: collapse;'>"
# Header row
html += "<tr>"
for col in df.columns:
html += f"<th style='padding:8px; border: none;'>{col}</th>"
html += "</tr>"
# Table rows
for _, row in df.iterrows():
html += "<tr style='border: none;'>"
for col in df.columns:
cell_value = row[col]
if col in numeric_cols:
color = get_color_for_value(cell_value, col_min[col], col_max[col])
html += f"<td style='padding: 8px; border: none; color: {color}; font-weight: bold;'>{cell_value}</td>"
else:
html += f"<td style='padding: 8px; border: none;'>{cell_value}</td>"
html += "</tr>"
html += "</table>"
return html
def update_leaderboard(selected_mwoz, selected_tau_airline, selected_tau_retail, sort_state, search_query=""):
"""
Update the leaderboard based on selection and sort state.
"""
try:
# Convert sort_state to dict if it's a string
if isinstance(sort_state, str):
try:
sort_state = json.loads(sort_state)
except:
# If JSON parsing fails, create a default state
sort_state = {"sort_by": "Average Score", "ascending": False}
# Ensure sort_state is a dict
if not isinstance(sort_state, dict):
sort_state = {"sort_by": "Average Score", "ascending": False}
# Generate the data and table
df = create_grouped_leaderboard(selected_mwoz, selected_tau_airline, selected_tau_retail, sort_state, search_query)
html_table = generate_html_table(df)
# Get sort info with fallbacks
sort_col = sort_state.get("sort_by", "Average Score")
sort_dir = "▼" if not sort_state.get("ascending", False) else "▲"
html_output = f"""
<div class="sort-info">
<p>Sorted by: {sort_col} {sort_dir}</p>
</div>
{html_table}
"""
return html_output
except Exception as e:
# If anything goes wrong, return a basic table with an error message
print(f"Error in update_leaderboard: {str(e)}")
df = create_grouped_leaderboard(selected_mwoz, selected_tau_airline, selected_tau_retail,
{"sort_by": "Average Score", "ascending": False})
html_table = generate_html_table(df)
return f"""
<div class="sort-info" style="color: #ff6b6b;">
<p>Error in sorting. Using default sort: Average Score (descending)</p>
</div>
{html_table}
"""
# Create our custom theme instance
custom_theme = CustomTheme()
with gr.Blocks(css=custom_css, title="TD-EVAL Leaderboard", theme=custom_theme) as demo:
gr.Markdown("# 🏆 TD-EVAL Model Evaluation Leaderboard")
gr.HTML('<div class="subtitle">This leaderboard displays aggregated model performance across multiple evaluation metrics.</div>')
# Add JavaScript to ensure backgrounds are properly set
gr.HTML("""
<script>
// Function to fix background colors
function fixBackgrounds() {
// Add a style tag to force all block-info spans to be black
var styleEl = document.createElement('style');
styleEl.textContent = `
span[data-testid="block-info"] { color: #000000 !important; }
.svelte-1gfkn6j { color: #000000 !important; }
.search-panel label,
.search-panel .label-wrap,
.search-panel span,
#search-panel span,
div[id="search-panel"] span { color: #000000 !important; }
`;
document.head.appendChild(styleEl);
// Only fix specific panels by ID
var checkboxPanel = document.getElementById('checkbox-panel');
if (checkboxPanel) {
checkboxPanel.style.backgroundColor = '#F0F0F0';
// Only make checkboxes and their direct containers transparent
var checkboxes = checkboxPanel.querySelectorAll('input[type="checkbox"]');
checkboxes.forEach(function(checkbox) {
var parent = checkbox.parentElement;
if (parent) parent.style.backgroundColor = 'transparent';
checkbox.style.backgroundColor = 'transparent';
// Find and style the associated label to be black
var label = checkbox.nextElementSibling;
if (label && label.tagName === 'LABEL') {
label.style.color = '#000000';
}
// Also find any span elements that might contain the label text
var spans = parent.querySelectorAll('span');
spans.forEach(function(span) {
span.style.color = '#000000';
});
// Find label elements in the parent container
var labels = parent.querySelectorAll('label');
labels.forEach(function(label) {
label.style.color = '#000000';
});
// Apply custom styling for the checkbox to show orange checkmark
if (checkbox.checked) {
checkbox.style.position = 'relative';
checkbox.style.appearance = 'none';
checkbox.style.backgroundColor = '#F0F0F0';
checkbox.style.border = '1px solid #CCCCCC';
checkbox.style.borderRadius = '3px';
// Create or update the checkmark element
var checkmark = checkbox.querySelector('.orange-checkmark');
if (!checkmark) {
checkmark = document.createElement('span');
checkmark.className = 'orange-checkmark';
checkmark.style.position = 'absolute';
checkmark.style.left = '50%';
checkmark.style.top = '50%';
checkmark.style.transform = 'translate(-50%, -50%)';
checkmark.style.color = '#c34700';
checkmark.style.fontSize = '14px';
checkmark.style.fontWeight = 'bold';
checkmark.innerText = '✓';
checkbox.appendChild(checkmark);
}
}
});
}
var searchPanel = document.getElementById('search-panel');
if (searchPanel) {
searchPanel.style.backgroundColor = '#F0F0F0';
// Only make search input and its direct container transparent
var searchInput = searchPanel.querySelector('input[type="text"]');
if (searchInput) {
var parent = searchInput.parentElement;
if (parent) parent.style.backgroundColor = 'transparent';
searchInput.style.backgroundColor = '#FFFFFF';
// Ensure the border is visible and matches text color
searchInput.style.border = '2px solid #000000';
searchInput.style.color = '#000000';
}
// Make sure the label is black
var searchLabels = searchPanel.querySelectorAll('label, .label-wrap, .label-wrap span');
searchLabels.forEach(function(label) {
label.style.color = '#000000';
});
// Target the specific span element that contains the label text
var blockInfoSpans = document.querySelectorAll('span[data-testid="block-info"]');
blockInfoSpans.forEach(function(span) {
span.style.color = '#000000';
});
// Also target elements with the svelte class
var svelteElements = document.querySelectorAll('.svelte-1gfkn6j');
svelteElements.forEach(function(element) {
if (element.textContent.includes('Search models')) {
element.style.color = '#000000';
}
});
}
}
// Run on page load and every second for 3 seconds to catch any delayed rendering
setTimeout(fixBackgrounds, 500);
setTimeout(fixBackgrounds, 1000);
setTimeout(fixBackgrounds, 2000);
</script>
""")
gr.HTML('''
<div class="variants_container">
<div class="variants_title">Variants:</div>
<ul style="list-style: none; padding: 0; margin: 8px 0;">
<li>mwoz: Baseline variant.</li>
<li>tau-airline: Airline specialty variant.</li>
<li>tau-retail: Retail specialty variant.</li>
</ul>
<p>Use the checkboxes below to select which variants to include. At least one variant must be active.</p>
</div>
''')
with gr.Row(elem_classes="checkbox-panel", elem_id="checkbox-panel"):
cb_mwoz = gr.Checkbox(label="mwoz", value=True)
cb_tau_airline = gr.Checkbox(label="tau-airline", value=True)
cb_tau_retail = gr.Checkbox(label="tau-retail", value=True)
with gr.Row(elem_classes="search-panel", elem_id="search-panel"):
search_input = gr.Textbox(
label="Search models",
placeholder="Type to filter…",
elem_classes="search-input",
elem_id="search-input"
)
hidden_sort_state = gr.State(value={"sort_by": "Average Score", "ascending": False})
# Add sorting buttons
gr.Markdown("### Sort by:")
with gr.Row():
btn_avg = gr.Button("Average Score ▼")
btn_conv = gr.Button("Conversation Consistency")
btn_backend = gr.Button("Backend Consistency")
btn_policy = gr.Button("Policy Completeness")
leaderboard_display = gr.HTML(label="Aggregated Model Rankings")
# Function to toggle sort state and update button labels
def toggle_sort(column, current_state, btn_avg, btn_conv, btn_backend, btn_policy):
# Default new state - flip direction if same column, otherwise default to descending
if isinstance(current_state, dict) and current_state.get("sort_by") == column:
new_ascending = not current_state.get("ascending", False)
else:
new_ascending = False
new_state = {"sort_by": column, "ascending": new_ascending}
# Update button labels
direction = "▲" if new_ascending else "▼"
avg_label = f"Average Score {direction}" if column == "Average Score" else "Average Score"
conv_label = f"Conversation Consistency {direction}" if column == "Conversation Consistency" else "Conversation Consistency"
backend_label = f"Backend Consistency {direction}" if column == "Backend Consistency" else "Backend Consistency"
policy_label = f"Policy Completeness {direction}" if column == "Policy Completeness" else "Policy Completeness"
return new_state, avg_label, conv_label, backend_label, policy_label
# Connect sort buttons with the toggle function
btn_avg.click(
fn=toggle_sort,
inputs=[gr.Textbox(value="Average Score", visible=False), hidden_sort_state, btn_avg, btn_conv, btn_backend, btn_policy],
outputs=[hidden_sort_state, btn_avg, btn_conv, btn_backend, btn_policy]
).then(
fn=update_leaderboard,
inputs=[cb_mwoz, cb_tau_airline, cb_tau_retail, hidden_sort_state, search_input],
outputs=leaderboard_display
)
btn_conv.click(
fn=toggle_sort,
inputs=[gr.Textbox(value="Conversation Consistency", visible=False), hidden_sort_state, btn_avg, btn_conv, btn_backend, btn_policy],
outputs=[hidden_sort_state, btn_avg, btn_conv, btn_backend, btn_policy]
).then(
fn=update_leaderboard,
inputs=[cb_mwoz, cb_tau_airline, cb_tau_retail, hidden_sort_state, search_input],
outputs=leaderboard_display
)
btn_backend.click(
fn=toggle_sort,
inputs=[gr.Textbox(value="Backend Consistency", visible=False), hidden_sort_state, btn_avg, btn_conv, btn_backend, btn_policy],
outputs=[hidden_sort_state, btn_avg, btn_conv, btn_backend, btn_policy]
).then(
fn=update_leaderboard,
inputs=[cb_mwoz, cb_tau_airline, cb_tau_retail, hidden_sort_state, search_input],
outputs=leaderboard_display
)
btn_policy.click(
fn=toggle_sort,
inputs=[gr.Textbox(value="Policy Completeness", visible=False), hidden_sort_state, btn_avg, btn_conv, btn_backend, btn_policy],
outputs=[hidden_sort_state, btn_avg, btn_conv, btn_backend, btn_policy]
).then(
fn=update_leaderboard,
inputs=[cb_mwoz, cb_tau_airline, cb_tau_retail, hidden_sort_state, search_input],
outputs=leaderboard_display
)
# Connect dataflow for variant checkboxes and search
cb_mwoz.change(fn=update_leaderboard, inputs=[cb_mwoz, cb_tau_airline, cb_tau_retail, hidden_sort_state, search_input], outputs=leaderboard_display)
cb_tau_airline.change(fn=update_leaderboard, inputs=[cb_mwoz, cb_tau_airline, cb_tau_retail, hidden_sort_state, search_input], outputs=leaderboard_display)
cb_tau_retail.change(fn=update_leaderboard, inputs=[cb_mwoz, cb_tau_airline, cb_tau_retail, hidden_sort_state, search_input], outputs=leaderboard_display)
search_input.change(fn=update_leaderboard, inputs=[cb_mwoz, cb_tau_airline, cb_tau_retail, hidden_sort_state, search_input], outputs=leaderboard_display)
demo.load(fn=update_leaderboard, inputs=[cb_mwoz, cb_tau_airline, cb_tau_retail, hidden_sort_state, search_input], outputs=leaderboard_display)
if __name__ == "__main__":
demo.launch()