import logging from typing import List, Dict, Any, Optional import weaviate import weaviate.classes.query as wvc_query from concurrent.futures import ThreadPoolExecutor from langchain_core.retrievers import BaseRetriever from langchain_core.callbacks.manager import CallbackManagerForRetrieverRun from langchain_core.documents import Document from langchain_core.prompts import ChatPromptTemplate from langchain_core.output_parsers import StrOutputParser from utils.process_data import infer_field, infer_entity_type from utils.synonym_map import rewrite_query_with_legal_synonyms import prompt_templete logger = logging.getLogger(__name__) class AdvancedLawRetriever(BaseRetriever): client: weaviate.WeaviateClient collection_name: str llm: Any reranker: Any embeddings_model: Any default_k: int = 5 initial_k: int = 15 # Lấy nhiều ứng viên ban đầu hybrid_search_alpha: float = 0.5 doc_type_boost: float = 0.4 class ConfigDict: arbitrary_types_allowed = True # === CÁC HÀM HELPER === def _extract_searchable_keywords_with_llm(self, question: str) -> List[str]: """Sử dụng LLM để trích xuất các cụm từ khóa tìm kiếm hiệu quả.""" keyword_extraction_prompt = ChatPromptTemplate.from_template(prompt_templete.KEYWORD_EXTRACTION_PROMPT) keyword_chain = keyword_extraction_prompt | self.llm | StrOutputParser() | (lambda text: [k.strip() for k in text.strip().split("\n") if k.strip()]) try: keywords = keyword_chain.invoke({"question": question}) # Luôn bao gồm cả câu hỏi gốc đã được viết lại làm một truy vấn để không mất ngữ cảnh return [question] + keywords except Exception as e: logger.error(f"Failed to extract keywords: {e}") return [question] def _extract_and_build_filters(self, filters_dict: Dict[str, Any]) -> Optional[wvc_query.Filter]: """ CẢI TIẾN: Hàm này CHỈ nhận một dict và xây dựng đối tượng Filter. Nó không còn nhiệm vụ suy luận nữa. """ if not filters_dict: return None filter_conditions = [] for key, value in filters_dict.items(): if value is None: continue # Logic xây dựng Filter if key == "entity_type" and isinstance(value, list) and value: filter_conditions.append(wvc_query.Filter.by_property(key).contains_any(value)) elif isinstance(value, str): filter_conditions.append(wvc_query.Filter.by_property(key).equal(value)) # Thêm các điều kiện khác nếu cần if not filter_conditions: return None return wvc_query.Filter.all_of(filter_conditions) if len(filter_conditions) > 1 else filter_conditions[0] def _perform_hybrid_search(self, query: str, k: int, where_filter: Optional[wvc_query.Filter]) -> List[Document]: # ... (giữ nguyên logic) ... try: collection = self.client.collections.get(self.collection_name) query_vector = self.embeddings_model.embed_query(query) response = collection.query.hybrid(query=query, vector=query_vector, limit=k, alpha=self.hybrid_search_alpha, filters=where_filter, return_metadata=wvc_query.MetadataQuery(score=True)) docs = [Document(page_content=obj.properties.pop('text', ''), metadata={**obj.properties, 'hybrid_score': obj.metadata.score if obj.metadata else 0}) for obj in response.objects] return docs except Exception: return [] # === HÀM CHÍNH _get_relevant_documents === def _get_relevant_documents( self, query: str, *, run_manager: CallbackManagerForRetrieverRun ) -> List[Document]: # ================================================================= # PHASE 0: PREPARATION - Chuẩn bị và làm giàu truy vấn # ================================================================= # 0.1. Đảm bảo an toàn cho input safe_query = str(query) logger.info(f"--- Starting Advanced Retrieval (FINAL) for Original Query: '{safe_query}' ---") # 0.2. Trích xuất thông tin và ý định từ câu hỏi gốc query_info = self._extract_query_info_with_intent(safe_query) inferred_field = query_info.get("base_filters", {}).get("field") preferred_doc_type = query_info.get("preferred_doc_type") # 0.3. "Dịch" câu hỏi sang ngôn ngữ pháp lý bằng từ điển rewritten_query = rewrite_query_with_legal_synonyms(safe_query, field=inferred_field) if safe_query != rewritten_query: logger.info(f"Query after Synonym Rewriting: '{rewritten_query}'") # 0.4. Trích xuất từ khóa "vàng" bằng LLM từ câu hỏi đã được viết lại search_terms = self._extract_searchable_keywords_with_llm(rewritten_query) logger.info(f"Extracted {len(search_terms)} searchable terms: {search_terms}") # 0.5. Xây dựng bộ lọc Weaviate từ thông tin đã trích xuất base_weaviate_filter = self._extract_and_build_filters(query_info["base_filters"]) # ================================================================= # PHASE 1: RETRIEVAL - Truy xuất dữ liệu có fallback # ================================================================= def run_search_tasks(filters: Optional[wvc_query.Filter]) -> List[Document]: """Hàm nội bộ để thực hiện tìm kiếm song song.""" docs = [] with ThreadPoolExecutor(max_workers=len(search_terms) or 1) as executor: futures = [executor.submit(self._perform_hybrid_search, term, self.initial_k, filters) for term in search_terms] for future in futures: try: docs.extend(future.result()) except Exception as e: logger.error(f"A search task failed: {e}") return docs logger.info(f"--- Attempt 1: Searching with inferred filters: {base_weaviate_filter} ---") retrieved_docs = run_search_tasks(base_weaviate_filter) # Lọc trùng lặp unique_docs_dict = {doc.page_content: doc for doc in retrieved_docs if isinstance(doc.page_content, str)} # Cơ chế Fallback if len(unique_docs_dict) < self.default_k and base_weaviate_filter is not None: logger.warning("Initial search yielded few results. Retrying without any filters (fallback)...") fallback_docs = run_search_tasks(None) for doc in fallback_docs: if isinstance(doc.page_content, str) and doc.page_content not in unique_docs_dict: unique_docs_dict[doc.page_content] = doc candidate_docs_list = list(unique_docs_dict.values()) # ================================================================= # PHASE 2: REFINEMENT - Tinh chỉnh, ưu tiên và xếp hạng kết quả # ================================================================= # 2.1. Intent-based Boosting: Tăng điểm dựa trên loại văn bản ưu tiên final_candidates_for_rerank = candidate_docs_list if preferred_doc_type: logger.info(f"Applying INTENT-BASED BOOST for preferred type: '{preferred_doc_type}'") docs_with_scores = [] for doc in candidate_docs_list: score = doc.metadata.get('hybrid_score', 0.5) if doc.metadata.get("loai_van_ban") == preferred_doc_type: score += self.doc_type_boost else: score -= 0.05 # Giảm nhẹ điểm của các loại không ưu tiên docs_with_scores.append((doc, score)) docs_with_scores.sort(key=lambda x: x[1], reverse=True) final_candidates_for_rerank = [doc for doc, score in docs_with_scores] logger.info(f"Found {len(final_candidates_for_rerank)} candidates for re-ranking.") if not final_candidates_for_rerank: return [] # 2.2. Cross-Encoder Re-ranking với Structured Context logger.info("Applying Cross-Encoder re-ranking with STRUCTURED CONTEXT...") docs_for_reranking = [] for doc in final_candidates_for_rerank: # Tạo chuỗi context giàu thông tin structured_content = ( f"Loại văn bản: {doc.metadata.get('loai_van_ban', 'N/A')}. " f"Lĩnh vực: {doc.metadata.get('field', 'N/A')}. " f"Đối tượng: {doc.metadata.get('entity_type', 'N/A')}.\n" f"Nội dung trích từ {doc.metadata.get('title', 'N/A')}: {doc.page_content}" ) docs_for_reranking.append({"original_doc": doc, "structured_content": structured_content}) contents_to_rank = [item["structured_content"] for item in docs_for_reranking] try: # Sử dụng câu hỏi đã được viết lại để có ngữ cảnh tốt nhất ranked_results_info = self.reranker.rank(rewritten_query, contents_to_rank, return_documents=False, top_k=self.default_k * 2) # Lấy nhiều hơn một chút except Exception as e: logger.error(f"Failed to re-rank with custom structured content: {e}. Falling back to default re-ranking.") # Fallback về cách re-rank mặc định nếu có lỗi reranked_docs = self.reranker.compress_documents(final_candidates_for_rerank, rewritten_query) return reranked_docs[:self.default_k] # Lấy lại các Document gốc theo thứ tự đã được re-rank final_reranked_docs = [] for rank_info in ranked_results_info: original_doc = docs_for_reranking[rank_info['corpus_id']]["original_doc"] original_doc.metadata['rerank_score'] = rank_info['score'] final_reranked_docs.append(original_doc) # 2.3. Log và Trả về kết quả cuối cùng logger.info(f"--- Re-ranked down to {len(final_reranked_docs)} documents. Final results: ---") for i, doc in enumerate(final_reranked_docs[:self.default_k]): score_str = f"{doc.metadata.get('rerank_score', 0.0):.4f}" logger.info(f" - RANK #{i+1} | ReRank Score: {score_str} | Source: {doc.metadata.get('source')}") logger.info(f" CONTENT: {doc.page_content[:400]}...") # Log dài hơn logger.info("-" * 25) return final_reranked_docs[:self.default_k] def _extract_query_info_with_intent(self, query: str) -> Dict[str, Any]: """ Trích xuất filter và xác định ý định của câu hỏi để ưu tiên loại văn bản. """ info = {"base_filters": {}, "preferred_doc_type": None} query_lower = query.lower() # 1. Suy luận field và entity inferred_field = infer_field(query, None) if inferred_field and inferred_field != "khac": info["base_filters"]["field"] = inferred_field inferred_entities = infer_entity_type(query, inferred_field) if inferred_entities: info["base_filters"]["entity_type"] = inferred_entities # 2. XÁC ĐỊNH Ý ĐỊNH -> ƯU TIÊN LOẠI VĂN BẢN # Nếu câu hỏi về MỨC PHẠT, ưu tiên tuyệt đối NGHỊ ĐỊNH if any(kw in query_lower for kw in ["phạt bao nhiêu", "mức xử phạt", "tiền phạt", "xử phạt"]): info["preferred_doc_type"] = "NGHỊ ĐỊNH" logger.info("Intent detected: Sanction/Penalty -> Preferring 'NGHỊ ĐỊNH'.") # Nếu câu hỏi về NGUYÊN TẮC CHUNG, QUYỀN, NGHĨA VỤ, ưu tiên LUẬT elif any(kw in query_lower for kw in ["nguyên tắc", "quyền và nghĩa vụ", "cấm", "được phép", "khái niệm", "định nghĩa"]): info["preferred_doc_type"] = "LUẬT" logger.info("Intent detected: General Rule/Definition -> Preferring 'LUẬT'.") logger.info(f"Extracted query info: {info}") return info