File size: 1,627 Bytes
18df1d9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
from typing import Any, List, Optional
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import AIMessage, HumanMessage
from langchain_core.outputs import ChatGeneration, ChatResult
import requests

class GrokChatModel(BaseChatModel):
    model: str
    api_key: str
    base_url: str

    def _llm_type(self) -> str:
        return "grok-chat"

    def _default_headers(self):
        return {
            "x-api-key": self.api_key,
            "Content-Type": "application/json",
        }

    def _generate(
        self, messages: List[HumanMessage],
        stop: Optional[List[str]] = None,
        **kwargs: Any
    ) -> ChatResult:
        last_message = messages[-1].content

        payload = {
            "model": self.model,
            "messages": [{"role": "user", "content": last_message}],
            "temperature": 0.7
        }

        response = requests.post(
            self.base_url,
            headers=self._default_headers(),
            json=payload
        )

        if response.status_code != 200:
            raise ValueError(f"Erro na chamada da API da Grok: {response.status_code} - {response.text}")

        result = response.json()
        content = result["choices"][0]["message"]["content"]

        # ✅ Correção: Retornar ChatResult com lista de ChatGeneration
        return ChatResult(
            generations=[ChatGeneration(message=AIMessage(content=content))]
        )

    @property
    def _identifying_params(self) -> dict:
        return {
            "model": self.model,
            "base_url": self.base_url,
        }