File size: 19,801 Bytes
a95b710
 
908a00f
 
0ae0ade
 
 
908a00f
9d59e2b
 
 
 
ef0d090
9d59e2b
 
 
 
 
 
 
 
a95b710
 
9b489f6
 
a95b710
 
 
 
 
 
 
 
 
9d59e2b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a95b710
 
 
 
 
13753a4
 
 
a95b710
73bbe4a
a95b710
 
 
 
 
13753a4
5cdd81a
 
 
a95b710
9d59e2b
a95b710
13753a4
a95b710
 
9d59e2b
 
 
 
 
5cdd81a
 
0ae0ade
5cdd81a
 
0ae0ade
5cdd81a
9d59e2b
 
 
 
 
 
 
 
 
 
 
 
 
a95b710
0ae0ade
 
 
 
 
7879fc7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a95b710
 
7879fc7
9d59e2b
 
00c763e
9d59e2b
 
 
 
 
00c763e
 
 
13753a4
908a00f
a95b710
908a00f
 
 
 
 
 
 
 
 
 
 
a95b710
908a00f
a95b710
 
 
908a00f
 
 
a95b710
908a00f
a95b710
 
 
908a00f
a95b710
13753a4
908a00f
a95b710
 
 
 
7879fc7
 
 
 
 
 
 
 
 
 
 
 
 
a95b710
908a00f
 
ef0d090
908a00f
 
 
 
 
 
 
 
 
 
 
 
 
 
feb987c
 
00c763e
feb987c
 
 
 
 
 
00c763e
feb987c
 
7879fc7
908a00f
13753a4
908a00f
13753a4
908a00f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59eeb53
0ae0ade
 
 
 
 
 
 
5cdd81a
 
 
 
0ae0ade
 
 
 
 
 
5cdd81a
 
 
9d59e2b
7879fc7
9d59e2b
 
 
 
6444d2c
7879fc7
9d59e2b
 
 
 
 
 
 
7879fc7
9d59e2b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
00c763e
9d59e2b
00c763e
9d59e2b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
00c763e
9d59e2b
00c763e
9d59e2b
0ae0ade
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a95b710
 
0ae0ade
a95b710
0ae0ade
a95b710
 
7879fc7
a95b710
 
908a00f
a95b710
908a00f
0ae0ade
5cdd81a
0ae0ade
908a00f
 
 
7879fc7
a95b710
 
 
908a00f
 
 
 
 
 
 
 
 
 
7879fc7
908a00f
 
 
 
 
 
 
 
 
 
 
 
 
7879fc7
908a00f
 
 
 
 
 
 
 
 
 
a95b710
 
 
7879fc7
a95b710
 
 
908a00f
a95b710
908a00f
 
 
 
 
 
0ae0ade
5cdd81a
0ae0ade
908a00f
 
 
 
7879fc7
908a00f
 
 
7879fc7
9d59e2b
13753a4
9d59e2b
908a00f
 
 
7879fc7
908a00f
 
 
9d59e2b
 
 
7879fc7
9d59e2b
0ae0ade
9d59e2b
 
 
7879fc7
9d59e2b
0ae0ade
9d59e2b
 
 
7879fc7
9d59e2b
0ae0ade
 
 
 
 
 
 
 
 
9d59e2b
0ae0ade
9d59e2b
908a00f
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
import pandas as pd
import gradio as gr
from transformers import pipeline
import nltk
from retrieval_bm25s import retrieve_with_bm25s
from retrieval_bert import retrieve_with_deberta
from retrieval_gpt import retrieve_with_gpt
import os
import json
from datetime import datetime
from pathlib import Path
from uuid import uuid4
import spaces


def is_running_in_hf_spaces():
    """
    Detects if app is running in Hugging Face Spaces
    """
    return "SPACE_ID" in os.environ


