mehdi-pokemon-agent / agents.py
mehdinathani's picture
Update agents.py
9c9f9b5 verified
from poke_env.player import Player
from poke_env.environment import Battle, Move, Pokemon
from typing import Dict, Any, Optional
import openai
import os
STANDARD_TOOL_SCHEMA = {
"choose_move": {
"name": "choose_move",
"description": "Use this to choose a move during battle.",
"parameters": {
"type": "object",
"properties": {
"move_name": {
"type": "string",
"description": "Name of the move to use."
}
},
"required": ["move_name"]
}
},
"choose_switch": {
"name": "choose_switch",
"description": "Use this to switch to another Pokemon.",
"parameters": {
"type": "object",
"properties": {
"pokemon_name": {
"type": "string",
"description": "Name of the Pokémon to switch into."
}
},
"required": ["pokemon_name"]
}
}
}
def normalize_name(name):
return name.lower().replace(" ", "").replace("-", "")
class LLMAgentBase(Player):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.standard_tools = STANDARD_TOOL_SCHEMA
def _format_battle_state(self, battle: Battle) -> str:
# (Simplified for now)
return f"Active Pokémon: {battle.active_pokemon.species}, Moves: {[m.id for m in battle.available_moves]}"
def _find_move_by_name(self, battle: Battle, move_name: str) -> Optional[Move]:
normalized_name = normalize_name(move_name)
for move in battle.available_moves:
if normalize_name(move.id) == normalized_name:
return move
return None
def _find_pokemon_by_name(self, battle: Battle, pokemon_name: str) -> Optional[Pokemon]:
normalized_name = normalize_name(pokemon_name)
for pkmn in battle.available_switches:
if normalize_name(pkmn.species) == normalized_name:
return pkmn
return None
async def choose_move(self, battle: Battle) -> str:
battle_state = self._format_battle_state(battle)
decision_result = await self._get_llm_decision(battle_state)
decision = decision_result.get("decision", {})
fn = decision.get("name", "")
args = decision.get("arguments", {})
if fn == "choose_move":
move = self._find_move_by_name(battle, args.get("move_name", ""))
if move:
return self.create_order(move)
elif fn == "choose_switch":
switch = self._find_pokemon_by_name(battle, args.get("pokemon_name", ""))
if switch:
return self.create_order(switch)
return self.choose_random_move(battle)
async def _get_llm_decision(self, battle_state: str) -> Dict[str, Any]:
raise NotImplementedError("This method should be implemented by your agent subclass.")
class TemplateAgent(LLMAgentBase):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.api_url = "https://api-inference.huggingface.co/models/mistralai/Mixtral-8x7B-Instruct-v0.1"
self.headers = {
"Authorization": f"Bearer {os.getenv('HF_TOKEN')}",
"Content-Type": "application/json"
}
async def _get_llm_decision(self, battle_state: str) -> Dict[str, Any]:
prompt = (
"You are a Pokémon battle agent. Based on the battle state, "
"decide which move to use or which Pokémon to switch to.\n"
"Respond ONLY in JSON format like this:\n"
'{"decision": {"name": "choose_move", "arguments": {"move_name": "Flamethrower"}}}\n'
"OR\n"
'{"decision": {"name": "choose_switch", "arguments": {"pokemon_name": "Pikachu"}}}\n\n"
f"Battle state:\n{battle_state}"
)
payload = {
"inputs": prompt,
"parameters": {"temperature": 0.7, "max_new_tokens": 256}
}
try:
import requests
response = requests.post(self.api_url, headers=self.headers, json=payload)
if response.status_code != 200:
return {"error": f"HTTP {response.status_code}: {response.text}"}
raw = response.json()
output = raw[0]["generated_text"]
# crude parse for now — better to use regex or json.loads w/ validation
if "choose_move" in output:
move_name = output.split('"move_name":')[1].split('"')[1]
return {"decision": {"name": "choose_move", "arguments": {"move_name": move_name}}}
elif "choose_switch" in output:
poke_name = output.split('"pokemon_name":')[1].split('"')[1]
return {"decision": {"name": "choose_switch", "arguments": {"pokemon_name": poke_name}}}
else:
return {"error": "Unrecognized response"}
except Exception as e:
return {"error": f"Exception: {str(e)}"}