Spaces:
Running
on
Zero
Running
on
Zero
jedick
commited on
Commit
·
0ae0ade
1
Parent(s):
adcbc55
Add GPT retrieval
Browse files- app.py +64 -23
- requirements.txt +1 -0
- llm_retrieval.py → retrieval_bert.py +21 -21
- retrieval.py → retrieval_bm25s.py +5 -5
- retrieval_gpt.py +58 -0
app.py
CHANGED
@@ -2,8 +2,9 @@ import pandas as pd
|
|
2 |
import gradio as gr
|
3 |
from transformers import pipeline
|
4 |
import nltk
|
5 |
-
from
|
6 |
-
from
|
|
|
7 |
import os
|
8 |
import json
|
9 |
from datetime import datetime
|
@@ -102,10 +103,10 @@ with gr.Blocks(theme=my_theme, head=font_awesome_html) as demo:
|
|
102 |
)
|
103 |
with gr.Row():
|
104 |
retrieval_method = gr.Radio(
|
105 |
-
choices=["BM25S", "
|
106 |
value="BM25S",
|
107 |
label="Retrieval Method",
|
108 |
-
info="
|
109 |
)
|
110 |
get_evidence = gr.Button(value="Get Evidence")
|
111 |
top_k = gr.Slider(
|
@@ -113,7 +114,6 @@ with gr.Blocks(theme=my_theme, head=font_awesome_html) as demo:
|
|
113 |
10,
|
114 |
value=5,
|
115 |
step=1,
|
116 |
-
interactive=True,
|
117 |
label="Top k sentences",
|
118 |
)
|
119 |
with gr.Column(scale=3):
|
@@ -122,7 +122,11 @@ with gr.Blocks(theme=my_theme, head=font_awesome_html) as demo:
|
|
122 |
info="aka premise",
|
123 |
placeholder="Input evidence or use Get Evidence from PDF",
|
124 |
)
|
125 |
-
|
|
|
|
|
|
|
|
|
126 |
|
127 |
with gr.Column(scale=2):
|
128 |
# Keep the prediction textbox hidden
|
@@ -234,17 +238,17 @@ with gr.Blocks(theme=my_theme, head=font_awesome_html) as demo:
|
|
234 |
#### *Capstone project*
|
235 |
- <i class="fa-brands fa-github"></i> [jedick/MLE-capstone-project](https://github.com/jedick/MLE-capstone-project) (project repo)
|
236 |
- <i class="fa-brands fa-github"></i> [jedick/AI4citations](https://github.com/jedick/AI4citations) (app repo)
|
237 |
-
#### *
|
238 |
- <img src="https://huggingface.co/datasets/huggingface/brand-assets/resolve/main/hf-logo.svg" style="height: 1.2em; display: inline-block;"> [jedick/DeBERTa-v3-base-mnli-fever-anli-scifact-citint](https://huggingface.co/jedick/DeBERTa-v3-base-mnli-fever-anli-scifact-citint) (fine-tuned)
|
239 |
- <img src="https://huggingface.co/datasets/huggingface/brand-assets/resolve/main/hf-logo.svg" style="height: 1.2em; display: inline-block;"> [MoritzLaurer/DeBERTa-v3-base-mnli-fever-anli](https://huggingface.co/MoritzLaurer/DeBERTa-v3-base-mnli-fever-anli) (base)
|
240 |
-
#### *Evidence Retrieval
|
241 |
-
- <
|
242 |
-
- <img src="https://huggingface.co/datasets/huggingface/brand-assets/resolve/main/hf-logo.svg" style="height: 1.2em; display: inline-block;"> [
|
|
|
243 |
#### *Datasets for fine-tuning*
|
244 |
- <i class="fa-brands fa-github"></i> [allenai/SciFact](https://github.com/allenai/scifact) (SciFact)
|
245 |
- <i class="fa-brands fa-github"></i> [ScienceNLP-Lab/Citation-Integrity](https://github.com/ScienceNLP-Lab/Citation-Integrity) (CitInt)
|
246 |
#### *Other sources*
|
247 |
-
- <i class="fa-brands fa-github"></i> [xhluca/bm25s](https://github.com/xhluca/bm25s) (evidence retrieval)
|
248 |
- <img src="https://plos.org/wp-content/uploads/2020/01/logo-color-blue.svg" style="height: 1.4em; display: inline-block;"> [Medicine](https://doi.org/10.1371/journal.pmed.0030197), <i class="fa-brands fa-wikipedia-w"></i> [CRISPR](https://en.wikipedia.org/wiki/CRISPR) (evidence retrieval examples)
|
249 |
- <img src="https://huggingface.co/datasets/huggingface/brand-assets/resolve/main/hf-logo.svg" style="height: 1.2em; display: inline-block;"> [nyu-mll/multi_nli](https://huggingface.co/datasets/nyu-mll/multi_nli/viewer/default/train?row=37&views%5B%5D=train) (MNLI example)
|
250 |
- <img src="https://huggingface.co/datasets/huggingface/brand-assets/resolve/main/hf-logo.svg" style="height: 1.2em; display: inline-block;"> [NoCrypt/miku](https://huggingface.co/spaces/NoCrypt/miku) (theme)
|
@@ -329,16 +333,23 @@ with gr.Blocks(theme=my_theme, head=font_awesome_html) as demo:
|
|
329 |
return pdf_file, claim
|
330 |
|
331 |
@spaces.GPU()
|
332 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
333 |
"""
|
334 |
Retrieve evidence using the selected method
|
335 |
"""
|
336 |
if method == "BM25S":
|
337 |
-
|
338 |
-
|
339 |
-
|
340 |
-
|
341 |
-
|
|
|
342 |
else:
|
343 |
return f"Unknown retrieval method: {method}"
|
344 |
|
@@ -399,11 +410,29 @@ with gr.Blocks(theme=my_theme, head=font_awesome_html) as demo:
|
|
399 |
else:
|
400 |
append_feedback(*args, user_label="REFUTE")
|
401 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
402 |
# Event listeners
|
403 |
|
404 |
-
#
|
405 |
gr.on(
|
406 |
-
triggers=[claim.submit, evidence.submit
|
407 |
fn=query_model,
|
408 |
inputs=[claim, evidence],
|
409 |
outputs=[prediction, label],
|
@@ -412,9 +441,9 @@ with gr.Blocks(theme=my_theme, head=font_awesome_html) as demo:
|
|
412 |
# Get evidence from PDF and run the model
|
413 |
gr.on(
|
414 |
triggers=[get_evidence.click],
|
415 |
-
fn=
|
416 |
inputs=[pdf_file, claim, top_k, retrieval_method],
|
417 |
-
outputs=evidence,
|
418 |
).then(
|
419 |
fn=query_model,
|
420 |
inputs=[claim, evidence],
|
@@ -472,9 +501,9 @@ with gr.Blocks(theme=my_theme, head=font_awesome_html) as demo:
|
|
472 |
outputs=[pdf_file, claim],
|
473 |
api_name=False,
|
474 |
).then(
|
475 |
-
fn=
|
476 |
inputs=[pdf_file, claim, top_k, retrieval_method],
|
477 |
-
outputs=evidence,
|
478 |
api_name=False,
|
479 |
).then(
|
480 |
fn=query_model,
|
@@ -515,17 +544,29 @@ with gr.Blocks(theme=my_theme, head=font_awesome_html) as demo:
|
|
515 |
fn=save_feedback_support,
|
516 |
inputs=[claim, evidence, model, label],
|
517 |
outputs=None,
|
|
|
518 |
)
|
519 |
flag_nei.click(
|
520 |
fn=save_feedback_nei,
|
521 |
inputs=[claim, evidence, model, label],
|
522 |
outputs=None,
|
|
|
523 |
)
|
524 |
flag_refute.click(
|
525 |
fn=save_feedback_refute,
|
526 |
inputs=[claim, evidence, model, label],
|
527 |
outputs=None,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
528 |
)
|
|
|
529 |
|
530 |
|
531 |
if __name__ == "__main__":
|
|
|
2 |
import gradio as gr
|
3 |
from transformers import pipeline
|
4 |
import nltk
|
5 |
+
from retrieval_bm25s import retrieve_with_bm25s
|
6 |
+
from retrieval_bert import retrieve_with_deberta
|
7 |
+
from retrieval_gpt import retrieve_with_gpt
|
8 |
import os
|
9 |
import json
|
10 |
from datetime import datetime
|
|
|
103 |
)
|
104 |
with gr.Row():
|
105 |
retrieval_method = gr.Radio(
|
106 |
+
choices=["BM25S", "DeBERTa", "GPT"],
|
107 |
value="BM25S",
|
108 |
label="Retrieval Method",
|
109 |
+
info="Keyword search (BM25S) or AI (DeBERTa, GPT)",
|
110 |
)
|
111 |
get_evidence = gr.Button(value="Get Evidence")
|
112 |
top_k = gr.Slider(
|
|
|
114 |
10,
|
115 |
value=5,
|
116 |
step=1,
|
|
|
117 |
label="Top k sentences",
|
118 |
)
|
119 |
with gr.Column(scale=3):
|
|
|
122 |
info="aka premise",
|
123 |
placeholder="Input evidence or use Get Evidence from PDF",
|
124 |
)
|
125 |
+
with gr.Row():
|
126 |
+
prompt_tokens = gr.Number(label="Prompt tokens", visible=False)
|
127 |
+
completion_tokens = gr.Number(
|
128 |
+
label="Completion tokens", visible=False
|
129 |
+
)
|
130 |
|
131 |
with gr.Column(scale=2):
|
132 |
# Keep the prediction textbox hidden
|
|
|
238 |
#### *Capstone project*
|
239 |
- <i class="fa-brands fa-github"></i> [jedick/MLE-capstone-project](https://github.com/jedick/MLE-capstone-project) (project repo)
|
240 |
- <i class="fa-brands fa-github"></i> [jedick/AI4citations](https://github.com/jedick/AI4citations) (app repo)
|
241 |
+
#### *Text Classification*
|
242 |
- <img src="https://huggingface.co/datasets/huggingface/brand-assets/resolve/main/hf-logo.svg" style="height: 1.2em; display: inline-block;"> [jedick/DeBERTa-v3-base-mnli-fever-anli-scifact-citint](https://huggingface.co/jedick/DeBERTa-v3-base-mnli-fever-anli-scifact-citint) (fine-tuned)
|
243 |
- <img src="https://huggingface.co/datasets/huggingface/brand-assets/resolve/main/hf-logo.svg" style="height: 1.2em; display: inline-block;"> [MoritzLaurer/DeBERTa-v3-base-mnli-fever-anli](https://huggingface.co/MoritzLaurer/DeBERTa-v3-base-mnli-fever-anli) (base)
|
244 |
+
#### *Evidence Retrieval*
|
245 |
+
- <i class="fa-brands fa-github"></i> [xhluca/bm25s](https://github.com/xhluca/bm25s) (BM25S)
|
246 |
+
- <img src="https://huggingface.co/datasets/huggingface/brand-assets/resolve/main/hf-logo.svg" style="height: 1.2em; display: inline-block;"> [deepset/deberta-v3-large-squad2](https://huggingface.co/deepset/deberta-v3-large-squad2) (DeBERTa)
|
247 |
+
- <img src="https://upload.wikimedia.org/wikipedia/commons/4/4d/OpenAI_Logo.svg" style="height: 1.2em; display: inline-block;"> [gpt-4o-mini-2024-07-18](https://platform.openai.com/docs/pricing) (GPT)
|
248 |
#### *Datasets for fine-tuning*
|
249 |
- <i class="fa-brands fa-github"></i> [allenai/SciFact](https://github.com/allenai/scifact) (SciFact)
|
250 |
- <i class="fa-brands fa-github"></i> [ScienceNLP-Lab/Citation-Integrity](https://github.com/ScienceNLP-Lab/Citation-Integrity) (CitInt)
|
251 |
#### *Other sources*
|
|
|
252 |
- <img src="https://plos.org/wp-content/uploads/2020/01/logo-color-blue.svg" style="height: 1.4em; display: inline-block;"> [Medicine](https://doi.org/10.1371/journal.pmed.0030197), <i class="fa-brands fa-wikipedia-w"></i> [CRISPR](https://en.wikipedia.org/wiki/CRISPR) (evidence retrieval examples)
|
253 |
- <img src="https://huggingface.co/datasets/huggingface/brand-assets/resolve/main/hf-logo.svg" style="height: 1.2em; display: inline-block;"> [nyu-mll/multi_nli](https://huggingface.co/datasets/nyu-mll/multi_nli/viewer/default/train?row=37&views%5B%5D=train) (MNLI example)
|
254 |
- <img src="https://huggingface.co/datasets/huggingface/brand-assets/resolve/main/hf-logo.svg" style="height: 1.2em; display: inline-block;"> [NoCrypt/miku](https://huggingface.co/spaces/NoCrypt/miku) (theme)
|
|
|
333 |
return pdf_file, claim
|
334 |
|
335 |
@spaces.GPU()
|
336 |
+
def _retrieve_with_deberta(pdf_file, claim, top_k):
|
337 |
+
"""
|
338 |
+
Retrieve evidence using DeBERTa
|
339 |
+
"""
|
340 |
+
return retrieve_with_deberta(pdf_file, claim, top_k)
|
341 |
+
|
342 |
+
def retrieve_evidence(pdf_file, claim, top_k, method):
|
343 |
"""
|
344 |
Retrieve evidence using the selected method
|
345 |
"""
|
346 |
if method == "BM25S":
|
347 |
+
# Append 0 for number of prompt and completion tokens
|
348 |
+
return retrieve_with_bm25s(pdf_file, claim, top_k), 0, 0
|
349 |
+
elif method == "DeBERTa":
|
350 |
+
return _retrieve_with_deberta(pdf_file, claim, top_k), 0, 0
|
351 |
+
elif method == "GPT":
|
352 |
+
return retrieve_with_gpt(pdf_file, claim)
|
353 |
else:
|
354 |
return f"Unknown retrieval method: {method}"
|
355 |
|
|
|
410 |
else:
|
411 |
append_feedback(*args, user_label="REFUTE")
|
412 |
|
413 |
+
def number_visible(value):
|
414 |
+
"""
|
415 |
+
Show numbers (token counts) if GPT is selcted for retrieval
|
416 |
+
"""
|
417 |
+
if value == "GPT":
|
418 |
+
return gr.Number(visible=True)
|
419 |
+
else:
|
420 |
+
return gr.Number(visible=False)
|
421 |
+
|
422 |
+
def slider_visible(value):
|
423 |
+
"""
|
424 |
+
Hide slider (top_k) if GPT is selcted for retrieval
|
425 |
+
"""
|
426 |
+
if value == "GPT":
|
427 |
+
return gr.Slider(visible=False)
|
428 |
+
else:
|
429 |
+
return gr.Slider(visible=True)
|
430 |
+
|
431 |
# Event listeners
|
432 |
|
433 |
+
# Press Enter or Shift-Enter to submit
|
434 |
gr.on(
|
435 |
+
triggers=[claim.submit, evidence.submit],
|
436 |
fn=query_model,
|
437 |
inputs=[claim, evidence],
|
438 |
outputs=[prediction, label],
|
|
|
441 |
# Get evidence from PDF and run the model
|
442 |
gr.on(
|
443 |
triggers=[get_evidence.click],
|
444 |
+
fn=retrieve_evidence,
|
445 |
inputs=[pdf_file, claim, top_k, retrieval_method],
|
446 |
+
outputs=[evidence, prompt_tokens, completion_tokens],
|
447 |
).then(
|
448 |
fn=query_model,
|
449 |
inputs=[claim, evidence],
|
|
|
501 |
outputs=[pdf_file, claim],
|
502 |
api_name=False,
|
503 |
).then(
|
504 |
+
fn=retrieve_evidence,
|
505 |
inputs=[pdf_file, claim, top_k, retrieval_method],
|
506 |
+
outputs=[evidence, prompt_tokens, completion_tokens],
|
507 |
api_name=False,
|
508 |
).then(
|
509 |
fn=query_model,
|
|
|
544 |
fn=save_feedback_support,
|
545 |
inputs=[claim, evidence, model, label],
|
546 |
outputs=None,
|
547 |
+
api_name=False,
|
548 |
)
|
549 |
flag_nei.click(
|
550 |
fn=save_feedback_nei,
|
551 |
inputs=[claim, evidence, model, label],
|
552 |
outputs=None,
|
553 |
+
api_name=False,
|
554 |
)
|
555 |
flag_refute.click(
|
556 |
fn=save_feedback_refute,
|
557 |
inputs=[claim, evidence, model, label],
|
558 |
outputs=None,
|
559 |
+
api_name=False,
|
560 |
+
)
|
561 |
+
|
562 |
+
# Change visibility of top-k slider and token counts if GPT is selected for retrieval
|
563 |
+
retrieval_method.change(
|
564 |
+
number_visible, retrieval_method, prompt_tokens, api_name=False
|
565 |
+
)
|
566 |
+
retrieval_method.change(
|
567 |
+
number_visible, retrieval_method, completion_tokens, api_name=False
|
568 |
)
|
569 |
+
retrieval_method.change(slider_visible, retrieval_method, top_k, api_name=False)
|
570 |
|
571 |
|
572 |
if __name__ == "__main__":
|
requirements.txt
CHANGED
@@ -8,3 +8,4 @@ nltk
|
|
8 |
bm25s
|
9 |
huggingface_hub
|
10 |
spaces
|
|
|
|
8 |
bm25s
|
9 |
huggingface_hub
|
10 |
spaces
|
11 |
+
openai
|
llm_retrieval.py → retrieval_bert.py
RENAMED
@@ -12,14 +12,14 @@ logging.basicConfig(level=logging.INFO)
|
|
12 |
logger = logging.getLogger(__name__)
|
13 |
|
14 |
|
15 |
-
class
|
16 |
"""
|
17 |
-
|
18 |
"""
|
19 |
|
20 |
def __init__(self, model_name: str = "deepset/deberta-v3-large-squad2"):
|
21 |
"""
|
22 |
-
Initialize the
|
23 |
|
24 |
Args:
|
25 |
model_name: HuggingFace model for question answering
|
@@ -34,7 +34,7 @@ class LLMEvidenceRetriever:
|
|
34 |
)
|
35 |
# Maximum context length for the model
|
36 |
self.max_length = self.tokenizer.model_max_length
|
37 |
-
logger.info(f"Initialized
|
38 |
|
39 |
def _extract_and_clean_text(self, pdf_file: str) -> str:
|
40 |
"""
|
@@ -124,9 +124,9 @@ class LLMEvidenceRetriever:
|
|
124 |
else:
|
125 |
return f"What evidence supports the claim that {claim.lower()}?"
|
126 |
|
127 |
-
def retrieve_evidence(self, pdf_file: str, claim: str,
|
128 |
"""
|
129 |
-
Retrieve evidence from PDF using
|
130 |
|
131 |
Args:
|
132 |
pdf_file: Path to PDF file
|
@@ -177,7 +177,7 @@ class LLMEvidenceRetriever:
|
|
177 |
|
178 |
# Sort by confidence score and take top k
|
179 |
answers.sort(key=lambda x: x["score"], reverse=True)
|
180 |
-
top_answers = answers[:
|
181 |
|
182 |
# Combine evidence passages
|
183 |
if top_answers:
|
@@ -189,30 +189,30 @@ class LLMEvidenceRetriever:
|
|
189 |
return "No relevant evidence found in the document."
|
190 |
|
191 |
except Exception as e:
|
192 |
-
logger.error(f"Error in
|
193 |
return f"Error retrieving evidence: {str(e)}"
|
194 |
|
195 |
|
196 |
-
def
|
197 |
"""
|
198 |
-
Wrapper function for
|
199 |
Compatible with the existing BM25S interface
|
200 |
|
201 |
Args:
|
202 |
pdf_file: Path to PDF file
|
203 |
-
|
204 |
-
|
205 |
|
206 |
Returns:
|
207 |
Retrieved evidence text
|
208 |
"""
|
209 |
# Initialize retriever (in production, this should be cached)
|
210 |
-
retriever =
|
211 |
-
return retriever.retrieve_evidence(pdf_file,
|
212 |
|
213 |
|
214 |
# Alternative lightweight model for faster inference
|
215 |
-
class
|
216 |
"""
|
217 |
Lightweight version using smaller, faster models
|
218 |
"""
|
@@ -221,17 +221,17 @@ class LightweightLLMRetriever(LLMEvidenceRetriever):
|
|
221 |
super().__init__(model_name="distilbert-base-cased-distilled-squad")
|
222 |
|
223 |
|
224 |
-
def
|
225 |
"""
|
226 |
-
Fast
|
227 |
|
228 |
Args:
|
229 |
pdf_file: Path to PDF file
|
230 |
-
|
231 |
-
|
232 |
|
233 |
Returns:
|
234 |
Retrieved evidence text
|
235 |
"""
|
236 |
-
retriever =
|
237 |
-
return retriever.retrieve_evidence(pdf_file,
|
|
|
12 |
logger = logging.getLogger(__name__)
|
13 |
|
14 |
|
15 |
+
class BERTRetriever:
|
16 |
"""
|
17 |
+
BERT-based evidence retrieval using extractive question answering
|
18 |
"""
|
19 |
|
20 |
def __init__(self, model_name: str = "deepset/deberta-v3-large-squad2"):
|
21 |
"""
|
22 |
+
Initialize the BERT evidence retriever
|
23 |
|
24 |
Args:
|
25 |
model_name: HuggingFace model for question answering
|
|
|
34 |
)
|
35 |
# Maximum context length for the model
|
36 |
self.max_length = self.tokenizer.model_max_length
|
37 |
+
logger.info(f"Initialized BERT retriever with model: {model_name}")
|
38 |
|
39 |
def _extract_and_clean_text(self, pdf_file: str) -> str:
|
40 |
"""
|
|
|
124 |
else:
|
125 |
return f"What evidence supports the claim that {claim.lower()}?"
|
126 |
|
127 |
+
def retrieve_evidence(self, pdf_file: str, claim: str, top_k: int = 5) -> str:
|
128 |
"""
|
129 |
+
Retrieve evidence from PDF using BERT-based question answering
|
130 |
|
131 |
Args:
|
132 |
pdf_file: Path to PDF file
|
|
|
177 |
|
178 |
# Sort by confidence score and take top k
|
179 |
answers.sort(key=lambda x: x["score"], reverse=True)
|
180 |
+
top_answers = answers[:top_k]
|
181 |
|
182 |
# Combine evidence passages
|
183 |
if top_answers:
|
|
|
189 |
return "No relevant evidence found in the document."
|
190 |
|
191 |
except Exception as e:
|
192 |
+
logger.error(f"Error in BERT evidence retrieval: {str(e)}")
|
193 |
return f"Error retrieving evidence: {str(e)}"
|
194 |
|
195 |
|
196 |
+
def retrieve_with_deberta(pdf_file: str, claim: str, top_k: int = 5) -> str:
|
197 |
"""
|
198 |
+
Wrapper function for DeBERTa-based evidence retrieval
|
199 |
Compatible with the existing BM25S interface
|
200 |
|
201 |
Args:
|
202 |
pdf_file: Path to PDF file
|
203 |
+
claim: Claim to find evidence for
|
204 |
+
top_k: Number of evidence passages to retrieve
|
205 |
|
206 |
Returns:
|
207 |
Retrieved evidence text
|
208 |
"""
|
209 |
# Initialize retriever (in production, this should be cached)
|
210 |
+
retriever = BERTRetriever()
|
211 |
+
return retriever.retrieve_evidence(pdf_file, claim, top_k)
|
212 |
|
213 |
|
214 |
# Alternative lightweight model for faster inference
|
215 |
+
class DistilBERTRetriever(BERTRetriever):
|
216 |
"""
|
217 |
Lightweight version using smaller, faster models
|
218 |
"""
|
|
|
221 |
super().__init__(model_name="distilbert-base-cased-distilled-squad")
|
222 |
|
223 |
|
224 |
+
def retrieve_with_distilbert(pdf_file: str, claim: str, top_k: int = 5) -> str:
|
225 |
"""
|
226 |
+
Fast DistilBERT-based evidence retrieval
|
227 |
|
228 |
Args:
|
229 |
pdf_file: Path to PDF file
|
230 |
+
claim: Claim to find evidence for
|
231 |
+
top_k: Number of evidence passages to retrieve
|
232 |
|
233 |
Returns:
|
234 |
Retrieved evidence text
|
235 |
"""
|
236 |
+
retriever = DistilBERTRetriever()
|
237 |
+
return retriever.retrieve_evidence(pdf_file, claim, top_k)
|
retrieval.py → retrieval_bm25s.py
RENAMED
@@ -5,7 +5,7 @@ from nltk.tokenize import sent_tokenize
|
|
5 |
import bm25s
|
6 |
|
7 |
|
8 |
-
def
|
9 |
|
10 |
# Get PDF file as binary
|
11 |
with open(pdf_file, mode="rb") as f:
|
@@ -35,12 +35,12 @@ def retrieve_from_pdf(pdf_file, query, k=10):
|
|
35 |
# Initialize the BM25 model
|
36 |
retriever = bm25s.BM25()
|
37 |
retriever.index(corpus_tokens, show_progress=False)
|
38 |
-
# Tokenize the
|
39 |
-
query_tokens = bm25s.tokenize(
|
40 |
|
41 |
-
# Get top
|
42 |
# Use int(k) in case we get str value (as in retrieval example)
|
43 |
-
results, scores = retriever.retrieve(query_tokens, corpus=corpus, k=int(
|
44 |
## Print results
|
45 |
# for i in range(results.shape[1]):
|
46 |
# doc, score = results[0, i], scores[0, i]
|
|
|
5 |
import bm25s
|
6 |
|
7 |
|
8 |
+
def retrieve_with_bm25s(pdf_file, claim, top_k=10):
|
9 |
|
10 |
# Get PDF file as binary
|
11 |
with open(pdf_file, mode="rb") as f:
|
|
|
35 |
# Initialize the BM25 model
|
36 |
retriever = bm25s.BM25()
|
37 |
retriever.index(corpus_tokens, show_progress=False)
|
38 |
+
# Tokenize the claim
|
39 |
+
query_tokens = bm25s.tokenize(claim)
|
40 |
|
41 |
+
# Get top k results
|
42 |
# Use int(k) in case we get str value (as in retrieval example)
|
43 |
+
results, scores = retriever.retrieve(query_tokens, corpus=corpus, k=int(top_k))
|
44 |
## Print results
|
45 |
# for i in range(results.shape[1]):
|
46 |
# doc, score = results[0, i], scores[0, i]
|
retrieval_gpt.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from openai import OpenAI
|
2 |
+
import os
|
3 |
+
from typing import Tuple
|
4 |
+
|
5 |
+
|
6 |
+
def retrieve_with_gpt(pdf_file: str, claim: str) -> Tuple[str, int, int]:
|
7 |
+
"""
|
8 |
+
Retrieve evidence from PDF using GPT
|
9 |
+
|
10 |
+
Args:
|
11 |
+
pdf_file: Path to PDF file
|
12 |
+
claim: Claim to find evidence for
|
13 |
+
|
14 |
+
Returns:
|
15 |
+
Tuple with retrieved evidence text, prompt tokens, and completion tokens
|
16 |
+
"""
|
17 |
+
|
18 |
+
model = "gpt-4o-mini-2024-07-18"
|
19 |
+
|
20 |
+
prompt = """Retrieve sentences from the PDF (title, abstract, text, sections, not References/Bibliography) to support or refute this claim. \
|
21 |
+
Summarize any information from images. \
|
22 |
+
Respond only with verbatim sentences from the text and/or summarized sentences from images. \
|
23 |
+
If no conclusive evidence is found, respond with the five sentences that are most relevant to the claim. \
|
24 |
+
Combine all sentences into one response without quotation marks or line numbers. \
|
25 |
+
"""
|
26 |
+
|
27 |
+
prompt = "".join([prompt, f"CLAIM: {claim}"])
|
28 |
+
|
29 |
+
client = OpenAI()
|
30 |
+
|
31 |
+
file = client.files.create(file=open(pdf_file, "rb"), purpose="user_data")
|
32 |
+
|
33 |
+
completion = client.chat.completions.create(
|
34 |
+
model=model,
|
35 |
+
messages=[
|
36 |
+
{
|
37 |
+
"role": "user",
|
38 |
+
"content": [
|
39 |
+
{
|
40 |
+
"type": "file",
|
41 |
+
"file": {
|
42 |
+
"file_id": file.id,
|
43 |
+
},
|
44 |
+
},
|
45 |
+
{
|
46 |
+
"type": "text",
|
47 |
+
"text": prompt,
|
48 |
+
},
|
49 |
+
],
|
50 |
+
}
|
51 |
+
],
|
52 |
+
)
|
53 |
+
|
54 |
+
return (
|
55 |
+
completion.choices[0].message.content,
|
56 |
+
completion.usage.prompt_tokens,
|
57 |
+
completion.usage.completion_tokens,
|
58 |
+
)
|