PPO playing MountainCar-v0 from https://github.com/sgoodfriend/rl-algo-impls/tree/983cb75e43e51cf4ef57f177194ab9a4a1a8808b
3cc5c1d
| from typing import Callable | |
| from rl_algo_impls.shared.callbacks import Callback | |
| from rl_algo_impls.shared.policy.policy import Policy | |
| from rl_algo_impls.wrappers.self_play_wrapper import SelfPlayWrapper | |
| class SelfPlayCallback(Callback): | |
| def __init__( | |
| self, | |
| policy: Policy, | |
| policy_factory: Callable[[], Policy], | |
| selfPlayWrapper: SelfPlayWrapper, | |
| ) -> None: | |
| super().__init__() | |
| self.policy = policy | |
| self.policy_factory = policy_factory | |
| self.selfPlayWrapper = selfPlayWrapper | |
| self.checkpoint_policy() | |
| def on_step(self, timesteps_elapsed: int = 1) -> bool: | |
| super().on_step(timesteps_elapsed) | |
| if ( | |
| self.timesteps_elapsed | |
| >= self.last_checkpoint_step + self.selfPlayWrapper.save_steps | |
| ): | |
| self.checkpoint_policy() | |
| return True | |
| def checkpoint_policy(self): | |
| self.selfPlayWrapper.checkpoint_policy( | |
| self.policy_factory().load_from(self.policy) | |
| ) | |
| self.last_checkpoint_step = self.timesteps_elapsed | |