FBChatBot / app /reranker.py
VietCat's picture
fix racing issues when sending message
4032184
from typing import List, Dict
from .config import get_settings
from .gemini_client import GeminiClient
from loguru import logger
import asyncio
import hashlib
import time
# from .constants import BATCH_STATUS_MESSAGES
# from .utils import get_random_message
class Reranker:
def __init__(self):
settings = get_settings()
self.provider = getattr(settings, 'rerank_provider', settings.llm_provider)
self.model = getattr(settings, 'rerank_model', settings.llm_model)
if self.provider == 'gemini':
self.client = GeminiClient()
# elif self.provider == 'openai':
# self.client = OpenAIClient(settings.openai_api_key, model=self.model)
# elif self.provider == 'cohere':
# self.client = CohereClient(settings.cohere_api_key, model=self.model)
else:
raise NotImplementedError(f"Rerank provider {self.provider} not supported yet.")
# Cải thiện cache với TTL và quản lý memory
self._rerank_cache = {}
self._cache_ttl = 3600 # 1 giờ
self._max_cache_size = 200 # Tăng cache size
self._cache_timestamps = {}
# Sử dụng max_docs_to_rerank từ config
self.max_docs_to_rerank = settings.max_docs_to_rerank
def _get_cache_key(self, query: str, docs: List[Dict]) -> str:
"""Tạo cache key từ query và docs."""
# Tối ưu hóa cache key generation
query_normalized = query.lower().strip()
doc_ids = [str(doc.get('id', '')) for doc in docs[:15]] # Chỉ cache top 15 docs
cache_content = query_normalized + "|".join(sorted(doc_ids))
return hashlib.md5(cache_content.encode()).hexdigest()
def _clean_cache(self):
"""Dọn dẹp cache cũ và quản lý memory."""
current_time = time.time()
# Xóa cache entries đã hết hạn
expired_keys = [
key for key, timestamp in self._cache_timestamps.items()
if current_time - timestamp > self._cache_ttl
]
for key in expired_keys:
del self._rerank_cache[key]
del self._cache_timestamps[key]
# Nếu cache vẫn quá lớn, xóa entries cũ nhất
if len(self._rerank_cache) > self._max_cache_size:
sorted_keys = sorted(
self._cache_timestamps.keys(),
key=lambda k: self._cache_timestamps[k]
)
# Xóa 20% cache entries cũ nhất
keys_to_remove = sorted_keys[:len(sorted_keys) // 5]
for key in keys_to_remove:
del self._rerank_cache[key]
del self._cache_timestamps[key]
logger.info(f"[RERANK] Cleaned cache: removed {len(keys_to_remove)} old entries")
def _get_cached_result(self, cache_key: str, top_k: int) -> List[Dict]:
"""Lấy kết quả từ cache nếu có và còn hợp lệ."""
if cache_key in self._rerank_cache:
current_time = time.time()
if current_time - self._cache_timestamps.get(cache_key, 0) <= self._cache_ttl:
cached_result = self._rerank_cache[cache_key][:top_k]
logger.info(f"[RERANK] Cache hit for query, returning {len(cached_result)} cached results")
return cached_result
else:
# Cache đã hết hạn, xóa
del self._rerank_cache[cache_key]
del self._cache_timestamps[cache_key]
return []
def _set_cached_result(self, cache_key: str, scored_docs: List[Dict]):
"""Lưu kết quả vào cache."""
self._rerank_cache[cache_key] = scored_docs
self._cache_timestamps[cache_key] = time.time()
# Dọn dẹp cache nếu cần
if len(self._rerank_cache) > self._max_cache_size:
self._clean_cache()
async def _batch_score_docs(self, query: str, docs: List[Dict]) -> List[Dict]:
"""
Score nhiều documents cùng lúc bằng một prompt duy nhất.
Không cắt bớt nội dung luật.
"""
if not docs:
return []
# Không giới hạn content length, giữ nguyên nội dung luật
docs_content = []
for i, doc in enumerate(docs):
tieude = (doc.get('tieude') or '').strip()
noidung = (doc.get('noidung') or '').strip()
content = f"{tieude} {noidung}".strip()
docs_content.append(f"{i+1}. {content}")
batch_prompt = (
f"Đánh giá mức độ liên quan giữa câu hỏi và các đoạn luật sau:\n\n"
f"Câu hỏi: {query}\n\n"
f"Các đoạn luật:\n" + "\n".join(docs_content) + "\n\n"
f"Trả về điểm số từ 0-10 cho từng đoạn, phân cách bằng dấu phẩy.\n"
f"Ví dụ: 8,5,7,3,9"
)
try:
if self.provider == 'gemini':
loop = asyncio.get_event_loop()
logger.info(f"[RERANK] Sending batch prompt to Gemini for {len(docs)} docs")
response = await loop.run_in_executor(None, self.client.generate_text, batch_prompt)
logger.info(f"[RERANK] Got batch scores from Gemini: {response}")
# Cải thiện parsing scores
scores_text = str(response).strip()
scores = []
# Xử lý nhiều format response có thể có
if ',' in scores_text:
score_parts = scores_text.split(',')
elif ' ' in scores_text:
score_parts = scores_text.split()
else:
score_parts = scores_text.replace('.', ',').split(',')
for score_str in score_parts:
try:
clean_score = ''.join(c for c in score_str.strip() if c.isdigit() or c == '.')
if clean_score:
score = float(clean_score)
score = max(0, min(10, score))
scores.append(score)
else:
scores.append(0)
except (ValueError, TypeError):
scores.append(0)
while len(scores) < len(docs):
scores.append(0)
for i, doc in enumerate(docs):
doc['rerank_score'] = scores[i]
logger.info(f"[RERANK] Successfully scored {len(docs)} docs with scores: {scores}")
return docs
else:
raise NotImplementedError(f"Rerank provider {self.provider} not supported yet in batch method.")
except Exception as e:
logger.error(f"[RERANK] Lỗi khi batch score: {e}")
for doc in docs:
doc['rerank_score'] = 0
return docs
async def _score_doc(self, query: str, doc: Dict) -> Dict:
"""
Score một document với query.
Không cắt bớt nội dung luật.
"""
tieude = (doc.get('tieude') or '').strip()
noidung = (doc.get('noidung') or '').strip()
content = f"{tieude} {noidung}".strip()
prompt = (
f"Đánh giá mức độ liên quan:\n"
f"Luật: {content}\n"
f"Hỏi: {query}\n"
f"Điểm (0-10):"
)
try:
if self.provider == 'gemini':
loop = asyncio.get_event_loop()
logger.info(f"[RERANK] Sending individual prompt to Gemini")
score_response = await loop.run_in_executor(None, self.client.generate_text, prompt)
logger.info(f"[RERANK] Got individual score from Gemini: {score_response}")
score_text = str(score_response).strip()
try:
clean_score = ''.join(c for c in score_text if c.isdigit() or c == '.')
if clean_score:
score = float(clean_score)
score = max(0, min(10, score))
else:
score = 0
except (ValueError, TypeError):
score = 0
doc['rerank_score'] = score
return doc
else:
raise NotImplementedError(f"Rerank provider {self.provider} not supported yet in rerank method.")
except Exception as e:
logger.error(f"[RERANK] Lỗi khi tính score: {e} | doc: {doc}")
doc['rerank_score'] = 0
return doc
async def rerank(self, query: str, docs: List[Dict], top_k: int = 5) -> List[Dict]:
"""
Rerank docs theo độ liên quan với query, trả về top_k docs.
Sử dụng batch processing và caching để tối ưu hiệu suất.
"""
logger.info(f"[RERANK] Start rerank for query: {query} | docs: {len(docs)} | top_k: {top_k}")
if not docs:
return []
# Kiểm tra cache trước
cache_key = self._get_cache_key(query, docs)
cached_result = self._get_cached_result(cache_key, top_k)
if cached_result:
return cached_result
# Giới hạn số lượng docs để rerank - chỉ rerank top 15 docs có similarity cao nhất
max_docs_to_rerank = self.max_docs_to_rerank
docs_to_rerank = docs[:max_docs_to_rerank]
logger.info(f"[RERANK] Will rerank {len(docs_to_rerank)} docs (limited to top {max_docs_to_rerank})")
# Sử dụng batch processing thay vì individual scoring
try:
scored = await self._batch_score_docs(query, docs_to_rerank)
logger.info(f"[RERANK] Batch processing completed, scored {len(scored)} docs")
except Exception as e:
logger.error(f"[RERANK] Batch processing failed, falling back to individual scoring: {e}")
# Fallback về individual scoring nếu batch processing thất bại
scored = []
for doc in docs_to_rerank:
try:
scored_doc = await self._score_doc(query, doc)
scored.append(scored_doc)
except Exception as e:
logger.error(f"[RERANK] Error scoring individual doc: {e}")
doc['rerank_score'] = 0
scored.append(doc)
# Sort theo score và trả về top_k
scored = sorted(scored, key=lambda x: x['rerank_score'], reverse=True)
result = scored[:top_k]
# Cache kết quả với system mới
self._set_cached_result(cache_key, scored)
logger.info(f"[RERANK] Top reranked docs: {result}")
return result