import numpy as np import random from loguru import logger from typing import Dict, Any, List class QLearningAgent: def __init__(self, learning_rate: float = 0.1, discount_factor: float = 0.9, exploration_rate: float = 0.1): """Initialize the Q-Learning Agent.""" self.learning_rate = learning_rate self.discount_factor = discount_factor self.exploration_rate = exploration_rate self.q_table = {} # Initialize Q-table as an empty dictionary self.setup_logger() def setup_logger(self): """Configure logging for the agent.""" logger.add("logs/q_learning_agent.log", rotation="500 MB") def get_q_value(self, state: Dict[str, Any], action: str) -> float: """Get the Q-value for a given state-action pair.""" state_key = self.serialize_state(state) if state_key not in self.q_table: self.q_table[state_key] = {} return self.q_table[state_key].get( action, 0.0) # Default Q-value is 0.0 def set_q_value(self, state: Dict[str, Any], action: str, value: float): """Set the Q-value for a given state-action pair.""" state_key = self.serialize_state(state) if state_key not in self.q_table: self.q_table[state_key] = {} self.q_table[state_key][action] = value def choose_action(self, state: Dict[str, Any], available_actions: List[str]) -> str: """Choose an action based on the current state and Q-table.""" if random.random() < self.exploration_rate: # Explore: choose a random action return random.choice(available_actions) else: # Exploit: choose the action with the highest Q-value q_values = [self.get_q_value(state, action) for action in available_actions] max_q = max(q_values) # If multiple actions have the same max Q-value, choose randomly # among them best_actions = [ action for action, q in zip( available_actions, q_values) if q == max_q] return random.choice(best_actions) def update_q_table(self, state: Dict[str, Any], action: str, reward: float, next_state: Dict[str, Any], next_actions: List[str]): """Update the Q-table based on the observed reward and next state.""" current_q = self.get_q_value(state, action) max_next_q = max([self.get_q_value(next_state, next_action) for next_action in next_actions], default=0) new_q = current_q + self.learning_rate * \ (reward + self.discount_factor * max_next_q - current_q) self.set_q_value(state, action, new_q) logger.info( f"Q-table updated for state-action pair: ({self.serialize_state(state)}, {action})") def serialize_state(self, state: Dict[str, Any]) -> str: """Serialize the state into a string representation for use as a dictionary key.""" # Convert the state dictionary to a string representation return str(state)