Spaces:
Runtime error
Runtime error
from poke_env.player import Player | |
from poke_env.environment import Battle, Move, Pokemon | |
from typing import Dict, Any, Optional | |
import openai | |
import os | |
STANDARD_TOOL_SCHEMA = { | |
"choose_move": { | |
"name": "choose_move", | |
"description": "Use this to choose a move during battle.", | |
"parameters": { | |
"type": "object", | |
"properties": { | |
"move_name": { | |
"type": "string", | |
"description": "Name of the move to use." | |
} | |
}, | |
"required": ["move_name"] | |
} | |
}, | |
"choose_switch": { | |
"name": "choose_switch", | |
"description": "Use this to switch to another Pokemon.", | |
"parameters": { | |
"type": "object", | |
"properties": { | |
"pokemon_name": { | |
"type": "string", | |
"description": "Name of the Pokémon to switch into." | |
} | |
}, | |
"required": ["pokemon_name"] | |
} | |
} | |
} | |
def normalize_name(name): | |
return name.lower().replace(" ", "").replace("-", "") | |
class LLMAgentBase(Player): | |
def __init__(self, *args, **kwargs): | |
super().__init__(*args, **kwargs) | |
self.standard_tools = STANDARD_TOOL_SCHEMA | |
def _format_battle_state(self, battle: Battle) -> str: | |
# (Simplified for now) | |
return f"Active Pokémon: {battle.active_pokemon.species}, Moves: {[m.id for m in battle.available_moves]}" | |
def _find_move_by_name(self, battle: Battle, move_name: str) -> Optional[Move]: | |
normalized_name = normalize_name(move_name) | |
for move in battle.available_moves: | |
if normalize_name(move.id) == normalized_name: | |
return move | |
return None | |
def _find_pokemon_by_name(self, battle: Battle, pokemon_name: str) -> Optional[Pokemon]: | |
normalized_name = normalize_name(pokemon_name) | |
for pkmn in battle.available_switches: | |
if normalize_name(pkmn.species) == normalized_name: | |
return pkmn | |
return None | |
async def choose_move(self, battle: Battle) -> str: | |
battle_state = self._format_battle_state(battle) | |
decision_result = await self._get_llm_decision(battle_state) | |
decision = decision_result.get("decision", {}) | |
fn = decision.get("name", "") | |
args = decision.get("arguments", {}) | |
if fn == "choose_move": | |
move = self._find_move_by_name(battle, args.get("move_name", "")) | |
if move: | |
return self.create_order(move) | |
elif fn == "choose_switch": | |
switch = self._find_pokemon_by_name(battle, args.get("pokemon_name", "")) | |
if switch: | |
return self.create_order(switch) | |
return self.choose_random_move(battle) | |
async def _get_llm_decision(self, battle_state: str) -> Dict[str, Any]: | |
raise NotImplementedError("This method should be implemented by your agent subclass.") | |
class TemplateAgent(LLMAgentBase): | |
def __init__(self, *args, **kwargs): | |
super().__init__(*args, **kwargs) | |
self.api_url = "https://api-inference.huggingface.co/models/mistralai/Mixtral-8x7B-Instruct-v0.1" | |
self.headers = { | |
"Authorization": f"Bearer {os.getenv('HF_TOKEN')}", | |
"Content-Type": "application/json" | |
} | |
async def _get_llm_decision(self, battle_state: str) -> Dict[str, Any]: | |
prompt = ( | |
"You are a Pokémon battle agent. Based on the battle state, " | |
"decide which move to use or which Pokémon to switch to.\n" | |
"Respond ONLY in JSON format like this:\n" | |
'{"decision": {"name": "choose_move", "arguments": {"move_name": "Flamethrower"}}}\n' | |
"OR\n" | |
'{"decision": {"name": "choose_switch", "arguments": {"pokemon_name": "Pikachu"}}}\n\n" | |
f"Battle state:\n{battle_state}" | |
) | |
payload = { | |
"inputs": prompt, | |
"parameters": {"temperature": 0.7, "max_new_tokens": 256} | |
} | |
try: | |
import requests | |
response = requests.post(self.api_url, headers=self.headers, json=payload) | |
if response.status_code != 200: | |
return {"error": f"HTTP {response.status_code}: {response.text}"} | |
raw = response.json() | |
output = raw[0]["generated_text"] | |
# crude parse for now — better to use regex or json.loads w/ validation | |
if "choose_move" in output: | |
move_name = output.split('"move_name":')[1].split('"')[1] | |
return {"decision": {"name": "choose_move", "arguments": {"move_name": move_name}}} | |
elif "choose_switch" in output: | |
poke_name = output.split('"pokemon_name":')[1].split('"')[1] | |
return {"decision": {"name": "choose_switch", "arguments": {"pokemon_name": poke_name}}} | |
else: | |
return {"error": "Unrecognized response"} | |
except Exception as e: | |
return {"error": f"Exception: {str(e)}"} | |