File size: 2,911 Bytes
a941b5a
 
 
18df1d9
 
 
 
 
 
a941b5a
 
 
 
18df1d9
 
 
 
a941b5a
18df1d9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
968a78b
 
 
 
a941b5a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18df1d9
a941b5a
 
 
18df1d9
a941b5a
 
 
18df1d9
a941b5a
 
 
 
 
18df1d9
 
a941b5a
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
# custom_grok.py

from typing import Any, List, Optional, Dict
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import AIMessage, HumanMessage
from langchain_core.outputs import ChatGeneration, ChatResult
import requests

class GrokChatModel(BaseChatModel):
    """
    Wrapper customizado e robusto para o modelo GROK da xAI,
    com tratamento aprimorado de timeouts e erros de resposta.
    """
    model: str
    api_key: str
    base_url: str

    @property
    def _llm_type(self) -> str:
        return "grok-chat"

    def _default_headers(self):
        return {
            "x-api-key": self.api_key,
            "Content-Type": "application/json",
        }

    def _generate(
        self, messages: List[HumanMessage],
        stop: Optional[List[str]] = None,
        **kwargs: Any
    ) -> ChatResult:
        last_message = messages[-1].content

        payload = {
            "model": self.model,
            "messages": [{"role": "user", "content": last_message}],
            "temperature": 0.7
        }

         # Adiciona max_tokens ao payload se for fornecido
        if "max_tokens" in kwargs:
            payload["max_tokens"] = kwargs["max_tokens"]       

        try:
            response = requests.post(
                self.base_url,
                headers=self._default_headers(),
                json=payload,
                timeout=300  # Adiciona um timeout de 300 segundos (5 minutos)
            )
            # Lança um erro para status HTTP 4xx ou 5xx
            response.raise_for_status()

            result = response.json()
            
            # Validação robusta da resposta da API
            if not result.get("choices") or not isinstance(result["choices"], list) or len(result["choices"]) == 0:
                raise ValueError("Resposta da API do GROK inválida: campo 'choices' ausente ou vazio.")

            message = result["choices"][0].get("message", {})
            content = message.get("content")

            if not content or not content.strip():
                # Isso captura o caso de uma resposta bem-sucedida, mas com conteúdo vazio.
                raise ValueError("Resposta da API do GROK retornou conteúdo vazio.")

            return ChatResult(
                generations=[ChatGeneration(message=AIMessage(content=content))]
            )

        except requests.exceptions.Timeout:
            raise ValueError("Erro na chamada da API da Grok: Tempo limite excedido (Timeout).")
        except requests.exceptions.RequestException as e:
            # Captura outros erros de conexão (DNS, rede, etc.)
            raise ValueError(f"Erro de conexão com a API da Grok: {e}")

    @property
    def _identifying_params(self) -> Dict[str, Any]:
        """Retorna um dicionário para identificar o modelo."""
        return {"model": self.model, "base_url": self.base_url}