Spaces:
Running
Running
File size: 5,722 Bytes
60d1d13 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 |
"""
LLM Client Abstraction Layer
Supports multiple LLM providers without hardcoding
"""
from abc import ABC, abstractmethod
from typing import Dict, Any, Optional
import os
# Import configuration utility
from src.utils.config import get_gemini_api_key, get_openai_api_key
class BaseLLMClient(ABC):
"""Abstract base class for LLM clients"""
def __init__(self, **kwargs):
pass
@abstractmethod
def generate(self, prompt: str, **kwargs) -> str:
"""Generate response from prompt"""
pass
@abstractmethod
def is_available(self) -> bool:
"""Check if LLM service is available"""
pass
class GeminiClient(BaseLLMClient):
"""Google Gemini client implementation"""
def __init__(self, api_key: Optional[str] = None, model: str = "gemini-2.0-flash"):
self.api_key = api_key or get_gemini_api_key()
self.model = model
if not self.api_key:
raise ValueError("Gemini API key not provided")
try:
import google.generativeai as genai
genai.configure(api_key=self.api_key)
self.client = genai.GenerativeModel(self.model)
print(f"✅ Gemini client initialized with model: {self.model}")
except ImportError:
raise ImportError("google-generativeai package not installed")
def generate(self, prompt: str, **kwargs) -> str:
"""Generate response using Gemini"""
try:
# Set default temperature to 0.1 for consistency
generation_config = {
"temperature": kwargs.get("temperature", 0.1),
"top_p": kwargs.get("top_p", 0.8),
"top_k": kwargs.get("top_k", 40),
"max_output_tokens": kwargs.get("max_output_tokens", 2048),
}
response = self.client.generate_content(
prompt, generation_config=generation_config
)
return response.text
except Exception as e:
print(f"❌ Gemini generation error: {e}")
raise
def is_available(self) -> bool:
"""Check Gemini availability"""
try:
test_response = self.client.generate_content("Hello")
return bool(test_response.text)
except:
return False
class OpenAIClient(BaseLLMClient):
"""OpenAI client implementation"""
def __init__(self, api_key: Optional[str] = None, model: str = "gpt-4"):
self.api_key = api_key or get_openai_api_key()
self.model = model
if not self.api_key:
raise ValueError("OpenAI API key not provided")
try:
import openai
self.client = openai.OpenAI(api_key=self.api_key)
print(f"✅ OpenAI client initialized with model: {self.model}")
except ImportError:
raise ImportError("openai package not installed")
def generate(self, prompt: str, **kwargs) -> str:
"""Generate response using OpenAI"""
try:
# Set default temperature to 0.1 for consistency
openai_kwargs = {
"temperature": kwargs.get("temperature", 0.1),
"top_p": kwargs.get("top_p", 1.0),
"max_tokens": kwargs.get("max_tokens", 2048),
}
# Remove any Gemini-specific parameters
openai_kwargs.update(
{
k: v
for k, v in kwargs.items()
if k
in [
"temperature",
"top_p",
"max_tokens",
"frequency_penalty",
"presence_penalty",
]
}
)
response = self.client.chat.completions.create(
model=self.model,
messages=[{"role": "user", "content": prompt}],
**openai_kwargs,
)
return response.choices[0].message.content
except Exception as e:
print(f"❌ OpenAI generation error: {e}")
raise
def is_available(self) -> bool:
"""Check OpenAI availability"""
try:
response = self.client.chat.completions.create(
model=self.model,
messages=[{"role": "user", "content": "Hello"}],
max_tokens=5,
)
return bool(response.choices[0].message.content)
except:
return False
class LLMClientFactory:
"""Factory for creating LLM clients"""
SUPPORTED_PROVIDERS = {
"gemini": GeminiClient,
"openai": OpenAIClient,
}
@classmethod
def create_client(self, provider: str = "gemini", **kwargs) -> BaseLLMClient:
"""Create LLM client by provider name"""
if provider not in self.SUPPORTED_PROVIDERS:
raise ValueError(
f"Unsupported provider: {provider}. Supported: {list(self.SUPPORTED_PROVIDERS.keys())}"
)
client_class = self.SUPPORTED_PROVIDERS[provider]
return client_class(**kwargs)
@classmethod
def get_available_providers(cls) -> list:
"""Get list of available providers"""
return list(cls.SUPPORTED_PROVIDERS.keys())
# Usage example
if __name__ == "__main__":
# Test Gemini client
try:
client = LLMClientFactory.create_client("gemini")
if client.is_available():
response = client.generate("Xin chào, bạn có khỏe không?")
print(f"Response: {response}")
else:
print("Gemini not available")
except Exception as e:
print(f"Error: {e}")
|