|
""" |
|
Gradio demo – visualise benchmark accuracy curves. |
|
|
|
Required CSV files (place in the *same* folder as app.py): |
|
|
|
├── aggregated_accuracy.csv |
|
├── qa_accuracy.csv |
|
├── ocr_accuracy.csv |
|
└── temporal_accuracy.csv |
|
|
|
Each file has the columns |
|
|
|
Model,<context‑length‑1>,<context‑length‑2>,… |
|
|
|
where the context‑length headers are strings such as `30min`, `60min`, `120min`, … |
|
|
|
No further cleaning / renaming is done apart from two cosmetic replacements |
|
(“gpt4.1” → “ChatGPT 4.1”, “gemini2.5pro” → “Gemini 2.5 Pro”). |
|
""" |
|
|
|
from pathlib import Path |
|
|
|
import pandas as pd |
|
import plotly.graph_objects as go |
|
import gradio as gr |
|
import math |
|
|
|
|
|
|
|
|
|
|
|
FILES = { |
|
"aggregated": "aggregated_accuracy.csv", |
|
"qa": "qa_accuracy.csv", |
|
"ocr": "ocr_accuracy.csv", |
|
"temporal": "temporal_accuracy.csv", |
|
} |
|
|
|
|
|
DISPLAY_LABELS = { |
|
"aggregated": "Aggregated", |
|
"qa": "QA", |
|
"ocr": "OCR", |
|
"temporal": "Temporal", |
|
} |
|
|
|
|
|
|
|
|
|
DEFAULT_MODELS: dict[str, list[str]] = { |
|
"aggregated": [ |
|
"Gemini 2.5 Pro", |
|
"ChatGPT 4.1", |
|
"Qwen2.5-VL-7B", |
|
"InternVL2.5-8B", |
|
"LLaMA-3.2-11B-Vision", |
|
], |
|
} |
|
|
|
RENAME = { |
|
r"gpt4\.1": "ChatGPT 4.1", |
|
r"Gemini\s2\.5\spro": "Gemini 2.5 Pro", |
|
r"LLaMA-3\.2B-11B": "LLaMA-3.2-11B-Vision", |
|
} |
|
|
|
|
|
|
|
|
|
|
|
def _read_csv(path: str | Path) -> pd.DataFrame: |
|
df = pd.read_csv(path) |
|
df["Model"] = df["Model"].replace(RENAME, regex=True).astype(str) |
|
return df |
|
|
|
dfs: dict[str, pd.DataFrame] = {name: _read_csv(path) for name, path in FILES.items()} |
|
|
|
|
|
|
|
|
|
|
|
import plotly.express as px |
|
|
|
SAFE_PALETTE = px.colors.qualitative.Safe |
|
|
|
|
|
ALL_MODELS: list[str] = sorted({m for df in dfs.values() for m in df["Model"].unique()}) |
|
|
|
MARKER_SYMBOLS = [ |
|
"circle", |
|
"square", |
|
"triangle-up", |
|
"diamond", |
|
"cross", |
|
"triangle-down", |
|
"x", |
|
"triangle-right", |
|
"triangle-left", |
|
"pentagon", |
|
] |
|
|
|
TIME_COLS = [c for c in dfs["aggregated"].columns if c.lower() != "model"] |
|
|
|
|
|
def _pretty_time(label: str) -> str: |
|
"""‘30min’ → ‘30min’; ‘120min’ → ‘2hr’; keeps original if no match.""" |
|
if label.endswith("min"): |
|
minutes = int(label[:-3]) |
|
if minutes >= 60: |
|
hours = minutes / 60 |
|
return f"{hours:.0f}hr" if hours.is_integer() else f"{hours:.1f}hr" |
|
return label |
|
|
|
|
|
TIME_LABELS = {c: _pretty_time(c) for c in TIME_COLS} |
|
|
|
|
|
|
|
|
|
|
|
def render_chart( |
|
benchmark: str, |
|
models: list[str], |
|
log_scale: bool, |
|
) -> go.Figure: |
|
bench_key = benchmark.lower() |
|
df = dfs[bench_key] |
|
fig = go.Figure() |
|
|
|
|
|
palette = SAFE_PALETTE |
|
|
|
|
|
min_y_val = None |
|
|
|
for idx, m in enumerate(models): |
|
row = df.loc[df["Model"] == m] |
|
if row.empty: |
|
continue |
|
y = row[TIME_COLS].values.flatten() |
|
y = [val if val != 0 else None for val in y] |
|
|
|
|
|
y_non_none = [val for val in y if val is not None] |
|
if y_non_none: |
|
cur_min = min(y_non_none) |
|
if min_y_val is None or cur_min < min_y_val: |
|
min_y_val = cur_min |
|
|
|
model_idx = ALL_MODELS.index(m) if m in ALL_MODELS else idx |
|
color = palette[model_idx % len(palette)] |
|
symbol = MARKER_SYMBOLS[model_idx % len(MARKER_SYMBOLS)] |
|
fig.add_trace( |
|
go.Scatter( |
|
x=[TIME_LABELS[c] for c in TIME_COLS], |
|
y=y, |
|
mode="lines+markers", |
|
name=m, |
|
line=dict(width=3, color=color), |
|
marker=dict(size=6, color=color, symbol=symbol), |
|
connectgaps=False, |
|
) |
|
) |
|
|
|
|
|
if log_scale: |
|
|
|
if min_y_val is None or min_y_val <= 0: |
|
min_y_val = 0.1 |
|
|
|
yaxis_range = [math.floor(math.log10(min_y_val)), 2] |
|
yaxis_type = "log" |
|
else: |
|
yaxis_range = [0, 100] |
|
yaxis_type = "linear" |
|
|
|
fig.update_layout( |
|
title=f"{DISPLAY_LABELS.get(bench_key, bench_key.capitalize())} Accuracy Over Time", |
|
xaxis_title="Video Duration", |
|
yaxis_title="Accuracy (%)", |
|
yaxis_type=yaxis_type, |
|
yaxis_range=yaxis_range, |
|
legend_title="Model", |
|
legend=dict( |
|
orientation="h", |
|
y=-0.25, |
|
x=0.5, |
|
xanchor="center", |
|
tracegroupgap=8, |
|
itemwidth=60, |
|
), |
|
margin=dict(t=40, r=20, b=80, l=60), |
|
template="plotly_dark", |
|
font=dict(family="Inter,Helvetica,Arial,sans-serif", size=14), |
|
title_font=dict(size=20, family="Inter,Helvetica,Arial,sans-serif", color="white"), |
|
xaxis=dict(gridcolor="rgba(255,255,255,0.15)"), |
|
yaxis=dict(gridcolor="rgba(255,255,255,0.15)"), |
|
hoverlabel=dict(bgcolor="#1e1e1e", font_color="#eeeeee", bordercolor="#888"), |
|
) |
|
return fig |
|
|
|
|
|
|
|
|
|
|
|
|
|
CSS = """ |
|
#controls { |
|
padding: 8px 12px; |
|
} |
|
.scrollbox { |
|
max-height: 300px; |
|
overflow-y: auto; |
|
} |
|
body, .gradio-container { |
|
font-family: 'Inter', 'Helvetica', sans-serif; |
|
} |
|
.gradio-container h1, .gradio-container h2 { |
|
font-weight: 600; |
|
} |
|
|
|
#controls, .scrollbox { |
|
background: rgba(255,255,255,0.02); |
|
border-radius: 6px; |
|
} |
|
|
|
input[type="checkbox"]:checked { |
|
accent-color: #FF715E; |
|
} |
|
""" |
|
|
|
def available_models(bench: str) -> list[str]: |
|
return sorted(dfs[bench]["Model"].unique()) |
|
|
|
|
|
def default_models(bench: str) -> list[str]: |
|
"""Return list of default-selected models for a benchmark.""" |
|
opts = available_models(bench) |
|
configured = DEFAULT_MODELS.get(bench, []) |
|
|
|
valid = [m for m in configured if m in opts] |
|
if not valid: |
|
|
|
valid = opts[:6] |
|
return valid |
|
|
|
|
|
with gr.Blocks(theme=gr.themes.Base(), css=CSS) as demo: |
|
gr.Markdown( |
|
""" |
|
# 📈 TimeScope |
|
|
|
How long can your video model keep up? |
|
""" |
|
) |
|
|
|
|
|
with gr.Row(): |
|
benchmark_dd = gr.Dropdown( |
|
label="Type", |
|
choices=list(DISPLAY_LABELS.values()), |
|
value=DISPLAY_LABELS["aggregated"], |
|
scale=1, |
|
) |
|
log_cb = gr.Checkbox( |
|
label="Log-scale Y-axis", |
|
value=False, |
|
scale=1, |
|
) |
|
|
|
|
|
plot_out = gr.Plot( |
|
render_chart("Aggregated", default_models("aggregated"), False) |
|
) |
|
|
|
models_cb = gr.CheckboxGroup( |
|
label="Models", |
|
choices=available_models("aggregated"), |
|
value=default_models("aggregated"), |
|
interactive=True, |
|
elem_classes=["scrollbox"], |
|
) |
|
|
|
|
|
def _update_models(bench: str): |
|
bench_key = bench.lower() |
|
opts = available_models(bench_key) |
|
defaults = default_models(bench_key) |
|
|
|
return gr.update(choices=opts, value=defaults) |
|
|
|
benchmark_dd.change( |
|
fn=_update_models, |
|
inputs=benchmark_dd, |
|
outputs=models_cb, |
|
queue=False, |
|
) |
|
|
|
for ctrl in (benchmark_dd, models_cb, log_cb): |
|
ctrl.change( |
|
fn=render_chart, |
|
inputs=[benchmark_dd, models_cb, log_cb], |
|
outputs=plot_out, |
|
queue=False, |
|
) |
|
|
|
|
|
|
|
demo.launch(share=True) |