|
from typing import List, Optional |
|
import numpy as np |
|
from loguru import logger |
|
import httpx |
|
from .config import get_settings |
|
from .utils import timing_decorator_async, timing_decorator_sync, call_endpoint_with_retry |
|
from .llm import LLMClient |
|
from .gemini_client import GeminiClient |
|
|
|
class EmbeddingClient: |
|
def __init__(self): |
|
""" |
|
Khởi tạo EmbeddingClient. |
|
Input: None |
|
Output: EmbeddingClient instance. |
|
""" |
|
self._client = httpx.AsyncClient() |
|
settings = get_settings() |
|
self.provider = getattr(settings, 'embedding_provider', 'default') |
|
self.model = getattr(settings, 'embedding_model', 'models/embedding-001') |
|
self.gemini_client: Optional[GeminiClient] = GeminiClient() if self.provider == 'gemini' else None |
|
|
|
logger.info(f"[EMBEDDING] Initialized with provider={self.provider}, model={self.model}") |
|
|
|
@timing_decorator_async |
|
async def create_embedding(self, text: str, task_type: str = "retrieval_query") -> List[float]: |
|
""" |
|
Tạo embedding vector từ text bằng dịch vụ embedding (ví dụ OpenAI hoặc Gemini). |
|
Input: text (str) |
|
Output: list[float] embedding vector. |
|
""" |
|
if self.provider == 'gemini': |
|
if not self.gemini_client: |
|
raise RuntimeError("GeminiClient is not initialized") |
|
try: |
|
|
|
import asyncio |
|
loop = asyncio.get_event_loop() |
|
gemini_client = self.gemini_client |
|
|
|
|
|
logger.info(f"[EMBEDDING] Creating embedding with model={self.model}, task_type={task_type}") |
|
embedding = await loop.run_in_executor(None, lambda: gemini_client.create_embedding(text, model=self.model, task_type=task_type)) |
|
|
|
|
|
if isinstance(embedding, list): |
|
preview = f"{embedding[:10]}...{embedding[-10:]}" if len(embedding) > 20 else str(embedding) |
|
logger.info(f"[EMBEDDING] API response: {preview}") |
|
return embedding |
|
else: |
|
logger.error(f"[EMBEDDING] Unknown embedding type: {type(embedding)} - value: {embedding}") |
|
raise RuntimeError(f"Embedding returned unexpected type: {type(embedding)}") |
|
except Exception as e: |
|
logger.error(f"[EMBEDDING] Error creating embedding with Gemini: {e}") |
|
raise |
|
|
|
|
|
url = "https://vietcat-vietnameseembeddingv2.hf.space/embed" |
|
payload = {"text": text} |
|
try: |
|
response = await call_endpoint_with_retry(self._client, url, payload) |
|
if response is not None: |
|
data = response.json() |
|
logger.info(f"[EMBEDDING] HuggingFace API response: {data['embedding'][:10]}...{data['embedding'][-10:]}") |
|
return data["embedding"] |
|
else: |
|
logger.error("[EMBEDDING] HuggingFace API response is None") |
|
raise RuntimeError("HuggingFace API response is None") |
|
except Exception as e: |
|
logger.error(f"[EMBEDDING] Error creating embedding with HuggingFace: {e}") |
|
raise |
|
|
|
def cosine_similarity(self, embedding1: List[float], embedding2: List[float]) -> float: |
|
""" |
|
Tính cosine similarity giữa hai embedding. |
|
Input: embedding1 (list[float]), embedding2 (list[float]) |
|
Output: float (giá trị similarity) |
|
""" |
|
try: |
|
a = np.array(embedding1) |
|
b = np.array(embedding2) |
|
return float(np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))) |
|
except Exception as e: |
|
logger.error(f"[EMBEDDING] Error calculating similarity: {e}") |
|
return 0.0 |
|
|
|
def get_embedding_model(self) -> str: |
|
""" |
|
Trả về model được config cho embedding. |
|
Dùng để verify rằng model đúng được sử dụng. |
|
""" |
|
return self.model |