Hugues Sibille commited on
Commit
edb334d
β€’
1 Parent(s): 187990b

feat : metrics dropdown added to gradio

Browse files
Files changed (1) hide show
  1. app.py +62 -239
app.py CHANGED
@@ -1,263 +1,86 @@
1
- import json
2
- import os
3
-
4
  import gradio as gr
5
- import pandas as pd
6
- from huggingface_hub import HfApi, hf_hub_download, get_collection
7
- from huggingface_hub.repocard import metadata_load
8
- from typing import Dict
9
-
10
-
11
- def get_datasets_nickname() -> Dict:
12
-
13
- datasets_nickname = {}
14
-
15
- collection = get_collection("vidore/vidore-benchmark-667173f98e70a1c0fa4db00d")
16
-
17
- collection_items = collection.items
18
-
19
- for item in collection_items:
20
- dataset_name = item.item_id
21
-
22
- if 'arxivqa' in dataset_name:
23
- datasets_nickname[dataset_name] = 'ArxivQA'
24
- datasets_nickname[dataset_name + '_ocr_chunk'] = 'ArxivQA'
25
- datasets_nickname[dataset_name + '_captioning'] = 'ArxivQA'
26
-
27
- elif 'docvqa' in dataset_name:
28
- datasets_nickname[dataset_name] = 'DocVQA'
29
- datasets_nickname[dataset_name + '_ocr_chunk'] = 'DocVQA'
30
- datasets_nickname[dataset_name + '_captioning'] = 'DocVQA'
31
-
32
- elif 'infovqa' in dataset_name:
33
- datasets_nickname[dataset_name] = 'InfoVQA'
34
- datasets_nickname[dataset_name + '_ocr_chunk'] = 'InfoVQA'
35
- datasets_nickname[dataset_name + '_captioning'] = 'InfoVQA'
36
-
37
- elif 'tabfquad' in dataset_name:
38
- datasets_nickname[dataset_name] = 'TabFQuad'
39
- datasets_nickname[dataset_name + '_ocr_chunk'] = 'TabFQuad'
40
- datasets_nickname[dataset_name + '_captioning'] = 'TabFQuad'
41
-
42
- elif 'tatdqa' in dataset_name:
43
- datasets_nickname[dataset_name] = 'TATDQA'
44
- datasets_nickname[dataset_name + '_ocr_chunk'] = 'TATDQA'
45
- datasets_nickname[dataset_name + '_captioning'] = 'TATDQA'
46
-
47
- elif 'shiftproject' in dataset_name:
48
- datasets_nickname[dataset_name] = 'ShiftProject'
49
- datasets_nickname[dataset_name + '_ocr_chunk'] = 'ShiftProject'
50
- datasets_nickname[dataset_name + '_captioning'] = 'ShiftProject'
51
-
52
- elif 'artificial_intelligence' in dataset_name:
53
- datasets_nickname[dataset_name] = 'Artificial Intelligence'
54
- datasets_nickname[dataset_name + '_ocr_chunk'] = 'Artificial Intelligence'
55
- datasets_nickname[dataset_name + '_captioning'] = 'Artificial Intelligence'
56
-
57
- elif 'energy' in dataset_name:
58
- datasets_nickname[dataset_name] = 'Energy'
59
- datasets_nickname[dataset_name + '_ocr_chunk'] = 'Energy'
60
- datasets_nickname[dataset_name + '_captioning'] = 'Energy'
61
-
62
- elif 'government_reports' in dataset_name:
63
- datasets_nickname[dataset_name] = 'Government Reports'
64
- datasets_nickname[dataset_name + '_ocr_chunk'] = 'Government Reports'
65
- datasets_nickname[dataset_name + '_captioning'] = 'Government Reports'
66
-
67
- elif 'healthcare' in dataset_name:
68
- datasets_nickname[dataset_name] = 'Healthcare'
69
- datasets_nickname[dataset_name + '_ocr_chunk'] = 'Healthcare'
70
- datasets_nickname[dataset_name + '_captioning'] = 'Healthcare'
71
-
72
- return datasets_nickname
73
-
74
-
75
- def make_clickable_model(model_name, link=None):
76
-
77
- if link is None:
78
- desanitized_model_name = model_name.replace("_", "/")
79
-
80
- if '/captioning' in desanitized_model_name:
81
- desanitized_model_name = desanitized_model_name.replace('/captioning', '')
82
- if '/ocr' in desanitized_model_name:
83
- desanitized_model_name = desanitized_model_name.replace('/ocr', '')
84
-
85
- link = "https://huggingface.co/" + desanitized_model_name
86
-
87
- # Remove user from model name
88
- # return (
89
- # f'<a target="_blank" style="text-decoration: underline" href="{link}">{model_name.split("/")[-1]}</a>'
90
- # )
91
- return f'<a target="_blank" style="text-decoration: underline" href="{link}">{model_name}</a>'
92
-
93
-
94
- def add_rank(df):
95
- cols_to_rank = [
96
- col
97
- for col in df.columns
98
- if col
99
- not in [
100
- "Model",
101
- "Model Size (Million Parameters)",
102
- "Memory Usage (GB, fp32)",
103
- "Embedding Dimensions",
104
- "Max Tokens",
105
- ]
106
- ]
107
- if len(cols_to_rank) == 1:
108
- df.sort_values(cols_to_rank[0], ascending=False, inplace=True)
109
- else:
110
- df.insert(len(df.columns) - len(cols_to_rank), "Average", df[cols_to_rank].mean(axis=1, skipna=False))
111
- df.sort_values("Average", ascending=False, inplace=True)
112
- df.insert(0, "Rank", list(range(1, len(df) + 1)))
113
- df = df.round(2)
114
- # Fill NaN after averaging
115
- df.fillna("", inplace=True)
116
- return df
117
-
118
-
119
- def get_vidore_data():
120
- api = HfApi()
121
 
