import streamlit as st
import pandas as pd
import json
from os.path import split as path_split, splitext as path_splitext

st.set_page_config(
    page_title="PPE Metrics Explorer",
    layout="wide",  # This makes the app use the entire screen width
    initial_sidebar_state="expanded",
)

# Set the title of the app
st.title("PPE Metrics Explorer")

@st.cache_data
def load_data(file_path):
    """
    Load json data from a file.
    """
    with open(file_path, 'r') as file:
        data = json.load(file)
    return data

def contains_list(column):
    return column.apply(lambda x: isinstance(x, list)).any()

INVERT = {'brier', 'loss'}

SCALE = {'accuracy', 'row-wise pearson', 'confidence_agreement', 'spearman', 'kendalltau', 'arena_under_curve', 'mean_max_score', 'mean_end_score'}

def main():
    # Load the JSON data
    data = load_data('results.json')

    # Extract the list of benchmarks
    benchmarks = list(sorted(data.keys(), key=lambda s: "A" + s if s == "human_preference_v1" else s))

    # Dropdown for selecting benchmark
    selected_benchmark = st.selectbox("Select a Benchmark", benchmarks)

    # Extract data for the selected benchmark
    benchmark_data = data[selected_benchmark]

    # Prepare a list to store records
    records = []

    # Iterate over each model in the selected benchmark
    for model, metrics in benchmark_data.items():

        model_type = "LLM Judge" if model.endswith(".jsonl") else "Reward Model"

        model = path_split(path_splitext(model)[0])[-1]
        # Flatten the metrics dictionary if there are nested metrics
        # For example, in "human_preference_v1", there are subcategories like "overall", "hard_prompt", etc.
        # We'll aggregate these or allow the user to select subcategories as needed
        if isinstance(metrics, dict):
            # If there are nested keys, we can allow the user to select a subcategory
            # For simplicity, let's assume we want to display all nested metrics concatenated
            flattened_metrics = {}
            for subkey, submetrics in metrics.items():
                if isinstance(submetrics, dict):
                    for metric_name, value in submetrics.items():
                        # Create a compound key
                        if metric_name in SCALE:

                            value = 100 * value

                        if metric_name in INVERT:
                            key = f"{subkey} - (1 - {metric_name})"
                            flattened_metrics[key] = 1 - value
                        else:
                            key = f"{subkey} - {metric_name}"
                            flattened_metrics[key] = value
                else:
                    flattened_metrics[subkey] = submetrics

            records.append({
                "Model": model,
                "Type": model_type,
                **flattened_metrics
            })
        else:
            # If metrics are not nested, just add them directly
            records.append({
                "Model": model,
                "Type": model_type,
                "Value": metrics
            })

    # Create a DataFrame
    df = pd.DataFrame(records)

    # Drop columns that contain lists
    df = df.loc[:, ~df.apply(contains_list)]

    if "human" not in selected_benchmark:
        df = df[sorted(df.columns, key=lambda s: s.replace("(1", "l").lower() if s != "Type" else "A")]

    # Set 'Model' as the index
    df.set_index(["Model"], inplace=True)


        # Create two columns: one for spacing and one for the search bar
    col1, col2, col3 = st.columns([1, 1, 2])  # Adjust the ratios as needed
    with col1:
        
        column_search = st.text_input("", placeholder="Search metrics...", key="search")

    with col2:

        model_search = st.text_input("", placeholder="Filter Models (separate criteria with ,) ...", key="search2")

        model_search_crit = model_search.replace(", ", "|").replace(",", "|")

    if column_search:
        # Filter columns that contain the search term (case-insensitive)
        filtered_columns = ["Type"] + [col for col in df.columns if column_search.lower() in col.lower()]
        if filtered_columns:
            df_display = df[filtered_columns]
        else:
            st.warning("No columns match your search.")
            df_display = pd.DataFrame()  # Empty DataFrame
    else:
        # If no search term, display all columns
        df_display = df

    if model_search:

        df_display = df_display[df_display.index.str.contains(model_search_crit, case=False)]

        if len(df_display) == 0:
            st.warning("No models match your filter.")
            df_display = pd.DataFrame()  # Empty DataFrame
        
        

    # Display the DataFrame
    st.dataframe(df_display.sort_values(df_display.columns[1], ascending=False).style.background_gradient(cmap='summer_r', axis=0).format(precision=4)
 if len(df_display) else df_display, use_container_width=True, height=500)

    # Optional: Allow user to download the data as CSV
    csv = df_display.to_csv()
    st.download_button(
        label="Download data as CSV",
        data=csv,
        file_name=f"{selected_benchmark}_metrics.csv",
        mime='text/csv',
    )

if __name__ == "__main__":
    main()