import pandas as pd
import gradio as gr
from transformers import pipeline
import nltk
from retrieval import retrieve_from_pdf
from llm_retrieval import retrieve_from_pdf_llm, retrieve_from_pdf_llm_fast
import os
import json
from datetime import datetime
from pathlib import Path
from uuid import uuid4
import spaces
def is_running_in_hf_spaces():
"""
Detects if app is running in Hugging Face Spaces
"""
return "SPACE_ID" in os.environ
if gr.NO_RELOAD:
# Resource punkt_tab not found during application startup on HF spaces
nltk.download("punkt_tab")
# Keep track of the model name in a global variable so correct model is shown after page refresh
# https://github.com/gradio-app/gradio/issues/3173
MODEL_NAME = "jedick/DeBERTa-v3-base-mnli-fever-anli-scifact-citint"
pipe = pipeline(
"text-classification",
model=MODEL_NAME,
)
# Setup user feedback file for uploading to HF dataset
# https://huggingface.co/spaces/Wauplin/space_to_dataset_saver
# https://huggingface.co/docs/huggingface_hub/v0.16.3/en/guides/upload#scheduled-uploads
USER_FEEDBACK_DIR = Path("user_feedback")
USER_FEEDBACK_DIR.mkdir(parents=True, exist_ok=True)
USER_FEEDBACK_PATH = USER_FEEDBACK_DIR / f"train-{uuid4()}.json"
if is_running_in_hf_spaces():
from huggingface_hub import CommitScheduler
scheduler = CommitScheduler(
repo_id="AI4citations-feedback",
repo_type="dataset",
folder_path=USER_FEEDBACK_DIR,
path_in_repo="data",
)
def prediction_to_df(prediction=None):
"""
Convert prediction text to DataFrame for barplot
"""
if prediction is None or prediction == "":
# Show an empty plot for app initialization or auto-reload
prediction = {"SUPPORT": 0, "NEI": 0, "REFUTE": 0}
elif "Model" in prediction:
# Show full-height bars when the model is changed
prediction = {"SUPPORT": 1, "NEI": 1, "REFUTE": 1}
else:
# Convert predictions text to dictionary
prediction = eval(prediction)
# Use custom order for labels (pipe() returns labels in descending order of softmax score)
labels = ["SUPPORT", "NEI", "REFUTE"]
prediction = {k: prediction[k] for k in labels}
# Convert dictionary to DataFrame with one column (Probability)
df = pd.DataFrame.from_dict(prediction, orient="index", columns=["Probability"])
# Move the index to the Class column
return df.reset_index(names="Class")
# Setup theme without background image
my_theme = gr.Theme.from_hub("NoCrypt/miku")
my_theme.set(body_background_fill="#FFFFFF", body_background_fill_dark="#000000")
# Custom CSS to center content
custom_css = """
.center-content {
text-align: center;
display:block;
}
"""
# Define the HTML for Font Awesome
font_awesome_html = ''
# Gradio interface setup
with gr.Blocks(theme=my_theme, css=custom_css, head=font_awesome_html) as demo:
# Layout
with gr.Row():
with gr.Column(scale=3):
with gr.Row():
gr.Markdown("# AI4citations")
gr.Markdown(
"## *AI-powered citation verification* ([more info](https://github.com/jedick/AI4citations))"
)
claim = gr.Textbox(
label="Claim",
info="aka hypothesis",
placeholder="Input claim",
)
with gr.Row():
with gr.Column(scale=2):
with gr.Accordion("Get Evidence from PDF"):
pdf_file = gr.File(
label="Upload PDF", type="filepath", height=120
)
with gr.Row():
retrieval_method = gr.Radio(
choices=["BM25S", "LLM (Large)", "LLM (Fast)"],
value="BM25S",
label="Retrieval Method",
info="Choose between keyword-based (BM25S) or AI-based (LLM) evidence retrieval",
)
get_evidence = gr.Button(value="Get Evidence")
top_k = gr.Slider(
1,
10,
value=5,
step=1,
interactive=True,
label="Top k sentences",
)
with gr.Column(scale=3):
evidence = gr.TextArea(
label="Evidence",
info="aka premise",
placeholder="Input evidence or use Get Evidence from PDF",
)
submit = gr.Button("3. Submit", visible=False)
with gr.Column(scale=2):
# Keep the prediction textbox hidden
with gr.Accordion(visible=False):
prediction = gr.Textbox(label="Prediction")
barplot = gr.BarPlot(
prediction_to_df,
x="Class",
y="Probability",
color="Class",
color_map={"SUPPORT": "green", "NEI": "#888888", "REFUTE": "#FF8888"},
inputs=prediction,
y_lim=([0, 1]),
visible=False,
)
label = gr.Label(label="Prediction")
with gr.Accordion("Feedback"):
gr.Markdown(
"*Provide the correct label to help improve this app* **NOTE:** The claim and evidence will also be saved"
),
with gr.Row():
flag_support = gr.Button("Support")
flag_nei = gr.Button("NEI")
flag_refute = gr.Button("Refute")
gr.Markdown(
"Feedback is uploaded every 5 minutes to [AI4citations-feedback](https://huggingface.co/datasets/jedick/AI4citations-feedback)"
),
with gr.Accordion("Examples"):
gr.Markdown("*Examples are run when clicked*"),
with gr.Row():
support_example = gr.Examples(
examples="examples/Support",
label="Support",
inputs=[claim, evidence],
example_labels=pd.read_csv("examples/Support/log.csv")[
"label"
].tolist(),
)
nei_example = gr.Examples(
examples="examples/NEI",
label="NEI",
inputs=[claim, evidence],
example_labels=pd.read_csv("examples/NEI/log.csv")[
"label"
].tolist(),
)
refute_example = gr.Examples(
examples="examples/Refute",
label="Refute",
inputs=[claim, evidence],
example_labels=pd.read_csv("examples/Refute/log.csv")[
"label"
].tolist(),
)
retrieval_example = gr.Examples(
examples="examples/retrieval",
label="Get Evidence from PDF",
inputs=[pdf_file, claim],
example_labels=pd.read_csv("examples/retrieval/log.csv")[
"label"
].tolist(),
)
with gr.Row():
with gr.Column(scale=3):
with gr.Row():
with gr.Column(scale=1):
gr.Markdown(
"""
### Usage:
- Input a **Claim**, then:
- Upload a PDF, select retrieval method, and click **Get Evidence** OR
- Input **Evidence** statements yourself
"""
)
with gr.Column(scale=2):
gr.Markdown(
"""
### To make the prediction:
- Hit 'Enter' in the **Claim** text box OR
- Hit 'Shift-Enter' in the **Evidence** text box
_The prediction is also made after clicking **Get Evidence**_
"""
)
with gr.Column(scale=2):
with gr.Accordion("Settings", open=False):
# Create dropdown menu to select the model
model = gr.Dropdown(
choices=[
# TODO: For bert-base-uncased, how can we set num_labels = 2 in HF pipeline?
# (num_labels is available in AutoModelForSequenceClassification.from_pretrained)
# "bert-base-uncased",
"MoritzLaurer/DeBERTa-v3-base-mnli-fever-anli",
"jedick/DeBERTa-v3-base-mnli-fever-anli-scifact-citint",
],
value=MODEL_NAME,
label="Model",
)
radio = gr.Radio(
["label", "barplot"], value="label", label="Prediction"
)
with gr.Accordion("Sources", open=False, elem_classes=["center_content"]):
gr.Markdown(
"""
#### *Capstone project*
- [jedick/MLE-capstone-project](https://github.com/jedick/MLE-capstone-project) (project repo)
- [jedick/AI4citations](https://github.com/jedick/AI4citations) (app repo)
#### *Claim Verification Models (text classification)*
- [jedick/DeBERTa-v3-base-mnli-fever-anli-scifact-citint](https://huggingface.co/jedick/DeBERTa-v3-base-mnli-fever-anli-scifact-citint) (fine-tuned)
- [MoritzLaurer/DeBERTa-v3-base-mnli-fever-anli](https://huggingface.co/MoritzLaurer/DeBERTa-v3-base-mnli-fever-anli) (base)
#### *Evidence Retrieval Models (question answering)*
- [deepset/deberta-v3-large-squad2](https://huggingface.co/deepset/deberta-v3-large-squad2) (Large)
- [distilbert-base-cased-distilled-squad](https://huggingface.co/distilbert/distilbert-base-cased-distilled-squad) (Fast)
#### *Datasets for fine-tuning*
- [allenai/SciFact](https://github.com/allenai/scifact) (SciFact)
- [ScienceNLP-Lab/Citation-Integrity](https://github.com/ScienceNLP-Lab/Citation-Integrity) (CitInt)
#### *Other sources*
- [xhluca/bm25s](https://github.com/xhluca/bm25s) (evidence retrieval)
- [Medicine](https://doi.org/10.1371/journal.pmed.0030197), [CRISPR](https://en.wikipedia.org/wiki/CRISPR) (evidence retrieval examples)
- [nyu-mll/multi_nli](https://huggingface.co/datasets/nyu-mll/multi_nli/viewer/default/train?row=37&views%5B%5D=train) (MNLI example)
- [NoCrypt/miku](https://huggingface.co/spaces/NoCrypt/miku) (theme)
"""
)
# Functions
@spaces.GPU(duration=10)
def query_model(claim, evidence):
"""
Get prediction for a claim and evidence pair
"""
prediction = {
# Send a dictionary containing {"text", "text_pair"} keys; use top_k=3 to get results for all classes
# https://huggingface.co/docs/transformers/v4.51.3/en/main_classes/pipelines#transformers.TextClassificationPipeline.__call__.inputs
# Put evidence before claim
# https://github.com/jedick/MLE-capstone-project
# Output {label: confidence} dictionary format as expected by gr.Label()
# https://github.com/gradio-app/gradio/issues/11170
d["label"]: d["score"]
for d in pipe({"text": evidence, "text_pair": claim}, top_k=3)
}
# Rename dictionary keys to use consistent labels across models
prediction = {
("SUPPORT" if k in ["SUPPORT", "entailment"] else k): v
for k, v in prediction.items()
}
prediction = {
("NEI" if k in ["NEI", "neutral"] else k): v for k, v in prediction.items()
}
prediction = {
("REFUTE" if k in ["REFUTE", "contradiction"] else k): v
for k, v in prediction.items()
}
# Return two instances of the prediction to send to different Gradio components
return prediction, prediction
def select_model(model_name):
"""
Select the specified model
"""
global pipe, MODEL_NAME
MODEL_NAME = model_name
pipe = pipeline(
"text-classification",
model=MODEL_NAME,
)
def change_visualization(choice):
if choice == "barplot":
barplot = gr.update(visible=True)
label = gr.update(visible=False)
elif choice == "label":
barplot = gr.update(visible=False)
label = gr.update(visible=True)
return barplot, label
# From gradio/client/python/gradio_client/utils.py
def is_http_url_like(possible_url) -> bool:
"""
Check if the given value is a string that looks like an HTTP(S) URL.
"""
if not isinstance(possible_url, str):
return False
return possible_url.startswith(("http://", "https://"))
def select_example(value, evt: gr.EventData):
# Get the PDF file and claim from the event data
claim, evidence = value[1]
# Add the directory path
return claim, evidence
def select_retrieval_example(value, evt: gr.EventData):
"""
Get the PDF file and claim from the event data.
"""
pdf_file, claim = value[1]
# Add the directory path
if not is_http_url_like(pdf_file):
pdf_file = f"examples/retrieval/{pdf_file}"
return pdf_file, claim
def retrieve_evidence_with_method(pdf_file, claim, top_k, method):
"""
Retrieve evidence using the selected method
"""
if method == "BM25S":
return retrieve_from_pdf(pdf_file, claim, k=top_k)
elif method == "LLM (Large)":
return retrieve_from_pdf_llm(pdf_file, claim, k=top_k)
elif method == "LLM (Fast)":
return retrieve_from_pdf_llm_fast(pdf_file, claim, k=top_k)
else:
return f"Unknown retrieval method: {method}"
def append_feedback(
claim: str, evidence: str, model: str, label: str, user_label: str
) -> None:
"""
Append input/outputs and user feedback to a JSON Lines file.
"""
# Get the first label (prediction with highest probability)
prediction = next(iter(label))
with USER_FEEDBACK_PATH.open("a") as f:
f.write(
json.dumps(
{
"claim": claim,
"evidence": evidence,
"model": model,
"prediction": prediction,
"user_label": user_label,
"datetime": datetime.now().isoformat(),
}
)
)
f.write("\n")
gr.Success(f"Saved your feedback: {user_label}", duration=2, title="Thank you!")
def save_feedback_support(*args) -> None:
"""
Save user feedback: Support
"""
if is_running_in_hf_spaces():
# Use a thread lock to avoid concurrent writes from different users.
with scheduler.lock:
append_feedback(*args, user_label="SUPPORT")
else:
append_feedback(*args, user_label="SUPPORT")
def save_feedback_nei(*args) -> None:
"""
Save user feedback: NEI
"""
if is_running_in_hf_spaces():
# Use a thread lock to avoid concurrent writes from different users.
with scheduler.lock:
append_feedback(*args, user_label="NEI")
else:
append_feedback(*args, user_label="NEI")
def save_feedback_refute(*args) -> None:
"""
Save user feedback: Refute
"""
if is_running_in_hf_spaces():
# Use a thread lock to avoid concurrent writes from different users.
with scheduler.lock:
append_feedback(*args, user_label="REFUTE")
else:
append_feedback(*args, user_label="REFUTE")
# Event listeners
# Click the submit button or press Enter to submit
gr.on(
triggers=[claim.submit, evidence.submit, submit.click],
fn=query_model,
inputs=[claim, evidence],
outputs=[prediction, label],
)
# Get evidence from PDF and run the model
gr.on(
triggers=[get_evidence.click],
fn=retrieve_evidence_with_method,
inputs=[pdf_file, claim, top_k, retrieval_method],
outputs=evidence,
).then(
fn=query_model,
inputs=[claim, evidence],
outputs=[prediction, label],
api_name=False,
)
# Handle "Support" examples
gr.on(
triggers=[support_example.dataset.select],
fn=select_example,
inputs=support_example.dataset,
outputs=[claim, evidence],
api_name=False,
).then(
fn=query_model,
inputs=[claim, evidence],
outputs=[prediction, label],
api_name=False,
)
# Handle "NEI" examples
gr.on(
triggers=[nei_example.dataset.select],
fn=select_example,
inputs=nei_example.dataset,
outputs=[claim, evidence],
api_name=False,
).then(
fn=query_model,
inputs=[claim, evidence],
outputs=[prediction, label],
api_name=False,
)
# Handle "Refute" examples
gr.on(
triggers=[refute_example.dataset.select],
fn=select_example,
inputs=refute_example.dataset,
outputs=[claim, evidence],
api_name=False,
).then(
fn=query_model,
inputs=[claim, evidence],
outputs=[prediction, label],
api_name=False,
)
# Handle evidence retrieval examples: get evidence from PDF and run the model
gr.on(
triggers=[retrieval_example.dataset.select],
fn=select_retrieval_example,
inputs=retrieval_example.dataset,
outputs=[pdf_file, claim],
api_name=False,
).then(
fn=retrieve_evidence_with_method,
inputs=[pdf_file, claim, top_k, retrieval_method],
outputs=evidence,
api_name=False,
).then(
fn=query_model,
inputs=[claim, evidence],
outputs=[prediction, label],
api_name=False,
)
# Change visualization
radio.change(
fn=change_visualization,
inputs=radio,
outputs=[barplot, label],
api_name=False,
)
# Clear the previous predictions when the model is changed
gr.on(
triggers=[model.select],
fn=lambda: "Model changed! Waiting for updated predictions...",
outputs=[prediction],
api_name=False,
)
# Change the model the update the predictions
model.change(
fn=select_model,
inputs=model,
).then(
fn=query_model,
inputs=[claim, evidence],
outputs=[prediction, label],
api_name=False,
)
# Log user feedback when button is clicked
flag_support.click(
fn=save_feedback_support,
inputs=[claim, evidence, model, label],
outputs=None,
)
flag_nei.click(
fn=save_feedback_nei,
inputs=[claim, evidence, model, label],
outputs=None,
)
flag_refute.click(
fn=save_feedback_refute,
inputs=[claim, evidence, model, label],
outputs=None,
)
if __name__ == "__main__":
# allowed_paths is needed to upload PDFs from specific example directory
demo.launch(allowed_paths=[f"{os.getcwd()}/examples/retrieval"])