122
- # local cache path
123
- model_infos_path = "model_infos.json"
124
- metric = "ndcg_at_5"
125
 
126
-
127
- MODEL_INFOS = {}
128
- if os.path.exists(model_infos_path):
129
- with open(model_infos_path) as f:
130
- MODEL_INFOS = json.load(f)
131
-
132
- models = api.list_models(filter="vidore")
133
- repositories = [model.modelId for model in models]
134
-
135
- datasets_nickname = get_datasets_nickname()
136
- for repo_id in repositories:
137
- files = [f for f in api.list_repo_files(repo_id) if f.endswith('_metrics.json')]
138
- if len(files) == 0:
139
- continue
140
- else :
141
- for file in files:
142
- model_name = file.split('_metrics.json')[0]
143
-
144
- if model_name not in MODEL_INFOS:
145
- readme_path = hf_hub_download(repo_id, filename="README.md")
146
- meta = metadata_load(readme_path)
147
- try:
148
- result_path = hf_hub_download(repo_id, filename= file)
149
-
150
- with open(result_path) as f:
151
- results = json.load(f)
152
-
153
- for dataset in results:
154
- results[dataset] = {key: value for key, value in results[dataset].items() if metric in key}
155
-
156
- MODEL_INFOS[model_name] = {"meta":meta, "results": results}
157
- except Exception as e:
158
- print(f"Error loading {model_name} - {e}")
159
- continue
160
 
161
- model_res = {}
162
- df = None
163
- if len(MODEL_INFOS) > 0:
164
- for model in MODEL_INFOS.keys():
165
- res = MODEL_INFOS[model]["results"]
166
- dataset_res = {}
167
- for dataset in res.keys():
168
- if "validation_set" == dataset:
169
- continue
170
- dataset_res[datasets_nickname[dataset]] = res[dataset][metric]
171
- model_res[model] = dataset_res
172
-
173
- df = pd.DataFrame(model_res).T
174
-
175
- #save model_infos
176
- with open(model_infos_path, "w") as f:
177
- json.dump(MODEL_INFOS, f)
178
-
179
- return df
180
-
181
 
182
- def add_rank_and_format(df):
183
- df = df.reset_index()
184
- df = df.rename(columns={"index": "Model"})
185
- df = add_rank(df)
186
- df["Model"] = df["Model"].apply(make_clickable_model)
187
- return df
188
 
 
 
 
 
189
 
190
- # 1. Force headers to wrap
191
- # 2. Force model column (maximum) width
192
- # 3. Prevent model column from overflowing, scroll instead
193
- # 4. Prevent checkbox groups from taking up too much space
194
 
195
- css = """
196
- table > thead {
197
- white-space: normal
198
- }
199
 
200
- table {
201
- --cell-width-1: 250px
202
- }
 
203
 
204
- table > tbody > tr > td:nth-child(2) > div {
205
- overflow-x: auto
206
- }
207
 
208
- .filter-checkbox-group {
209
- max-width: max-content;
210
- }
211
- """
 
 
 
212
 
 
 
 
213
 
214
- def get_refresh_function():
215
- def _refresh():
216
- data_task_category = get_vidore_data()
217
- return add_rank_and_format(data_task_category)
218
 
219
- return _refresh
 
 
220
 
221
- data = get_vidore_data()
222
- data = add_rank_and_format(data)
223
 
224
- NUM_DATASETS = len(data.columns) - 3
225
- NUM_SCORES = len(data) * NUM_DATASETS
226
- NUM_MODELS = len(data)
227
 
228
- with gr.Blocks(css=css) as block:
229
- gr.Markdown("# ViDoRe: The Visual Document Retrieval Benchmark πŸ“šπŸ”")
230
- gr.Markdown("## From the paper - ColPali: Efficient Document Retrieval with Vision Language Models πŸ‘€")
231
 
