jedick commited on
Commit
5cdd81a
·
1 Parent(s): 00c763e

Add LLM retrieval

Browse files
Files changed (2) hide show
  1. app.py +33 -19
  2. llm_retrieval.py +237 -0
app.py CHANGED
@@ -3,6 +3,7 @@ import gradio as gr
3
  from transformers import pipeline
4
  import nltk
5
  from retrieval import retrieve_from_pdf
 
6
  import os
7
  import json
8
  from datetime import datetime
@@ -93,7 +94,9 @@ with gr.Blocks(theme=my_theme, css=custom_css, head=font_awesome_html) as demo:
93
  with gr.Column(scale=3):
94
  with gr.Row():
95
  gr.Markdown("# AI4citations")
96
- gr.Markdown("## *AI-powered citation verification*")
 
 
97
  claim = gr.Textbox(
98
  label="Claim",
99
  info="aka hypothesis",
@@ -105,6 +108,13 @@ with gr.Blocks(theme=my_theme, css=custom_css, head=font_awesome_html) as demo:
105
  pdf_file = gr.File(
106
  label="Upload PDF", type="filepath", height=120
107
  )
 
 
 
 
 
 
 
108
  get_evidence = gr.Button(value="Get Evidence")
109
  top_k = gr.Slider(
110
  1,
@@ -193,7 +203,7 @@ with gr.Blocks(theme=my_theme, css=custom_css, head=font_awesome_html) as demo:
193
  ### Usage:
194
 
195
  - Input a **Claim**, then:
196
- - Upload a PDF and click **Get Evidence** OR
197
  - Input **Evidence** statements yourself
198
  """
199
  )
@@ -232,24 +242,15 @@ with gr.Blocks(theme=my_theme, css=custom_css, head=font_awesome_html) as demo:
232
  #### *Capstone project*
233
  - <i class="fa-brands fa-github"></i> [jedick/MLE-capstone-project](https://github.com/jedick/MLE-capstone-project) (project repo)
234
  - <i class="fa-brands fa-github"></i> [jedick/AI4citations](https://github.com/jedick/AI4citations) (app repo)
235
- """
236
- )
237
- gr.Markdown(
238
- """
239
- #### *Models*
240
  - <img src="https://huggingface.co/datasets/huggingface/brand-assets/resolve/main/hf-logo.svg" style="height: 1.2em; display: inline-block;"> [jedick/DeBERTa-v3-base-mnli-fever-anli-scifact-citint](https://huggingface.co/jedick/DeBERTa-v3-base-mnli-fever-anli-scifact-citint) (fine-tuned)
241
  - <img src="https://huggingface.co/datasets/huggingface/brand-assets/resolve/main/hf-logo.svg" style="height: 1.2em; display: inline-block;"> [MoritzLaurer/DeBERTa-v3-base-mnli-fever-anli](https://huggingface.co/MoritzLaurer/DeBERTa-v3-base-mnli-fever-anli) (base)
242
- """
243
- )
244
- gr.Markdown(
245
- """
246
  #### *Datasets for fine-tuning*
247
  - <i class="fa-brands fa-github"></i> [allenai/SciFact](https://github.com/allenai/scifact) (SciFact)
248
  - <i class="fa-brands fa-github"></i> [ScienceNLP-Lab/Citation-Integrity](https://github.com/ScienceNLP-Lab/Citation-Integrity) (CitInt)
249
- """
250
- )
251
- gr.Markdown(
252
- """
253
  #### *Other sources*
254
  - <i class="fa-brands fa-github"></i> [xhluca/bm25s](https://github.com/xhluca/bm25s) (evidence retrieval)
255
  - <img src="https://plos.org/wp-content/uploads/2020/01/logo-color-blue.svg" style="height: 1.4em; display: inline-block;"> [Medicine](https://doi.org/10.1371/journal.pmed.0030197), <i class="fa-brands fa-wikipedia-w"></i> [CRISPR](https://en.wikipedia.org/wiki/CRISPR) (evidence retrieval examples)
@@ -335,6 +336,19 @@ with gr.Blocks(theme=my_theme, css=custom_css, head=font_awesome_html) as demo:
335
  pdf_file = f"examples/retrieval/{pdf_file}"
336
  return pdf_file, claim
337
 
 
 
 
 
 
 
 
 
 
 
 
 
 
338
  def append_feedback(
339
  claim: str, evidence: str, model: str, label: str, user_label: str
340
  ) -> None:
@@ -405,8 +419,8 @@ with gr.Blocks(theme=my_theme, css=custom_css, head=font_awesome_html) as demo:
405
  # Get evidence from PDF and run the model
406
  gr.on(
407
  triggers=[get_evidence.click],
408
- fn=retrieve_from_pdf,
409
- inputs=[pdf_file, claim, top_k],
410
  outputs=evidence,
411
  ).then(
412
  fn=query_model,
@@ -465,8 +479,8 @@ with gr.Blocks(theme=my_theme, css=custom_css, head=font_awesome_html) as demo:
465
  outputs=[pdf_file, claim],
466
  api_name=False,
467
  ).then(
468
- fn=retrieve_from_pdf,
469
- inputs=[pdf_file, claim, top_k],
470
  outputs=evidence,
471
  api_name=False,
472
  ).then(
 
3
  from transformers import pipeline
4
  import nltk
5
  from retrieval import retrieve_from_pdf
6
+ from llm_retrieval import retrieve_from_pdf_llm, retrieve_from_pdf_llm_fast
7
  import os
8
  import json
9
  from datetime import datetime
 
94
  with gr.Column(scale=3):
95
  with gr.Row():
96
  gr.Markdown("# AI4citations")
97
+ gr.Markdown(
98
+ "## *AI-powered citation verification* ([more info](https://github.com/jedick/AI4citations))"
99
+ )
100
  claim = gr.Textbox(
101
  label="Claim",
102
  info="aka hypothesis",
 
108
  pdf_file = gr.File(
109
  label="Upload PDF", type="filepath", height=120
110
  )
111
+ with gr.Row():
112
+ retrieval_method = gr.Radio(
113
+ choices=["BM25S", "LLM (Large)", "LLM (Fast)"],
114
+ value="BM25S",
115
+ label="Retrieval Method",
116
+ info="Choose between keyword-based (BM25S) or AI-based (LLM) evidence retrieval",
117
+ )
118
  get_evidence = gr.Button(value="Get Evidence")
119
  top_k = gr.Slider(
120
  1,
 
203
  ### Usage:
204
 
205
  - Input a **Claim**, then:
206
+ - Upload a PDF, select retrieval method, and click **Get Evidence** OR
207
  - Input **Evidence** statements yourself
208
  """
209
  )
 
242
  #### *Capstone project*
243
  - <i class="fa-brands fa-github"></i> [jedick/MLE-capstone-project](https://github.com/jedick/MLE-capstone-project) (project repo)
244
  - <i class="fa-brands fa-github"></i> [jedick/AI4citations](https://github.com/jedick/AI4citations) (app repo)
245
+ #### *Claim Verification Models (text classification)*
 
 
 
 
246
  - <img src="https://huggingface.co/datasets/huggingface/brand-assets/resolve/main/hf-logo.svg" style="height: 1.2em; display: inline-block;"> [jedick/DeBERTa-v3-base-mnli-fever-anli-scifact-citint](https://huggingface.co/jedick/DeBERTa-v3-base-mnli-fever-anli-scifact-citint) (fine-tuned)
247
  - <img src="https://huggingface.co/datasets/huggingface/brand-assets/resolve/main/hf-logo.svg" style="height: 1.2em; display: inline-block;"> [MoritzLaurer/DeBERTa-v3-base-mnli-fever-anli](https://huggingface.co/MoritzLaurer/DeBERTa-v3-base-mnli-fever-anli) (base)
248
+ #### *Evidence Retrieval Models (question answering)*
249
+ - <img src="https://huggingface.co/datasets/huggingface/brand-assets/resolve/main/hf-logo.svg" style="height: 1.2em; display: inline-block;"> [deepset/deberta-v3-large-squad2](https://huggingface.co/deepset/deberta-v3-large-squad2) (Large)
250
+ - <img src="https://huggingface.co/datasets/huggingface/brand-assets/resolve/main/hf-logo.svg" style="height: 1.2em; display: inline-block;"> [distilbert-base-cased-distilled-squad](https://huggingface.co/distilbert/distilbert-base-cased-distilled-squad) (Fast)
 
251
  #### *Datasets for fine-tuning*
252
  - <i class="fa-brands fa-github"></i> [allenai/SciFact](https://github.com/allenai/scifact) (SciFact)
253
  - <i class="fa-brands fa-github"></i> [ScienceNLP-Lab/Citation-Integrity](https://github.com/ScienceNLP-Lab/Citation-Integrity) (CitInt)
 
 
 
 
254
  #### *Other sources*
255
  - <i class="fa-brands fa-github"></i> [xhluca/bm25s](https://github.com/xhluca/bm25s) (evidence retrieval)
256
  - <img src="https://plos.org/wp-content/uploads/2020/01/logo-color-blue.svg" style="height: 1.4em; display: inline-block;"> [Medicine](https://doi.org/10.1371/journal.pmed.0030197), <i class="fa-brands fa-wikipedia-w"></i> [CRISPR](https://en.wikipedia.org/wiki/CRISPR) (evidence retrieval examples)
 
336
  pdf_file = f"examples/retrieval/{pdf_file}"
337
  return pdf_file, claim
338
 
339
+ def retrieve_evidence_with_method(pdf_file, claim, top_k, method):
340
+ """
341
+ Retrieve evidence using the selected method
342
+ """
343
+ if method == "BM25S":
344
+ return retrieve_from_pdf(pdf_file, claim, k=top_k)
345
+ elif method == "LLM (Large)":
346
+ return retrieve_from_pdf_llm(pdf_file, claim, k=top_k)
347
+ elif method == "LLM (Fast)":
348
+ return retrieve_from_pdf_llm_fast(pdf_file, claim, k=top_k)
349
+ else:
350
+ return f"Unknown retrieval method: {method}"
351
+
352
  def append_feedback(
353
  claim: str, evidence: str, model: str, label: str, user_label: str
354
  ) -> None:
 
419
  # Get evidence from PDF and run the model
420
  gr.on(
421
  triggers=[get_evidence.click],
422
+ fn=retrieve_evidence_with_method,
423
+ inputs=[pdf_file, claim, top_k, retrieval_method],
424
  outputs=evidence,
425
  ).then(
426
  fn=query_model,
 
479
  outputs=[pdf_file, claim],
480
  api_name=False,
481
  ).then(
482
+ fn=retrieve_evidence_with_method,
483
+ inputs=[pdf_file, claim, top_k, retrieval_method],
484
  outputs=evidence,
485
  api_name=False,
486
  ).then(
llm_retrieval.py ADDED
@@ -0,0 +1,237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import fitz # pip install pymupdf
3
+ from unidecode import unidecode
4
+ from nltk.tokenize import sent_tokenize
5
+ from transformers import pipeline, AutoTokenizer
6
+ import torch
7
+ from typing import List, Tuple, Optional
8
+ import logging
9
+
10
+ # Configure logging
11
+ logging.basicConfig(level=logging.INFO)
12
+ logger = logging.getLogger(__name__)
13
+
14
+
15
+ class LLMEvidenceRetriever:
16
+ """
17
+ LLM-based evidence retrieval using extractive question answering
18
+ """
19
+
20
+ def __init__(self, model_name: str = "deepset/deberta-v3-large-squad2"):
21
+ """
22
+ Initialize the LLM evidence retriever
23
+
24
+ Args:
25
+ model_name: HuggingFace model for question answering
26
+ """
27
+ self.model_name = model_name
28
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
29
+ self.qa_pipeline = pipeline(
30
+ "question-answering",
31
+ model=model_name,
32
+ tokenizer=self.tokenizer,
33
+ device=0 if torch.cuda.is_available() else -1,
34
+ )
35
+ # Maximum context length for the model
36
+ self.max_length = self.tokenizer.model_max_length
37
+ logger.info(f"Initialized LLM retriever with model: {model_name}")
38
+
39
+ def _extract_and_clean_text(self, pdf_file: str) -> str:
40
+ """
41
+ Extract and clean text from PDF file
42
+
43
+ Args:
44
+ pdf_file: Path to PDF file
45
+
46
+ Returns:
47
+ Cleaned text from PDF
48
+ """
49
+ # Get PDF file as binary
50
+ with open(pdf_file, mode="rb") as f:
51
+ pdf_file_bytes = f.read()
52
+
53
+ # Extract text from the PDF
54
+ pdf_doc = fitz.open(stream=pdf_file_bytes, filetype="pdf")
55
+ pdf_text = ""
56
+ for page_num in range(pdf_doc.page_count):
57
+ page = pdf_doc.load_page(page_num)
58
+ pdf_text += page.get_text("text")
59
+
60
+ # Clean text
61
+ # Remove hyphens at end of lines
62
+ clean_text = re.sub("-\n", "", pdf_text)
63
+ # Replace remaining newline characters with space
64
+ clean_text = re.sub("\n", " ", clean_text)
65
+ # Replace unicode with ascii
66
+ clean_text = unidecode(clean_text)
67
+
68
+ return clean_text
69
+
70
+ def _chunk_text(self, text: str, max_chunk_size: int = 3000) -> List[str]:
71
+ """
72
+ Split text into chunks that fit within model context window
73
+
74
+ Args:
75
+ text: Input text to chunk
76
+ max_chunk_size: Maximum size per chunk
77
+
78
+ Returns:
79
+ List of text chunks
80
+ """
81
+ sentences = sent_tokenize(text)
82
+ chunks = []
83
+ current_chunk = ""
84
+
85
+ for sentence in sentences:
86
+ # Check if adding this sentence would exceed the limit
87
+ if len(current_chunk) + len(sentence) + 1 <= max_chunk_size:
88
+ current_chunk += " " + sentence if current_chunk else sentence
89
+ else:
90
+ if current_chunk:
91
+ chunks.append(current_chunk.strip())
92
+ current_chunk = sentence
93
+
94
+ # Add the last chunk
95
+ if current_chunk:
96
+ chunks.append(current_chunk.strip())
97
+
98
+ return chunks
99
+
100
+ def _format_claim_as_question(self, claim: str) -> str:
101
+ """
102
+ Convert a claim into a question format for better QA performance
103
+
104
+ Args:
105
+ claim: Input claim
106
+
107
+ Returns:
108
+ Question formatted for QA model
109
+ """
110
+ # Simple heuristics to convert claims to questions
111
+ claim = claim.strip()
112
+
113
+ # If already a question, return as is
114
+ if claim.endswith("?"):
115
+ return claim
116
+
117
+ # Convert common claim patterns to questions
118
+ if claim.lower().startswith(("the ", "a ", "an ")):
119
+ return f"What evidence supports that {claim.lower()}?"
120
+ elif "is" in claim.lower() or "are" in claim.lower():
121
+ return f"Is it true that {claim.lower()}?"
122
+ elif "can" in claim.lower() or "could" in claim.lower():
123
+ return f"{claim}?"
124
+ else:
125
+ return f"What evidence supports the claim that {claim.lower()}?"
126
+
127
+ def retrieve_evidence(self, pdf_file: str, claim: str, k: int = 5) -> str:
128
+ """
129
+ Retrieve evidence from PDF using LLM-based question answering
130
+
131
+ Args:
132
+ pdf_file: Path to PDF file
133
+ claim: Claim to find evidence for
134
+ k: Number of evidence passages to retrieve
135
+
136
+ Returns:
137
+ Combined evidence text
138
+ """
139
+ try:
140
+ # Extract and clean text from PDF
141
+ clean_text = self._extract_and_clean_text(pdf_file)
142
+
143
+ # Convert claim to question format
144
+ question = self._format_claim_as_question(claim)
145
+
146
+ # Split text into manageable chunks
147
+ chunks = self._chunk_text(clean_text)
148
+
149
+ # Get answers from each chunk
150
+ answers = []
151
+ for i, chunk in enumerate(chunks):
152
+ try:
153
+ result = self.qa_pipeline(
154
+ question=question, context=chunk, max_answer_len=200, top_k=1
155
+ )
156
+
157
+ # Handle both single answer and list of answers
158
+ if isinstance(result, list):
159
+ result = result[0]
160
+
161
+ if result["score"] > 0.1: # Confidence threshold
162
+ # Extract surrounding context for better evidence
163
+ answer_text = result["answer"]
164
+ start_idx = max(0, chunk.find(answer_text) - 100)
165
+ end_idx = min(
166
+ len(chunk), chunk.find(answer_text) + len(answer_text) + 100
167
+ )
168
+ context = chunk[start_idx:end_idx].strip()
169
+
170
+ answers.append(
171
+ {"text": context, "score": result["score"], "chunk_idx": i}
172
+ )
173
+
174
+ except Exception as e:
175
+ logger.warning(f"Error processing chunk {i}: {str(e)}")
176
+ continue
177
+
178
+ # Sort by confidence score and take top k
179
+ answers.sort(key=lambda x: x["score"], reverse=True)
180
+ top_answers = answers[:k]
181
+
182
+ # Combine evidence passages
183
+ if top_answers:
184
+ evidence_texts = [answer["text"] for answer in top_answers]
185
+ combined_evidence = " ".join(evidence_texts)
186
+ return combined_evidence
187
+ else:
188
+ logger.warning("No evidence found with sufficient confidence")
189
+ return "No relevant evidence found in the document."
190
+
191
+ except Exception as e:
192
+ logger.error(f"Error in LLM evidence retrieval: {str(e)}")
193
+ return f"Error retrieving evidence: {str(e)}"
194
+
195
+
196
+ def retrieve_from_pdf_llm(pdf_file: str, query: str, k: int = 5) -> str:
197
+ """
198
+ Wrapper function for LLM-based evidence retrieval
199
+ Compatible with the existing BM25S interface
200
+
201
+ Args:
202
+ pdf_file: Path to PDF file
203
+ query: Query/claim to find evidence for
204
+ k: Number of evidence passages to retrieve
205
+
206
+ Returns:
207
+ Retrieved evidence text
208
+ """
209
+ # Initialize retriever (in production, this should be cached)
210
+ retriever = LLMEvidenceRetriever()
211
+ return retriever.retrieve_evidence(pdf_file, query, k)
212
+
213
+
214
+ # Alternative lightweight model for faster inference
215
+ class LightweightLLMRetriever(LLMEvidenceRetriever):
216
+ """
217
+ Lightweight version using smaller, faster models
218
+ """
219
+
220
+ def __init__(self):
221
+ super().__init__(model_name="distilbert-base-cased-distilled-squad")
222
+
223
+
224
+ def retrieve_from_pdf_llm_fast(pdf_file: str, query: str, k: int = 5) -> str:
225
+ """
226
+ Fast LLM-based evidence retrieval using lightweight model
227
+
228
+ Args:
229
+ pdf_file: Path to PDF file
230
+ query: Query/claim to find evidence for
231
+ k: Number of evidence passages to retrieve
232
+
233
+ Returns:
234
+ Retrieved evidence text
235
+ """
236
+ retriever = LightweightLLMRetriever()
237
+ return retriever.retrieve_evidence(pdf_file, query, k)