File size: 3,529 Bytes
145e1cf
 
 
 
ac660b1
 
 
5804b96
145e1cf
 
 
 
 
 
 
 
 
 
 
 
 
 
00f72c6
145e1cf
 
 
bf99d54
fe9b596
145e1cf
 
 
 
00f72c6
145e1cf
 
 
bf99d54
ac660b1
145e1cf
 
 
 
00f72c6
145e1cf
 
 
bf99d54
ac660b1
145e1cf
 
 
 
00f72c6
145e1cf
 
 
bf99d54
 
145e1cf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
from smolagents import InferenceClientModel
import os

# Configuration des modèles
ORCHESTRATOR_MODEL = "Qwen/Qwen3-235B-A22B-Instruct-2507"
CODE_AGENT_MODEL = "Qwen/Qwen3-Coder-480B-A35B-Instruct"
VISION_MODEL = "Qwen/Qwen2.5-VL-72B-Instruct"
REASONING_MODEL = "deepseek-ai/DeepSeek-R1-0528"

class ModelManager:
    """Gestionnaire centralisé des modèles"""
    
    def __init__(self, hf_token=None):
        self.hf_token = hf_token or os.getenv("HF_TOKEN")
        self._models = {}
        self._initialize_models()
    
    def _initialize_models(self):
        """Initialise tous les modèles"""
        try:
            # Modèle orchestrateur principal
            self._models['orchestrator'] = InferenceClientModel(
                model_id=ORCHESTRATOR_MODEL,
                token=self.hf_token,
                max_tokens=4096,
                temperature=0.1,
                timeout=240,
                provider="together"
            )
            
            # Modèle pour le code
            self._models['code_agent'] = InferenceClientModel(
                model_id=CODE_AGENT_MODEL,
                token=self.hf_token,
                max_tokens=4096,
                temperature=0.0,
                timeout=240,
                provider="novita"
            )
            
            # Modèle de vision
            self._models['vision'] = InferenceClientModel(
                model_id=VISION_MODEL,
                token=self.hf_token,
                max_tokens=2048,
                temperature=0.1,
                timeout=240,
                provider="novita"
            )
            
            # Modèle de raisonnement
            self._models['reasoning'] = InferenceClientModel(
                model_id=REASONING_MODEL,
                token=self.hf_token,
                max_tokens=8192,
                temperature=0.2,
                timeout=240,
                provider="fireworks-ai"
            )
            
            print("✅ Tous les modèles ont été initialisés avec succès")
            
        except Exception as e:
            print(f"❌ Erreur lors de l'initialisation des modèles: {e}")
            raise
    
    def get_model(self, model_type: str):
        """Récupère un modèle spécifique"""
        if model_type not in self._models:
            raise ValueError(f"Type de modèle inconnu: {model_type}")
        return self._models[model_type]
    
    def get_orchestrator(self):
        """Récupère le modèle orchestrateur"""
        return self.get_model('orchestrator')
    
    def get_code_agent(self):
        """Récupère le modèle de code"""
        return self.get_model('code_agent')
    
    def get_vision_model(self):
        """Récupère le modèle de vision"""
        return self.get_model('vision')
    
    def get_reasoning_model(self):
        """Récupère le modèle de raisonnement"""
        return self.get_model('reasoning')
    
    def test_models(self):
        """Test rapide de tous les modèles"""
        results = {}
        test_prompt = "Hello, can you confirm you're working?"
        
        for model_name, model in self._models.items():
            try:
                response = model(test_prompt, max_tokens=50)
                results[model_name] = "✅ OK"
                print(f"✅ {model_name}: OK")
            except Exception as e:
                results[model_name] = f"❌ Error: {str(e)[:100]}"
                print(f"❌ {model_name}: {e}")
        
        return results