PPO playing MountainCar-v0 from https://github.com/sgoodfriend/rl-algo-impls/tree/2067e21d62fff5db60168687e7d9e89019a8bfc0
b18ddcc
| import gym | |
| import numpy as np | |
| from typing import Any, Dict, Tuple, Union | |
| from rl_algo_impls.wrappers.vectorable_wrapper import VecotarableWrapper | |
| ObsType = Union[np.ndarray, dict] | |
| ActType = Union[int, float, np.ndarray, dict] | |
| class EpisodicLifeEnv(VecotarableWrapper): | |
| def __init__(self, env: gym.Env, training: bool = True, noop_act: int = 0) -> None: | |
| super().__init__(env) | |
| self.training = training | |
| self.noop_act = noop_act | |
| self.life_done_continue = False | |
| self.lives = 0 | |
| def step(self, action: ActType) -> Tuple[ObsType, float, bool, Dict[str, Any]]: | |
| obs, rew, done, info = self.env.step(action) | |
| new_lives = self.env.unwrapped.ale.lives() | |
| self.life_done_continue = new_lives != self.lives and not done | |
| # Only if training should life-end be marked as done | |
| if self.training and 0 < new_lives < self.lives: | |
| done = True | |
| self.lives = new_lives | |
| return obs, rew, done, info | |
| def reset(self, **kwargs) -> ObsType: | |
| # If life_done_continue (but not game over), then a reset should just allow the | |
| # game to progress to the next life. | |
| if self.training and self.life_done_continue: | |
| obs, _, _, _ = self.env.step(self.noop_act) | |
| else: | |
| obs = self.env.reset(**kwargs) | |
| self.lives = self.env.unwrapped.ale.lives() | |
| return obs | |
| class FireOnLifeStarttEnv(VecotarableWrapper): | |
| def __init__(self, env: gym.Env, fire_act: int = 1) -> None: | |
| super().__init__(env) | |
| self.fire_act = fire_act | |
| action_meanings = env.unwrapped.get_action_meanings() | |
| assert action_meanings[fire_act] == "FIRE" | |
| assert len(action_meanings) >= 3 | |
| self.lives = 0 | |
| self.fire_on_next_step = True | |
| def step(self, action: ActType) -> Tuple[ObsType, float, bool, Dict[str, Any]]: | |
| if self.fire_on_next_step: | |
| action = self.fire_act | |
| self.fire_on_next_step = False | |
| obs, rew, done, info = self.env.step(action) | |
| new_lives = self.env.unwrapped.ale.lives() | |
| if 0 < new_lives < self.lives and not done: | |
| self.fire_on_next_step = True | |
| self.lives = new_lives | |
| return obs, rew, done, info | |
| def reset(self, **kwargs) -> ObsType: | |
| self.env.reset(**kwargs) | |
| obs, _, done, _ = self.env.step(self.fire_act) | |
| if done: | |
| self.env.reset(**kwargs) | |
| obs, _, done, _ = self.env.step(2) | |
| if done: | |
| self.env.reset(**kwargs) | |
| self.fire_on_next_step = False | |
| return obs | |
| class ClipRewardEnv(VecotarableWrapper): | |
| def __init__(self, env: gym.Env, training: bool = True) -> None: | |
| super().__init__(env) | |
| self.training = training | |
| def step(self, action: ActType) -> Tuple[ObsType, float, bool, Dict[str, Any]]: | |
| obs, rew, done, info = self.env.step(action) | |
| if self.training: | |
| info["unclipped_reward"] = rew | |
| rew = np.sign(rew) | |
| return obs, rew, done, info | |