VietCat commited on
Commit
dfd32d8
·
1 Parent(s): 523c69e

refactor llm/embedding flow

Browse files
Files changed (4) hide show
  1. app/config.py +7 -0
  2. app/embedding.py +28 -2
  3. app/llm.py +17 -34
  4. requirements.txt +3 -2
app/config.py CHANGED
@@ -35,6 +35,13 @@ class Settings(BaseSettings):
35
  gemini_base_url: str = os.getenv("GEMINI_BASE_URL", "https://generativelanguage.googleapis.com/v1/models/gemini-2.5-flash:generateContent") or ""
36
  gemini_model: str = os.getenv("GEMINI_MODEL", "gemini-2.5-flash") or ""
37
 
 
 
 
 
 
 
 
38
  class Config:
39
  env_file = ".env"
40
 
 
35
  gemini_base_url: str = os.getenv("GEMINI_BASE_URL", "https://generativelanguage.googleapis.com/v1/models/gemini-2.5-flash:generateContent") or ""
36
  gemini_model: str = os.getenv("GEMINI_MODEL", "gemini-2.5-flash") or ""
37
 
38
+ # LLM (chat/completion) provider/model
39
+ llm_provider: str = os.getenv("LLM_PROVIDER", "gemini") or ""
40
+ llm_model: str = os.getenv("LLM_MODEL", "gemini-1.5-flash-latest") or ""
41
+ # Embedding provider/model
42
+ embedding_provider: str = os.getenv("EMBEDDING_PROVIDER", "gemini") or ""
43
+ embedding_model: str = os.getenv("EMBEDDING_MODEL", "models/embedding-001") or ""
44
+
45
  class Config:
46
  env_file = ".env"
47
 
app/embedding.py CHANGED
@@ -2,8 +2,10 @@ from typing import List
2
  import numpy as np
3
  from loguru import logger
4
  import httpx
5
-
6
  from .utils import timing_decorator_async, timing_decorator_sync, call_endpoint_with_retry
 
 
7
 
8
  class EmbeddingClient:
9
  def __init__(self):
@@ -13,14 +15,38 @@ class EmbeddingClient:
13
  Output: EmbeddingClient instance.
14
  """
15
  self._client = httpx.AsyncClient()
 
 
 
 
 
 
 
16
 
17
  @timing_decorator_async
18
  async def create_embedding(self, text: str) -> List[float]:
19
  """
20
- Tạo embedding vector từ text bằng dịch vụ embedding (ví dụ OpenAI).
21
  Input: text (str)
22
  Output: list[float] embedding vector.
23
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  url = "https://vietcat-vietnameseembeddingv2.hf.space/embed"
25
  payload = {"text": text}
26
  try:
 
2
  import numpy as np
3
  from loguru import logger
4
  import httpx
5
+ from .config import get_settings
6
  from .utils import timing_decorator_async, timing_decorator_sync, call_endpoint_with_retry
7
+ from .llm import LLMClient
8
+ from .gemini_client import GeminiClient
9
 
10
  class EmbeddingClient:
11
  def __init__(self):
 
15
  Output: EmbeddingClient instance.
16
  """
17
  self._client = httpx.AsyncClient()
18
+ settings = get_settings()
19
+ self.provider = getattr(settings, 'embedding_provider', 'default')
20
+ self.model = getattr(settings, 'embedding_model', 'models/embedding-001')
21
+ if self.provider == 'gemini':
22
+ self.gemini_client = GeminiClient(settings.gemini_api_key, model=self.model)
23
+ else:
24
+ self.gemini_client = None
25
 
26
  @timing_decorator_async
27
  async def create_embedding(self, text: str) -> List[float]:
28
  """
29
+ Tạo embedding vector từ text bằng dịch vụ embedding (ví dụ OpenAI hoặc Gemini).
30
  Input: text (str)
31
  Output: list[float] embedding vector.