if gr.NO_RELOAD:
    # Resource punkt_tab not found during application startup on HF spaces
    nltk.download("punkt_tab")

    # Keep track of the model name in a global variable so correct model is shown after page refresh
    # https://github.com/gradio-app/gradio/issues/3173
    MODEL_NAME = "jedick/DeBERTa-v3-base-mnli-fever-anli-scifact-citint"
    pipe = pipeline(
        "text-classification",
        model=MODEL_NAME,
    )

    # Setup user feedback file for uploading to HF dataset
    # https://huggingface.co/spaces/Wauplin/space_to_dataset_saver
    # https://huggingface.co/docs/huggingface_hub/v0.16.3/en/guides/upload#scheduled-uploads
    USER_FEEDBACK_DIR = Path("user_feedback")
    USER_FEEDBACK_DIR.mkdir(parents=True, exist_ok=True)
    USER_FEEDBACK_PATH = USER_FEEDBACK_DIR / f"train-{uuid4()}.json"

    if is_running_in_hf_spaces():
        from huggingface_hub import CommitScheduler

        scheduler = CommitScheduler(
            repo_id="AI4citations-feedback",
            repo_type="dataset",
            folder_path=USER_FEEDBACK_DIR,
            path_in_repo="data",
        )


# Setup theme without background image
my_theme = gr.Theme.from_hub("NoCrypt/miku")
my_theme.set(body_background_fill="#FFFFFF", body_background_fill_dark="#000000")

# Define the HTML for Font Awesome
font_awesome_html = '<link href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.0.0-beta3/css/all.min.css" rel="stylesheet">'

