from typing import List, Dict from .config import get_settings from .gemini_client import GeminiClient from loguru import logger import asyncio import hashlib import time # from .constants import BATCH_STATUS_MESSAGES # from .utils import get_random_message class Reranker: def __init__(self): settings = get_settings() self.provider = getattr(settings, 'rerank_provider', settings.llm_provider) self.model = getattr(settings, 'rerank_model', settings.llm_model) if self.provider == 'gemini': self.client = GeminiClient() # elif self.provider == 'openai': # self.client = OpenAIClient(settings.openai_api_key, model=self.model) # elif self.provider == 'cohere': # self.client = CohereClient(settings.cohere_api_key, model=self.model) else: raise NotImplementedError(f"Rerank provider {self.provider} not supported yet.") # Cải thiện cache với TTL và quản lý memory self._rerank_cache = {} self._cache_ttl = 3600 # 1 giờ self._max_cache_size = 200 # Tăng cache size self._cache_timestamps = {} # Sử dụng max_docs_to_rerank từ config self.max_docs_to_rerank = settings.max_docs_to_rerank def _get_cache_key(self, query: str, docs: List[Dict]) -> str: """Tạo cache key từ query và docs.""" # Tối ưu hóa cache key generation query_normalized = query.lower().strip() doc_ids = [str(doc.get('id', '')) for doc in docs[:15]] # Chỉ cache top 15 docs cache_content = query_normalized + "|".join(sorted(doc_ids)) return hashlib.md5(cache_content.encode()).hexdigest() def _clean_cache(self): """Dọn dẹp cache cũ và quản lý memory.""" current_time = time.time() # Xóa cache entries đã hết hạn expired_keys = [ key for key, timestamp in self._cache_timestamps.items() if current_time - timestamp > self._cache_ttl ] for key in expired_keys: del self._rerank_cache[key] del self._cache_timestamps[key] # Nếu cache vẫn quá lớn, xóa entries cũ nhất if len(self._rerank_cache) > self._max_cache_size: sorted_keys = sorted( self._cache_timestamps.keys(), key=lambda k: self._cache_timestamps[k] ) # Xóa 20% cache entries cũ nhất keys_to_remove = sorted_keys[:len(sorted_keys) // 5] for key in keys_to_remove: del self._rerank_cache[key] del self._cache_timestamps[key] logger.info(f"[RERANK] Cleaned cache: removed {len(keys_to_remove)} old entries") def _get_cached_result(self, cache_key: str, top_k: int) -> List[Dict]: """Lấy kết quả từ cache nếu có và còn hợp lệ.""" if cache_key in self._rerank_cache: current_time = time.time() if current_time - self._cache_timestamps.get(cache_key, 0) <= self._cache_ttl: cached_result = self._rerank_cache[cache_key][:top_k] logger.info(f"[RERANK] Cache hit for query, returning {len(cached_result)} cached results") return cached_result else: # Cache đã hết hạn, xóa del self._rerank_cache[cache_key] del self._cache_timestamps[cache_key] return [] def _set_cached_result(self, cache_key: str, scored_docs: List[Dict]): """Lưu kết quả vào cache.""" self._rerank_cache[cache_key] = scored_docs self._cache_timestamps[cache_key] = time.time() # Dọn dẹp cache nếu cần if len(self._rerank_cache) > self._max_cache_size: self._clean_cache() async def _batch_score_docs(self, query: str, docs: List[Dict]) -> List[Dict]: """ Score nhiều documents cùng lúc bằng một prompt duy nhất. Không cắt bớt nội dung luật. """ if not docs: return [] # Không giới hạn content length, giữ nguyên nội dung luật docs_content = [] for i, doc in enumerate(docs): tieude = (doc.get('tieude') or '').strip() noidung = (doc.get('noidung') or '').strip() content = f"{tieude} {noidung}".strip() docs_content.append(f"{i+1}. {content}") batch_prompt = ( f"Đánh giá mức độ liên quan giữa câu hỏi và các đoạn luật sau:\n\n" f"Câu hỏi: {query}\n\n" f"Các đoạn luật:\n" + "\n".join(docs_content) + "\n\n" f"Trả về điểm số từ 0-10 cho từng đoạn, phân cách bằng dấu phẩy.\n" f"Ví dụ: 8,5,7,3,9" ) try: if self.provider == 'gemini': loop = asyncio.get_event_loop() logger.info(f"[RERANK] Sending batch prompt to Gemini for {len(docs)} docs") response = await loop.run_in_executor(None, self.client.generate_text, batch_prompt) logger.info(f"[RERANK] Got batch scores from Gemini: {response}") # Cải thiện parsing scores scores_text = str(response).strip() scores = [] # Xử lý nhiều format response có thể có if ',' in scores_text: score_parts = scores_text.split(',') elif ' ' in scores_text: score_parts = scores_text.split() else: score_parts = scores_text.replace('.', ',').split(',') for score_str in score_parts: try: clean_score = ''.join(c for c in score_str.strip() if c.isdigit() or c == '.') if clean_score: score = float(clean_score) score = max(0, min(10, score)) scores.append(score) else: scores.append(0) except (ValueError, TypeError): scores.append(0) while len(scores) < len(docs): scores.append(0) for i, doc in enumerate(docs): doc['rerank_score'] = scores[i] logger.info(f"[RERANK] Successfully scored {len(docs)} docs with scores: {scores}") return docs else: raise NotImplementedError(f"Rerank provider {self.provider} not supported yet in batch method.") except Exception as e: logger.error(f"[RERANK] Lỗi khi batch score: {e}") for doc in docs: doc['rerank_score'] = 0 return docs async def _score_doc(self, query: str, doc: Dict) -> Dict: """ Score một document với query. Không cắt bớt nội dung luật. """ tieude = (doc.get('tieude') or '').strip() noidung = (doc.get('noidung') or '').strip() content = f"{tieude} {noidung}".strip() prompt = ( f"Đánh giá mức độ liên quan:\n" f"Luật: {content}\n" f"Hỏi: {query}\n" f"Điểm (0-10):" ) try: if self.provider == 'gemini': loop = asyncio.get_event_loop() logger.info(f"[RERANK] Sending individual prompt to Gemini") score_response = await loop.run_in_executor(None, self.client.generate_text, prompt) logger.info(f"[RERANK] Got individual score from Gemini: {score_response}") score_text = str(score_response).strip() try: clean_score = ''.join(c for c in score_text if c.isdigit() or c == '.') if clean_score: score = float(clean_score) score = max(0, min(10, score)) else: score = 0 except (ValueError, TypeError): score = 0 doc['rerank_score'] = score return doc else: raise NotImplementedError(f"Rerank provider {self.provider} not supported yet in rerank method.") except Exception as e: logger.error(f"[RERANK] Lỗi khi tính score: {e} | doc: {doc}") doc['rerank_score'] = 0 return doc async def rerank(self, query: str, docs: List[Dict], top_k: int = 5) -> List[Dict]: """ Rerank docs theo độ liên quan với query, trả về top_k docs. Sử dụng batch processing và caching để tối ưu hiệu suất. """ logger.info(f"[RERANK] Start rerank for query: {query} | docs: {len(docs)} | top_k: {top_k}") if not docs: return [] # Kiểm tra cache trước cache_key = self._get_cache_key(query, docs) cached_result = self._get_cached_result(cache_key, top_k) if cached_result: return cached_result # Giới hạn số lượng docs để rerank - chỉ rerank top 15 docs có similarity cao nhất max_docs_to_rerank = self.max_docs_to_rerank docs_to_rerank = docs[:max_docs_to_rerank] logger.info(f"[RERANK] Will rerank {len(docs_to_rerank)} docs (limited to top {max_docs_to_rerank})") # Sử dụng batch processing thay vì individual scoring try: scored = await self._batch_score_docs(query, docs_to_rerank) logger.info(f"[RERANK] Batch processing completed, scored {len(scored)} docs") except Exception as e: logger.error(f"[RERANK] Batch processing failed, falling back to individual scoring: {e}") # Fallback về individual scoring nếu batch processing thất bại scored = [] for doc in docs_to_rerank: try: scored_doc = await self._score_doc(query, doc) scored.append(scored_doc) except Exception as e: logger.error(f"[RERANK] Error scoring individual doc: {e}") doc['rerank_score'] = 0 scored.append(doc) # Sort theo score và trả về top_k scored = sorted(scored, key=lambda x: x['rerank_score'], reverse=True) result = scored[:top_k] # Cache kết quả với system mới self._set_cached_result(cache_key, scored) logger.info(f"[RERANK] Top reranked docs: {result}") return result