jedick commited on
Commit
0ae0ade
·
1 Parent(s): adcbc55

Add GPT retrieval

Browse files
app.py CHANGED
@@ -2,8 +2,9 @@ import pandas as pd
2
  import gradio as gr
3
  from transformers import pipeline
4
  import nltk
5
- from retrieval import retrieve_from_pdf
6
- from llm_retrieval import retrieve_from_pdf_llm, retrieve_from_pdf_llm_fast
 
7
  import os
8
  import json
9
  from datetime import datetime
@@ -102,10 +103,10 @@ with gr.Blocks(theme=my_theme, head=font_awesome_html) as demo:
102
  )
103
  with gr.Row():
104
  retrieval_method = gr.Radio(
105
- choices=["BM25S", "LLM (Large)", "LLM (Fast)"],
106
  value="BM25S",
107
  label="Retrieval Method",
108
- info="Choose between keyword-based (BM25S) or AI-based (LLM) evidence retrieval",
109
  )
110
  get_evidence = gr.Button(value="Get Evidence")
111
  top_k = gr.Slider(
@@ -113,7 +114,6 @@ with gr.Blocks(theme=my_theme, head=font_awesome_html) as demo:
113
  10,
114
  value=5,
115
  step=1,
116
- interactive=True,
117
  label="Top k sentences",
118
  )
119
  with gr.Column(scale=3):
@@ -122,7 +122,11 @@ with gr.Blocks(theme=my_theme, head=font_awesome_html) as demo:
122
  info="aka premise",
123
  placeholder="Input evidence or use Get Evidence from PDF",
124
  )
125
- submit = gr.Button("3. Submit", visible=False)
 
 
 
 
126
 
127
  with gr.Column(scale=2):
128
  # Keep the prediction textbox hidden
@@ -234,17 +238,17 @@ with gr.Blocks(theme=my_theme, head=font_awesome_html) as demo:
234
  #### *Capstone project*
