import gradio as gr
import json
import pandas as pd
from urllib.request import urlopen
from urllib.error import URLError
import re
from datetime import datetime

CITATION_BUTTON_TEXT = r"""@misc{2023opencompass,
    title={OpenCompass: A Universal Evaluation Platform for Foundation Models},
    author={OpenCompass Contributors},
    howpublished = {\url{https://github.com/open-compass/opencompass}},
    year={2023}
}"""
CITATION_BUTTON_LABEL = "Copy the following snippet to cite these results"


head_style = """
<style>
@media (min-width: 1536px)
{
    .gradio-container {
        min-width: var(--size-full) !important;
    }
}
</style>
"""

DATA_URL_BASE = "http://opencompass.oss-cn-shanghai.aliyuncs.com/dev-assets/hf-research/"

MAIN_LEADERBOARD_DESCRIPTION = """## Compass Academic Leaderboard (Full Version)
--WIP--

"""
Initial_title = 'Compass Academic Leaderboard'

MODEL_SIZE = ['<10B', '10B-70B', '>70B', 'Unknown']
MODEL_TYPE = ['API', 'OpenSource']



def findfile():
    model_meta_info = 'model-meta-info'
    results_sum = 'hf-academic'

    url = f"{DATA_URL_BASE}{model_meta_info}.json"
    response = urlopen(url)
    model_info = json.loads(response.read().decode('utf-8'))

    url = f"{DATA_URL_BASE}{results_sum}.json"
    response = urlopen(url)
    results = json.loads(response.read().decode('utf-8'))

    return model_info, results

model_info, results = findfile()





def make_results_tab(model_info, results):
    models_list, datasets_list = [], []
    for i in model_info:
        models_list.append(i)
    for i in results.keys():
        datasets_list.append(i)
    
    result_list = []
    index = 1
    for model in models_list:
        this_result = {}
        this_result['Index'] = index
        this_result['Model Name'] = model['display_name']
        this_result['Parameters'] = model['num_param']
        this_result['OpenSource'] = model['release_type']
        index += 1        
        for dataset in datasets_list:
            this_result[dataset] = results[dataset][model['abbr']]
        result_list.append(this_result)

    df = pd.DataFrame(result_list)
    return df 



def calculate_column_widths(df):
    column_widths = []
    for column in df.columns:
        header_length = len(str(column))
        max_content_length = df[column].astype(str).map(len).max()
        width = max(header_length * 10, max_content_length * 8) + 20
        width = max(160, min(400, width))
        column_widths.append(width)
    return column_widths



def show_results_tab(df):

    
    def filter_df(model_name, size_ranges, model_types):
        
        newdf = make_results_tab(model_info, results)

        # search model name
        default_val = 'Input the Model Name'
        if model_name != default_val:
            method_names = [x.split('</a>')[0].split('>')[-1].lower() for x in newdf['Model Name']]
            flag = [model_name.lower() in name for name in method_names]
            newdf['TEMP'] = flag
            newdf = newdf[newdf['TEMP'] == True] 
            newdf.pop('TEMP')
            
        
        # filter size
        if size_ranges:
            def get_size_in_B(param):
                if param == 'N/A':
                    return None
                try:
                    return float(param.replace('B', ''))
                except:
                    return None
            
            newdf['size_in_B'] = newdf['Parameters'].apply(get_size_in_B)
            mask = pd.Series(False, index=newdf.index)
            
            for size_range in size_ranges:
                if size_range == '<10B':
                    mask |= (newdf['size_in_B'] < 10) & (newdf['size_in_B'].notna())
                elif size_range == '10B-70B':
                    mask |= (newdf['size_in_B'] >= 10) & (newdf['size_in_B'] < 70)
                elif size_range == '>70B':
                    mask |= newdf['size_in_B'] >= 70
                elif size_range == 'Unknown':
                    mask |= newdf['size_in_B'].isna()
                    
            newdf = newdf[mask]
            newdf.drop('size_in_B', axis=1, inplace=True)

        # filter opensource
        if model_types:
            type_mask = pd.Series(False, index=newdf.index)
            for model_type in model_types:
                if model_type == 'API':
                    type_mask |= newdf['OpenSource'] == 'API'
                elif model_type == 'OpenSource':
                    type_mask |= newdf['OpenSource'] == 'OpenSource'
            newdf = newdf[type_mask]

        # for i in range(len(newdf)):
        #     newdf.loc[i, 'Index'] = i+1
        
        return newdf

        
    with gr.Row():
        with gr.Column():
            model_name = gr.Textbox(
                value='Input the Model Name', 
                label='Search Model Name',
                interactive=True
            )
        with gr.Column():
            size_filter = gr.CheckboxGroup(
                choices=MODEL_SIZE,
                value=MODEL_SIZE,
                label='Model Size',
                interactive=True,
            )
        with gr.Column():
            type_filter = gr.CheckboxGroup(
                choices=MODEL_TYPE,
                value=MODEL_TYPE,
                label='Model Type',
                interactive=True,
            )

    with gr.Column():
        table = gr.DataFrame(
                value=df,
                interactive=False,
                wrap=False,
                column_widths=calculate_column_widths(df),
        )

    
    model_name.submit(
        fn=filter_df,
        inputs=[model_name, size_filter, type_filter],
        outputs=table
    )
    size_filter.change(
        fn=filter_df,
        inputs=[model_name, size_filter, type_filter],
        outputs=table,
    )
    type_filter.change(
        fn=filter_df,
        inputs=[model_name, size_filter, type_filter],
        outputs=table,
    )

    with gr.Row():
        with gr.Accordion("Citation", open=False):
            citation_button = gr.Textbox(
                value=CITATION_BUTTON_TEXT,
                label=CITATION_BUTTON_LABEL,
                elem_id='citation-button',
                lines=6,  # 增加行数
                max_lines=8,  # 设置最大行数
                show_copy_button=True  # 添加复制按钮使其更方便使用
            )


def show_predictions_tab(df):
    pass


def create_interface():

    df = make_results_tab(model_info, results)
    with gr.Blocks() as demo:
        # title_comp = gr.Markdown(Initial_title)
        gr.Markdown(MAIN_LEADERBOARD_DESCRIPTION)
        with gr.Tabs(elem_classes='tab-buttons') as tabs:
            with gr.TabItem('Results', elem_id='main', id=0):
                
                show_results_tab(df)

            with gr.TabItem('Predictions', elem_id='notmain', id=1):
                
                show_predictions_tab(df)

    return demo

# model_info, results = findfile()
# breakpoint()

if __name__ == '__main__':
    demo = create_interface()
    demo.queue()
    demo.launch(server_name='0.0.0.0')