Spaces:
				
			
			
	
			
			
		Sleeping
		
	
	
	
			
			
	
	
	
	
		
		
		Sleeping
		
	cosmetics
Browse files- document_qa/document_qa_engine.py +22 -8
- streamlit_app.py +1 -0
    	
        document_qa/document_qa_engine.py
    CHANGED
    
    | @@ -23,7 +23,13 @@ class DocumentQAEngine: | |
| 23 | 
             
                embeddings_map_from_md5 = {}
         | 
| 24 | 
             
                embeddings_map_to_md5 = {}
         | 
| 25 |  | 
| 26 | 
            -
                def __init__(self, | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 27 | 
             
                    self.embedding_function = embedding_function
         | 
| 28 | 
             
                    self.llm = llm
         | 
| 29 | 
             
                    self.chain = load_qa_chain(llm, chain_type=qa_chain_type)
         | 
| @@ -81,14 +87,14 @@ class DocumentQAEngine: | |
| 81 | 
             
                    return self.embeddings_map_from_md5[md5]
         | 
| 82 |  | 
| 83 | 
             
                def query_document(self, query: str, doc_id, output_parser=None, context_size=4, extraction_schema=None,
         | 
| 84 | 
            -
                                   verbose=False) -> (
         | 
| 85 | 
             
                        Any, str):
         | 
| 86 | 
             
                    # self.load_embeddings(self.embeddings_root_path)
         | 
| 87 |  | 
| 88 | 
             
                    if verbose:
         | 
| 89 | 
             
                        print(query)
         | 
| 90 |  | 
| 91 | 
            -
                    response = self._run_query(doc_id, query, context_size=context_size)
         | 
| 92 | 
             
                    response = response['output_text'] if 'output_text' in response else response
         | 
| 93 |  | 
| 94 | 
             
                    if verbose:
         | 
| @@ -138,9 +144,15 @@ class DocumentQAEngine: | |
| 138 |  | 
| 139 | 
             
                    return parsed_output
         | 
| 140 |  | 
| 141 | 
            -
                def _run_query(self, doc_id, query, context_size=4):
         | 
| 142 | 
             
                    relevant_documents = self._get_context(doc_id, query, context_size)
         | 
| 143 | 
            -
                     | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 144 | 
             
                    # return self.chain({"input_documents": relevant_documents, "question": prompt_chat_template}, return_only_outputs=True)
         | 
| 145 |  | 
| 146 | 
             
                def _get_context(self, doc_id, query, context_size=4):
         | 
| @@ -150,6 +162,7 @@ class DocumentQAEngine: | |
| 150 | 
             
                    return relevant_documents
         | 
| 151 |  | 
| 152 | 
             
                def get_all_context_by_document(self, doc_id):
         | 
|  | |
| 153 | 
             
                    db = self.embeddings_dict[doc_id]
         | 
| 154 | 
             
                    docs = db.get()
         | 
| 155 | 
             
                    return docs['documents']
         | 
| @@ -161,6 +174,7 @@ class DocumentQAEngine: | |
| 161 | 
             
                    return relevant_documents
         | 
| 162 |  | 
| 163 | 
             
                def get_text_from_document(self, pdf_file_path, chunk_size=-1, perc_overlap=0.1, verbose=False):
         | 
|  | |
| 164 | 
             
                    if verbose:
         | 
| 165 | 
             
                        print("File", pdf_file_path)
         | 
| 166 | 
             
                    filename = Path(pdf_file_path).stem
         | 
| @@ -215,12 +229,11 @@ class DocumentQAEngine: | |
| 215 | 
             
                        self.embeddings_dict[hash] = Chroma.from_texts(texts, embedding=self.embedding_function, metadatas=metadata,
         | 
| 216 | 
             
                                                                       collection_name=hash)
         | 
| 217 |  | 
| 218 | 
            -
             | 
| 219 | 
             
                    self.embeddings_root_path = None
         | 
| 220 |  | 
| 221 | 
             
                    return hash
         | 
| 222 |  | 
| 223 | 
            -
                def create_embeddings(self, pdfs_dir_path: Path):
         | 
| 224 | 
             
                    input_files = []
         | 
| 225 | 
             
                    for root, dirs, files in os.walk(pdfs_dir_path, followlinks=False):
         | 
| 226 | 
             
                        for file_ in files:
         | 
| @@ -238,7 +251,8 @@ class DocumentQAEngine: | |
| 238 | 
             
                            print(data_path, "exists. Skipping it ")
         | 
| 239 | 
             
                            continue
         | 
| 240 |  | 
| 241 | 
            -
                        texts, metadata, ids = self.get_text_from_document(input_file, chunk_size= | 
|  | |
| 242 | 
             
                        filename = metadata[0]['filename']
         | 
| 243 |  | 
| 244 | 
             
                        vector_db_document = Chroma.from_texts(texts,
         | 
|  | |
| 23 | 
             
                embeddings_map_from_md5 = {}
         | 
| 24 | 
             
                embeddings_map_to_md5 = {}
         | 
| 25 |  | 
| 26 | 
            +
                def __init__(self,
         | 
| 27 | 
            +
                             llm,
         | 
| 28 | 
            +
                             embedding_function,
         | 
| 29 | 
            +
                             qa_chain_type="stuff",
         | 
| 30 | 
            +
                             embeddings_root_path=None,
         | 
| 31 | 
            +
                             grobid_url=None,
         | 
| 32 | 
            +
                             ):
         | 
| 33 | 
             
                    self.embedding_function = embedding_function
         | 
| 34 | 
             
                    self.llm = llm
         | 
| 35 | 
             
                    self.chain = load_qa_chain(llm, chain_type=qa_chain_type)
         | 
|  | |
| 87 | 
             
                    return self.embeddings_map_from_md5[md5]
         | 
| 88 |  | 
| 89 | 
             
                def query_document(self, query: str, doc_id, output_parser=None, context_size=4, extraction_schema=None,
         | 
| 90 | 
            +
                                   verbose=False, memory=None) -> (
         | 
| 91 | 
             
                        Any, str):
         | 
| 92 | 
             
                    # self.load_embeddings(self.embeddings_root_path)
         | 
| 93 |  | 
| 94 | 
             
                    if verbose:
         | 
| 95 | 
             
                        print(query)
         | 
| 96 |  | 
| 97 | 
            +
                    response = self._run_query(doc_id, query, context_size=context_size, memory=memory)
         | 
| 98 | 
             
                    response = response['output_text'] if 'output_text' in response else response
         | 
| 99 |  | 
| 100 | 
             
                    if verbose:
         | 
|  | |
| 144 |  | 
| 145 | 
             
                    return parsed_output
         | 
| 146 |  | 
| 147 | 
            +
                def _run_query(self, doc_id, query, memory=None, context_size=4):
         | 
| 148 | 
             
                    relevant_documents = self._get_context(doc_id, query, context_size)
         | 
| 149 | 
            +
                    if memory:
         | 
| 150 | 
            +
                        return self.chain.run(input_documents=relevant_documents,
         | 
| 151 | 
            +
                                              question=query)
         | 
| 152 | 
            +
                    else:
         | 
| 153 | 
            +
                        return self.chain.run(input_documents=relevant_documents,
         | 
| 154 | 
            +
                                              question=query,
         | 
| 155 | 
            +
                                              memory=memory)
         | 
| 156 | 
             
                    # return self.chain({"input_documents": relevant_documents, "question": prompt_chat_template}, return_only_outputs=True)
         | 
| 157 |  | 
| 158 | 
             
                def _get_context(self, doc_id, query, context_size=4):
         | 
|  | |
| 162 | 
             
                    return relevant_documents
         | 
| 163 |  | 
| 164 | 
             
                def get_all_context_by_document(self, doc_id):
         | 
| 165 | 
            +
                    """Return the full context from the document"""
         | 
| 166 | 
             
                    db = self.embeddings_dict[doc_id]
         | 
| 167 | 
             
                    docs = db.get()
         | 
| 168 | 
             
                    return docs['documents']
         | 
|  | |
| 174 | 
             
                    return relevant_documents
         | 
| 175 |  | 
| 176 | 
             
                def get_text_from_document(self, pdf_file_path, chunk_size=-1, perc_overlap=0.1, verbose=False):
         | 
| 177 | 
            +
                    """Extract text from documents using Grobid, if chunk_size is < 0 it keep each paragraph separately"""
         | 
| 178 | 
             
                    if verbose:
         | 
| 179 | 
             
                        print("File", pdf_file_path)
         | 
| 180 | 
             
                    filename = Path(pdf_file_path).stem
         | 
|  | |
| 229 | 
             
                        self.embeddings_dict[hash] = Chroma.from_texts(texts, embedding=self.embedding_function, metadatas=metadata,
         | 
| 230 | 
             
                                                                       collection_name=hash)
         | 
| 231 |  | 
|  | |
| 232 | 
             
                    self.embeddings_root_path = None
         | 
| 233 |  | 
| 234 | 
             
                    return hash
         | 
| 235 |  | 
| 236 | 
            +
                def create_embeddings(self, pdfs_dir_path: Path, chunk_size=500, perc_overlap=0.1):
         | 
| 237 | 
             
                    input_files = []
         | 
| 238 | 
             
                    for root, dirs, files in os.walk(pdfs_dir_path, followlinks=False):
         | 
| 239 | 
             
                        for file_ in files:
         | 
|  | |
| 251 | 
             
                            print(data_path, "exists. Skipping it ")
         | 
| 252 | 
             
                            continue
         | 
| 253 |  | 
| 254 | 
            +
                        texts, metadata, ids = self.get_text_from_document(input_file, chunk_size=chunk_size,
         | 
| 255 | 
            +
                                                                           perc_overlap=perc_overlap)
         | 
| 256 | 
             
                        filename = metadata[0]['filename']
         | 
| 257 |  | 
| 258 | 
             
                        vector_db_document = Chroma.from_texts(texts,
         | 
    	
        streamlit_app.py
    CHANGED
    
    | @@ -97,6 +97,7 @@ def init_qa(model, api_key=None): | |
| 97 | 
             
                else:
         | 
| 98 | 
             
                    st.error("The model was not loaded properly. Try reloading. ")
         | 
| 99 | 
             
                    st.stop()
         | 
|  | |
| 100 |  | 
| 101 | 
             
                return DocumentQAEngine(chat, embeddings, grobid_url=os.environ['GROBID_URL'])
         | 
| 102 |  | 
|  | |
| 97 | 
             
                else:
         | 
| 98 | 
             
                    st.error("The model was not loaded properly. Try reloading. ")
         | 
| 99 | 
             
                    st.stop()
         | 
| 100 | 
            +
                    return
         | 
| 101 |  | 
| 102 | 
             
                return DocumentQAEngine(chat, embeddings, grobid_url=os.environ['GROBID_URL'])
         | 
| 103 |  | 