235
  - <i class="fa-brands fa-github"></i> [jedick/MLE-capstone-project](https://github.com/jedick/MLE-capstone-project) (project repo)
236
  - <i class="fa-brands fa-github"></i> [jedick/AI4citations](https://github.com/jedick/AI4citations) (app repo)
237
- #### *Claim Verification Models (text classification)*
238
  - <img src="https://huggingface.co/datasets/huggingface/brand-assets/resolve/main/hf-logo.svg" style="height: 1.2em; display: inline-block;"> [jedick/DeBERTa-v3-base-mnli-fever-anli-scifact-citint](https://huggingface.co/jedick/DeBERTa-v3-base-mnli-fever-anli-scifact-citint) (fine-tuned)
239
  - <img src="https://huggingface.co/datasets/huggingface/brand-assets/resolve/main/hf-logo.svg" style="height: 1.2em; display: inline-block;"> [MoritzLaurer/DeBERTa-v3-base-mnli-fever-anli](https://huggingface.co/MoritzLaurer/DeBERTa-v3-base-mnli-fever-anli) (base)
240
- #### *Evidence Retrieval Models (question answering)*
241
- - <img src="https://huggingface.co/datasets/huggingface/brand-assets/resolve/main/hf-logo.svg" style="height: 1.2em; display: inline-block;"> [deepset/deberta-v3-large-squad2](https://huggingface.co/deepset/deberta-v3-large-squad2) (Large)
242
- - <img src="https://huggingface.co/datasets/huggingface/brand-assets/resolve/main/hf-logo.svg" style="height: 1.2em; display: inline-block;"> [distilbert-base-cased-distilled-squad](https://huggingface.co/distilbert/distilbert-base-cased-distilled-squad) (Fast)
 
243
  #### *Datasets for fine-tuning*
244
  - <i class="fa-brands fa-github"></i> [allenai/SciFact](https://github.com/allenai/scifact) (SciFact)
245
  - <i class="fa-brands fa-github"></i> [ScienceNLP-Lab/Citation-Integrity](https://github.com/ScienceNLP-Lab/Citation-Integrity) (CitInt)
246
  #### *Other sources*
247
- - <i class="fa-brands fa-github"></i> [xhluca/bm25s](https://github.com/xhluca/bm25s) (evidence retrieval)
248
  - <img src="https://plos.org/wp-content/uploads/2020/01/logo-color-blue.svg" style="height: 1.4em; display: inline-block;"> [Medicine](https://doi.org/10.1371/journal.pmed.0030197), <i class="fa-brands fa-wikipedia-w"></i> [CRISPR](https://en.wikipedia.org/wiki/CRISPR) (evidence retrieval examples)
249
  - <img src="https://huggingface.co/datasets/huggingface/brand-assets/resolve/main/hf-logo.svg" style="height: 1.2em; display: inline-block;"> [nyu-mll/multi_nli](https://huggingface.co/datasets/nyu-mll/multi_nli/viewer/default/train?row=37&views%5B%5D=train) (MNLI example)
250
  - <img src="https://huggingface.co/datasets/huggingface/brand-assets/resolve/main/hf-logo.svg" style="height: 1.2em; display: inline-block;"> [NoCrypt/miku](https://huggingface.co/spaces/NoCrypt/miku) (theme)
@@ -329,16 +333,23 @@ with gr.Blocks(theme=my_theme, head=font_awesome_html) as demo:
329
  return pdf_file, claim
330
 
331
  @spaces.GPU()
332
- def retrieve_evidence_with_method(pdf_file, claim, top_k, method):
 
 
 
 
 
 
333
  """
334
  Retrieve evidence using the selected method
335
  """
336
  if method == "BM25S":
337
- return retrieve_from_pdf(pdf_file, claim, k=top_k)
338
- elif method == "LLM (Large)":
339
- return retrieve_from_pdf_llm(pdf_file, claim, k=top_k)
340
- elif method == "LLM (Fast)":
341
- return retrieve_from_pdf_llm_fast(pdf_file, claim, k=top_k)
 
342
  else:
343
  return f"Unknown retrieval method: {method}"
344
 
@@ -399,11 +410,29 @@ with gr.Blocks(theme=my_theme, head=font_awesome_html) as demo:
399
  else:
400
  append_feedback(*args, user_label="REFUTE")
401
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
402
  # Event listeners
403
 
404
- # Click the submit button or press Enter to submit
405
  gr.on(
406
- triggers=[claim.submit, evidence.submit, submit.click],
407
  fn=query_model,
408
  inputs=[claim, evidence],
409
  outputs=[prediction, label],
@@ -412,9 +441,9 @@ with gr.Blocks(theme=my_theme, head=font_awesome_html) as demo:
412
  # Get evidence from PDF and run the model
413
  gr.on(
414
  triggers=[get_evidence.click],
415
- fn=retrieve_evidence_with_method,
416
  inputs=[pdf_file, claim, top_k, retrieval_method],
417
- outputs=evidence,
418
  ).then(
419
  fn=query_model,
420
  inputs=[claim, evidence],
@@ -472,9 +501,9 @@ with gr.Blocks(theme=my_theme, head=font_awesome_html) as demo:
472
  outputs=[pdf_file, claim],
473
  api_name=False,
474
  ).then(
475
- fn=retrieve_evidence_with_method,
476
  inputs=[pdf_file, claim, top_k, retrieval_method],
477
- outputs=evidence,
478
  api_name=False,
479
  ).then(
480
  fn=query_model,
@@ -515,17 +544,29 @@ with gr.Blocks(theme=my_theme, head=font_awesome_html) as demo:
515
  fn=save_feedback_support,
516
  inputs=[claim, evidence, model, label],
517
  outputs=None,
 
518
  )
519
  flag_nei.click(
520
  fn=save_feedback_nei,
521
  inputs=[claim, evidence, model, label],
522
  outputs=None,
 
523
  )
524
  flag_refute.click(
525
  fn=save_feedback_refute,
526
  inputs=[claim, evidence, model, label],
527
  outputs=None,
 
 
 
 
 
 
 
 
 
528
  )
 
529
 
530
 
531
  if __name__ == "__main__":
 
2
  import gradio as gr
3
  from transformers import pipeline
4
  import nltk
5
+ from retrieval_bm25s import retrieve_with_bm25s
6
+ from retrieval_bert import retrieve_with_deberta
7
+ from retrieval_gpt import retrieve_with_gpt
8
  import os
9
  import json
10
  from datetime import datetime
 
103
  )
104
  with gr.Row():
105
  retrieval_method = gr.Radio(
106
+ choices=["BM25S", "DeBERTa", "GPT"],
107
  value="BM25S",
108
  label="Retrieval Method",
109
+ info="Keyword search (BM25S) or AI (DeBERTa, GPT)",
110
  )
111
  get_evidence = gr.Button(value="Get Evidence")
112
  top_k = gr.Slider(
 
114
  10,
115
  value=5,
116
  step=1,
 
117
  label="Top k sentences",
118
  )
119
  with gr.Column(scale=3):
 
122
  info="aka premise",
123
  placeholder="Input evidence or use Get Evidence from PDF",
124
  )
125
+ with gr.Row():
126
+ prompt_tokens = gr.Number(label="Prompt tokens", visible=False)
127
+ completion_tokens = gr.Number(
128
+ label="Completion tokens", visible=False
129
+ )
130
 
131
  with gr.Column(scale=2):
132
  # Keep the prediction textbox hidden
 
238
  #### *Capstone project*
239
  - <i class="fa-brands fa-github"></i> [jedick/MLE-capstone-project](https://github.com/jedick/MLE-capstone-project) (project repo)
240
  - <i class="fa-brands fa-github"></i> [jedick/AI4citations](https://github.com/jedick/AI4citations) (app repo)
241
+ #### *Text Classification*
242
  - <img src="https://huggingface.co/datasets/huggingface/brand-assets/resolve/main/hf-logo.svg" style="height: 1.2em; display: inline-block;"> [jedick/DeBERTa-v3-base-mnli-fever-anli-scifact-citint](https://huggingface.co/jedick/DeBERTa-v3-base-mnli-fever-anli-scifact-citint) (fine-tuned)
243
  - <img src="https://huggingface.co/datasets/huggingface/brand-assets/resolve/main/hf-logo.svg" style="height: 1.2em; display: inline-block;"> [MoritzLaurer/DeBERTa-v3-base-mnli-fever-anli](https://huggingface.co/MoritzLaurer/DeBERTa-v3-base-mnli-fever-anli) (base)
244
+ #### *Evidence Retrieval*
245
+ - <i class="fa-brands fa-github"></i> [xhluca/bm25s](https://github.com/xhluca/bm25s) (BM25S)
246
+ - <img src="https://huggingface.co/datasets/huggingface/brand-assets/resolve/main/hf-logo.svg" style="height: 1.2em; display: inline-block;"> [deepset/deberta-v3-large-squad2](https://huggingface.co/deepset/deberta-v3-large-squad2) (DeBERTa)
247
+ - <img src="https://upload.wikimedia.org/wikipedia/commons/4/4d/OpenAI_Logo.svg" style="height: 1.2em; display: inline-block;"> [gpt-4o-mini-2024-07-18](https://platform.openai.com/docs/pricing) (GPT)
248
  #### *Datasets for fine-tuning*
249
  - <i class="fa-brands fa-github"></i> [allenai/SciFact](https://github.com/allenai/scifact) (SciFact)
250
  - <i class="fa-brands fa-github"></i> [ScienceNLP-Lab/Citation-Integrity](https://github.com/ScienceNLP-Lab/Citation-Integrity) (CitInt)
251
  #### *Other sources*
 
252
  - <img src="https://plos.org/wp-content/uploads/2020/01/logo-color-blue.svg" style="height: 1.4em; display: inline-block;"> [Medicine](https://doi.org/10.1371/journal.pmed.0030197), <i class="fa-brands fa-wikipedia-w"></i> [CRISPR](https://en.wikipedia.org/wiki/CRISPR) (evidence retrieval examples)
253
  - <img src="https://huggingface.co/datasets/huggingface/brand-assets/resolve/main/hf-logo.svg" style="height: 1.2em; display: inline-block;"> [nyu-mll/multi_nli](https://huggingface.co/datasets/nyu-mll/multi_nli/viewer/default/train?row=37&views%5B%5D=train) (MNLI example)
254
  - <img src="https://huggingface.co/datasets/huggingface/brand-assets/resolve/main/hf-logo.svg" style="height: 1.2em; display: inline-block;"> [NoCrypt/miku](https://huggingface.co/spaces/NoCrypt/miku) (theme)
 
333
  return pdf_file, claim
334
 
335
  @spaces.GPU()
336
+ def _retrieve_with_deberta(pdf_file, claim, top_k):
337
+ """
338
+ Retrieve evidence using DeBERTa
339
+ """
340
+ return retrieve_with_deberta(pdf_file, claim, top_k)
341
+
342
+ def retrieve_evidence(pdf_file, claim, top_k, method):
343
  """
344
  Retrieve evidence using the selected method
345
  """
346
  if method == "BM25S":
347
+ # Append 0 for number of prompt and completion tokens
348
+ return retrieve_with_bm25s(pdf_file, claim, top_k), 0, 0
349
+ elif method == "DeBERTa":
350
+ return _retrieve_with_deberta(pdf_file, claim, top_k), 0, 0
351
+ elif method == "GPT":
352
+ return retrieve_with_gpt(pdf_file, claim)
353
  else:
354
  return f"Unknown retrieval method: {method}"
355
 
 
410
  else:
411
  append_feedback(*args, user_label="REFUTE")
412
 
413
+ def number_visible(value):
414
+ """
415
+ Show numbers (token counts) if GPT is selcted for retrieval
416
+ """
417
+ if value == "GPT":
418
+ return gr.Number(visible=True)
419
+ else:
420
+ return gr.Number(visible=False)
421
+
422
+ def slider_visible(value):
423
+ """
424
+ Hide slider (top_k) if GPT is selcted for retrieval
425
+ """
426
+ if value == "GPT":
427
+ return gr.Slider(visible=False)
428
+ else:
429
+ return gr.Slider(visible=True)
430
+
431
  # Event listeners
432
 
433
+ # Press Enter or Shift-Enter to submit
434
  gr.on(
435
+ triggers=[claim.submit, evidence.submit],
436
  fn=query_model,
437
  inputs=[claim, evidence],
438
  outputs=[prediction, label],
 
441
  # Get evidence from PDF and run the model
442
  gr.on(
443
  triggers=[get_evidence.click],
444
+ fn=retrieve_evidence,
445
  inputs=[pdf_file, claim, top_k, retrieval_method],
446
+ outputs=[evidence, prompt_tokens, completion_tokens],
447
  ).then(
448
  fn=query_model,
449
  inputs=[claim, evidence],
 
501
  outputs=[pdf_file, claim],
502
  api_name=False,
503
  ).then(
504
+ fn=retrieve_evidence,
505
  inputs=[pdf_file, claim, top_k, retrieval_method],
506
+ outputs=[evidence, prompt_tokens, completion_tokens],
507
  api_name=False,
508
  ).then(
509
  fn=query_model,
 
544
  fn=save_feedback_support,
545
  inputs=[claim, evidence, model, label],
546
  outputs=None,
547
+ api_name=False,
548
  )
549
  flag_nei.click(
550
  fn=save_feedback_nei,
551
  inputs=[claim, evidence, model, label],
552
  outputs=None,
553
+ api_name=False,
554
  )
555
  flag_refute.click(
556
  fn=save_feedback_refute,
557
  inputs=[claim, evidence, model, label],
558
  outputs=None,
559
+ api_name=False,
560
+ )
561
+
562
+ # Change visibility of top-k slider and token counts if GPT is selected for retrieval
563
+ retrieval_method.change(
564
+ number_visible, retrieval_method, prompt_tokens, api_name=False
565
+ )
566
+ retrieval_method.change(
567
+ number_visible, retrieval_method, completion_tokens, api_name=False
568
  )
569
+ retrieval_method.change(slider_visible, retrieval_method, top_k, api_name=False)
570
 
571
 
572
  if __name__ == "__main__":
requirements.txt CHANGED
@@ -8,3 +8,4 @@ nltk
8
  bm25s
9
  huggingface_hub
10
  spaces
 
 
8
  bm25s
9
  huggingface_hub
10
  spaces
11
+ openai
llm_retrieval.py → retrieval_bert.py RENAMED
@@ -12,14 +12,14 @@ logging.basicConfig(level=logging.INFO)
12
  logger = logging.getLogger(__name__)
13
 
14
 
15
- class LLMEvidenceRetriever:
16
  """
17
- LLM-based evidence retrieval using extractive question answering
18
  """
19
 
20
  def __init__(self, model_name: str = "deepset/deberta-v3-large-squad2"):
21
  """
22
- Initialize the LLM evidence retriever
23
 
24
  Args:
25
  model_name: HuggingFace model for question answering
@@ -34,7 +34,7 @@ class LLMEvidenceRetriever:
34
  )
35
  # Maximum context length for the model
36
  self.max_length = self.tokenizer.model_max_length
37
- logger.info(f"Initialized LLM retriever with model: {model_name}")
38
 
39
  def _extract_and_clean_text(self, pdf_file: str) -> str:
40
  """
@@ -124,9 +124,9 @@ class LLMEvidenceRetriever:
124
  else:
125
  return f"What evidence supports the claim that {claim.lower()}?"
126
 
127
- def retrieve_evidence(self, pdf_file: str, claim: str, k: int = 5) -> str:
128
  """
129
- Retrieve evidence from PDF using LLM-based question answering
130
 
131
  Args:
132
  pdf_file: Path to PDF file
@@ -177,7 +177,7 @@ class LLMEvidenceRetriever:
177
 
178
  # Sort by confidence score and take top k
179
  answers.sort(key=lambda x: x["score"], reverse=True)
180
- top_answers = answers[:k]
181
 
182
  # Combine evidence passages
183
  if top_answers:
@@ -189,30 +189,30 @@ class LLMEvidenceRetriever:
189
  return "No relevant evidence found in the document."
190
 
191
  except Exception as e:
192
- logger.error(f"Error in LLM evidence retrieval: {str(e)}")
193
  return f"Error retrieving evidence: {str(e)}"
194
 
195
 
196
- def retrieve_from_pdf_llm(pdf_file: str, query: str, k: int = 5) -> str:
197
  """
198
- Wrapper function for LLM-based evidence retrieval
199
  Compatible with the existing BM25S interface
200
 
201
  Args:
202
  pdf_file: Path to PDF file
203
- query: Query/claim to find evidence for
204
- k: Number of evidence passages to retrieve
205
 
206
  Returns:
207
  Retrieved evidence text
208
  """
209
  # Initialize retriever (in production, this should be cached)
210
- retriever = LLMEvidenceRetriever()
211
- return retriever.retrieve_evidence(pdf_file, query, k)
212
 
213
 
214
  # Alternative lightweight model for faster inference
215
- class LightweightLLMRetriever(LLMEvidenceRetriever):
216
  """
217
  Lightweight version using smaller, faster models
218
  """
@@ -221,17 +221,17 @@ class LightweightLLMRetriever(LLMEvidenceRetriever):
221
  super().__init__(model_name="distilbert-base-cased-distilled-squad")
222
 
223
 
224
- def retrieve_from_pdf_llm_fast(pdf_file: str, query: str, k: int = 5) -> str:
225
  """
226
- Fast LLM-based evidence retrieval using lightweight model
227
 
228
  Args:
229
  pdf_file: Path to PDF file
230
- query: Query/claim to find evidence for
231
- k: Number of evidence passages to retrieve
232
 
233
  Returns:
234
  Retrieved evidence text
235
  """
236
- retriever = LightweightLLMRetriever()
237
- return retriever.retrieve_evidence(pdf_file, query, k)
 
12
  logger = logging.getLogger(__name__)
13
 
14
 
15
+ class BERTRetriever:
16
  """
17
+ BERT-based evidence retrieval using extractive question answering
18
  """
19
 
20
  def __init__(self, model_name: str = "deepset/deberta-v3-large-squad2"):
21
  """
22
+ Initialize the BERT evidence retriever
23
 
24
  Args:
25
  model_name: HuggingFace model for question answering
 
34
  )
35
  # Maximum context length for the model
36
  self.max_length = self.tokenizer.model_max_length
37
+ logger.info(f"Initialized BERT retriever with model: {model_name}")
38
 
39
  def _extract_and_clean_text(self, pdf_file: str) -> str:
40
  """
 
124
  else:
125
  return f"What evidence supports the claim that {claim.lower()}?"
126
 
127
+ def retrieve_evidence(self, pdf_file: str, claim: str, top_k: int = 5) -> str:
128
  """
129
+ Retrieve evidence from PDF using BERT-based question answering
130
 
131
  Args:
132
  pdf_file: Path to PDF file
 
177
 
178
  # Sort by confidence score and take top k
179
  answers.sort(key=lambda x: x["score"], reverse=True)
180
+ top_answers = answers[:top_k]
181
 
182
  # Combine evidence passages
183
  if top_answers:
 
189
  return "No relevant evidence found in the document."
190
 
191
  except Exception as e:
192
+ logger.error(f"Error in BERT evidence retrieval: {str(e)}")
193
  return f"Error retrieving evidence: {str(e)}"
194
 
195
 
196
+ def retrieve_with_deberta(pdf_file: str, claim: str, top_k: int = 5) -> str:
197
  """
198
+ Wrapper function for DeBERTa-based evidence retrieval
199
  Compatible with the existing BM25S interface
200
 
201
  Args:
202
  pdf_file: Path to PDF file
203
+ claim: Claim to find evidence for
204
+ top_k: Number of evidence passages to retrieve
205
 
206
  Returns:
207
  Retrieved evidence text
208
  """
209
  # Initialize retriever (in production, this should be cached)
210
+ retriever = BERTRetriever()
211
+ return retriever.retrieve_evidence(pdf_file, claim, top_k)
212
 
213
 
214
  # Alternative lightweight model for faster inference
215
+ class DistilBERTRetriever(BERTRetriever):
216
  """
217
  Lightweight version using smaller, faster models
218
  """
 
221
  super().__init__(model_name="distilbert-base-cased-distilled-squad")
222
 
223
 
224
+ def retrieve_with_distilbert(pdf_file: str, claim: str, top_k: int = 5) -> str:
225
  """
226
+ Fast DistilBERT-based evidence retrieval
227
 
228
  Args:
229
  pdf_file: Path to PDF file
230
+ claim: Claim to find evidence for
231
+ top_k: Number of evidence passages to retrieve
232
 
233
  Returns:
234
  Retrieved evidence text
235
  """
236
+ retriever = DistilBERTRetriever()
237
+ return retriever.retrieve_evidence(pdf_file, claim, top_k)
retrieval.py → retrieval_bm25s.py RENAMED
@@ -5,7 +5,7 @@ from nltk.tokenize import sent_tokenize
5
  import bm25s
6
 
7
 
8
- def retrieve_from_pdf(pdf_file, query, k=10):
9
 
10
  # Get PDF file as binary
11
  with open(pdf_file, mode="rb") as f:
@@ -35,12 +35,12 @@ def retrieve_from_pdf(pdf_file, query, k=10):
35
  # Initialize the BM25 model
36
  retriever = bm25s.BM25()
37
  retriever.index(corpus_tokens, show_progress=False)
38
- # Tokenize the query
39
- query_tokens = bm25s.tokenize(query)
40
 
41
- # Get top-k results
42
  # Use int(k) in case we get str value (as in retrieval example)
43
- results, scores = retriever.retrieve(query_tokens, corpus=corpus, k=int(k))
44
  ## Print results
45
  # for i in range(results.shape[1]):
46
  # doc, score = results[0, i], scores[0, i]
 
5
  import bm25s
6
 
7
 
8
+ def retrieve_with_bm25s(pdf_file, claim, top_k=10):
9
 
10
  # Get PDF file as binary
11
  with open(pdf_file, mode="rb") as f:
 
35
  # Initialize the BM25 model
36
  retriever = bm25s.BM25()
37
  retriever.index(corpus_tokens, show_progress=False)
38
+ # Tokenize the claim
39
+ query_tokens = bm25s.tokenize(claim)
40
 
41
+ # Get top k results
42
  # Use int(k) in case we get str value (as in retrieval example)
43
+ results, scores = retriever.retrieve(query_tokens, corpus=corpus, k=int(top_k))
44
  ## Print results
45
  # for i in range(results.shape[1]):
46
  # doc, score = results[0, i], scores[0, i]
retrieval_gpt.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from openai import OpenAI
2
+ import os
3
+ from typing import Tuple
4
+
5
+
6
+ def retrieve_with_gpt(pdf_file: str, claim: str) -> Tuple[str, int, int]:
7
+ """
8
+ Retrieve evidence from PDF using GPT
9
+
10
+ Args:
11
+ pdf_file: Path to PDF file
12
+ claim: Claim to find evidence for
13
+
14
+ Returns:
15
+ Tuple with retrieved evidence text, prompt tokens, and completion tokens
16
+ """
17
+
18
+ model = "gpt-4o-mini-2024-07-18"
19
+
20
+ prompt = """Retrieve sentences from the PDF (title, abstract, text, sections, not References/Bibliography) to support or refute this claim. \
21
+ Summarize any information from images. \
22
+ Respond only with verbatim sentences from the text and/or summarized sentences from images. \
23
+ If no conclusive evidence is found, respond with the five sentences that are most relevant to the claim. \
24
+ Combine all sentences into one response without quotation marks or line numbers. \
25
+ """
26
+
27
+ prompt = "".join([prompt, f"CLAIM: {claim}"])
28
+
29
+ client = OpenAI()
30
+
31
+ file = client.files.create(file=open(pdf_file, "rb"), purpose="user_data")
32
+
33
+ completion = client.chat.completions.create(
34
+ model=model,
35
+ messages=[
36
+ {
37
+ "role": "user",
38
+ "content": [
39
+ {
40
+ "type": "file",
41
+ "file": {
42
+ "file_id": file.id,
43
+ },
44
+ },
45
+ {
46
+ "type": "text",
47
+ "text": prompt,
48
+ },
49
+ ],
50
+ }
51
+ ],
52
+ )
53
+
54
+ return (
55
+ completion.choices[0].message.content,
56
+ completion.usage.prompt_tokens,
57
+ completion.usage.completion_tokens,
58
+ )