File size: 2,230 Bytes
685013e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import dspy
import litellm
import logging
from typing import List, Dict, Any, Optional

logger = logging.getLogger(__name__)

class SyncCustomGeminiDspyLM(dspy.LM):
    def __init__(self, model: str, api_key: str, **kwargs):
        super().__init__(model)
        self.model = model
        self.api_key = api_key
        self.kwargs = kwargs
        self.provider = "custom_sync_gemini_litellm"
        logger.info(f"SyncCustomGeminiDspyLM initialized for model: {self.model}")

    def _prepare_litellm_messages(self, dspy_input: Any) -> List[Dict[str, str]]:
        if isinstance(dspy_input, str):
            return [{"role": "user", "content": dspy_input}]
        elif isinstance(dspy_input, list):
            # Convert DSPy's 'model' role to LiteLLM's 'assistant' role for history
            return [
                {"role": "assistant" if msg.get("role") == "model" else msg.get("role", "user"), "content": msg.get("content", "")}
                for msg in dspy_input
            ]
        else:
            raise TypeError(f"Unsupported dspy_input type: {type(dspy_input)}")

    def __call__(self, prompt: Optional[str] = None, messages: Optional[List[Dict[str, str]]] = None, **kwargs):
        if not prompt and not messages:
            raise ValueError("Either 'prompt' or 'messages' must be provided.")
        
        dspy_input_content = prompt if prompt is not None else messages
        
        try:
            messages_for_litellm = self._prepare_litellm_messages(dspy_input_content)
        except TypeError as e:
            return [f"[ERROR: Message preparation error - {e}]"]

        final_call_kwargs = self.kwargs.copy()
        final_call_kwargs.update(kwargs)

        try:
            response_obj = litellm.completion(
                model=self.model,
                messages=messages_for_litellm,
                api_key=self.api_key,
                **final_call_kwargs,
            )
            completions = [choice.message.content for choice in response_obj.choices if choice.message.content]
            return completions
        except Exception as e:
            logger.error(f"LiteLLM call failed: {e}", exc_info=True)
            return [f"[ERROR: LiteLLM call failed - {e}]"]