refactor llm/embedding flow
Browse files- app/config.py +7 -0
- app/embedding.py +28 -2
- app/llm.py +17 -34
- requirements.txt +3 -2
app/config.py
CHANGED
@@ -35,6 +35,13 @@ class Settings(BaseSettings):
|
|
35 |
gemini_base_url: str = os.getenv("GEMINI_BASE_URL", "https://generativelanguage.googleapis.com/v1/models/gemini-2.5-flash:generateContent") or ""
|
36 |
gemini_model: str = os.getenv("GEMINI_MODEL", "gemini-2.5-flash") or ""
|
37 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
38 |
class Config:
|
39 |
env_file = ".env"
|
40 |
|
|
|
35 |
gemini_base_url: str = os.getenv("GEMINI_BASE_URL", "https://generativelanguage.googleapis.com/v1/models/gemini-2.5-flash:generateContent") or ""
|
36 |
gemini_model: str = os.getenv("GEMINI_MODEL", "gemini-2.5-flash") or ""
|
37 |
|
38 |
+
# LLM (chat/completion) provider/model
|
39 |
+
llm_provider: str = os.getenv("LLM_PROVIDER", "gemini") or ""
|
40 |
+
llm_model: str = os.getenv("LLM_MODEL", "gemini-1.5-flash-latest") or ""
|
41 |
+
# Embedding provider/model
|
42 |
+
embedding_provider: str = os.getenv("EMBEDDING_PROVIDER", "gemini") or ""
|
43 |
+
embedding_model: str = os.getenv("EMBEDDING_MODEL", "models/embedding-001") or ""
|
44 |
+
|
45 |
class Config:
|
46 |
env_file = ".env"
|
47 |
|
app/embedding.py
CHANGED
@@ -2,8 +2,10 @@ from typing import List
|
|
2 |
import numpy as np
|
3 |
from loguru import logger
|
4 |
import httpx
|
5 |
-
|
6 |
from .utils import timing_decorator_async, timing_decorator_sync, call_endpoint_with_retry
|
|
|
|
|
7 |
|
8 |
class EmbeddingClient:
|
9 |
def __init__(self):
|
@@ -13,14 +15,38 @@ class EmbeddingClient:
|
|
13 |
Output: EmbeddingClient instance.
|
14 |
"""
|
15 |
self._client = httpx.AsyncClient()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
|
17 |
@timing_decorator_async
|
18 |
async def create_embedding(self, text: str) -> List[float]:
|
19 |
"""
|
20 |
-
Tạo embedding vector từ text bằng dịch vụ embedding (ví dụ OpenAI).
|
21 |
Input: text (str)
|
22 |
Output: list[float] embedding vector.
|
23 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
url = "https://vietcat-vietnameseembeddingv2.hf.space/embed"
|
25 |
payload = {"text": text}
|
26 |
try:
|
|
|
2 |
import numpy as np
|
3 |
from loguru import logger
|
4 |
import httpx
|
5 |
+
from .config import get_settings
|
6 |
from .utils import timing_decorator_async, timing_decorator_sync, call_endpoint_with_retry
|
7 |
+
from .llm import LLMClient
|
8 |
+
from .gemini_client import GeminiClient
|
9 |
|
10 |
class EmbeddingClient:
|
11 |
def __init__(self):
|
|
|
15 |
Output: EmbeddingClient instance.
|
16 |
"""
|
17 |
self._client = httpx.AsyncClient()
|
18 |
+
settings = get_settings()
|
19 |
+
self.provider = getattr(settings, 'embedding_provider', 'default')
|
20 |
+
self.model = getattr(settings, 'embedding_model', 'models/embedding-001')
|
21 |
+
if self.provider == 'gemini':
|
22 |
+
self.gemini_client = GeminiClient(settings.gemini_api_key, model=self.model)
|
23 |
+
else:
|
24 |
+
self.gemini_client = None
|
25 |
|
26 |
@timing_decorator_async
|
27 |
async def create_embedding(self, text: str) -> List[float]:
|
28 |
"""
|
29 |
+
Tạo embedding vector từ text bằng dịch vụ embedding (ví dụ OpenAI hoặc Gemini).
|
30 |
Input: text (str)
|
31 |
Output: list[float] embedding vector.
|
32 |
"""
|
33 |
+
if self.provider == 'gemini' and self.gemini_client:
|
34 |
+
try:
|
35 |
+
# GeminiClient.create_embedding là hàm sync, chạy trong executor
|
36 |
+
import asyncio
|
37 |
+
loop = asyncio.get_event_loop()
|
38 |
+
embedding = await loop.run_in_executor(None, self.gemini_client.create_embedding, text)
|
39 |
+
# Kiểm tra kiểu dữ liệu trả về
|
40 |
+
if isinstance(embedding, list):
|
41 |
+
preview = f"{embedding[:10]}...{embedding[-10:]}" if len(embedding) > 20 else str(embedding)
|
42 |
+
logger.info(f"[DEBUG] Embedding API response: {preview}")
|
43 |
+
return embedding
|
44 |
+
else:
|
45 |
+
logger.error(f"[DEBUG] Unknown embedding type: {type(embedding)} - value: {embedding}")
|
46 |
+
raise RuntimeError(f"Embedding returned unexpected type: {type(embedding)}")
|
47 |
+
except Exception as e:
|
48 |
+
logger.error(f"Error creating embedding with Gemini: {e}")
|
49 |
+
raise
|
50 |
url = "https://vietcat-vietnameseembeddingv2.hf.space/embed"
|
51 |
payload = {"text": text}
|
52 |
try:
|
app/llm.py
CHANGED
@@ -4,6 +4,8 @@ import json
|
|
4 |
from loguru import logger
|
5 |
from tenacity import retry, stop_after_attempt, wait_exponential
|
6 |
import os
|
|
|
|
|
7 |
|
8 |
from .utils import timing_decorator_async, timing_decorator_sync, call_endpoint_with_retry
|
9 |
|
@@ -87,10 +89,8 @@ class LLMClient:
|
|
87 |
def _setup_gemini(self, config: Dict[str, Any]):
|
88 |
"""Cấu hình cho Gemini."""
|
89 |
self.api_key = config.get("api_key", "")
|
90 |
-
self.
|
91 |
-
self.
|
92 |
-
self.max_tokens = config.get("max_tokens", 1024)
|
93 |
-
self.temperature = config.get("temperature", 0.7)
|
94 |
|
95 |
@timing_decorator_async
|
96 |
async def generate_text(
|
@@ -205,32 +205,9 @@ class LLMClient:
|
|
205 |
raise RuntimeError("HFS API response is None")
|
206 |
|
207 |
async def _generate_gemini(self, prompt: str, **kwargs) -> str:
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
-
if self.api_key:
|
212 |
-
headers["X-Goog-Api-Key"] = f"{self.api_key}"
|
213 |
-
# Gemini API expects {"contents": [{"parts": [{"text": prompt}]}]}
|
214 |
-
payload = {"contents": [{"parts": [{"text": prompt}]}]}
|
215 |
-
response = await call_endpoint_with_retry(self._client, url, payload, headers=headers)
|
216 |
-
if response is not None and hasattr(response, 'text'):
|
217 |
-
logger.info(f"[LLM][GEMINI][RAW_RESPONSE] {response.text}")
|
218 |
-
else:
|
219 |
-
logger.info(f"[LLM][GEMINI][RAW_RESPONSE] {str(response)}")
|
220 |
-
if response is not None:
|
221 |
-
data = response.json()
|
222 |
-
# Log token usage nếu có
|
223 |
-
usage = data.get('usage') or data.get('usageMetadata')
|
224 |
-
if usage:
|
225 |
-
logger.info(f"[LLM][GEMINI][USAGE] {usage}")
|
226 |
-
# Gemini trả về: {'candidates': [{'content': {'parts': [{'text': '...'}]}}]}
|
227 |
-
try:
|
228 |
-
return data['candidates'][0]['content']['parts'][0]['text']
|
229 |
-
except Exception:
|
230 |
-
return str(data)
|
231 |
-
else:
|
232 |
-
logger.error("Gemini API response is None")
|
233 |
-
raise RuntimeError("Gemini API response is None")
|
234 |
|
235 |
@timing_decorator_async
|
236 |
async def chat(
|
@@ -498,19 +475,25 @@ if __name__ == "__main__":
|
|
498 |
|
499 |
async def test_llm():
|
500 |
# Test với OpenAI
|
501 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
502 |
|
503 |
# Generate text
|
504 |
-
response = await
|
505 |
print(f"Response: {response}")
|
506 |
|
507 |
# Chat
|
508 |
messages = [
|
509 |
{"role": "user", "content": "Bạn có thể giúp tôi không?"}
|
510 |
]
|
511 |
-
chat_response = await
|
512 |
print(f"Chat response: {chat_response}")
|
513 |
|
514 |
-
await
|
515 |
|
516 |
asyncio.run(test_llm())
|
|
|
4 |
from loguru import logger
|
5 |
from tenacity import retry, stop_after_attempt, wait_exponential
|
6 |
import os
|
7 |
+
from .gemini_client import GeminiClient
|
8 |
+
from .config import get_settings
|
9 |
|
10 |
from .utils import timing_decorator_async, timing_decorator_sync, call_endpoint_with_retry
|
11 |
|
|
|
89 |
def _setup_gemini(self, config: Dict[str, Any]):
|
90 |
"""Cấu hình cho Gemini."""
|
91 |
self.api_key = config.get("api_key", "")
|
92 |
+
self.model = config.get("model", "gemini-1.5-flash-latest")
|
93 |
+
self.gemini_client = GeminiClient(self.api_key, self.model)
|
|
|
|
|
94 |
|
95 |
@timing_decorator_async
|
96 |
async def generate_text(
|
|
|
205 |
raise RuntimeError("HFS API response is None")
|
206 |
|
207 |
async def _generate_gemini(self, prompt: str, **kwargs) -> str:
|
208 |
+
import asyncio
|
209 |
+
loop = asyncio.get_event_loop()
|
210 |
+
return await loop.run_in_executor(None, self.gemini_client.generate_text, prompt)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
211 |
|
212 |
@timing_decorator_async
|
213 |
async def chat(
|
|
|
475 |
|
476 |
async def test_llm():
|
477 |
# Test với OpenAI
|
478 |
+
settings = get_settings()
|
479 |
+
llm_client = create_llm_client(
|
480 |
+
provider=settings.llm_provider,
|
481 |
+
model=settings.llm_model,
|
482 |
+
api_key=settings.gemini_api_key,
|
483 |
+
# ... các config khác nếu cần ...
|
484 |
+
)
|
485 |
|
486 |
# Generate text
|
487 |
+
response = await llm_client.generate_text("Xin chào, bạn có khỏe không?")
|
488 |
print(f"Response: {response}")
|
489 |
|
490 |
# Chat
|
491 |
messages = [
|
492 |
{"role": "user", "content": "Bạn có thể giúp tôi không?"}
|
493 |
]
|
494 |
+
chat_response = await llm_client.chat(messages)
|
495 |
print(f"Chat response: {chat_response}")
|
496 |
|
497 |
+
await llm_client.close()
|
498 |
|
499 |
asyncio.run(test_llm())
|
requirements.txt
CHANGED
@@ -2,7 +2,7 @@ fastapi==0.104.1
|
|
2 |
uvicorn==0.24.0
|
3 |
python-dotenv==1.0.0
|
4 |
httpx>=0.24.0,<0.25.0
|
5 |
-
loguru
|
6 |
google-auth==2.23.4
|
7 |
google-auth-oauthlib==1.1.0
|
8 |
google-auth-httplib2==0.1.1
|
@@ -11,4 +11,5 @@ supabase==2.0.3
|
|
11 |
numpy==1.26.2
|
12 |
python-multipart==0.0.6
|
13 |
tenacity==8.2.3
|
14 |
-
pydantic-settings
|
|
|
|
2 |
uvicorn==0.24.0
|
3 |
python-dotenv==1.0.0
|
4 |
httpx>=0.24.0,<0.25.0
|
5 |
+
loguru>=0.7.0
|
6 |
google-auth==2.23.4
|
7 |
google-auth-oauthlib==1.1.0
|
8 |
google-auth-httplib2==0.1.1
|
|
|
11 |
numpy==1.26.2
|
12 |
python-multipart==0.0.6
|
13 |
tenacity==8.2.3
|
14 |
+
pydantic-settings
|
15 |
+
google-generativeai>=0.3.0
|