mehdinathani commited on
Commit
9c9f9b5
·
verified ·
1 Parent(s): 01cfb4b

Update agents.py

Browse files
Files changed (1) hide show
  1. agents.py +39 -18
agents.py CHANGED
@@ -88,25 +88,46 @@ class LLMAgentBase(Player):
88
  class TemplateAgent(LLMAgentBase):
89
  def __init__(self, *args, **kwargs):
90
  super().__init__(*args, **kwargs)
91
- openai.api_key = os.getenv("OPENAI_API_KEY") # Store key in secret later
 
 
 
 
92
 
93
  async def _get_llm_decision(self, battle_state: str) -> Dict[str, Any]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
  try:
95
- response = openai.ChatCompletion.create(
96
- model="gpt-3.5-turbo",
97
- messages=[
98
- {"role": "system", "content": "You are a Pokémon battle strategist."},
99
- {"role": "user", "content": f"{battle_state}\n\nChoose move using one of these functions: choose_move(move_name='...') or choose_switch(pokemon_name='...')"}
100
- ],
101
- functions=list(self.standard_tools.values()),
102
- function_call="auto"
103
- )
104
- choice = response["choices"][0]["message"]["function_call"]
105
- return {
106
- "decision": {
107
- "name": choice["name"],
108
- "arguments": eval(choice["arguments"]) # Simple for now
109
- }
110
- }
 
 
111
  except Exception as e:
112
- return {"error": str(e)}
 
88
  class TemplateAgent(LLMAgentBase):
89
  def __init__(self, *args, **kwargs):
90
  super().__init__(*args, **kwargs)
91
+ self.api_url = "https://api-inference.huggingface.co/models/mistralai/Mixtral-8x7B-Instruct-v0.1"
92
+ self.headers = {
93
+ "Authorization": f"Bearer {os.getenv('HF_TOKEN')}",
94
+ "Content-Type": "application/json"
95
+ }
96
 
97
  async def _get_llm_decision(self, battle_state: str) -> Dict[str, Any]:
98
+ prompt = (
99
+ "You are a Pokémon battle agent. Based on the battle state, "
100
+ "decide which move to use or which Pokémon to switch to.\n"
101
+ "Respond ONLY in JSON format like this:\n"
102
+ '{"decision": {"name": "choose_move", "arguments": {"move_name": "Flamethrower"}}}\n'
103
+ "OR\n"
104
+ '{"decision": {"name": "choose_switch", "arguments": {"pokemon_name": "Pikachu"}}}\n\n"
105
+ f"Battle state:\n{battle_state}"
106
+ )
107
+
108
+ payload = {
109
+ "inputs": prompt,
110
+ "parameters": {"temperature": 0.7, "max_new_tokens": 256}
111
+ }
112
+
113
  try:
114
+ import requests
115
+ response = requests.post(self.api_url, headers=self.headers, json=payload)
116
+ if response.status_code != 200:
117
+ return {"error": f"HTTP {response.status_code}: {response.text}"}
118
+
119
+ raw = response.json()
120
+ output = raw[0]["generated_text"]
121
+
122
+ # crude parse for now — better to use regex or json.loads w/ validation
123
+ if "choose_move" in output:
124
+ move_name = output.split('"move_name":')[1].split('"')[1]
125
+ return {"decision": {"name": "choose_move", "arguments": {"move_name": move_name}}}
126
+ elif "choose_switch" in output:
127
+ poke_name = output.split('"pokemon_name":')[1].split('"')[1]
128
+ return {"decision": {"name": "choose_switch", "arguments": {"pokemon_name": poke_name}}}
129
+ else:
130
+ return {"error": "Unrecognized response"}
131
+
132
  except Exception as e:
133
+ return {"error": f"Exception: {str(e)}"}