|
import dspy |
|
import litellm |
|
import logging |
|
from typing import List, Dict, Any, Optional |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
class SyncCustomGeminiDspyLM(dspy.LM): |
|
def __init__(self, model: str, api_key: str, **kwargs): |
|
super().__init__(model) |
|
self.model = model |
|
self.api_key = api_key |
|
self.kwargs = kwargs |
|
self.provider = "custom_sync_gemini_litellm" |
|
logger.info(f"SyncCustomGeminiDspyLM initialized for model: {self.model}") |
|
|
|
def _prepare_litellm_messages(self, dspy_input: Any) -> List[Dict[str, str]]: |
|
if isinstance(dspy_input, str): |
|
return [{"role": "user", "content": dspy_input}] |
|
elif isinstance(dspy_input, list): |
|
|
|
return [ |
|
{"role": "assistant" if msg.get("role") == "model" else msg.get("role", "user"), "content": msg.get("content", "")} |
|
for msg in dspy_input |
|
] |
|
else: |
|
raise TypeError(f"Unsupported dspy_input type: {type(dspy_input)}") |
|
|
|
def __call__(self, prompt: Optional[str] = None, messages: Optional[List[Dict[str, str]]] = None, **kwargs): |
|
if not prompt and not messages: |
|
raise ValueError("Either 'prompt' or 'messages' must be provided.") |
|
|
|
dspy_input_content = prompt if prompt is not None else messages |
|
|
|
try: |
|
messages_for_litellm = self._prepare_litellm_messages(dspy_input_content) |
|
except TypeError as e: |
|
return [f"[ERROR: Message preparation error - {e}]"] |
|
|
|
final_call_kwargs = self.kwargs.copy() |
|
final_call_kwargs.update(kwargs) |
|
|
|
try: |
|
response_obj = litellm.completion( |
|
model=self.model, |
|
messages=messages_for_litellm, |
|
api_key=self.api_key, |
|
**final_call_kwargs, |
|
) |
|
completions = [choice.message.content for choice in response_obj.choices if choice.message.content] |
|
return completions |
|
except Exception as e: |
|
logger.error(f"LiteLLM call failed: {e}", exc_info=True) |
|
return [f"[ERROR: LiteLLM call failed - {e}]"] |