32
  """
33
+ if self.provider == 'gemini' and self.gemini_client:
34
+ try:
35
+ # GeminiClient.create_embedding là hàm sync, chạy trong executor
36
+ import asyncio
37
+ loop = asyncio.get_event_loop()
38
+ embedding = await loop.run_in_executor(None, self.gemini_client.create_embedding, text)
39
+ # Kiểm tra kiểu dữ liệu trả về
40
+ if isinstance(embedding, list):
41
+ preview = f"{embedding[:10]}...{embedding[-10:]}" if len(embedding) > 20 else str(embedding)
42
+ logger.info(f"[DEBUG] Embedding API response: {preview}")
43
+ return embedding
44
+ else:
45
+ logger.error(f"[DEBUG] Unknown embedding type: {type(embedding)} - value: {embedding}")
46
+ raise RuntimeError(f"Embedding returned unexpected type: {type(embedding)}")
47
+ except Exception as e:
48
+ logger.error(f"Error creating embedding with Gemini: {e}")
49
+ raise
50
  url = "https://vietcat-vietnameseembeddingv2.hf.space/embed"
51
  payload = {"text": text}
52
  try:
app/llm.py CHANGED
@@ -4,6 +4,8 @@ import json
4
  from loguru import logger
5
  from tenacity import retry, stop_after_attempt, wait_exponential
6
  import os
 
 
7
 
8
  from .utils import timing_decorator_async, timing_decorator_sync, call_endpoint_with_retry
9
 
@@ -87,10 +89,8 @@ class LLMClient:
87
  def _setup_gemini(self, config: Dict[str, Any]):
88
  """Cấu hình cho Gemini."""
89
  self.api_key = config.get("api_key", "")
90
- self.base_url = config.get("base_url", "")
91
- self.model = config.get("model", "")
92
- self.max_tokens = config.get("max_tokens", 1024)
93
- self.temperature = config.get("temperature", 0.7)
94
 
95
  @timing_decorator_async
96
  async def generate_text(
@@ -205,32 +205,9 @@ class LLMClient:
205
  raise RuntimeError("HFS API response is None")
206
 
207
  async def _generate_gemini(self, prompt: str, **kwargs) -> str:
208
- """Gọi Gemini API để sinh text từ prompt."""
209
- url = self.base_url
210
- headers = {"Content-Type": "application/json"}
211
- if self.api_key:
212
- headers["X-Goog-Api-Key"] = f"{self.api_key}"
213
- # Gemini API expects {"contents": [{"parts": [{"text": prompt}]}]}
214
- payload = {"contents": [{"parts": [{"text": prompt}]}]}
215
- response = await call_endpoint_with_retry(self._client, url, payload, headers=headers)
216
- if response is not None and hasattr(response, 'text'):
217
- logger.info(f"[LLM][GEMINI][RAW_RESPONSE] {response.text}")
218
- else:
219
- logger.info(f"[LLM][GEMINI][RAW_RESPONSE] {str(response)}")
220
- if response is not None:
221
- data = response.json()
222
- # Log token usage nếu có
223
- usage = data.get('usage') or data.get('usageMetadata')
224
- if usage:
225
- logger.info(f"[LLM][GEMINI][USAGE] {usage}")
226
- # Gemini trả về: {'candidates': [{'content': {'parts': [{'text': '...'}]}}]}
227
- try:
228
- return data['candidates'][0]['content']['parts'][0]['text']
229
- except Exception:
230
- return str(data)
231
- else:
232
- logger.error("Gemini API response is None")
233
- raise RuntimeError("Gemini API response is None")
234
 
235
  @timing_decorator_async
236
  async def chat(
@@ -498,19 +475,25 @@ if __name__ == "__main__":
498
 
499
  async def test_llm():
500
  # Test với OpenAI
501
- llm = create_llm_client("openai", model="gpt-3.5-turbo")
 
 
 
 
 
 
502
 
503
  # Generate text
504
- response = await llm.generate_text("Xin chào, bạn có khỏe không?")
505
  print(f"Response: {response}")
506
 
507
  # Chat
508
  messages = [
509
  {"role": "user", "content": "Bạn có thể giúp tôi không?"}
510
  ]
511
- chat_response = await llm.chat(messages)
512
  print(f"Chat response: {chat_response}")
513
 
514
- await llm.close()
515
 
516
  asyncio.run(test_llm())
 
4
  from loguru import logger
5
  from tenacity import retry, stop_after_attempt, wait_exponential
6
  import os
7
+ from .gemini_client import GeminiClient
8
+ from .config import get_settings
9
 
10
  from .utils import timing_decorator_async, timing_decorator_sync, call_endpoint_with_retry
11
 
 
89
  def _setup_gemini(self, config: Dict[str, Any]):
90
  """Cấu hình cho Gemini."""
91
  self.api_key = config.get("api_key", "")
92
+ self.model = config.get("model", "gemini-1.5-flash-latest")
93
+ self.gemini_client = GeminiClient(self.api_key, self.model)
 
 
94
 
95
  @timing_decorator_async
96
  async def generate_text(
 
205
  raise RuntimeError("HFS API response is None")
206
 
207
  async def _generate_gemini(self, prompt: str, **kwargs) -> str:
208
+ import asyncio
209
+ loop = asyncio.get_event_loop()
210
+ return await loop.run_in_executor(None, self.gemini_client.generate_text, prompt)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
211
 
212
  @timing_decorator_async
213
  async def chat(
 
475
 
476
  async def test_llm():
477
  # Test với OpenAI
478
+ settings = get_settings()
479
+ llm_client = create_llm_client(
480
+ provider=settings.llm_provider,
481
+ model=settings.llm_model,
482
+ api_key=settings.gemini_api_key,
483
+ # ... các config khác nếu cần ...
484
+ )
485
 
486
  # Generate text
487
+ response = await llm_client.generate_text("Xin chào, bạn có khỏe không?")
488
  print(f"Response: {response}")
489
 
490
  # Chat
491
  messages = [
492
  {"role": "user", "content": "Bạn có thể giúp tôi không?"}
493
  ]
494
+ chat_response = await llm_client.chat(messages)
495
  print(f"Chat response: {chat_response}")
496
 
497
+ await llm_client.close()
498
 
499
  asyncio.run(test_llm())
requirements.txt CHANGED
@@ -2,7 +2,7 @@ fastapi==0.104.1
2
  uvicorn==0.24.0
3
  python-dotenv==1.0.0
4
  httpx>=0.24.0,<0.25.0
5
- loguru==0.7.2
6
  google-auth==2.23.4
7
  google-auth-oauthlib==1.1.0
8
  google-auth-httplib2==0.1.1
@@ -11,4 +11,5 @@ supabase==2.0.3
11
  numpy==1.26.2
12
  python-multipart==0.0.6
13
  tenacity==8.2.3
14
- pydantic-settings
 
 
2
  uvicorn==0.24.0
3
  python-dotenv==1.0.0
4
  httpx>=0.24.0,<0.25.0
5
+ loguru>=0.7.0
6
  google-auth==2.23.4
7
  google-auth-oauthlib==1.1.0
8
  google-auth-httplib2==0.1.1
 
11
  numpy==1.26.2
12
  python-multipart==0.0.6
13
  tenacity==8.2.3
14
+ pydantic-settings
15
+ google-generativeai>=0.3.0