Spaces:
Running
Running
| import pandas as pd | |
| from typing import List, Dict, Optional | |
| from constants import ( | |
| DATASET_ARXIV_SCAN_PAPERS, | |
| DATASET_CONFERENCE_PAPERS, | |
| DATASET_COMMUNITY_SCIENCE, | |
| NEURIPS_ICO, | |
| DATASET_PAPER_CENTRAL, | |
| COLM_ICO, | |
| DEFAULT_ICO, | |
| MICCAI24ICO, | |
| ) | |
| import gradio as gr | |
| from utils import load_and_process | |
| import numpy as np | |
| class PaperCentral: | |
| """ | |
| A class to manage and process paper data for display in a Gradio Dataframe component. | |
| """ | |
| CONFERENCES = [ | |
| "ACL2023", | |
| "ACL2024", | |
| "COLING2024", | |
| "CVPR2023", | |
| "CVPR2024", | |
| "ECCV2024", | |
| "EMNLP2023", | |
| "NAACL2023", | |
| "NeurIPS2023", | |
| "NeurIPS2023 D&B", | |
| "COLM2024", | |
| "MICCAI2024", | |
| ] | |
| CONFERENCES_ICONS = { | |
| "ACL2023": 'https://aclanthology.org/aclicon.ico', | |
| "ACL2024": 'https://aclanthology.org/aclicon.ico', | |
| "COLING2024": 'https://aclanthology.org/aclicon.ico', | |
| "CVPR2023": "https://openaccess.thecvf.com/favicon.ico", | |
| "CVPR2024": "https://openaccess.thecvf.com/favicon.ico", | |
| "ECCV2024": "https://openaccess.thecvf.com/favicon.ico", | |
| "EMNLP2023": 'https://aclanthology.org/aclicon.ico', | |
| "NAACL2023": 'https://aclanthology.org/aclicon.ico', | |
| "NeurIPS2023": NEURIPS_ICO, | |
| "NeurIPS2023 D&B": NEURIPS_ICO, | |
| "COLM2024": COLM_ICO, | |
| "MICCAI2024": MICCAI24ICO, | |
| } | |
| # Class-level constants defining columns and their data types | |
| COLUMNS_START_PAPER_PAGE: List[str] = [ | |
| 'date', | |
| 'arxiv_id', | |
| 'paper_page', | |
| 'title', | |
| ] | |
| COLUMNS_ORDER_PAPER_PAGE: List[str] = [ | |
| 'date', | |
| 'arxiv_id', | |
| 'paper_page', | |
| 'num_models', | |
| 'num_datasets', | |
| 'num_spaces', | |
| 'upvotes', | |
| 'num_comments', | |
| 'github', | |
| 'conference_name', | |
| 'id', | |
| 'type', | |
| 'proceedings', | |
| 'title', | |
| 'authors', | |
| ] | |
| DATATYPES: Dict[str, str] = { | |
| 'date': 'str', | |
| 'arxiv_id': 'markdown', | |
| 'paper_page': 'markdown', | |
| 'upvotes': 'str', | |
| 'num_comments': 'str', | |
| 'num_models': 'markdown', | |
| 'num_datasets': 'markdown', | |
| 'num_spaces': 'markdown', | |
| 'github': 'markdown', | |
| 'title': 'str', | |
| 'proceedings': 'markdown', | |
| 'conference_name': 'str', | |
| 'id': 'str', | |
| 'type': 'str', | |
| 'authors': 'str', | |
| } | |
| # Mapping for renaming columns for display purposes | |
| COLUMN_RENAME_MAP: Dict[str, str] = { | |
| 'num_models': 'models', | |
| 'num_spaces': 'spaces', | |
| 'num_datasets': 'datasets', | |
| 'conference_name': 'venue', | |
| } | |
| def __init__(self): | |
| """ | |
| Initialize the PaperCentral class by loading and processing the datasets. | |
| """ | |
| self.df_raw: pd.DataFrame = self.get_df() | |
| self.df_prettified: pd.DataFrame = self.prettify(self.df_raw) | |
| def get_columns_order(columns: List[str]) -> List[str]: | |
| """ | |
| Get columns ordered according to COLUMNS_ORDER_PAPER_PAGE. | |
| Args: | |
| columns (List[str]): List of column names to order. | |
| Returns: | |
| List[str]: Ordered list of column names. | |
| """ | |
| return [c for c in PaperCentral.COLUMNS_ORDER_PAPER_PAGE if c in columns] | |
| def get_columns_datatypes(columns: List[str]) -> List[str]: | |
| """ | |
| Get data types for the specified columns. | |
| Args: | |
| columns (List[str]): List of column names. | |
| Returns: | |
| List[str]: List of data types corresponding to the columns. | |
| """ | |
| return [PaperCentral.DATATYPES[c] for c in columns] | |
| def get_df() -> pd.DataFrame: | |
| """ | |
| Load and merge datasets to create the raw DataFrame. | |
| Returns: | |
| pd.DataFrame: The merged and processed DataFrame. | |
| """ | |
| # Load datasets | |
| paper_central_df: pd.DataFrame = load_and_process(DATASET_PAPER_CENTRAL)[ | |
| ['arxiv_id', 'categories', 'primary_category', 'date', 'upvotes', 'num_comments', 'github', 'num_models', | |
| 'num_datasets', 'num_spaces', 'id', 'proceedings', 'type', | |
| 'conference_name', 'title', 'paper_page', 'authors'] | |
| ] | |
| return paper_central_df | |
| def format_df_date(df: pd.DataFrame, date_column: str = "date") -> pd.DataFrame: | |
| """ | |
| Format the date column in the DataFrame to 'YYYY-MM-DD'. | |
| Args: | |
| df (pd.DataFrame): The DataFrame to format. | |
| date_column (str): The name of the date column. | |
| Returns: | |
| pd.DataFrame: The DataFrame with the formatted date column. | |
| """ | |
| df.loc[:, date_column] = pd.to_datetime(df[date_column]).dt.strftime('%Y-%m-%d') | |
| return df | |
| def prettify(df: pd.DataFrame) -> pd.DataFrame: | |
| """ | |
| Prettify the DataFrame by adding markdown links and sorting. | |
| Args: | |
| df (pd.DataFrame): The DataFrame to prettify. | |
| Returns: | |
| pd.DataFrame: The prettified DataFrame. | |
| """ | |
| def update_row(row: pd.Series) -> pd.Series: | |
| """ | |
| Update a row by adding markdown links to 'paper_page' and 'arxiv_id' columns. | |
| Args: | |
| row (pd.Series): A row from the DataFrame. | |
| Returns: | |
| pd.Series: The updated row. | |
| """ | |
| # Process 'num_models' column | |
| if ( | |
| 'num_models' in row and pd.notna(row['num_models']) and row["arxiv_id"] | |
| and float(row['num_models']) > 0 | |
| ): | |
| num_models = int(float(row['num_models'])) | |
| row['num_models'] = ( | |
| f"[{num_models}](https://huggingface.co/models?other=arxiv:{row['arxiv_id']})" | |
| ) | |
| if ( | |
| 'num_datasets' in row and pd.notna(row['num_datasets']) and row["arxiv_id"] | |
| and float(row['num_datasets']) > 0 | |
| ): | |
| num_datasets = int(float(row['num_datasets'])) | |
| row['num_datasets'] = ( | |
| f"[{num_datasets}](https://huggingface.co/datasets?other=arxiv:{row['arxiv_id']})" | |
| ) | |
| if ( | |
| 'num_spaces' in row and pd.notna(row['num_spaces']) and row["arxiv_id"] | |
| and float(row['num_spaces']) > 0 | |
| ): | |
| num_spaces = int(float(row['num_spaces'])) | |
| row['num_spaces'] = ( | |
| f"[{num_spaces}](https://huggingface.co/spaces?other=arxiv:{row['arxiv_id']})" | |
| ) | |
| if 'proceedings' in row and pd.notna(row['proceedings']) and row['proceedings']: | |
| image_url = PaperCentral.CONFERENCES_ICONS.get(row["conference_name"], DEFAULT_ICO) | |
| style = "display:inline-block; vertical-align:middle; width: 16px; height:16px" | |
| row['proceedings'] = ( | |
| f"<img src='{image_url}' style='{style}'/>" | |
| f"<a href='{row['proceedings']}'>proc_page</a>" | |
| ) | |
| #### | |
| ### This should be processed last :) | |
| #### | |
| # Add markdown link to 'paper_page' if it exists | |
| if 'paper_page' in row and pd.notna(row['paper_page']) and row['paper_page']: | |
| row['paper_page'] = f"🤗[paper_page](https://huggingface.co/papers/{row['paper_page']})" | |
| # Add image and link to 'arxiv_id' if it exists | |
| if 'arxiv_id' in row and pd.notna(row['arxiv_id']) and row['arxiv_id']: | |
| image_url = "https://arxiv.org/static/browse/0.3.4/images/icons/favicon-16x16.png" | |
| style = "display:inline-block; vertical-align:middle;" | |
| row['arxiv_id'] = ( | |
| f"<img src='{image_url}' style='{style}'/>" | |
| f"<a href='https://arxiv.org/abs/{row['arxiv_id']}'>arxiv_page</a>" | |
| ) | |
| # Add image and link to 'arxiv_id' if it exists | |
| if 'github' in row and pd.notna(row['github']) and row["github"]: | |
| image_url = "https://github.githubassets.com/favicons/favicon.png" | |
| style = "display:inline-block; vertical-align:middle;width:16px;" | |
| row['github'] = ( | |
| f"<img src='{image_url}' style='{style}'/>" | |
| f"<a href='{row['github']}'>github</a>" | |
| ) | |
| return row | |
| df = df.copy() | |
| # Apply the update_row function to each row | |
| prettified_df: pd.DataFrame = df.apply(update_row, axis=1) | |
| return prettified_df | |
| def rename_columns_for_display(self, df: pd.DataFrame) -> pd.DataFrame: | |
| """ | |
| Rename columns in the DataFrame according to COLUMN_RENAME_MAP for display purposes. | |
| Args: | |
| df (pd.DataFrame): The DataFrame whose columns need to be renamed. | |
| Returns: | |
| pd.DataFrame: The DataFrame with renamed columns. | |
| """ | |
| return df.rename(columns=self.COLUMN_RENAME_MAP) | |
| def filter( | |
| self, | |
| selected_date: Optional[str] = None, | |
| cat_options: Optional[List[str]] = None, | |
| hf_options: Optional[List[str]] = None, | |
| conference_options: Optional[List[str]] = None, | |
| author_search_input: Optional[str] = None, | |
| title_search_input: Optional[str] = None, | |
| ) -> gr.update: | |
| """ | |
| Filter the DataFrame based on selected date and options, and prepare it for display. | |
| Args: | |
| selected_date (Optional[str]): The date to filter the DataFrame. | |
| hf_options (Optional[List[str]]): List of options selected by the user. | |
| conference_options (Optional[List[str]]): List of conference options selected by the user. | |
| Returns: | |
| gr.Update: An update object for the Gradio Dataframe component. | |
| """ | |
| filtered_df: pd.DataFrame = self.df_raw.copy() | |
| # Start with the initial columns to display | |
| columns_to_show: List[str] = PaperCentral.COLUMNS_START_PAPER_PAGE.copy() | |
| if title_search_input: | |
| if 'title' not in columns_to_show: | |
| columns_to_show.append('authors') | |
| search_string = title_search_input.lower() | |
| def title_match(title): | |
| if isinstance(title, str): | |
| # If authors_list is a single string | |
| return search_string in title.lower() | |
| else: | |
| # Handle unexpected data types | |
| return False | |
| filtered_df = filtered_df[filtered_df['title'].apply(title_match)] | |
| if author_search_input: | |
| if 'authors' not in columns_to_show: | |
| columns_to_show.append('authors') | |
| search_string = author_search_input.lower() | |
| def author_matches(authors_list): | |
| # Check if authors_list is None or empty | |
| if authors_list is None or len(authors_list) == 0: | |
| return False | |
| # Check if authors_list is an iterable (list, tuple, Series, or ndarray) | |
| if isinstance(authors_list, (list, tuple, pd.Series, np.ndarray)): | |
| return any( | |
| isinstance(author, str) and search_string in author.lower() | |
| for author in authors_list | |
| ) | |
| elif isinstance(authors_list, str): | |
| # If authors_list is a single string | |
| return search_string in authors_list.lower() | |
| else: | |
| # Handle unexpected data types | |
| return False | |
| filtered_df = filtered_df[filtered_df['authors'].apply(author_matches)] | |
| if cat_options: | |
| options = [o.replace(".*", "") for o in cat_options] | |
| # Initialize filter series | |
| conference_filter = pd.Series(False, index=filtered_df.index) | |
| for option in options: | |
| # Filter rows where 'conference_name' contains the conference string (case-insensitive) | |
| conference_filter |= ( | |
| filtered_df['primary_category'].notna() & | |
| filtered_df['primary_category'].str.contains(option, case=False) | |
| ) | |
| filtered_df = filtered_df[conference_filter] | |
| # Date | |
| if selected_date and not conference_options: | |
| selected_date = pd.to_datetime(selected_date).strftime('%Y-%m-%d') | |
| filtered_df = filtered_df[filtered_df['date'] == selected_date] | |
| # HF options | |
| if hf_options: | |
| if "🤗 artifacts" in hf_options: | |
| # Filter rows where 'paper_page' is not empty or NaN | |
| filtered_df = filtered_df[ | |
| (filtered_df['paper_page'] != "") & (filtered_df['paper_page'].notna()) | |
| ] | |
| # Add 'upvotes' column if not already in columns_to_show | |
| if 'upvotes' not in columns_to_show: | |
| columns_to_show.append('upvotes') | |
| # Add 'num_models' column if not already in columns_to_show | |
| if 'num_models' not in columns_to_show: | |
| columns_to_show.append('num_models') | |
| if 'num_datasets' not in columns_to_show: | |
| columns_to_show.append('num_datasets') | |
| if 'num_spaces' not in columns_to_show: | |
| columns_to_show.append('num_spaces') | |
| if "datasets" in hf_options: | |
| if 'num_datasets' not in columns_to_show: | |
| columns_to_show.append('num_datasets') | |
| filtered_df = filtered_df[filtered_df['num_datasets'] != 0] | |
| if "models" in hf_options: | |
| if 'num_models' not in columns_to_show: | |
| columns_to_show.append('num_models') | |
| filtered_df = filtered_df[filtered_df['num_models'] != 0] | |
| if "spaces" in hf_options: | |
| if 'num_spaces' not in columns_to_show: | |
| columns_to_show.append('num_spaces') | |
| filtered_df = filtered_df[filtered_df['num_spaces'] != 0] | |
| if "github" in hf_options: | |
| if 'github' not in columns_to_show: | |
| columns_to_show.append('github') | |
| filtered_df = filtered_df[(filtered_df['github'] != "") & (filtered_df['github'].notnull())] | |
| # Apply conference filtering | |
| if conference_options: | |
| columns_to_show.remove("date") | |
| columns_to_show.remove("arxiv_id") | |
| if 'conference_name' not in columns_to_show: | |
| columns_to_show.append('conference_name') | |
| if 'proceedings' not in columns_to_show: | |
| columns_to_show.append('proceedings') | |
| if 'type' not in columns_to_show: | |
| columns_to_show.append('type') | |
| if 'id' not in columns_to_show: | |
| columns_to_show.append('id') | |
| # If "In proceedings" is selected | |
| if "In proceedings" in conference_options: | |
| # Filter rows where 'conference_name' is not None, not NaN, and not empty | |
| filtered_df = filtered_df[ | |
| filtered_df['conference_name'].notna() & (filtered_df['conference_name'] != "") | |
| ] | |
| # For other conference options | |
| other_conferences = [conf for conf in conference_options if conf != "In proceedings"] | |
| if other_conferences: | |
| # Initialize filter series | |
| conference_filter = pd.Series(False, index=filtered_df.index) | |
| for conference in other_conferences: | |
| # Filter rows where 'conference_name' contains the conference string (case-insensitive) | |
| conference_filter |= ( | |
| filtered_df['conference_name'].notna() & | |
| (filtered_df['conference_name'].str.lower() == conference.lower()) | |
| ) | |
| filtered_df = filtered_df[conference_filter] | |
| # Prettify the DataFrame | |
| filtered_df = self.prettify(filtered_df) | |
| # Ensure columns are ordered according to COLUMNS_ORDER_PAPER_PAGE | |
| columns_in_order: List[str] = [col for col in PaperCentral.COLUMNS_ORDER_PAPER_PAGE if col in columns_to_show] | |
| # Select and reorder the columns | |
| filtered_df = filtered_df[columns_in_order] | |
| # Rename columns for display | |
| filtered_df = self.rename_columns_for_display(filtered_df) | |
| # Get the corresponding data types for the columns | |
| new_datatypes: List[str] = [ | |
| PaperCentral.DATATYPES.get(self._get_original_column_name(col), 'str') for col in filtered_df.columns | |
| ] | |
| # Sort rows to display entries with 'paper_page' first | |
| if 'paper_page' in filtered_df.columns: | |
| filtered_df['has_paper_page'] = filtered_df['paper_page'].notna() & (filtered_df['paper_page'] != "") | |
| filtered_df.sort_values(by='has_paper_page', ascending=False, inplace=True) | |
| filtered_df.drop(columns='has_paper_page', inplace=True) | |
| # Return an update object to modify the Dataframe component | |
| return gr.update(value=filtered_df, datatype=new_datatypes) | |
| def _get_original_column_name(self, display_column_name: str) -> str: | |
| """ | |
| Retrieve the original column name given a display column name. | |
| Args: | |
| display_column_name (str): The display name of the column. | |
| Returns: | |
| str: The original name of the column. | |
| """ | |
| inverse_map = {v: k for k, v in self.COLUMN_RENAME_MAP.items()} | |
| return inverse_map.get(display_column_name, display_column_name) | |