Spaces:
Runtime error
Runtime error
| import logging | |
| from typing import List, Dict, Any, Optional | |
| import weaviate | |
| import weaviate.classes.query as wvc_query | |
| from concurrent.futures import ThreadPoolExecutor | |
| from langchain_core.retrievers import BaseRetriever | |
| from langchain_core.callbacks.manager import CallbackManagerForRetrieverRun | |
| from langchain_core.documents import Document | |
| from langchain_core.prompts import ChatPromptTemplate | |
| from langchain_core.output_parsers import StrOutputParser | |
| from utils.process_data import infer_field, infer_entity_type | |
| from utils.synonym_map import rewrite_query_with_legal_synonyms | |
| import prompt_templete | |
| logger = logging.getLogger(__name__) | |
| class AdvancedLawRetriever(BaseRetriever): | |
| client: weaviate.WeaviateClient | |
| collection_name: str | |
| llm: Any | |
| reranker: Any | |
| embeddings_model: Any | |
| default_k: int = 5 | |
| initial_k: int = 15 # Lấy nhiều ứng viên ban đầu | |
| hybrid_search_alpha: float = 0.5 | |
| doc_type_boost: float = 0.4 | |
| class ConfigDict: | |
| arbitrary_types_allowed = True | |
| # === CÁC HÀM HELPER === | |
| def _extract_searchable_keywords_with_llm(self, question: str) -> List[str]: | |
| """Sử dụng LLM để trích xuất các cụm từ khóa tìm kiếm hiệu quả.""" | |
| keyword_extraction_prompt = ChatPromptTemplate.from_template(prompt_templete.KEYWORD_EXTRACTION_PROMPT) | |
| keyword_chain = keyword_extraction_prompt | self.llm | StrOutputParser() | (lambda text: [k.strip() for k in text.strip().split("\n") if k.strip()]) | |
| try: | |
| keywords = keyword_chain.invoke({"question": question}) | |
| # Luôn bao gồm cả câu hỏi gốc đã được viết lại làm một truy vấn để không mất ngữ cảnh | |
| return [question] + keywords | |
| except Exception as e: | |
| logger.error(f"Failed to extract keywords: {e}") | |
| return [question] | |
| def _extract_and_build_filters(self, filters_dict: Dict[str, Any]) -> Optional[wvc_query.Filter]: | |
| """ | |
| CẢI TIẾN: Hàm này CHỈ nhận một dict và xây dựng đối tượng Filter. | |
| Nó không còn nhiệm vụ suy luận nữa. | |
| """ | |
| if not filters_dict: | |
| return None | |
| filter_conditions = [] | |
| for key, value in filters_dict.items(): | |
| if value is None: | |
| continue | |
| # Logic xây dựng Filter | |
| if key == "entity_type" and isinstance(value, list) and value: | |
| filter_conditions.append(wvc_query.Filter.by_property(key).contains_any(value)) | |
| elif isinstance(value, str): | |
| filter_conditions.append(wvc_query.Filter.by_property(key).equal(value)) | |
| # Thêm các điều kiện khác nếu cần | |
| if not filter_conditions: | |
| return None | |
| return wvc_query.Filter.all_of(filter_conditions) if len(filter_conditions) > 1 else filter_conditions[0] | |
| def _perform_hybrid_search(self, query: str, k: int, where_filter: Optional[wvc_query.Filter]) -> List[Document]: | |
| # ... (giữ nguyên logic) ... | |
| try: | |
| collection = self.client.collections.get(self.collection_name) | |
| query_vector = self.embeddings_model.embed_query(query) | |
| response = collection.query.hybrid(query=query, vector=query_vector, limit=k, alpha=self.hybrid_search_alpha, filters=where_filter, return_metadata=wvc_query.MetadataQuery(score=True)) | |
| docs = [Document(page_content=obj.properties.pop('text', ''), metadata={**obj.properties, 'hybrid_score': obj.metadata.score if obj.metadata else 0}) for obj in response.objects] | |
| return docs | |
| except Exception: return [] | |
| # === HÀM CHÍNH _get_relevant_documents === | |
| def _get_relevant_documents( | |
| self, query: str, *, run_manager: CallbackManagerForRetrieverRun | |
| ) -> List[Document]: | |
| # ================================================================= | |
| # PHASE 0: PREPARATION - Chuẩn bị và làm giàu truy vấn | |
| # ================================================================= | |
| # 0.1. Đảm bảo an toàn cho input | |
| safe_query = str(query) | |
| logger.info(f"--- Starting Advanced Retrieval (FINAL) for Original Query: '{safe_query}' ---") | |
| # 0.2. Trích xuất thông tin và ý định từ câu hỏi gốc | |
| query_info = self._extract_query_info_with_intent(safe_query) | |
| inferred_field = query_info.get("base_filters", {}).get("field") | |
| preferred_doc_type = query_info.get("preferred_doc_type") | |
| # 0.3. "Dịch" câu hỏi sang ngôn ngữ pháp lý bằng từ điển | |
| rewritten_query = rewrite_query_with_legal_synonyms(safe_query, field=inferred_field) | |
| if safe_query != rewritten_query: | |
| logger.info(f"Query after Synonym Rewriting: '{rewritten_query}'") | |
| # 0.4. Trích xuất từ khóa "vàng" bằng LLM từ câu hỏi đã được viết lại | |
| search_terms = self._extract_searchable_keywords_with_llm(rewritten_query) | |
| logger.info(f"Extracted {len(search_terms)} searchable terms: {search_terms}") | |
| # 0.5. Xây dựng bộ lọc Weaviate từ thông tin đã trích xuất | |
| base_weaviate_filter = self._extract_and_build_filters(query_info["base_filters"]) | |
| # ================================================================= | |
| # PHASE 1: RETRIEVAL - Truy xuất dữ liệu có fallback | |
| # ================================================================= | |
| def run_search_tasks(filters: Optional[wvc_query.Filter]) -> List[Document]: | |
| """Hàm nội bộ để thực hiện tìm kiếm song song.""" | |
| docs = [] | |
| with ThreadPoolExecutor(max_workers=len(search_terms) or 1) as executor: | |
| futures = [executor.submit(self._perform_hybrid_search, term, self.initial_k, filters) for term in search_terms] | |
| for future in futures: | |
| try: docs.extend(future.result()) | |
| except Exception as e: logger.error(f"A search task failed: {e}") | |
| return docs | |
| logger.info(f"--- Attempt 1: Searching with inferred filters: {base_weaviate_filter} ---") | |
| retrieved_docs = run_search_tasks(base_weaviate_filter) | |
| # Lọc trùng lặp | |
| unique_docs_dict = {doc.page_content: doc for doc in retrieved_docs if isinstance(doc.page_content, str)} | |
| # Cơ chế Fallback | |
| if len(unique_docs_dict) < self.default_k and base_weaviate_filter is not None: | |
| logger.warning("Initial search yielded few results. Retrying without any filters (fallback)...") | |
| fallback_docs = run_search_tasks(None) | |
| for doc in fallback_docs: | |
| if isinstance(doc.page_content, str) and doc.page_content not in unique_docs_dict: | |
| unique_docs_dict[doc.page_content] = doc | |
| candidate_docs_list = list(unique_docs_dict.values()) | |
| # ================================================================= | |
| # PHASE 2: REFINEMENT - Tinh chỉnh, ưu tiên và xếp hạng kết quả | |
| # ================================================================= | |
| # 2.1. Intent-based Boosting: Tăng điểm dựa trên loại văn bản ưu tiên | |
| final_candidates_for_rerank = candidate_docs_list | |
| if preferred_doc_type: | |
| logger.info(f"Applying INTENT-BASED BOOST for preferred type: '{preferred_doc_type}'") | |
| docs_with_scores = [] | |
| for doc in candidate_docs_list: | |
| score = doc.metadata.get('hybrid_score', 0.5) | |
| if doc.metadata.get("loai_van_ban") == preferred_doc_type: | |
| score += self.doc_type_boost | |
| else: | |
| score -= 0.05 # Giảm nhẹ điểm của các loại không ưu tiên | |
| docs_with_scores.append((doc, score)) | |
| docs_with_scores.sort(key=lambda x: x[1], reverse=True) | |
| final_candidates_for_rerank = [doc for doc, score in docs_with_scores] | |
| logger.info(f"Found {len(final_candidates_for_rerank)} candidates for re-ranking.") | |
| if not final_candidates_for_rerank: return [] | |
| # 2.2. Cross-Encoder Re-ranking với Structured Context | |
| logger.info("Applying Cross-Encoder re-ranking with STRUCTURED CONTEXT...") | |
| docs_for_reranking = [] | |
| for doc in final_candidates_for_rerank: | |
| # Tạo chuỗi context giàu thông tin | |
| structured_content = ( | |
| f"Loại văn bản: {doc.metadata.get('loai_van_ban', 'N/A')}. " | |
| f"Lĩnh vực: {doc.metadata.get('field', 'N/A')}. " | |
| f"Đối tượng: {doc.metadata.get('entity_type', 'N/A')}.\n" | |
| f"Nội dung trích từ {doc.metadata.get('title', 'N/A')}: {doc.page_content}" | |
| ) | |
| docs_for_reranking.append({"original_doc": doc, "structured_content": structured_content}) | |
| contents_to_rank = [item["structured_content"] for item in docs_for_reranking] | |
| try: | |
| # Sử dụng câu hỏi đã được viết lại để có ngữ cảnh tốt nhất | |
| ranked_results_info = self.reranker.rank(rewritten_query, contents_to_rank, return_documents=False, top_k=self.default_k * 2) # Lấy nhiều hơn một chút | |
| except Exception as e: | |
| logger.error(f"Failed to re-rank with custom structured content: {e}. Falling back to default re-ranking.") | |
| # Fallback về cách re-rank mặc định nếu có lỗi | |
| reranked_docs = self.reranker.compress_documents(final_candidates_for_rerank, rewritten_query) | |
| return reranked_docs[:self.default_k] | |
| # Lấy lại các Document gốc theo thứ tự đã được re-rank | |
| final_reranked_docs = [] | |
| for rank_info in ranked_results_info: | |
| original_doc = docs_for_reranking[rank_info['corpus_id']]["original_doc"] | |
| original_doc.metadata['rerank_score'] = rank_info['score'] | |
| final_reranked_docs.append(original_doc) | |
| # 2.3. Log và Trả về kết quả cuối cùng | |
| logger.info(f"--- Re-ranked down to {len(final_reranked_docs)} documents. Final results: ---") | |
| for i, doc in enumerate(final_reranked_docs[:self.default_k]): | |
| score_str = f"{doc.metadata.get('rerank_score', 0.0):.4f}" | |
| logger.info(f" - RANK #{i+1} | ReRank Score: {score_str} | Source: {doc.metadata.get('source')}") | |
| logger.info(f" CONTENT: {doc.page_content[:400]}...") # Log dài hơn | |
| logger.info("-" * 25) | |
| return final_reranked_docs[:self.default_k] | |
| def _extract_query_info_with_intent(self, query: str) -> Dict[str, Any]: | |
| """ | |
| Trích xuất filter và xác định ý định của câu hỏi để ưu tiên loại văn bản. | |
| """ | |
| info = {"base_filters": {}, "preferred_doc_type": None} | |
| query_lower = query.lower() | |
| # 1. Suy luận field và entity | |
| inferred_field = infer_field(query, None) | |
| if inferred_field and inferred_field != "khac": | |
| info["base_filters"]["field"] = inferred_field | |
| inferred_entities = infer_entity_type(query, inferred_field) | |
| if inferred_entities: | |
| info["base_filters"]["entity_type"] = inferred_entities | |
| # 2. XÁC ĐỊNH Ý ĐỊNH -> ƯU TIÊN LOẠI VĂN BẢN | |
| # Nếu câu hỏi về MỨC PHẠT, ưu tiên tuyệt đối NGHỊ ĐỊNH | |
| if any(kw in query_lower for kw in ["phạt bao nhiêu", "mức xử phạt", "tiền phạt", "xử phạt"]): | |
| info["preferred_doc_type"] = "NGHỊ ĐỊNH" | |
| logger.info("Intent detected: Sanction/Penalty -> Preferring 'NGHỊ ĐỊNH'.") | |
| # Nếu câu hỏi về NGUYÊN TẮC CHUNG, QUYỀN, NGHĨA VỤ, ưu tiên LUẬT | |
| elif any(kw in query_lower for kw in ["nguyên tắc", "quyền và nghĩa vụ", "cấm", "được phép", "khái niệm", "định nghĩa"]): | |
| info["preferred_doc_type"] = "LUẬT" | |
| logger.info("Intent detected: General Rule/Definition -> Preferring 'LUẬT'.") | |
| logger.info(f"Extracted query info: {info}") | |
| return info | |