Spaces:
Sleeping
Sleeping
cosmetics
Browse files- document_qa/document_qa_engine.py +22 -8
- streamlit_app.py +1 -0
document_qa/document_qa_engine.py
CHANGED
|
@@ -23,7 +23,13 @@ class DocumentQAEngine:
|
|
| 23 |
embeddings_map_from_md5 = {}
|
| 24 |
embeddings_map_to_md5 = {}
|
| 25 |
|
| 26 |
-
def __init__(self,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
self.embedding_function = embedding_function
|
| 28 |
self.llm = llm
|
| 29 |
self.chain = load_qa_chain(llm, chain_type=qa_chain_type)
|
|
@@ -81,14 +87,14 @@ class DocumentQAEngine:
|
|
| 81 |
return self.embeddings_map_from_md5[md5]
|
| 82 |
|
| 83 |
def query_document(self, query: str, doc_id, output_parser=None, context_size=4, extraction_schema=None,
|
| 84 |
-
verbose=False) -> (
|
| 85 |
Any, str):
|
| 86 |
# self.load_embeddings(self.embeddings_root_path)
|
| 87 |
|
| 88 |
if verbose:
|
| 89 |
print(query)
|
| 90 |
|
| 91 |
-
response = self._run_query(doc_id, query, context_size=context_size)
|
| 92 |
response = response['output_text'] if 'output_text' in response else response
|
| 93 |
|
| 94 |
if verbose:
|
|
@@ -138,9 +144,15 @@ class DocumentQAEngine:
|
|
| 138 |
|
| 139 |
return parsed_output
|
| 140 |
|
| 141 |
-
def _run_query(self, doc_id, query, context_size=4):
|
| 142 |
relevant_documents = self._get_context(doc_id, query, context_size)
|
| 143 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 144 |
# return self.chain({"input_documents": relevant_documents, "question": prompt_chat_template}, return_only_outputs=True)
|
| 145 |
|
| 146 |
def _get_context(self, doc_id, query, context_size=4):
|
|
@@ -150,6 +162,7 @@ class DocumentQAEngine:
|
|
| 150 |
return relevant_documents
|
| 151 |
|
| 152 |
def get_all_context_by_document(self, doc_id):
|
|
|
|
| 153 |
db = self.embeddings_dict[doc_id]
|
| 154 |
docs = db.get()
|
| 155 |
return docs['documents']
|
|
@@ -161,6 +174,7 @@ class DocumentQAEngine:
|
|
| 161 |
return relevant_documents
|
| 162 |
|
| 163 |
def get_text_from_document(self, pdf_file_path, chunk_size=-1, perc_overlap=0.1, verbose=False):
|
|
|
|
| 164 |
if verbose:
|
| 165 |
print("File", pdf_file_path)
|
| 166 |
filename = Path(pdf_file_path).stem
|
|
@@ -215,12 +229,11 @@ class DocumentQAEngine:
|
|
| 215 |
self.embeddings_dict[hash] = Chroma.from_texts(texts, embedding=self.embedding_function, metadatas=metadata,
|
| 216 |
collection_name=hash)
|
| 217 |
|
| 218 |
-
|
| 219 |
self.embeddings_root_path = None
|
| 220 |
|
| 221 |
return hash
|
| 222 |
|
| 223 |
-
def create_embeddings(self, pdfs_dir_path: Path):
|
| 224 |
input_files = []
|
| 225 |
for root, dirs, files in os.walk(pdfs_dir_path, followlinks=False):
|
| 226 |
for file_ in files:
|
|
@@ -238,7 +251,8 @@ class DocumentQAEngine:
|
|
| 238 |
print(data_path, "exists. Skipping it ")
|
| 239 |
continue
|
| 240 |
|
| 241 |
-
texts, metadata, ids = self.get_text_from_document(input_file, chunk_size=
|
|
|
|
| 242 |
filename = metadata[0]['filename']
|
| 243 |
|
| 244 |
vector_db_document = Chroma.from_texts(texts,
|
|
|
|
| 23 |
embeddings_map_from_md5 = {}
|
| 24 |
embeddings_map_to_md5 = {}
|
| 25 |
|
| 26 |
+
def __init__(self,
|
| 27 |
+
llm,
|
| 28 |
+
embedding_function,
|
| 29 |
+
qa_chain_type="stuff",
|
| 30 |
+
embeddings_root_path=None,
|
| 31 |
+
grobid_url=None,
|
| 32 |
+
):
|
| 33 |
self.embedding_function = embedding_function
|
| 34 |
self.llm = llm
|
| 35 |
self.chain = load_qa_chain(llm, chain_type=qa_chain_type)
|
|
|
|
| 87 |
return self.embeddings_map_from_md5[md5]
|
| 88 |
|
| 89 |
def query_document(self, query: str, doc_id, output_parser=None, context_size=4, extraction_schema=None,
|
| 90 |
+
verbose=False, memory=None) -> (
|
| 91 |
Any, str):
|
| 92 |
# self.load_embeddings(self.embeddings_root_path)
|
| 93 |
|
| 94 |
if verbose:
|
| 95 |
print(query)
|
| 96 |
|
| 97 |
+
response = self._run_query(doc_id, query, context_size=context_size, memory=memory)
|
| 98 |
response = response['output_text'] if 'output_text' in response else response
|
| 99 |
|
| 100 |
if verbose:
|
|
|
|
| 144 |
|
| 145 |
return parsed_output
|
| 146 |
|
| 147 |
+
def _run_query(self, doc_id, query, memory=None, context_size=4):
|
| 148 |
relevant_documents = self._get_context(doc_id, query, context_size)
|
| 149 |
+
if memory:
|
| 150 |
+
return self.chain.run(input_documents=relevant_documents,
|
| 151 |
+
question=query)
|
| 152 |
+
else:
|
| 153 |
+
return self.chain.run(input_documents=relevant_documents,
|
| 154 |
+
question=query,
|
| 155 |
+
memory=memory)
|
| 156 |
# return self.chain({"input_documents": relevant_documents, "question": prompt_chat_template}, return_only_outputs=True)
|
| 157 |
|
| 158 |
def _get_context(self, doc_id, query, context_size=4):
|
|
|
|
| 162 |
return relevant_documents
|
| 163 |
|
| 164 |
def get_all_context_by_document(self, doc_id):
|
| 165 |
+
"""Return the full context from the document"""
|
| 166 |
db = self.embeddings_dict[doc_id]
|
| 167 |
docs = db.get()
|
| 168 |
return docs['documents']
|
|
|
|
| 174 |
return relevant_documents
|
| 175 |
|
| 176 |
def get_text_from_document(self, pdf_file_path, chunk_size=-1, perc_overlap=0.1, verbose=False):
|
| 177 |
+
"""Extract text from documents using Grobid, if chunk_size is < 0 it keep each paragraph separately"""
|
| 178 |
if verbose:
|
| 179 |
print("File", pdf_file_path)
|
| 180 |
filename = Path(pdf_file_path).stem
|
|
|
|
| 229 |
self.embeddings_dict[hash] = Chroma.from_texts(texts, embedding=self.embedding_function, metadatas=metadata,
|
| 230 |
collection_name=hash)
|
| 231 |
|
|
|
|
| 232 |
self.embeddings_root_path = None
|
| 233 |
|
| 234 |
return hash
|
| 235 |
|
| 236 |
+
def create_embeddings(self, pdfs_dir_path: Path, chunk_size=500, perc_overlap=0.1):
|
| 237 |
input_files = []
|
| 238 |
for root, dirs, files in os.walk(pdfs_dir_path, followlinks=False):
|
| 239 |
for file_ in files:
|
|
|
|
| 251 |
print(data_path, "exists. Skipping it ")
|
| 252 |
continue
|
| 253 |
|
| 254 |
+
texts, metadata, ids = self.get_text_from_document(input_file, chunk_size=chunk_size,
|
| 255 |
+
perc_overlap=perc_overlap)
|
| 256 |
filename = metadata[0]['filename']
|
| 257 |
|
| 258 |
vector_db_document = Chroma.from_texts(texts,
|
streamlit_app.py
CHANGED
|
@@ -97,6 +97,7 @@ def init_qa(model, api_key=None):
|
|
| 97 |
else:
|
| 98 |
st.error("The model was not loaded properly. Try reloading. ")
|
| 99 |
st.stop()
|
|
|
|
| 100 |
|
| 101 |
return DocumentQAEngine(chat, embeddings, grobid_url=os.environ['GROBID_URL'])
|
| 102 |
|
|
|
|
| 97 |
else:
|
| 98 |
st.error("The model was not loaded properly. Try reloading. ")
|
| 99 |
st.stop()
|
| 100 |
+
return
|
| 101 |
|
| 102 |
return DocumentQAEngine(chat, embeddings, grobid_url=os.environ['GROBID_URL'])
|
| 103 |
|