|
from google.generativeai.embedding import embed_content |
|
from google.generativeai.client import configure |
|
from google.generativeai.generative_models import GenerativeModel |
|
from loguru import logger |
|
from .request_limit_manager import RequestLimitManager |
|
from typing import List, Optional |
|
|
|
class GeminiClient: |
|
def __init__(self): |
|
self.limit_manager = RequestLimitManager("gemini") |
|
self._cached_model = None |
|
self._cached_key = None |
|
self._cached_model_instance = None |
|
|
|
def _get_model_instance(self, key: str, model: str): |
|
""" |
|
Cache model instance để tránh recreate mỗi lần. |
|
""" |
|
if (self._cached_key == key and |
|
self._cached_model == model and |
|
self._cached_model_instance is not None): |
|
return self._cached_model_instance |
|
|
|
|
|
configure(api_key=key) |
|
self._cached_model_instance = GenerativeModel(model) |
|
self._cached_key = key |
|
self._cached_model = model |
|
|
|
logger.info(f"[GEMINI] Created new model instance for key={key[:5]}...{key[-5:]} model={model}") |
|
return self._cached_model_instance |
|
|
|
def _clear_cache_if_needed(self, new_key: str, new_model: str): |
|
""" |
|
Chỉ clear cache khi key/model thực sự thay đổi. |
|
""" |
|
if (self._cached_key != new_key or self._cached_model != new_model): |
|
logger.info(f"[GEMINI] Clearing cache due to key/model change: {self._cached_key}->{new_key}, {self._cached_model}->{new_model}") |
|
self._cached_model_instance = None |
|
self._cached_key = None |
|
self._cached_model = None |
|
|
|
def generate_text(self, prompt: str, **kwargs) -> str: |
|
last_error = None |
|
max_retries = 3 |
|
|
|
for attempt in range(max_retries): |
|
try: |
|
|
|
key, model = self.limit_manager.get_current_key_model() |
|
|
|
|
|
_model = self._get_model_instance(key, model) |
|
|
|
response = _model.generate_content(prompt, **kwargs) |
|
self.limit_manager.log_request(key, model, success=True) |
|
|
|
if hasattr(response, 'usage_metadata'): |
|
logger.info(f"[GEMINI][USAGE] Prompt Token Count: {response.usage_metadata.prompt_token_count} - Candidate Token Count: {response.usage_metadata.candidates_token_count} - Total Token Count: {response.usage_metadata.total_token_count}") |
|
|
|
if hasattr(response, 'text'): |
|
logger.info(f"[GEMINI][TEXT_RESPONSE] {response.text}") |
|
return response.text |
|
elif hasattr(response, 'candidates') and response.candidates: |
|
logger.info(f"[GEMINI][CANDIDATES_RESPONSE] {response.candidates[0].content.parts[0].text}") |
|
return response.candidates[0].content.parts[0].text |
|
|
|
logger.info(f"[GEMINI][RAW_RESPONSE] {response}") |
|
return str(response) |
|
|
|
except Exception as e: |
|
import re |
|
msg = str(e) |
|
if "429" in msg or "rate limit" in msg.lower(): |
|
retry_delay = 60 |
|
m = re.search(r'retry_delay.*?seconds: (\d+)', msg) |
|
if m: |
|
retry_delay = int(m.group(1)) |
|
|
|
|
|
self.limit_manager.log_request(key, model, success=False, retry_delay=retry_delay) |
|
|
|
|
|
|
|
|
|
logger.warning(f"[GEMINI] Rate limit hit, will retry with new key/model (attempt {attempt + 1}/{max_retries})") |
|
last_error = e |
|
continue |
|
else: |
|
|
|
logger.error(f"[GEMINI] Error generating text: {e}") |
|
last_error = e |
|
break |
|
|
|
raise last_error or RuntimeError("No available Gemini API key/model") |
|
|
|
def count_tokens(self, prompt: str) -> int: |
|
try: |
|
key, model = self.limit_manager.get_current_key_model() |
|
_model = self._get_model_instance(key, model) |
|
return _model.count_tokens(prompt).total_tokens |
|
except Exception as e: |
|
logger.error(f"[GEMINI] Error counting tokens: {e}") |
|
return 0 |
|
|
|
def create_embedding(self, text: str, model: Optional[str] = None, task_type: str = "retrieval_query") -> list: |
|
last_error = None |
|
max_retries = 3 |
|
|
|
for attempt in range(max_retries): |
|
try: |
|
key, default_model = self.limit_manager.get_current_key_model() |
|
|
|
|
|
use_model = model if model and model.strip() else default_model |
|
|
|
if not use_model: |
|
raise ValueError("No model specified for embedding") |
|
|
|
logger.info(f"[GEMINI][EMBEDDING] Using model={use_model} (requested={model}, default={default_model}), task_type={task_type}") |
|
|
|
configure(api_key=key) |
|
response = embed_content( |
|
model=use_model, |
|
content=text, |
|
task_type=task_type |
|
) |
|
|
|
self.limit_manager.log_request(key, use_model, success=True) |
|
logger.info(f"[GEMINI][EMBEDDING][RAW_RESPONSE] {response['embedding'][:10]} ..... {response['embedding'][-10:]}") |
|
return response['embedding'] |
|
|
|
except Exception as e: |
|
import re |
|
msg = str(e) |
|
if "429" in msg or "rate limit" in msg.lower(): |
|
retry_delay = 60 |
|
m_retry = re.search(r'retry_delay.*?seconds: (\d+)', msg) |
|
if m_retry: |
|
retry_delay = int(m_retry.group(1)) |
|
|
|
|
|
self.limit_manager.log_request(key, use_model, success=False, retry_delay=retry_delay) |
|
|
|
logger.warning(f"[GEMINI] Rate limit hit in embedding, will retry with new key/model (attempt {attempt + 1}/{max_retries})") |
|
last_error = e |
|
continue |
|
else: |
|
logger.error(f"[GEMINI] Error creating embedding: {e}") |
|
last_error = e |
|
break |
|
|
|
raise last_error or RuntimeError("No available Gemini API key/model") |