Yaswanth123 commited on
Commit
685013e
·
verified ·
1 Parent(s): aa03f63

Create dspy_llm_wrapper.py

Browse files
Files changed (1) hide show
  1. dspy_llm_wrapper.py +54 -0
dspy_llm_wrapper.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dspy
2
+ import litellm
3
+ import logging
4
+ from typing import List, Dict, Any, Optional
5
+
6
+ logger = logging.getLogger(__name__)
7
+
8
+ class SyncCustomGeminiDspyLM(dspy.LM):
9
+ def __init__(self, model: str, api_key: str, **kwargs):
10
+ super().__init__(model)
11
+ self.model = model
12
+ self.api_key = api_key
13
+ self.kwargs = kwargs
14
+ self.provider = "custom_sync_gemini_litellm"
15
+ logger.info(f"SyncCustomGeminiDspyLM initialized for model: {self.model}")
16
+
17
+ def _prepare_litellm_messages(self, dspy_input: Any) -> List[Dict[str, str]]:
18
+ if isinstance(dspy_input, str):
19
+ return [{"role": "user", "content": dspy_input}]
20
+ elif isinstance(dspy_input, list):
21
+ # Convert DSPy's 'model' role to LiteLLM's 'assistant' role for history
22
+ return [
23
+ {"role": "assistant" if msg.get("role") == "model" else msg.get("role", "user"), "content": msg.get("content", "")}
24
+ for msg in dspy_input
25
+ ]
26
+ else:
27
+ raise TypeError(f"Unsupported dspy_input type: {type(dspy_input)}")
28
+
29
+ def __call__(self, prompt: Optional[str] = None, messages: Optional[List[Dict[str, str]]] = None, **kwargs):
30
+ if not prompt and not messages:
31
+ raise ValueError("Either 'prompt' or 'messages' must be provided.")
32
+
33
+ dspy_input_content = prompt if prompt is not None else messages
34
+
35
+ try:
36
+ messages_for_litellm = self._prepare_litellm_messages(dspy_input_content)
37
+ except TypeError as e:
38
+ return [f"[ERROR: Message preparation error - {e}]"]
39
+
40
+ final_call_kwargs = self.kwargs.copy()
41
+ final_call_kwargs.update(kwargs)
42
+
43
+ try:
44
+ response_obj = litellm.completion(
45
+ model=self.model,
46
+ messages=messages_for_litellm,
47
+ api_key=self.api_key,
48
+ **final_call_kwargs,
49
+ )
50
+ completions = [choice.message.content for choice in response_obj.choices if choice.message.content]
51
+ return completions
52
+ except Exception as e:
53
+ logger.error(f"LiteLLM call failed: {e}", exc_info=True)
54
+ return [f"[ERROR: LiteLLM call failed - {e}]"]