232
- gr.Markdown(
 
 
 
 
233
  """
234
- Visual Document Retrieval Benchmark leaderboard. To submit, refer to the <a href="https://github.com/tonywu71/vidore-benchmark/" target="_blank" style="text-decoration: underline">ViDoRe GitHub repository</a>. Refer to the [ColPali paper](https://arxiv.org/abs/XXXX.XXXXX) for details on metrics, tasks and models.
235
- """
236
- )
237
 
238
- with gr.Row():
239
- datatype = ["number", "markdown"] + ["number"] * (NUM_DATASETS + 1)
240
- dataframe = gr.Dataframe(data, datatype=datatype, type="pandas", height=500)
241
-
242
- with gr.Row():
243
- refresh_button = gr.Button("Refresh")
244
- refresh_button.click(get_refresh_function(), inputs=None, outputs=dataframe, concurrency_limit=20)
245
 
246
- gr.Markdown(
247
- f"""
248
- - **Total Datasets**: {NUM_DATASETS}
249
- - **Total Scores**: {NUM_SCORES}
250
- - **Total Models**: {NUM_MODELS}
251
- """
252
- + r"""
253
- Please consider citing:
254
 
255
- ```bibtex
256
- INSERT LATER
257
- ```
258
- """
259
- )
260
 
 
 
261
 
262
- if __name__ == "__main__":
263
- block.queue(max_size=10).launch(debug=True)
 
1
+ from data.model_handler import ModelHandler
2
+ from app.utils import add_rank_and_format, get_refresh_function
 
3
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
+ METRICS = ["ndcg_at_5", "recall_at_1", "recall_at_5", "mrr_at_5"]
 
 
6
 
7
+ def main():
8
+ model_handler = ModelHandler()
9
+ initial_metric = "ndcg_at_5"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
+ data = model_handler.get_vidore_data(initial_metric)
12
+ data = add_rank_and_format(data)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
+ NUM_DATASETS = len(data.columns) - 3
15
+ NUM_SCORES = len(data) * NUM_DATASETS
16
+ NUM_MODELS = len(data)
 
 
 
17
 
18
+ css = """
19
+ table > thead {
20
+ white-space: normal
21
+ }
22
 
23
+ table {
24
+ --cell-width-1: 250px
25
+ }
 
26
 
27
+ table > tbody > tr > td:nth-child(2) > div {
28
+ overflow-x: auto
29
+ }
 
30
 
31
+ .filter-checkbox-group {
32
+ max-width: max-content;
33
+ }
34
+ """
35
 
36
+ with gr.Blocks(css=css) as block:
37
+ gr.Markdown("# ViDoRe: The Visual Document Retrieval Benchmark πŸ“šπŸ”")
38
+ gr.Markdown("## From the paper - ColPali: Efficient Document Retrieval with Vision Language Models πŸ‘€")
39
 
40
+ gr.Markdown(
41
+ """
42
+ Visual Document Retrieval Benchmark leaderboard. To submit, refer to the <a href="https://github.com/tonywu71/vidore-benchmark/" target="_blank" style="text-decoration: underline">ViDoRe GitHub repository</a>. Refer to the [ColPali paper](https://arxiv.org/abs/XXXX.XXXXX) for details on metrics, tasks and models.
43
+ """
44
+ )
45
+ #all_columns = list(data.columns)
46
+ #default_columns = all_columns
47
 
48
+ with gr.Row():
49
+ metric_dropdown = gr.Dropdown(choices=METRICS, value=initial_metric, label="Select Metric")
50
+ #column_checkboxes = gr.CheckboxGroup(choices=all_columns, value=default_columns, label="Select Columns to Display")
51
 
52
+ with gr.Row():
53
+ datatype = ["number", "markdown"] + ["number"] * (NUM_DATASETS + 1)
54
+ dataframe = gr.Dataframe(data, datatype=datatype, type="pandas")
 
55
 
56
+ with gr.Row():
57
+ refresh_button = gr.Button("Refresh")
58
+ refresh_button.click(get_refresh_function(), inputs=[metric_dropdown], outputs=dataframe, concurrency_limit=20)
59
 
 
 
60
 
61
+ # Automatically refresh the dataframe when the dropdown value changes
62
+ metric_dropdown.change(get_refresh_function(), inputs=[metric_dropdown], outputs=dataframe)
63
+ #column_checkboxes.change(get_refresh_function(), inputs=[metric_dropdown, column_checkboxes], outputs=dataframe)
64
 
 
 
 
65
 
66
+ gr.Markdown(
67
+ f"""
68
+ - **Total Datasets**: {NUM_DATASETS}
69
+ - **Total Scores**: {NUM_SCORES}
70
+ - **Total Models**: {NUM_MODELS}
71
  """
72
+ + r"""
73
+ Please consider citing:
 
74
 
75
+ ```bibtex
76
+ INSERT LATER
77
+ ```
78
+ """
79
+ )
 
 
80
 
81
+ block.queue(max_size=10).launch(debug=True)
 
 
 
 
 
 
 
82
 
 
 
 
 
 
83
 
84
+ if __name__ == "__main__":
85
+ main()
86