import pandas as pd import gradio as gr from transformers import pipeline import nltk from retrieval_bm25s import retrieve_with_bm25s from retrieval_bert import retrieve_with_deberta from retrieval_gpt import retrieve_with_gpt import os import json from datetime import datetime from pathlib import Path from uuid import uuid4 import spaces def is_running_in_hf_spaces(): """ Detects if app is running in Hugging Face Spaces """ return "SPACE_ID" in os.environ if gr.NO_RELOAD: # Resource punkt_tab not found during application startup on HF spaces nltk.download("punkt_tab") # Keep track of the model name in a global variable so correct model is shown after page refresh # https://github.com/gradio-app/gradio/issues/3173 MODEL_NAME = "jedick/DeBERTa-v3-base-mnli-fever-anli-scifact-citint" pipe = pipeline( "text-classification", model=MODEL_NAME, ) # Setup user feedback file for uploading to HF dataset # https://huggingface.co/spaces/Wauplin/space_to_dataset_saver # https://huggingface.co/docs/huggingface_hub/v0.16.3/en/guides/upload#scheduled-uploads USER_FEEDBACK_DIR = Path("user_feedback") USER_FEEDBACK_DIR.mkdir(parents=True, exist_ok=True) USER_FEEDBACK_PATH = USER_FEEDBACK_DIR / f"train-{uuid4()}.json" if is_running_in_hf_spaces(): from huggingface_hub import CommitScheduler scheduler = CommitScheduler( repo_id="AI4citations-feedback", repo_type="dataset", folder_path=USER_FEEDBACK_DIR, path_in_repo="data", ) # Setup theme without background image my_theme = gr.Theme.from_hub("NoCrypt/miku") my_theme.set(body_background_fill="#FFFFFF", body_background_fill_dark="#000000") # Define the HTML for Font Awesome font_awesome_html = '' # Gradio interface setup with gr.Blocks(theme=my_theme, head=font_awesome_html) as demo: # Layout with gr.Row(): with gr.Column(scale=3): with gr.Row(): gr.Markdown("# AI4citations") gr.Markdown( "## *AI-powered citation verification* ([more info](https://github.com/jedick/AI4citations))" ) claim = gr.Textbox( label="Claim", info="aka hypothesis", placeholder="Input claim", ) with gr.Row(): with gr.Column(scale=2): with gr.Accordion("Get Evidence from PDF"): pdf_file = gr.File( label="Upload PDF", type="filepath", height=120 ) with gr.Row(): retrieval_method = gr.Radio( choices=["BM25S", "DeBERTa", "GPT"], value="BM25S", label="Retrieval Method", info="Keyword search (BM25S) or AI (DeBERTa, GPT)", ) get_evidence = gr.Button(value="Get Evidence") top_k = gr.Slider( 1, 10, value=5, step=1, label="Top k sentences", ) with gr.Column(scale=3): evidence = gr.TextArea( label="Evidence", info="aka premise", placeholder="Input evidence or use Get Evidence from PDF", ) with gr.Row(): prompt_tokens = gr.Number(label="Prompt tokens", visible=False) completion_tokens = gr.Number( label="Completion tokens", visible=False ) gr.Markdown( """ ### App Usage: - Input a **Claim**, then: - Upload a PDF and click **Get Evidence** OR - Input **Evidence** statements yourself - Make the **Prediction**: - Hit 'Enter' in the **Claim** text box OR - Hit 'Shift-Enter' in the **Evidence** text box OR - Click **Get Evidence** """ ) with gr.Accordion("Sources", open=False): gr.Markdown( """ #### *Capstone project* - [jedick/MLE-capstone-project](https://github.com/jedick/MLE-capstone-project) (project repo) - [jedick/AI4citations](https://github.com/jedick/AI4citations) (app repo) #### *Text Classification* - [jedick/DeBERTa-v3-base-mnli-fever-anli-scifact-citint](https://huggingface.co/jedick/DeBERTa-v3-base-mnli-fever-anli-scifact-citint) (fine-tuned) - [MoritzLaurer/DeBERTa-v3-base-mnli-fever-anli](https://huggingface.co/MoritzLaurer/DeBERTa-v3-base-mnli-fever-anli) (base) #### *Evidence Retrieval* - [xhluca/bm25s](https://github.com/xhluca/bm25s) (BM25S) - [deepset/deberta-v3-large-squad2](https://huggingface.co/deepset/deberta-v3-large-squad2) (DeBERTa) - [gpt-4o-mini-2024-07-18](https://platform.openai.com/docs/pricing) (GPT) #### *Datasets for fine-tuning* - [allenai/SciFact](https://github.com/allenai/scifact) (SciFact) - [ScienceNLP-Lab/Citation-Integrity](https://github.com/ScienceNLP-Lab/Citation-Integrity) (CitInt) #### *Other sources* - [Medicine](https://doi.org/10.1371/journal.pmed.0030197), [CRISPR](https://en.wikipedia.org/wiki/CRISPR) (evidence retrieval examples) - [nyu-mll/multi_nli](https://huggingface.co/datasets/nyu-mll/multi_nli/viewer/default/train?row=37&views%5B%5D=train) (MNLI example) - [NoCrypt/miku](https://huggingface.co/spaces/NoCrypt/miku) (theme) """ ) with gr.Column(scale=2): prediction = gr.Label(label="Prediction") with gr.Accordion("Feedback"): gr.Markdown( "*Provide the correct label to help improve this app*
**NOTE:** The claim and evidence will also be saved" ), with gr.Row(): flag_support = gr.Button("Support") flag_nei = gr.Button("NEI") flag_refute = gr.Button("Refute") gr.Markdown( "Feedback is uploaded every 5 minutes to [AI4citations-feedback](https://huggingface.co/datasets/jedick/AI4citations-feedback)" ), with gr.Accordion("Examples"): gr.Markdown("*Examples are run when clicked*"), with gr.Row(): support_example = gr.Examples( examples="examples/Support", label="Support", inputs=[claim, evidence], example_labels=pd.read_csv("examples/Support/log.csv")[ "label" ].tolist(), ) nei_example = gr.Examples( examples="examples/NEI", label="NEI", inputs=[claim, evidence], example_labels=pd.read_csv("examples/NEI/log.csv")[ "label" ].tolist(), ) refute_example = gr.Examples( examples="examples/Refute", label="Refute", inputs=[claim, evidence], example_labels=pd.read_csv("examples/Refute/log.csv")[ "label" ].tolist(), ) retrieval_example = gr.Examples( examples="examples/retrieval", label="Get Evidence from PDF", inputs=[pdf_file, claim], example_labels=pd.read_csv("examples/retrieval/log.csv")[ "label" ].tolist(), ) # Create dropdown menu to select the model model = gr.Dropdown( choices=[ # TODO: For bert-base-uncased, how can we set num_labels = 2 in HF pipeline? # (num_labels is available in AutoModelForSequenceClassification.from_pretrained) # "bert-base-uncased", "MoritzLaurer/DeBERTa-v3-base-mnli-fever-anli", "jedick/DeBERTa-v3-base-mnli-fever-anli-scifact-citint", ], value=MODEL_NAME, label="Model", info="Text classification model used for claim verification", ) # Functions @spaces.GPU(duration=10) def query_model(claim, evidence): """ Get prediction for a claim and evidence pair """ prediction = { # Send a dictionary containing {"text", "text_pair"} keys; use top_k=3 to get results for all classes # https://huggingface.co/docs/transformers/v4.51.3/en/main_classes/pipelines#transformers.TextClassificationPipeline.__call__.inputs # Put evidence before claim # https://github.com/jedick/MLE-capstone-project # Output {label: confidence} dictionary format as expected by gr.Label() # https://github.com/gradio-app/gradio/issues/11170 d["label"]: d["score"] for d in pipe({"text": evidence, "text_pair": claim}, top_k=3) } # Rename dictionary keys to use consistent labels across models prediction = { ("SUPPORT" if k in ["SUPPORT", "entailment"] else k): v for k, v in prediction.items() } prediction = { ("NEI" if k in ["NEI", "neutral"] else k): v for k, v in prediction.items() } prediction = { ("REFUTE" if k in ["REFUTE", "contradiction"] else k): v for k, v in prediction.items() } return prediction def select_model(model_name): """ Select the specified model """ global pipe, MODEL_NAME MODEL_NAME = model_name pipe = pipeline( "text-classification", model=MODEL_NAME, ) # From gradio/client/python/gradio_client/utils.py def is_http_url_like(possible_url) -> bool: """ Check if the given value is a string that looks like an HTTP(S) URL. """ if not isinstance(possible_url, str): return False return possible_url.startswith(("http://", "https://")) def select_example(value, evt: gr.EventData): # Get the PDF file and claim from the event data claim, evidence = value[1] # Add the directory path return claim, evidence def select_retrieval_example(value, evt: gr.EventData): """ Get the PDF file and claim from the event data. """ pdf_file, claim = value[1] # Add the directory path if not is_http_url_like(pdf_file): pdf_file = f"examples/retrieval/{pdf_file}" return pdf_file, claim @spaces.GPU() def _retrieve_with_deberta(pdf_file, claim, top_k): """ Retrieve evidence using DeBERTa """ return retrieve_with_deberta(pdf_file, claim, top_k) def retrieve_evidence(pdf_file, claim, top_k, method): """ Retrieve evidence using the selected method """ if method == "BM25S": # Append 0 for number of prompt and completion tokens return retrieve_with_bm25s(pdf_file, claim, top_k), 0, 0 elif method == "DeBERTa": return _retrieve_with_deberta(pdf_file, claim, top_k), 0, 0 elif method == "GPT": return retrieve_with_gpt(pdf_file, claim) else: return f"Unknown retrieval method: {method}" def append_feedback( claim: str, evidence: str, model: str, prediction: str, user_label: str ) -> None: """ Append input/outputs and user feedback to a JSON Lines file. """ # Get the first label (prediction with highest probability) _prediction = next(iter(prediction)) with USER_FEEDBACK_PATH.open("a") as f: f.write( json.dumps( { "claim": claim, "evidence": evidence, "model": model, "prediction": _prediction, "user_label": user_label, "datetime": datetime.now().isoformat(), } ) ) f.write("\n") gr.Success(f"Saved your feedback: {user_label}", duration=2, title="Thank you!") def save_feedback_support(*args) -> None: """ Save user feedback: Support """ if is_running_in_hf_spaces(): # Use a thread lock to avoid concurrent writes from different users. with scheduler.lock: append_feedback(*args, user_label="SUPPORT") else: append_feedback(*args, user_label="SUPPORT") def save_feedback_nei(*args) -> None: """ Save user feedback: NEI """ if is_running_in_hf_spaces(): # Use a thread lock to avoid concurrent writes from different users. with scheduler.lock: append_feedback(*args, user_label="NEI") else: append_feedback(*args, user_label="NEI") def save_feedback_refute(*args) -> None: """ Save user feedback: Refute """ if is_running_in_hf_spaces(): # Use a thread lock to avoid concurrent writes from different users. with scheduler.lock: append_feedback(*args, user_label="REFUTE") else: append_feedback(*args, user_label="REFUTE") def number_visible(value): """ Show numbers (token counts) if GPT is selcted for retrieval """ if value == "GPT": return gr.Number(visible=True) else: return gr.Number(visible=False) def slider_visible(value): """ Hide slider (top_k) if GPT is selcted for retrieval """ if value == "GPT": return gr.Slider(visible=False) else: return gr.Slider(visible=True) # Event listeners # Press Enter or Shift-Enter to submit gr.on( triggers=[claim.submit, evidence.submit], fn=query_model, inputs=[claim, evidence], outputs=prediction, ) # Get evidence from PDF and run the model gr.on( triggers=[get_evidence.click], fn=retrieve_evidence, inputs=[pdf_file, claim, top_k, retrieval_method], outputs=[evidence, prompt_tokens, completion_tokens], ).then( fn=query_model, inputs=[claim, evidence], outputs=prediction, api_name=False, ) # Handle "Support" examples gr.on( triggers=[support_example.dataset.select], fn=select_example, inputs=support_example.dataset, outputs=[claim, evidence], api_name=False, ).then( fn=query_model, inputs=[claim, evidence], outputs=prediction, api_name=False, ) # Handle "NEI" examples gr.on( triggers=[nei_example.dataset.select], fn=select_example, inputs=nei_example.dataset, outputs=[claim, evidence], api_name=False, ).then( fn=query_model, inputs=[claim, evidence], outputs=prediction, api_name=False, ) # Handle "Refute" examples gr.on( triggers=[refute_example.dataset.select], fn=select_example, inputs=refute_example.dataset, outputs=[claim, evidence], api_name=False, ).then( fn=query_model, inputs=[claim, evidence], outputs=prediction, api_name=False, ) # Handle evidence retrieval examples: get evidence from PDF and run the model gr.on( triggers=[retrieval_example.dataset.select], fn=select_retrieval_example, inputs=retrieval_example.dataset, outputs=[pdf_file, claim], api_name=False, ).then( fn=retrieve_evidence, inputs=[pdf_file, claim, top_k, retrieval_method], outputs=[evidence, prompt_tokens, completion_tokens], api_name=False, ).then( fn=query_model, inputs=[claim, evidence], outputs=prediction, api_name=False, ) # Change the model then update the predictions model.change( fn=select_model, inputs=model, ).then( fn=query_model, inputs=[claim, evidence], outputs=prediction, api_name=False, ) # Log user feedback when button is clicked flag_support.click( fn=save_feedback_support, inputs=[claim, evidence, model, prediction], outputs=None, api_name=False, ) flag_nei.click( fn=save_feedback_nei, inputs=[claim, evidence, model, prediction], outputs=None, api_name=False, ) flag_refute.click( fn=save_feedback_refute, inputs=[claim, evidence, model, prediction], outputs=None, api_name=False, ) # Change visibility of top-k slider and token counts if GPT is selected for retrieval retrieval_method.change( number_visible, retrieval_method, prompt_tokens, api_name=False ) retrieval_method.change( number_visible, retrieval_method, completion_tokens, api_name=False ) retrieval_method.change(slider_visible, retrieval_method, top_k, api_name=False) if __name__ == "__main__": # allowed_paths is needed to upload PDFs from specific example directory demo.launch(allowed_paths=[f"{os.getcwd()}/examples/retrieval"])