# Gradio interface setup
with gr.Blocks(theme=my_theme, head=font_awesome_html) as demo:

    # Layout
    with gr.Row():
        with gr.Column(scale=3):
            with gr.Row():
                gr.Markdown("# AI4citations")
                gr.Markdown(
                    "## *AI-powered citation verification* ([more info](https://github.com/jedick/AI4citations))"
                )
            claim = gr.Textbox(
                label="Claim",
                info="aka hypothesis",
                placeholder="Input claim",
            )
            with gr.Row():
                with gr.Column(scale=2):
                    with gr.Accordion("Get Evidence from PDF"):
                        pdf_file = gr.File(
                            label="Upload PDF", type="filepath", height=120
                        )
                        with gr.Row():
                            retrieval_method = gr.Radio(
                                choices=["BM25S", "DeBERTa", "GPT"],
                                value="BM25S",
                                label="Retrieval Method",
                                info="Keyword search (BM25S) or AI (DeBERTa, GPT)",
                            )
                        get_evidence = gr.Button(value="Get Evidence")
                        top_k = gr.Slider(
                            1,
                            10,
                            value=5,
                            step=1,
                            label="Top k sentences",
                        )
                with gr.Column(scale=3):
                    evidence = gr.TextArea(
                        label="Evidence",
                        info="aka premise",
                        placeholder="Input evidence or use Get Evidence from PDF",
                    )
                    with gr.Row():
                        prompt_tokens = gr.Number(label="Prompt tokens", visible=False)
                        completion_tokens = gr.Number(
                            label="Completion tokens", visible=False
                        )
                    gr.Markdown(
                        """
                    ### App Usage:

                    - Input a **Claim**, then:
                        - Upload a PDF and click **Get Evidence** OR
                        - Input **Evidence** statements yourself
                    - Make the **Prediction**:
                        - Hit 'Enter' in the **Claim** text box OR
                        - Hit 'Shift-Enter' in the **Evidence** text box OR
                        - Click **Get Evidence**
                    """
                    )
            with gr.Accordion("Sources", open=False):
                gr.Markdown(
                    """
                #### *Capstone project*
                - <i class="fa-brands fa-github"></i> [jedick/MLE-capstone-project](https://github.com/jedick/MLE-capstone-project) (project repo)
                - <i class="fa-brands fa-github"></i> [jedick/AI4citations](https://github.com/jedick/AI4citations) (app repo)
                #### *Text Classification*
                - <img src="https://huggingface.co/datasets/huggingface/brand-assets/resolve/main/hf-logo.svg" style="height: 1.2em; display: inline-block;"> [jedick/DeBERTa-v3-base-mnli-fever-anli-scifact-citint](https://huggingface.co/jedick/DeBERTa-v3-base-mnli-fever-anli-scifact-citint) (fine-tuned)
                - <img src="https://huggingface.co/datasets/huggingface/brand-assets/resolve/main/hf-logo.svg" style="height: 1.2em; display: inline-block;"> [MoritzLaurer/DeBERTa-v3-base-mnli-fever-anli](https://huggingface.co/MoritzLaurer/DeBERTa-v3-base-mnli-fever-anli) (base)
                #### *Evidence Retrieval*
                - <i class="fa-brands fa-github"></i> [xhluca/bm25s](https://github.com/xhluca/bm25s) (BM25S)
                - <img src="https://huggingface.co/datasets/huggingface/brand-assets/resolve/main/hf-logo.svg" style="height: 1.2em; display: inline-block;"> [deepset/deberta-v3-large-squad2](https://huggingface.co/deepset/deberta-v3-large-squad2) (DeBERTa)
                - <img src="https://upload.wikimedia.org/wikipedia/commons/4/4d/OpenAI_Logo.svg" style="height: 1.2em; display: inline-block;"> [gpt-4o-mini-2024-07-18](https://platform.openai.com/docs/pricing) (GPT)
                #### *Datasets for fine-tuning*
                - <i class="fa-brands fa-github"></i> [allenai/SciFact](https://github.com/allenai/scifact) (SciFact)
                - <i class="fa-brands fa-github"></i> [ScienceNLP-Lab/Citation-Integrity](https://github.com/ScienceNLP-Lab/Citation-Integrity) (CitInt)
                #### *Other sources*
                - <img src="https://plos.org/wp-content/uploads/2020/01/logo-color-blue.svg" style="height: 1.4em; display: inline-block;"> [Medicine](https://doi.org/10.1371/journal.pmed.0030197), <i class="fa-brands fa-wikipedia-w"></i> [CRISPR](https://en.wikipedia.org/wiki/CRISPR) (evidence retrieval examples)
                - <img src="https://huggingface.co/datasets/huggingface/brand-assets/resolve/main/hf-logo.svg" style="height: 1.2em; display: inline-block;"> [nyu-mll/multi_nli](https://huggingface.co/datasets/nyu-mll/multi_nli/viewer/default/train?row=37&views%5B%5D=train) (MNLI example)
                - <img src="https://huggingface.co/datasets/huggingface/brand-assets/resolve/main/hf-logo.svg" style="height: 1.2em; display: inline-block;"> [NoCrypt/miku](https://huggingface.co/spaces/NoCrypt/miku) (theme)
                """
                )

        with gr.Column(scale=2):
            prediction = gr.Label(label="Prediction")
            with gr.Accordion("Feedback"):
                gr.Markdown(
                    "*Provide the correct label to help improve this app*<br>**NOTE:** The claim and evidence will also be saved"
                ),
                with gr.Row():
                    flag_support = gr.Button("Support")
                    flag_nei = gr.Button("NEI")
                    flag_refute = gr.Button("Refute")
                gr.Markdown(
                    "Feedback is uploaded every 5 minutes to [AI4citations-feedback](https://huggingface.co/datasets/jedick/AI4citations-feedback)"
                ),
            with gr.Accordion("Examples"):
                gr.Markdown("*Examples are run when clicked*"),
                with gr.Row():
                    support_example = gr.Examples(
                        examples="examples/Support",
                        label="Support",
                        inputs=[claim, evidence],
                        example_labels=pd.read_csv("examples/Support/log.csv")[
                            "label"
                        ].tolist(),
                    )
                    nei_example = gr.Examples(
                        examples="examples/NEI",
                        label="NEI",
                        inputs=[claim, evidence],
                        example_labels=pd.read_csv("examples/NEI/log.csv")[
                            "label"
                        ].tolist(),
                    )
                    refute_example = gr.Examples(
                        examples="examples/Refute",
                        label="Refute",
                        inputs=[claim, evidence],
                        example_labels=pd.read_csv("examples/Refute/log.csv")[
                            "label"
                        ].tolist(),
                    )
                retrieval_example = gr.Examples(
                    examples="examples/retrieval",
                    label="Get Evidence from PDF",
                    inputs=[pdf_file, claim],
                    example_labels=pd.read_csv("examples/retrieval/log.csv")[
                        "label"
                    ].tolist(),
                )
            # Create dropdown menu to select the model
            model = gr.Dropdown(
                choices=[
                    # TODO: For bert-base-uncased, how can we set num_labels = 2 in HF pipeline?
                    # (num_labels is available in AutoModelForSequenceClassification.from_pretrained)
                    # "bert-base-uncased",
                    "MoritzLaurer/DeBERTa-v3-base-mnli-fever-anli",
                    "jedick/DeBERTa-v3-base-mnli-fever-anli-scifact-citint",
                ],
                value=MODEL_NAME,
                label="Model",
                info="Text classification model used for claim verification",
            )

    # Functions

    @spaces.GPU(duration=10)
    def query_model(claim, evidence):
        """
        Get prediction for a claim and evidence pair
        """
        prediction = {
            # Send a dictionary containing {"text", "text_pair"} keys; use top_k=3 to get results for all classes
            #   https://huggingface.co/docs/transformers/v4.51.3/en/main_classes/pipelines#transformers.TextClassificationPipeline.__call__.inputs
            # Put evidence before claim
            #   https://github.com/jedick/MLE-capstone-project
            # Output {label: confidence} dictionary format as expected by gr.Label()
            #   https://github.com/gradio-app/gradio/issues/11170
            d["label"]: d["score"]
            for d in pipe({"text": evidence, "text_pair": claim}, top_k=3)
        }
        # Rename dictionary keys to use consistent labels across models
        prediction = {
            ("SUPPORT" if k in ["SUPPORT", "entailment"] else k): v
            for k, v in prediction.items()
        }
        prediction = {
            ("NEI" if k in ["NEI", "neutral"] else k): v for k, v in prediction.items()
        }
        prediction = {
            ("REFUTE" if k in ["REFUTE", "contradiction"] else k): v
            for k, v in prediction.items()
        }
        return prediction

    def select_model(model_name):
        """
        Select the specified model
        """
        global pipe, MODEL_NAME
        MODEL_NAME = model_name
        pipe = pipeline(
            "text-classification",
            model=MODEL_NAME,
        )

    # From gradio/client/python/gradio_client/utils.py
    def is_http_url_like(possible_url) -> bool:
        """
        Check if the given value is a string that looks like an HTTP(S) URL.
        """
        if not isinstance(possible_url, str):
            return False
        return possible_url.startswith(("http://", "https://"))

    def select_example(value, evt: gr.EventData):
        # Get the PDF file and claim from the event data
        claim, evidence = value[1]
        # Add the directory path
        return claim, evidence

    def select_retrieval_example(value, evt: gr.EventData):
        """
        Get the PDF file and claim from the event data.
        """
        pdf_file, claim = value[1]
        # Add the directory path
        if not is_http_url_like(pdf_file):
            pdf_file = f"examples/retrieval/{pdf_file}"
        return pdf_file, claim

    @spaces.GPU()
    def _retrieve_with_deberta(pdf_file, claim, top_k):
        """
        Retrieve evidence using DeBERTa
        """
        return retrieve_with_deberta(pdf_file, claim, top_k)

    def retrieve_evidence(pdf_file, claim, top_k, method):
        """
        Retrieve evidence using the selected method
        """
        if method == "BM25S":
            # Append 0 for number of prompt and completion tokens
            return retrieve_with_bm25s(pdf_file, claim, top_k), 0, 0
        elif method == "DeBERTa":
            return _retrieve_with_deberta(pdf_file, claim, top_k), 0, 0
        elif method == "GPT":
            return retrieve_with_gpt(pdf_file, claim)
        else:
            return f"Unknown retrieval method: {method}"

    def append_feedback(
        claim: str, evidence: str, model: str, prediction: str, user_label: str
    ) -> None:
        """
        Append input/outputs and user feedback to a JSON Lines file.
        """
        # Get the first label (prediction with highest probability)
        _prediction = next(iter(prediction))
        with USER_FEEDBACK_PATH.open("a") as f:
            f.write(
                json.dumps(
                    {
                        "claim": claim,
                        "evidence": evidence,
                        "model": model,
                        "prediction": _prediction,
                        "user_label": user_label,
                        "datetime": datetime.now().isoformat(),
                    }
                )
            )
            f.write("\n")
        gr.Success(f"Saved your feedback: {user_label}", duration=2, title="Thank you!")

    def save_feedback_support(*args) -> None:
        """
        Save user feedback: Support
        """
        if is_running_in_hf_spaces():
            # Use a thread lock to avoid concurrent writes from different users.
            with scheduler.lock:
                append_feedback(*args, user_label="SUPPORT")
        else:
            append_feedback(*args, user_label="SUPPORT")

    def save_feedback_nei(*args) -> None:
        """
        Save user feedback: NEI
        """
        if is_running_in_hf_spaces():
            # Use a thread lock to avoid concurrent writes from different users.
            with scheduler.lock:
                append_feedback(*args, user_label="NEI")
        else:
            append_feedback(*args, user_label="NEI")

    def save_feedback_refute(*args) -> None:
        """
        Save user feedback: Refute
        """
        if is_running_in_hf_spaces():
            # Use a thread lock to avoid concurrent writes from different users.
            with scheduler.lock:
                append_feedback(*args, user_label="REFUTE")
        else:
            append_feedback(*args, user_label="REFUTE")

    def number_visible(value):
        """
        Show numbers (token counts) if GPT is selcted for retrieval
        """
        if value == "GPT":
            return gr.Number(visible=True)
        else:
            return gr.Number(visible=False)

    def slider_visible(value):
        """
        Hide slider (top_k) if GPT is selcted for retrieval
        """
        if value == "GPT":
            return gr.Slider(visible=False)
        else:
            return gr.Slider(visible=True)

    # Event listeners

    # Press Enter or Shift-Enter to submit
    gr.on(
        triggers=[claim.submit, evidence.submit],
        fn=query_model,
        inputs=[claim, evidence],
        outputs=prediction,
    )

    # Get evidence from PDF and run the model
    gr.on(
        triggers=[get_evidence.click],
        fn=retrieve_evidence,
        inputs=[pdf_file, claim, top_k, retrieval_method],
        outputs=[evidence, prompt_tokens, completion_tokens],
    ).then(
        fn=query_model,
        inputs=[claim, evidence],
        outputs=prediction,
        api_name=False,
    )

    # Handle "Support" examples
    gr.on(
        triggers=[support_example.dataset.select],
        fn=select_example,
        inputs=support_example.dataset,
        outputs=[claim, evidence],
        api_name=False,
    ).then(
        fn=query_model,
        inputs=[claim, evidence],
        outputs=prediction,
        api_name=False,
    )

    # Handle "NEI" examples
    gr.on(
        triggers=[nei_example.dataset.select],
        fn=select_example,
        inputs=nei_example.dataset,
        outputs=[claim, evidence],
        api_name=False,
    ).then(
        fn=query_model,
        inputs=[claim, evidence],
        outputs=prediction,
        api_name=False,
    )

    # Handle "Refute" examples
    gr.on(
        triggers=[refute_example.dataset.select],
        fn=select_example,
        inputs=refute_example.dataset,
        outputs=[claim, evidence],
        api_name=False,
    ).then(
        fn=query_model,
        inputs=[claim, evidence],
        outputs=prediction,
        api_name=False,
    )

    # Handle evidence retrieval examples: get evidence from PDF and run the model
    gr.on(
        triggers=[retrieval_example.dataset.select],
        fn=select_retrieval_example,
        inputs=retrieval_example.dataset,
        outputs=[pdf_file, claim],
        api_name=False,
    ).then(
        fn=retrieve_evidence,
        inputs=[pdf_file, claim, top_k, retrieval_method],
        outputs=[evidence, prompt_tokens, completion_tokens],
        api_name=False,
    ).then(
        fn=query_model,
        inputs=[claim, evidence],
        outputs=prediction,
        api_name=False,
    )

    # Change the model then update the predictions
    model.change(
        fn=select_model,
        inputs=model,
    ).then(
        fn=query_model,
        inputs=[claim, evidence],
        outputs=prediction,
        api_name=False,
    )

    # Log user feedback when button is clicked
    flag_support.click(
        fn=save_feedback_support,
        inputs=[claim, evidence, model, prediction],
        outputs=None,
        api_name=False,
    )
    flag_nei.click(
        fn=save_feedback_nei,
        inputs=[claim, evidence, model, prediction],
        outputs=None,
        api_name=False,
    )
    flag_refute.click(
        fn=save_feedback_refute,
        inputs=[claim, evidence, model, prediction],
        outputs=None,
        api_name=False,
    )

    # Change visibility of top-k slider and token counts if GPT is selected for retrieval
    retrieval_method.change(
        number_visible, retrieval_method, prompt_tokens, api_name=False
    )
    retrieval_method.change(
        number_visible, retrieval_method, completion_tokens, api_name=False
    )
    retrieval_method.change(slider_visible, retrieval_method, top_k, api_name=False)


if __name__ == "__main__":
    # allowed_paths is needed to upload PDFs from specific example directory
    demo.launch(allowed_paths=[f"{os.getcwd()}/examples/retrieval"])