ArenaLite / pages /see_results.py
sonsus's picture
rebrand: varco-arena -> arena-lite
45f8fc7
raw
history blame
14.8 kB
import pandas as pd
import streamlit as st
import analysis_utils as au
from analysis_utils import number_breakdown_from_df
from app import load_and_cache_data
# from app import VA_ROOT
from query_comp import QueryWrapper, get_base_url
from varco_arena.varco_arena_core.prompts import load_prompt
from view_utils import (
default_page_setting,
escape_markdown,
set_nav_bar,
show_linebreak_in_md,
)
DEFAULT_LAYOUT_DICT = {
"title": {"font": {"size": 20, "family": "Gothic A1"}},
"font": {"size": 16, "family": "Gothic A1"},
"xaxis": {"tickfont": {"size": 12, "family": "Gothic A1"}},
"yaxis": {"tickfont": {"size": 12, "family": "Gothic A1"}},
"legend": {"font": {"size": 12, "family": "Gothic A1"}},
}
def navigate(t, source, key, val):
# print(key, val)
if source is None:
return
target_index = t.index(source) + val
if 0 <= target_index < len(t):
st.session_state[key] = t[target_index]
st.rerun()
def main():
sidebar_placeholder = default_page_setting(layout="wide")
set_nav_bar(
False,
sidebar_placeholder=sidebar_placeholder,
toggle_hashstr="see_results_init",
)
# load the data
# print(f"{st.session_state.get('result_file_path', None)=}")
most_recent_run = st.session_state.get("result_file_path", None)
most_recent_run = str(most_recent_run) if most_recent_run is not None else None
(
st.session_state["all_result_dict"],
st.session_state["df_dict"],
) = load_and_cache_data(result_file_path=most_recent_run)
# side bar
st.sidebar.title("Select Result:")
result_select = QueryWrapper("expname")(
st.sidebar.selectbox,
list(st.session_state["all_result_dict"].keys()),
)
if result_select is None:
if st.session_state.korean:
st.markdown("๊ฒฐ๊ณผ๋ฅผ ํ™•์ธํ•˜๋ ค๋ฉด ๋จผ์ € **๐Ÿ”ฅArena-Lite๋ฅผ ๊ตฌ๋™**ํ•˜์…”์•ผ ํ•ฉ๋‹ˆ๋‹ค")
else:
st.markdown("You should **๐Ÿ”ฅRun Arena-Lite** first to see results")
st.image("streamlit_app_local/page_result_1.png")
st.image("streamlit_app_local/page_result_2.png")
st.image("streamlit_app_local/page_result_3.png")
st.image("streamlit_app_local/page_result_3.png")
st.stop()
eval_prompt_name = result_select.split("/")[-1].strip()
if st.sidebar.button("Clear Cache"):
st.cache_data.clear()
st.cache_resource.clear()
st.rerun()
if result_select:
if "alpha2names" in st.session_state:
del st.session_state["alpha2names"]
fig_dict_per_task = st.session_state["all_result_dict"][result_select]
task_list = list(fig_dict_per_task.keys())
elo_rating_by_task = fig_dict_per_task["Overall"]["elo_rating_by_task"]
# tabs = st.tabs(task_list)
df_dict_per_task = st.session_state["df_dict"][result_select]
default_layout_dict = DEFAULT_LAYOUT_DICT
task = QueryWrapper("task", "Select Task")(st.selectbox, task_list)
if task is None:
st.stop()
figure_dict = fig_dict_per_task[task]
judgename = figure_dict["judgename"]
df = df_dict_per_task[task]
interpretation, n_models, size_testset = number_breakdown_from_df(df)
if st.session_state.korean:
st.markdown(f"## ๊ฒฐ๊ณผ ({task})")
st.markdown(f"##### Judge ๋ชจ๋ธ: {judgename} / ํ‰๊ฐ€ํ”„๋กฌ: {eval_prompt_name}")
st.markdown(f"##### ํ…Œ์ŠคํŠธ์…‹ ์‚ฌ์ด์ฆˆ: {int(size_testset)} ํ–‰")
else:
st.markdown(f"## Results ({task})")
st.markdown(f"##### Judge Model: {judgename} / prompt: {eval_prompt_name}")
st.markdown(f"##### Size of Testset: {int(size_testset)} rows")
col1, col2 = st.columns(2)
with col1:
with st.container(border=True):
st.markdown(f"#### Ratings ({task})")
st.table(figure_dict["elo_rating"])
st.write(show_linebreak_in_md(escape_markdown(interpretation)))
with col2:
with st.container(border=True):
st.plotly_chart(
elo_rating_by_task.update_layout(**default_layout_dict),
use_container_width=True,
key=f"{task}_elo_rating_by_task",
)
st.divider()
if st.session_state.korean:
st.markdown("### ํ† ๋„ˆ๋จผํŠธ (ํ…Œ์ŠคํŠธ ์‹œ๋‚˜๋ฆฌ์˜ค) ๋ณ„๋กœ ๋ณด๊ธฐ")
else:
st.markdown("### Tournament Results by Test Scenario")
# with st.expander("๋ณผ ํ† ๋„ˆ๋จผํŠธ ๊ณ ๋ฅด๊ธฐ"):
d = list(df.idx_inst_src.unique())
default_idx = st.session_state.get("selected_tournament", None)
cols = st.columns((1, 18, 1))
with cols[0]:
if st.button("โ—€", key="prev_tournament"):
navigate(d, default_idx, "selected_tournament", -1)
with cols[1]:
tournament_prm_select = QueryWrapper("tournament", "Select Tournament")(
st.selectbox,
d,
default_idx,
key=f"{task}_tournament_select",
on_change=lambda: st.session_state.update(
selected_tournament=st.session_state.get(f"{task}_tournament_select"),
selected_match=None,
),
label_visibility="collapsed",
)
with cols[2]:
if st.button("โ–ถ", key="next_tournament"):
navigate(d, default_idx, "selected_tournament", 1)
# tournament_prm_select = st.selectbox(
# "Select Tournament",
# df.idx_inst_src.unique(),
# index=d.index(st.session_state.get("selected_tournament")),
# key=f"{task}_tournament_{result_select}",
# )
# print(tournament_prm_select, type(tournament_prm_select))
st.session_state["selected_tournament"] = tournament_prm_select
# tournament_prm_select = st.selectbox(
# "Select Tournament",
# df.idx_inst_src.unique(),
# key=f"{task}_tournament_{result_select}",
# )
df_now_processed = None
if tournament_prm_select:
df_now = df[df.idx_inst_src == tournament_prm_select]
df_now_processed, _alpha2names = au.init_tournament_dataframe(
df_now,
alpha2names=st.session_state["alpha2names"]
if "alpha2names" in st.session_state.keys()
else None,
)
if "alpha2names" not in st.session_state:
st.session_state["alpha2names"] = _alpha2names
try:
bracket_drawing = au.draw(
df_now_processed,
alpha2names=st.session_state["alpha2names"],
)
legend = au.make_legend_str(
df_now_processed, st.session_state["alpha2names"]
)
st.code(bracket_drawing + legend)
m = list(df_now_processed.human_readable_idx)
default_idx = st.session_state.get("selected_match", None)
cols = st.columns((1, 18, 1))
with cols[0]:
if st.button("โ—€", key="prev_match"):
navigate(m, default_idx, "selected_match", -1)
with cols[1]:
match_idx_human = QueryWrapper("match", "Select Match")(
st.selectbox,
m,
default_idx,
key=f"{task}_match_select",
label_visibility="collapsed",
)
with cols[2]:
if st.button("โ–ถ", key="next_match"):
navigate(m, default_idx, "selected_match", 1)
# match_idx_human = st.selectbox(
# "Select Match",
# df_now_processed.human_readable_idx,
# key=f"{task}_match_{result_select}",
# )
# print(match_idx_human)
st.session_state["selected_match"] = match_idx_human
# match_idx_human = st.selectbox(
# "Select Match",
# df_now_processed.human_readable_idx,
# key=f"{task}_match_{result_select}",
# )
if match_idx_human:
match_idx = int(match_idx_human.split(": ")[0])
row = df_now_processed.loc[match_idx]
st.markdown("#### Current Test Scenario:")
with st.expander(
f"### Evaluation Prompt (evalprompt: {eval_prompt_name}--{task})"
):
prompt = load_prompt(eval_prompt_name, task=task)
kwargs = dict(
inst="{inst}",
src="{src}",
out_a="{out_a}",
out_b="{out_b}",
task=task,
)
if eval_prompt_name == "translation_pair":
kwargs["source_lang"] = "{source_lang}"
kwargs["target_lang"] = "{target_lang}"
prompt_cmpl = prompt.complete_prompt(**kwargs)
for msg in prompt_cmpl:
st.markdown(f"**{msg['role']}**")
st.info(show_linebreak_in_md(escape_markdown(msg["content"])))
st.info(show_linebreak_in_md(tournament_prm_select))
winner = row.winner
col1, col2 = st.columns(2)
winnerbox = st.success
loserbox = st.error
with col1:
iswinner = winner == "model_a"
writemsg = winnerbox if iswinner else loserbox
st.markdown(f"#### ({row.model_a}) {row.human_readable_model_a}")
writemsg(
show_linebreak_in_md(row.generated_a),
icon="โœ…" if iswinner else "โŒ",
)
with col2:
iswinner = winner == "model_b"
writemsg = winnerbox if iswinner else loserbox
st.markdown(f"#### ({row.model_b}) {row.human_readable_model_b}")
writemsg(
show_linebreak_in_md(row.generated_b),
icon="โœ…" if iswinner else "โŒ",
)
except Exception as e:
import traceback
traceback.print_exc()
st.markdown(
"**Bug: ์•„๋ž˜ ํ‘œ๋ฅผ ๋ณต์‚ฌํ•ด์„œ ์ด์Šˆ๋กœ ๋‚จ๊ฒจ์ฃผ์‹œ๋ฉด ๊ฐœ์„ ์— ๋„์›€์ด ๋ฉ๋‹ˆ๋‹ค. ๊ฐ์‚ฌํ•ฉ๋‹ˆ๋‹ค๐Ÿ™**"
if st.session_state.korean
else "Bug: Please open issue and attach the table output below to help me out. Thanks in advance.๐Ÿ™"
)
st.error(e)
st.info(tournament_prm_select)
st.table(
df_now_processed[
[
"depth",
"round",
"winner_nodes",
"winner_resolved",
"winner",
"model_a",
"model_b",
]
]
)
st.write("Sharable link")
st.code(f"{get_base_url()}/see_results?{QueryWrapper.get_sharable_link()}")
st.divider()
if st.session_state.korean:
st.markdown("### ๋งค์น˜ ํ†ต๊ณ„")
else:
st.markdown("### Match Stats.")
col1, col2 = st.columns(2)
col1, col2 = st.columns(2)
with col1:
with st.container(border=True):
st.plotly_chart(
figure_dict[
"fraction_of_model_a_wins_for_all_a_vs_b_matches"
].update_layout(autosize=True, **default_layout_dict),
use_container_width=True,
key=f"{task}_fraction_of_model_a_wins_for_all_a_vs_b_matches",
)
with col2:
with st.container(border=True):
st.plotly_chart(
figure_dict["match_count_of_each_combination_of_models"].update_layout(
autosize=True, **default_layout_dict
),
use_container_width=True,
key=f"{task}_match_count_of_each_combination_of_models",
)
with col1:
with st.container(border=True):
st.plotly_chart(
figure_dict["match_count_for_each_model"].update_layout(
**default_layout_dict
),
use_container_width=True,
key=f"{task}_match_count_for_each_model",
)
with col2:
pass
if st.session_state.korean:
st.markdown("### ์ฐธ๊ณ ์šฉ LLM Judge ํŽธํ–ฅ ์ •๋ณด")
else:
st.markdown("### FYI: How biased is your LLM Judge?")
with st.expander("ํŽผ์ณ์„œ ๋ณด๊ธฐ" if st.session_state.korean else "Expand to show"):
st.info(
"""
Arena-Lite์—์„œ๋Š” position bias์˜ ์˜ํ–ฅ์„ ์ตœ์†Œํ™”ํ•˜๊ธฐ ์œ„ํ•ด ๋ชจ๋“  ๋ชจ๋ธ์ด A๋‚˜ B์œ„์น˜์— ๋ฒˆ๊ฐˆ์•„ ์œ„์น˜ํ•˜๋„๋ก ํ•˜์˜€์Šต๋‹ˆ๋‹ค. ๊ทธ๋Ÿฌ๋‚˜ LLM Judge ํ˜น์€ Prompt์˜ ์„ฑ๋Šฅ์ด ๋ถ€์กฑํ•˜๋‹ค๊ณ  ๋А๊ปด์ง„๋‹ค๋ฉด, ์•„๋ž˜ ์•Œ๋ ค์ง„ LLM Judge bias๊ฐ€ ์ฐธ๊ณ ๊ฐ€ ๋ ๊ฒ๋‹ˆ๋‹ค.
* position bias (์™ผ์ชฝ)
* length bias (์˜ค๋ฅธ์ชฝ)
๊ฒฐ๊ณผ์˜ ์™œ๊ณก์ด LLM Judge์˜ ๋ถ€์กฑํ•จ ๋–„๋ฌธ์ด์—ˆ๋‹ค๋Š” ์ ์„ ๊ทœ๋ช…ํ•˜๋ ค๋ฉด ์‚ฌ์šฉํ•˜์‹  LLM Judge์™€ Prompt์˜ binary classification ์ •ํ™•๋„๋ฅผ ์ธก์ •ํ•ด๋ณด์‹œ๊ธธ ๋ฐ”๋ž๋‹ˆ๋‹ค (Arena-Lite๋ฅผ ํ™œ์šฉํ•˜์—ฌ ์ด๋ฅผ ์ˆ˜ํ–‰ํ•ด๋ณผ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค!).""".strip()
if st.session_state.korean
else """
In Arena-Lite, to minimize the effect of position bias, all models are alternately positioned in either position A or B. However, if you feel the LLM Judge or Prompt performance is insufficient, the following known LLM Judge biases may be helpful to reference:
* position bias (left)
* length bias (right)
To determine if result distortion was due to LLM Judge limitations, please measure the binary classification accuracy of your LLM Judge and Prompt (You could use Arena-Lite for this purpose!).
""".strip()
)
st.markdown(f"#### {judgename} + prompt = {eval_prompt_name}")
col1, col2 = st.columns(2)
with col1:
with st.container(border=True):
st.plotly_chart(
figure_dict["counts_of_match_winners"].update_layout(
**default_layout_dict
),
use_container_width=True,
key=f"{task}_counts_of_match_winners",
)
with col2:
with st.container(border=True):
st.plotly_chart(
figure_dict["length_bias"].update_layout(**default_layout_dict),
use_container_width=True,
key=f"{task}_length_bias",
)
st.table(figure_dict["length_bias_df"].groupby("category").describe().T)
if __name__ == "__main__":
main()