|
from typing import List, Dict |
|
from .config import get_settings |
|
from .gemini_client import GeminiClient |
|
from loguru import logger |
|
import asyncio |
|
import hashlib |
|
import time |
|
|
|
|
|
|
|
class Reranker: |
|
def __init__(self): |
|
settings = get_settings() |
|
self.provider = getattr(settings, 'rerank_provider', settings.llm_provider) |
|
self.model = getattr(settings, 'rerank_model', settings.llm_model) |
|
if self.provider == 'gemini': |
|
self.client = GeminiClient() |
|
|
|
|
|
|
|
|
|
else: |
|
raise NotImplementedError(f"Rerank provider {self.provider} not supported yet.") |
|
|
|
self._rerank_cache = {} |
|
self._cache_ttl = 3600 |
|
self._max_cache_size = 200 |
|
self._cache_timestamps = {} |
|
|
|
self.max_docs_to_rerank = settings.max_docs_to_rerank |
|
|
|
def _get_cache_key(self, query: str, docs: List[Dict]) -> str: |
|
"""Tạo cache key từ query và docs.""" |
|
|
|
query_normalized = query.lower().strip() |
|
doc_ids = [str(doc.get('id', '')) for doc in docs[:15]] |
|
cache_content = query_normalized + "|".join(sorted(doc_ids)) |
|
return hashlib.md5(cache_content.encode()).hexdigest() |
|
|
|
def _clean_cache(self): |
|
"""Dọn dẹp cache cũ và quản lý memory.""" |
|
current_time = time.time() |
|
|
|
|
|
expired_keys = [ |
|
key for key, timestamp in self._cache_timestamps.items() |
|
if current_time - timestamp > self._cache_ttl |
|
] |
|
|
|
for key in expired_keys: |
|
del self._rerank_cache[key] |
|
del self._cache_timestamps[key] |
|
|
|
|
|
if len(self._rerank_cache) > self._max_cache_size: |
|
sorted_keys = sorted( |
|
self._cache_timestamps.keys(), |
|
key=lambda k: self._cache_timestamps[k] |
|
) |
|
|
|
|
|
keys_to_remove = sorted_keys[:len(sorted_keys) // 5] |
|
for key in keys_to_remove: |
|
del self._rerank_cache[key] |
|
del self._cache_timestamps[key] |
|
|
|
logger.info(f"[RERANK] Cleaned cache: removed {len(keys_to_remove)} old entries") |
|
|
|
def _get_cached_result(self, cache_key: str, top_k: int) -> List[Dict]: |
|
"""Lấy kết quả từ cache nếu có và còn hợp lệ.""" |
|
if cache_key in self._rerank_cache: |
|
current_time = time.time() |
|
if current_time - self._cache_timestamps.get(cache_key, 0) <= self._cache_ttl: |
|
cached_result = self._rerank_cache[cache_key][:top_k] |
|
logger.info(f"[RERANK] Cache hit for query, returning {len(cached_result)} cached results") |
|
return cached_result |
|
else: |
|
|
|
del self._rerank_cache[cache_key] |
|
del self._cache_timestamps[cache_key] |
|
|
|
return [] |
|
|
|
def _set_cached_result(self, cache_key: str, scored_docs: List[Dict]): |
|
"""Lưu kết quả vào cache.""" |
|
self._rerank_cache[cache_key] = scored_docs |
|
self._cache_timestamps[cache_key] = time.time() |
|
|
|
|
|
if len(self._rerank_cache) > self._max_cache_size: |
|
self._clean_cache() |
|
|
|
async def _batch_score_docs(self, query: str, docs: List[Dict]) -> List[Dict]: |
|
""" |
|
Score nhiều documents cùng lúc bằng một prompt duy nhất. |
|
Không cắt bớt nội dung luật. |
|
""" |
|
if not docs: |
|
return [] |
|
|
|
|
|
docs_content = [] |
|
for i, doc in enumerate(docs): |
|
tieude = (doc.get('tieude') or '').strip() |
|
noidung = (doc.get('noidung') or '').strip() |
|
content = f"{tieude} {noidung}".strip() |
|
docs_content.append(f"{i+1}. {content}") |
|
|
|
batch_prompt = ( |
|
f"Đánh giá mức độ liên quan giữa câu hỏi và các đoạn luật sau:\n\n" |
|
f"Câu hỏi: {query}\n\n" |
|
f"Các đoạn luật:\n" + "\n".join(docs_content) + "\n\n" |
|
f"Trả về điểm số từ 0-10 cho từng đoạn, phân cách bằng dấu phẩy.\n" |
|
f"Ví dụ: 8,5,7,3,9" |
|
) |
|
|
|
try: |
|
if self.provider == 'gemini': |
|
loop = asyncio.get_event_loop() |
|
logger.info(f"[RERANK] Sending batch prompt to Gemini for {len(docs)} docs") |
|
response = await loop.run_in_executor(None, self.client.generate_text, batch_prompt) |
|
logger.info(f"[RERANK] Got batch scores from Gemini: {response}") |
|
|
|
|
|
scores_text = str(response).strip() |
|
scores = [] |
|
|
|
|
|
if ',' in scores_text: |
|
score_parts = scores_text.split(',') |
|
elif ' ' in scores_text: |
|
score_parts = scores_text.split() |
|
else: |
|
score_parts = scores_text.replace('.', ',').split(',') |
|
|
|
for score_str in score_parts: |
|
try: |
|
clean_score = ''.join(c for c in score_str.strip() if c.isdigit() or c == '.') |
|
if clean_score: |
|
score = float(clean_score) |
|
score = max(0, min(10, score)) |
|
scores.append(score) |
|
else: |
|
scores.append(0) |
|
except (ValueError, TypeError): |
|
scores.append(0) |
|
|
|
while len(scores) < len(docs): |
|
scores.append(0) |
|
|
|
for i, doc in enumerate(docs): |
|
doc['rerank_score'] = scores[i] |
|
|
|
logger.info(f"[RERANK] Successfully scored {len(docs)} docs with scores: {scores}") |
|
return docs |
|
|
|
else: |
|
raise NotImplementedError(f"Rerank provider {self.provider} not supported yet in batch method.") |
|
|
|
except Exception as e: |
|
logger.error(f"[RERANK] Lỗi khi batch score: {e}") |
|
for doc in docs: |
|
doc['rerank_score'] = 0 |
|
return docs |
|
|
|
async def _score_doc(self, query: str, doc: Dict) -> Dict: |
|
""" |
|
Score một document với query. |
|
Không cắt bớt nội dung luật. |
|
""" |
|
tieude = (doc.get('tieude') or '').strip() |
|
noidung = (doc.get('noidung') or '').strip() |
|
content = f"{tieude} {noidung}".strip() |
|
prompt = ( |
|
f"Đánh giá mức độ liên quan:\n" |
|
f"Luật: {content}\n" |
|
f"Hỏi: {query}\n" |
|
f"Điểm (0-10):" |
|
) |
|
try: |
|
if self.provider == 'gemini': |
|
loop = asyncio.get_event_loop() |
|
logger.info(f"[RERANK] Sending individual prompt to Gemini") |
|
score_response = await loop.run_in_executor(None, self.client.generate_text, prompt) |
|
logger.info(f"[RERANK] Got individual score from Gemini: {score_response}") |
|
score_text = str(score_response).strip() |
|
try: |
|
clean_score = ''.join(c for c in score_text if c.isdigit() or c == '.') |
|
if clean_score: |
|
score = float(clean_score) |
|
score = max(0, min(10, score)) |
|
else: |
|
score = 0 |
|
except (ValueError, TypeError): |
|
score = 0 |
|
doc['rerank_score'] = score |
|
return doc |
|
else: |
|
raise NotImplementedError(f"Rerank provider {self.provider} not supported yet in rerank method.") |
|
except Exception as e: |
|
logger.error(f"[RERANK] Lỗi khi tính score: {e} | doc: {doc}") |
|
doc['rerank_score'] = 0 |
|
return doc |
|
|
|
async def rerank(self, query: str, docs: List[Dict], top_k: int = 5) -> List[Dict]: |
|
""" |
|
Rerank docs theo độ liên quan với query, trả về top_k docs. |
|
Sử dụng batch processing và caching để tối ưu hiệu suất. |
|
""" |
|
logger.info(f"[RERANK] Start rerank for query: {query} | docs: {len(docs)} | top_k: {top_k}") |
|
|
|
if not docs: |
|
return [] |
|
|
|
|
|
cache_key = self._get_cache_key(query, docs) |
|
cached_result = self._get_cached_result(cache_key, top_k) |
|
|
|
if cached_result: |
|
return cached_result |
|
|
|
|
|
max_docs_to_rerank = self.max_docs_to_rerank |
|
docs_to_rerank = docs[:max_docs_to_rerank] |
|
logger.info(f"[RERANK] Will rerank {len(docs_to_rerank)} docs (limited to top {max_docs_to_rerank})") |
|
|
|
|
|
try: |
|
scored = await self._batch_score_docs(query, docs_to_rerank) |
|
logger.info(f"[RERANK] Batch processing completed, scored {len(scored)} docs") |
|
except Exception as e: |
|
logger.error(f"[RERANK] Batch processing failed, falling back to individual scoring: {e}") |
|
|
|
scored = [] |
|
for doc in docs_to_rerank: |
|
try: |
|
scored_doc = await self._score_doc(query, doc) |
|
scored.append(scored_doc) |
|
except Exception as e: |
|
logger.error(f"[RERANK] Error scoring individual doc: {e}") |
|
doc['rerank_score'] = 0 |
|
scored.append(doc) |
|
|
|
|
|
scored = sorted(scored, key=lambda x: x['rerank_score'], reverse=True) |
|
result = scored[:top_k] |
|
|
|
|
|
self._set_cached_result(cache_key, scored) |
|
|
|
logger.info(f"[RERANK] Top reranked docs: {result}") |
|
return result |