FBChatBot / app /embedding.py
VietCat's picture
fix metadata
aa8cb73
from typing import List, Optional
import numpy as np
from loguru import logger
import httpx
from .config import get_settings
from .utils import timing_decorator_async, timing_decorator_sync, call_endpoint_with_retry
from .llm import LLMClient
from .gemini_client import GeminiClient
class EmbeddingClient:
def __init__(self):
"""
Khởi tạo EmbeddingClient.
Input: None
Output: EmbeddingClient instance.
"""
self._client = httpx.AsyncClient()
settings = get_settings()
self.provider = getattr(settings, 'embedding_provider', 'default')
self.model = getattr(settings, 'embedding_model', 'models/embedding-001')
self.gemini_client: Optional[GeminiClient] = GeminiClient() if self.provider == 'gemini' else None
logger.info(f"[EMBEDDING] Initialized with provider={self.provider}, model={self.model}")
@timing_decorator_async
async def create_embedding(self, text: str, task_type: str = "retrieval_query") -> List[float]:
"""
Tạo embedding vector từ text bằng dịch vụ embedding (ví dụ OpenAI hoặc Gemini).
Input: text (str)
Output: list[float] embedding vector.
"""
if self.provider == 'gemini':
if not self.gemini_client:
raise RuntimeError("GeminiClient is not initialized")
try:
# GeminiClient.create_embedding là hàm sync, chạy trong executor
import asyncio
loop = asyncio.get_event_loop()
gemini_client = self.gemini_client # type: ignore
# Luôn sử dụng model từ config, không phụ thuộc vào key/model từ RequestLimitManager
logger.info(f"[EMBEDDING] Creating embedding with model={self.model}, task_type={task_type}")
embedding = await loop.run_in_executor(None, lambda: gemini_client.create_embedding(text, model=self.model, task_type=task_type))
# Kiểm tra kiểu dữ liệu trả về
if isinstance(embedding, list):
preview = f"{embedding[:10]}...{embedding[-10:]}" if len(embedding) > 20 else str(embedding)
logger.info(f"[EMBEDDING] API response: {preview}")
return embedding
else:
logger.error(f"[EMBEDDING] Unknown embedding type: {type(embedding)} - value: {embedding}")
raise RuntimeError(f"Embedding returned unexpected type: {type(embedding)}")
except Exception as e:
logger.error(f"[EMBEDDING] Error creating embedding with Gemini: {e}")
raise
# Fallback to HuggingFace embedding
url = "https://vietcat-vietnameseembeddingv2.hf.space/embed"
payload = {"text": text}
try:
response = await call_endpoint_with_retry(self._client, url, payload)
if response is not None:
data = response.json()
logger.info(f"[EMBEDDING] HuggingFace API response: {data['embedding'][:10]}...{data['embedding'][-10:]}")
return data["embedding"]
else:
logger.error("[EMBEDDING] HuggingFace API response is None")
raise RuntimeError("HuggingFace API response is None")
except Exception as e:
logger.error(f"[EMBEDDING] Error creating embedding with HuggingFace: {e}")
raise
def cosine_similarity(self, embedding1: List[float], embedding2: List[float]) -> float:
"""
Tính cosine similarity giữa hai embedding.
Input: embedding1 (list[float]), embedding2 (list[float])
Output: float (giá trị similarity)
"""
try:
a = np.array(embedding1)
b = np.array(embedding2)
return float(np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b)))
except Exception as e:
logger.error(f"[EMBEDDING] Error calculating similarity: {e}")
return 0.0
def get_embedding_model(self) -> str:
"""
Trả về model được config cho embedding.
Dùng để verify rằng model đúng được sử dụng.
"""
return self.model