Spaces:
Running
Running
| import gradio as gr | |
| from typing import Optional | |
| import pandas as pd | |
| from huggingface_hub import HfApi, hf_hub_download, CommitOperationAdd | |
| import json | |
| import os | |
| import requests | |
| # PR function remains the same | |
| def create_pr_in_hf_dataset(new_entry, oauth_token: gr.OAuthToken): | |
| # Dataset and directory | |
| REPO_ID = 'IAMJB/paper-central-pr' | |
| DATA_DIR = 'data' | |
| # Initialize HfApi | |
| api = HfApi() | |
| token = oauth_token.token | |
| # Ensure the repository exists and the data directory is initialized | |
| try: | |
| # Create the repository if it doesn't exist | |
| api.create_repo(repo_id=REPO_ID, token=token, repo_type='dataset', exist_ok=True) | |
| # Check if the data directory exists; if not, create it | |
| files = api.list_repo_files(REPO_ID, repo_type='dataset', token=token) | |
| if DATA_DIR not in files: | |
| # Create an empty directory (you can add a .gitkeep file) | |
| temp_filename = 'temp_gitkeep' | |
| with open(temp_filename, 'w') as f: | |
| pass # Empty file | |
| commit = CommitOperationAdd(path_in_repo=f"{DATA_DIR}/.gitkeep", path_or_fileobj=temp_filename) | |
| api.create_commit( | |
| repo_id=REPO_ID, | |
| operations=[commit], | |
| commit_message="Initialize data directory", | |
| repo_type="dataset", | |
| token=token, | |
| ) | |
| os.remove(temp_filename) | |
| except Exception as e: | |
| return f"Error creating or accessing repository: {e}" | |
| # Save the new entry to a temporary file | |
| temp_filename = f"{new_entry['arxiv_id']}.json" | |
| with open(temp_filename, 'w') as f: | |
| json.dump(new_entry, f, indent=2) | |
| # Define the path in the repository | |
| path_in_repo = f"{DATA_DIR}/{temp_filename}" | |
| # Create commit operation | |
| commit = CommitOperationAdd(path_in_repo=path_in_repo, path_or_fileobj=temp_filename) | |
| # Create PR | |
| try: | |
| res = api.create_commit( | |
| repo_id=REPO_ID, | |
| operations=[commit], | |
| commit_message=f"Add new entry for arXiv ID {new_entry['arxiv_id']}", | |
| repo_type="dataset", | |
| create_pr=True, | |
| token=token, | |
| ) | |
| pr_url = res.pr_url | |
| os.remove(temp_filename) | |
| except Exception as e: | |
| print(f"Error creating PR: {e}") | |
| pr_url = "Error creating PR." | |
| return pr_url | |
| def pr_paper_central_tab(paper_central_df): | |
| with gr.Column(): | |
| gr.Markdown("## Edit papers") | |
| # Message to prompt user to log in | |
| login_prompt = gr.Markdown("Please log in to proceed.", visible=False) | |
| # Input for arXiv ID | |
| arxiv_id_input = gr.Textbox(label="Enter arXiv ID") | |
| arxiv_id_button = gr.Button("Submit") | |
| # Message to display errors or information | |
| message = gr.Markdown("", visible=False) | |
| # Button to create paper page | |
| create_paper_page_button = gr.Button("Create Paper Page", visible=False, | |
| icon="https://huggingface.co/front/assets/huggingface_logo-noborder.svg") | |
| # Define the fields dynamically (removed 'paper_page') | |
| fields = [ | |
| {'name': 'github', 'label': 'GitHub URL'}, | |
| {'name': 'conference_name', 'label': 'Conference Name'}, | |
| {'name': 'type_', 'label': 'Type'}, # Renamed from 'type' to 'type_' | |
| {'name': 'proceedings', 'label': 'Proceedings'}, | |
| {'name': 'project_page', 'label': 'Project page'}, | |
| # Add or remove fields here as needed | |
| ] | |
| input_fields = {} | |
| for field in fields: | |
| input_fields[field['name']] = gr.Textbox(label=field['label'], visible=False) | |
| # Button to create PR | |
| create_pr_button = gr.Button("Create PR", visible=False, | |
| icon="https://huggingface.co/front/assets/huggingface_logo-noborder.svg") | |
| # Output message | |
| pr_message = gr.Markdown("", visible=False) | |
| # Loading message | |
| loading_message = gr.Markdown("Creating PR, please wait...", visible=False) | |
| # Function to handle arxiv_id submission and check login | |
| def check_login_and_handle_arxiv_id(arxiv_id, oauth_token: Optional[gr.OAuthToken]): | |
| if oauth_token is None: | |
| # Not logged in | |
| return [gr.update(value="Please log in to proceed.", visible=True)] + \ | |
| [gr.update(visible=False) for _ in fields] + \ | |
| [gr.update(visible=False)] + [gr.update(visible=False)] + [ | |
| gr.update(visible=False)] # create_pr_button, create_paper_page_button, pr_message | |
| else: | |
| ACCESS_TOKEN = os.getenv('paper_space_pr_token') | |
| access_token_exists = ACCESS_TOKEN is not None | |
| # Prepare the updates list | |
| updates = [] | |
| if arxiv_id not in paper_central_df['arxiv_id'].values: | |
| # arXiv ID not found | |
| updates.append(gr.update(value="arXiv ID not found. You can create a paper page.", visible=True)) | |
| # Input fields are empty | |
| for field in fields: | |
| updates.append(gr.update(value="", visible=True)) | |
| updates.append(gr.update(visible=True)) # create_pr_button | |
| # Show 'Create Paper Page' button if access token exists | |
| updates.append(gr.update(visible=access_token_exists)) # create_paper_page_button | |
| updates.append(gr.update(visible=False)) # pr_message | |
| else: | |
| # arXiv ID found | |
| row = paper_central_df[paper_central_df['arxiv_id'] == arxiv_id].iloc[0] | |
| paper_page = row.get('paper_page', "") | |
| if not paper_page: | |
| # paper_page missing or empty | |
| updates.append(gr.update(value="Paper page not found. You can create one.", visible=True)) | |
| for field in fields: | |
| value = row.get(field['name'], "") | |
| updates.append(gr.update(value=value, visible=True)) | |
| updates.append(gr.update(visible=True)) # create_pr_button | |
| updates.append(gr.update(visible=access_token_exists)) # create_paper_page_button | |
| updates.append(gr.update(visible=False)) # pr_message | |
| else: | |
| # paper_page exists | |
| updates.append(gr.update(value="", visible=False)) # message | |
| for field in fields: | |
| value = row.get(field['name'], "") | |
| updates.append(gr.update(value=value, visible=True)) | |
| updates.append(gr.update(visible=True)) # create_pr_button | |
| updates.append(gr.update(visible=False)) # create_paper_page_button | |
| updates.append(gr.update(visible=False)) # pr_message | |
| return updates | |
| arxiv_id_button.click( | |
| fn=check_login_and_handle_arxiv_id, | |
| inputs=[arxiv_id_input], | |
| outputs=[message] + [input_fields[field['name']] for field in fields] + [create_pr_button, | |
| create_paper_page_button, | |
| pr_message], | |
| api_name=False | |
| ) | |
| # Function to create PR | |
| def create_pr(arxiv_id, | |
| github, | |
| conference_name, | |
| type_, | |
| proceedings, | |
| project_page, | |
| oauth_token: Optional[gr.OAuthToken] = None): | |
| if oauth_token is None: | |
| return gr.update(value="Please log in first.", visible=True) | |
| else: | |
| new_entry = { | |
| 'arxiv_id': arxiv_id, | |
| 'github': github, | |
| 'conference_name': conference_name, | |
| 'type': type_, | |
| 'project_page': project_page | |
| } | |
| # Now add this to the dataset and create a PR | |
| pr_url = create_pr_in_hf_dataset(new_entry, oauth_token) | |
| return gr.update(value=f"PR created: {pr_url}", visible=True) | |
| create_pr_button.click( | |
| fn=lambda: gr.update(visible=True), # Show loading message | |
| inputs=[], | |
| outputs=[loading_message], | |
| api_name=False | |
| ).then( | |
| fn=create_pr, | |
| inputs=[arxiv_id_input] + [input_fields[field['name']] for field in fields], | |
| outputs=[pr_message], | |
| api_name=False | |
| ).then( | |
| fn=lambda: gr.update(visible=False), # Hide loading message | |
| inputs=[], | |
| outputs=[loading_message], | |
| api_name=False | |
| ) | |
| # Function to create paper page | |
| def create_paper_page(arxiv_id): | |
| # Implement the API calls to create the paper page | |
| INDEX_URL = "https://huggingface.co/api/papers/index" | |
| SUBMIT_URL = "https://huggingface.co/api/papers/submit" | |
| ACCESS_TOKEN = os.getenv('paper_space_pr_token') | |
| if not ACCESS_TOKEN: | |
| return gr.update(value="Server error: Access token not found.", visible=True) | |
| # Index the paper | |
| payload_index = {"arxivId": arxiv_id} | |
| headers = { | |
| "Authorization": f"Bearer {ACCESS_TOKEN}", | |
| "Content-Type": "application/json" | |
| } | |
| response_index = requests.post(INDEX_URL, json=payload_index, headers=headers) | |
| if response_index.status_code == 200: | |
| # Successfully indexed, now submit the paper | |
| paper_id = arxiv_id # Assuming paperId is the same as arxivId | |
| payload_submit = { | |
| "paperId": paper_id, | |
| "comment": "", | |
| "mediaUrls": [] | |
| } | |
| response_submit = requests.post(SUBMIT_URL, json=payload_submit, headers=headers) | |
| if response_submit.status_code == 200: | |
| return gr.update(value="Paper page created successfully.", visible=True) | |
| else: | |
| return gr.update( | |
| value=f"Failed to submit paper: {response_submit.status_code}, {response_submit.text}", | |
| visible=True) | |
| else: | |
| return gr.update(value=f"Failed to index paper: {response_index.status_code}, {response_index.text}", | |
| visible=True) | |
| create_paper_page_button.click( | |
| fn=create_paper_page, | |
| inputs=[arxiv_id_input], | |
| outputs=[message], | |
| api_name=False | |
| ) | |