Aasher commited on
Commit
6f91075
·
1 Parent(s): 59c556d

refactor: Update AxiomAgent to use Optional types and improve model configuration handling

Browse files
Files changed (1) hide show
  1. src/axiom/agent.py +27 -20
src/axiom/agent.py CHANGED
@@ -1,8 +1,11 @@
 
 
 
1
  from openai import AsyncOpenAI
2
  from openai.types.responses import ResponseTextDeltaEvent
3
 
4
  from agents import (
5
- Agent,
6
  OpenAIChatCompletionsModel,
7
  RunConfig,
8
  Runner,
@@ -13,41 +16,44 @@ from agents.mcp import MCPServer
13
  from .config import settings
14
  from .prompts import AXIOM_AGENT_PROMPT
15
 
16
- from typing_extensions import AsyncGenerator
17
-
18
  class AxiomAgent:
19
  def __init__(
20
  self,
21
- model: str | None = None,
22
- tools: list[Tool] | None = None,
23
- mcp_servers: list[MCPServer] | None = None,
24
  ):
25
  self._api_key = settings.GOOGLE_API_KEY
26
  self.base_url = settings.BASE_URL
27
- self.model = model if model else settings.DEFAULT_MODEL
 
 
 
 
 
28
 
29
  self.agent = Agent(
30
  name="Axiom 2.0",
31
  instructions=AXIOM_AGENT_PROMPT,
32
- mcp_servers=mcp_servers if mcp_servers is not None else [],
33
- tools=tools if tools is not None else [],
34
  )
35
 
36
- def _get_model_config(self):
37
 
38
- client = AsyncOpenAI(
39
- api_key=self._api_key,
40
- base_url=self.base_url,
 
41
  )
42
- model = OpenAIChatCompletionsModel(model=self.model, openai_client=client)
43
  return RunConfig(
44
- model=model,
45
- model_provider=client,
46
  tracing_disabled=True,
47
  )
48
 
49
  async def run_agent(self, input: str | list[dict[str, str]]):
50
- config = self._get_model_config()
51
 
52
  result = await Runner.run(
53
  starting_agent=self.agent,
@@ -57,7 +63,7 @@ class AxiomAgent:
57
  return result.final_output
58
 
59
  async def stream_agent(self, input: str | list[dict[str, str]]) -> AsyncGenerator:
60
- config = self._get_model_config()
61
 
62
  result = Runner.run_streamed(
63
  starting_agent=self.agent,
@@ -65,6 +71,7 @@ class AxiomAgent:
65
  run_config=config
66
  )
67
  async for event in result.stream_events():
68
- if event.type == "raw_response_event" and isinstance(event.data, ResponseTextDeltaEvent):
69
- if token:= event.data.delta or "":
 
70
  yield token
 
1
+ from typing import Optional
2
+ from typing_extensions import AsyncGenerator
3
+
4
  from openai import AsyncOpenAI
5
  from openai.types.responses import ResponseTextDeltaEvent
6
 
7
  from agents import (
8
+ Agent,
9
  OpenAIChatCompletionsModel,
10
  RunConfig,
11
  Runner,
 
16
  from .config import settings
17
  from .prompts import AXIOM_AGENT_PROMPT
18
 
 
 
19
  class AxiomAgent:
20
  def __init__(
21
  self,
22
+ model: Optional[str] = None,
23
+ tools: Optional[list[Tool]] = None,
24
+ mcp_servers: Optional[list[MCPServer]] = None,
25
  ):
26
  self._api_key = settings.GOOGLE_API_KEY
27
  self.base_url = settings.BASE_URL
28
+ self.model_name = model or settings.DEFAULT_MODEL
29
+
30
+ self._client: AsyncOpenAI = AsyncOpenAI(
31
+ api_key=self._api_key,
32
+ base_url=self._base_url,
33
+ )
34
 
35
  self.agent = Agent(
36
  name="Axiom 2.0",
37
  instructions=AXIOM_AGENT_PROMPT,
38
+ mcp_servers=mcp_servers or [],
39
+ tools=tools or [],
40
  )
41
 
42
+ def _get_run_config(self) -> RunConfig:
43
 
44
+ # Create the specific model configuration
45
+ model_instance = OpenAIChatCompletionsModel(
46
+ model=self.model_name,
47
+ openai_client=self._client
48
  )
49
+
50
  return RunConfig(
51
+ model=model_instance,
 
52
  tracing_disabled=True,
53
  )
54
 
55
  async def run_agent(self, input: str | list[dict[str, str]]):
56
+ config = self._get_run_config()
57
 
58
  result = await Runner.run(
59
  starting_agent=self.agent,
 
63
  return result.final_output
64
 
65
  async def stream_agent(self, input: str | list[dict[str, str]]) -> AsyncGenerator:
66
+ config = self._get_run_config()
67
 
68
  result = Runner.run_streamed(
69
  starting_agent=self.agent,
 
71
  run_config=config
72
  )
73
  async for event in result.stream_events():
74
+ if (event.type == "raw_response_event" and
75
+ isinstance(event.data, ResponseTextDeltaEvent)):
76
+ if token := event.data.delta:
77
  yield token