Spaces:
Sleeping
Sleeping
add merger that preserve the coordinates and aggregate them meaningfully
Browse files
document_qa/document_qa_engine.py
CHANGED
|
@@ -3,18 +3,89 @@ import os
|
|
| 3 |
from pathlib import Path
|
| 4 |
from typing import Union, Any
|
| 5 |
|
| 6 |
-
|
| 7 |
from grobid_client.grobid_client import GrobidClient
|
| 8 |
-
from langchain.chains import create_extraction_chain
|
| 9 |
from langchain.chains.question_answering import load_qa_chain, stuff_prompt, refine_prompts, map_reduce_prompt, \
|
| 10 |
map_rerank_prompt
|
| 11 |
from langchain.prompts import SystemMessagePromptTemplate, HumanMessagePromptTemplate, ChatPromptTemplate
|
| 12 |
from langchain.retrievers import MultiQueryRetriever
|
| 13 |
from langchain.schema import Document
|
| 14 |
-
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
| 15 |
from langchain.vectorstores import Chroma
|
| 16 |
from tqdm import tqdm
|
| 17 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
|
| 19 |
|
| 20 |
class DocumentQAEngine:
|
|
@@ -44,6 +115,7 @@ class DocumentQAEngine:
|
|
| 44 |
self.llm = llm
|
| 45 |
self.memory = memory
|
| 46 |
self.chain = load_qa_chain(llm, chain_type=qa_chain_type)
|
|
|
|
| 47 |
|
| 48 |
if embeddings_root_path is not None:
|
| 49 |
self.embeddings_root_path = embeddings_root_path
|
|
@@ -157,7 +229,9 @@ class DocumentQAEngine:
|
|
| 157 |
|
| 158 |
def _run_query(self, doc_id, query, context_size=4):
|
| 159 |
relevant_documents = self._get_context(doc_id, query, context_size)
|
| 160 |
-
relevant_document_coordinates = [doc.metadata['coordinates'].split(";") if 'coordinates' in doc.metadata else []
|
|
|
|
|
|
|
| 161 |
response = self.chain.run(input_documents=relevant_documents,
|
| 162 |
question=query)
|
| 163 |
|
|
@@ -196,7 +270,7 @@ class DocumentQAEngine:
|
|
| 196 |
if verbose:
|
| 197 |
print("File", pdf_file_path)
|
| 198 |
filename = Path(pdf_file_path).stem
|
| 199 |
-
coordinates = True if chunk_size == -1 else False
|
| 200 |
structure = self.grobid_processor.process_structure(pdf_file_path, coordinates=coordinates)
|
| 201 |
|
| 202 |
biblio = structure['biblio']
|
|
@@ -209,29 +283,25 @@ class DocumentQAEngine:
|
|
| 209 |
metadatas = []
|
| 210 |
ids = []
|
| 211 |
|
| 212 |
-
if chunk_size
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
texts.append(passage['text'])
|
| 217 |
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
metadatas.append(biblio_copy)
|
| 223 |
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
metadatas = [biblio for _ in range(len(texts))]
|
| 234 |
-
ids = [id for id, t in enumerate(texts)]
|
| 235 |
|
| 236 |
return texts, metadatas, ids
|
| 237 |
|
|
|
|
| 3 |
from pathlib import Path
|
| 4 |
from typing import Union, Any
|
| 5 |
|
| 6 |
+
import tiktoken
|
| 7 |
from grobid_client.grobid_client import GrobidClient
|
| 8 |
+
from langchain.chains import create_extraction_chain
|
| 9 |
from langchain.chains.question_answering import load_qa_chain, stuff_prompt, refine_prompts, map_reduce_prompt, \
|
| 10 |
map_rerank_prompt
|
| 11 |
from langchain.prompts import SystemMessagePromptTemplate, HumanMessagePromptTemplate, ChatPromptTemplate
|
| 12 |
from langchain.retrievers import MultiQueryRetriever
|
| 13 |
from langchain.schema import Document
|
|
|
|
| 14 |
from langchain.vectorstores import Chroma
|
| 15 |
from tqdm import tqdm
|
| 16 |
|
| 17 |
+
from document_qa.grobid_processors import GrobidProcessor
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class TextMerger:
|
| 21 |
+
def __init__(self, model_name=None, encoding_name="gpt2"):
|
| 22 |
+
if model_name is not None:
|
| 23 |
+
self.enc = tiktoken.encoding_for_model(model_name)
|
| 24 |
+
else:
|
| 25 |
+
self.enc = tiktoken.get_encoding(encoding_name)
|
| 26 |
+
|
| 27 |
+
def encode(self, text, allowed_special=set(), disallowed_special="all"):
|
| 28 |
+
return self.enc.encode(
|
| 29 |
+
text,
|
| 30 |
+
allowed_special=allowed_special,
|
| 31 |
+
disallowed_special=disallowed_special,
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
def merge_passages(self, passages, chunk_size, tolerance=0.2):
|
| 35 |
+
new_passages = []
|
| 36 |
+
new_coordinates = []
|
| 37 |
+
current_texts = []
|
| 38 |
+
current_coordinates = []
|
| 39 |
+
for idx, passage in enumerate(passages):
|
| 40 |
+
text = passage['text']
|
| 41 |
+
coordinates = passage['coordinates']
|
| 42 |
+
current_texts.append(text)
|
| 43 |
+
current_coordinates.append(coordinates)
|
| 44 |
+
|
| 45 |
+
accumulated_text = " ".join(current_texts)
|
| 46 |
+
|
| 47 |
+
encoded_accumulated_text = self.encode(accumulated_text)
|
| 48 |
+
|
| 49 |
+
if len(encoded_accumulated_text) > chunk_size + chunk_size * tolerance:
|
| 50 |
+
if len(current_texts) > 1:
|
| 51 |
+
new_passages.append(current_texts[:-1])
|
| 52 |
+
new_coordinates.append(current_coordinates[:-1])
|
| 53 |
+
current_texts = [current_texts[-1]]
|
| 54 |
+
current_coordinates = [current_coordinates[-1]]
|
| 55 |
+
else:
|
| 56 |
+
new_passages.append(current_texts)
|
| 57 |
+
new_coordinates.append(current_coordinates)
|
| 58 |
+
current_texts = []
|
| 59 |
+
current_coordinates = []
|
| 60 |
+
|
| 61 |
+
elif chunk_size <= len(encoded_accumulated_text) < chunk_size + chunk_size * tolerance:
|
| 62 |
+
new_passages.append(current_texts)
|
| 63 |
+
new_coordinates.append(current_coordinates)
|
| 64 |
+
current_texts = []
|
| 65 |
+
current_coordinates = []
|
| 66 |
+
else:
|
| 67 |
+
print("bao")
|
| 68 |
+
|
| 69 |
+
if len(current_texts) > 0:
|
| 70 |
+
new_passages.append(current_texts)
|
| 71 |
+
new_coordinates.append(current_coordinates)
|
| 72 |
+
|
| 73 |
+
new_passages_struct = []
|
| 74 |
+
for i, passages in enumerate(new_passages):
|
| 75 |
+
text = " ".join(passages)
|
| 76 |
+
coordinates = ";".join(new_coordinates[i])
|
| 77 |
+
|
| 78 |
+
new_passages_struct.append(
|
| 79 |
+
{
|
| 80 |
+
"text": text,
|
| 81 |
+
"coordinates": coordinates,
|
| 82 |
+
"type": "aggregated chunks",
|
| 83 |
+
"section": "mixed",
|
| 84 |
+
"subSection": "mixed"
|
| 85 |
+
}
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
return new_passages_struct
|
| 89 |
|
| 90 |
|
| 91 |
class DocumentQAEngine:
|
|
|
|
| 115 |
self.llm = llm
|
| 116 |
self.memory = memory
|
| 117 |
self.chain = load_qa_chain(llm, chain_type=qa_chain_type)
|
| 118 |
+
self.text_merger = TextMerger()
|
| 119 |
|
| 120 |
if embeddings_root_path is not None:
|
| 121 |
self.embeddings_root_path = embeddings_root_path
|
|
|
|
| 229 |
|
| 230 |
def _run_query(self, doc_id, query, context_size=4):
|
| 231 |
relevant_documents = self._get_context(doc_id, query, context_size)
|
| 232 |
+
relevant_document_coordinates = [doc.metadata['coordinates'].split(";") if 'coordinates' in doc.metadata else []
|
| 233 |
+
for doc in
|
| 234 |
+
relevant_documents] # filter(lambda d: d['type'] == "sentence", relevant_documents)]
|
| 235 |
response = self.chain.run(input_documents=relevant_documents,
|
| 236 |
question=query)
|
| 237 |
|
|
|
|
| 270 |
if verbose:
|
| 271 |
print("File", pdf_file_path)
|
| 272 |
filename = Path(pdf_file_path).stem
|
| 273 |
+
coordinates = True # if chunk_size == -1 else False
|
| 274 |
structure = self.grobid_processor.process_structure(pdf_file_path, coordinates=coordinates)
|
| 275 |
|
| 276 |
biblio = structure['biblio']
|
|
|
|
| 283 |
metadatas = []
|
| 284 |
ids = []
|
| 285 |
|
| 286 |
+
if chunk_size > 0:
|
| 287 |
+
new_passages = self.text_merger.merge_passages(structure['passages'], chunk_size=chunk_size)
|
| 288 |
+
else:
|
| 289 |
+
new_passages = structure['passages']
|
|
|
|
| 290 |
|
| 291 |
+
for passage in new_passages:
|
| 292 |
+
biblio_copy = copy.copy(biblio)
|
| 293 |
+
if len(str.strip(passage['text'])) > 0:
|
| 294 |
+
texts.append(passage['text'])
|
|
|
|
| 295 |
|
| 296 |
+
biblio_copy['type'] = passage['type']
|
| 297 |
+
biblio_copy['section'] = passage['section']
|
| 298 |
+
biblio_copy['subSection'] = passage['subSection']
|
| 299 |
+
biblio_copy['coordinates'] = passage['coordinates']
|
| 300 |
+
metadatas.append(biblio_copy)
|
| 301 |
+
|
| 302 |
+
# ids.append(passage['passage_id'])
|
| 303 |
+
|
| 304 |
+
ids = [id for id, t in enumerate(new_passages)]
|
|
|
|
|
|
|
| 305 |
|
| 306 |
return texts, metadatas, ids
|
| 307 |
|
tests/test_document_qa_engine.py
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from document_qa.document_qa_engine import TextMerger
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
def test_merge_passages_small_chunk():
|
| 5 |
+
merger = TextMerger()
|
| 6 |
+
|
| 7 |
+
passages = [
|
| 8 |
+
{
|
| 9 |
+
'text': "The quick brown fox jumps over the tree",
|
| 10 |
+
'coordinates': '1'
|
| 11 |
+
},
|
| 12 |
+
{
|
| 13 |
+
'text': "and went straight into the mouth of a bear.",
|
| 14 |
+
'coordinates': '2'
|
| 15 |
+
},
|
| 16 |
+
{
|
| 17 |
+
'text': "The color of the colors is a color with colors",
|
| 18 |
+
'coordinates': '3'
|
| 19 |
+
},
|
| 20 |
+
{
|
| 21 |
+
'text': "the main colors are not the colorw we show",
|
| 22 |
+
'coordinates': '4'
|
| 23 |
+
}
|
| 24 |
+
]
|
| 25 |
+
new_passages = merger.merge_passages(passages, chunk_size=10, tolerance=0)
|
| 26 |
+
|
| 27 |
+
assert len(new_passages) == 4
|
| 28 |
+
assert new_passages[0]['coordinates'] == "1"
|
| 29 |
+
assert new_passages[0]['text'] == "The quick brown fox jumps over the tree"
|
| 30 |
+
|
| 31 |
+
assert new_passages[1]['coordinates'] == "2"
|
| 32 |
+
assert new_passages[1]['text'] == "and went straight into the mouth of a bear."
|
| 33 |
+
|
| 34 |
+
assert new_passages[2]['coordinates'] == "3"
|
| 35 |
+
assert new_passages[2]['text'] == "The color of the colors is a color with colors"
|
| 36 |
+
|
| 37 |
+
assert new_passages[3]['coordinates'] == "4"
|
| 38 |
+
assert new_passages[3]['text'] == "the main colors are not the colorw we show"
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def test_merge_passages_big_chunk():
|
| 42 |
+
merger = TextMerger()
|
| 43 |
+
|
| 44 |
+
passages = [
|
| 45 |
+
{
|
| 46 |
+
'text': "The quick brown fox jumps over the tree",
|
| 47 |
+
'coordinates': '1'
|
| 48 |
+
},
|
| 49 |
+
{
|
| 50 |
+
'text': "and went straight into the mouth of a bear.",
|
| 51 |
+
'coordinates': '2'
|
| 52 |
+
},
|
| 53 |
+
{
|
| 54 |
+
'text': "The color of the colors is a color with colors",
|
| 55 |
+
'coordinates': '3'
|
| 56 |
+
},
|
| 57 |
+
{
|
| 58 |
+
'text': "the main colors are not the colorw we show",
|
| 59 |
+
'coordinates': '4'
|
| 60 |
+
}
|
| 61 |
+
]
|
| 62 |
+
new_passages = merger.merge_passages(passages, chunk_size=20, tolerance=0)
|
| 63 |
+
|
| 64 |
+
assert len(new_passages) == 2
|
| 65 |
+
assert new_passages[0]['coordinates'] == "1;2"
|
| 66 |
+
assert new_passages[0][
|
| 67 |
+
'text'] == "The quick brown fox jumps over the tree and went straight into the mouth of a bear."
|
| 68 |
+
|
| 69 |
+
assert new_passages[1]['coordinates'] == "3;4"
|
| 70 |
+
assert new_passages[1][
|
| 71 |
+
'text'] == "The color of the colors is a color with colors the main colors are not the colorw we show"
|