Spaces:
Runtime error
Runtime error
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| from abc import ABC, abstractmethod | |
| from contextlib import contextmanager | |
| from pathlib import Path | |
| import typing as tp | |
| import flashy | |
| import omegaconf | |
| import torch | |
| from torch import nn | |
| from .. import optim | |
| from ..optim import fsdp | |
| from ..utils import checkpoint | |
| from ..utils.autocast import TorchAutocast | |
| from ..utils.best_state import BestStateDictManager | |
| from ..utils.deadlock import DeadlockDetect | |
| from ..utils.profiler import Profiler | |
| from ..utils.utils import copy_state, dict_from_config, model_hash, with_rank_rng | |
| class StandardSolver(ABC, flashy.BaseSolver): | |
| """Standard solver for AudioCraft. | |
| The standard solver implements a base training loop with the following stages: | |
| train, valid, evaluate and generate that are expected to be all defined for | |
| solvers in AudioCraft. It also provides a nice default management of Dora history replay, | |
| checkpoint management across epoch, and logging configuration. | |
| AudioCraft solvers must inherit from the StandardSolver and define the methods | |
| associated to each stage as well as the show, build_model and build_dataloaders methods. | |
| """ | |
| def __init__(self, cfg: omegaconf.DictConfig): | |
| super().__init__() | |
| self.logger.info(f"Instantiating solver {self.__class__.__name__} for XP {self.xp.sig}") | |
| self.logger.info(f"All XP logs are stored in {self.xp.folder}") | |
| self.cfg = cfg | |
| self.device = cfg.device | |
| self.model: nn.Module | |
| self._continue_best_source_keys = ['best_state', 'fsdp_best_state'] | |
| self._fsdp_modules: tp.List[fsdp.FSDP] = [] | |
| self._ema_sources: nn.ModuleDict = nn.ModuleDict() | |
| self.ema: tp.Optional[optim.ModuleDictEMA] = None | |
| self.dataloaders: tp.Dict[str, torch.utils.data.DataLoader] = dict() | |
| self._log_updates = self.cfg.logging.get('log_updates', 10) | |
| if self.cfg.logging.log_tensorboard: | |
| self.init_tensorboard(**self.cfg.get('tensorboard')) | |
| if self.cfg.logging.log_wandb and self: | |
| self.init_wandb(**self.cfg.get('wandb')) | |
| # keep a copy of the best performing state for stateful objects | |
| # used for evaluation and generation stages | |
| dtype_best: tp.Optional[torch.dtype] = None | |
| if self.cfg.fsdp.use: | |
| dtype_best = getattr(torch, self.cfg.fsdp.param_dtype) # type: ignore | |
| assert isinstance(dtype_best, torch.dtype) | |
| elif self.cfg.autocast: | |
| dtype_best = getattr(torch, self.cfg.autocast_dtype) # type: ignore | |
| assert isinstance(dtype_best, torch.dtype) | |
| self.best_state: BestStateDictManager = BestStateDictManager(dtype=dtype_best) | |
| # Hacky support for keeping a copy of the full best state in rank0. | |
| self.fsdp_best_state: tp.Dict[str, tp.Any] = {} | |
| self.register_stateful('best_state', 'fsdp_best_state') # register best_state object to keep it in state_dict | |
| self._new_best_state: bool = False # should save a new checkpoint | |
| # instantiate datasets and appropriate number of updates per epoch | |
| self.build_dataloaders() | |
| if self.cfg.execute_only is None: | |
| assert 'train' in self.dataloaders, "The train dataset split must be provided." | |
| assert 'valid' in self.dataloaders, "The valid dataset split must be provided." | |
| self.train_updates_per_epoch = len(self.dataloaders['train']) if 'train' in self.dataloaders else 0 | |
| if self.cfg.optim.updates_per_epoch: | |
| self.train_updates_per_epoch = self.cfg.optim.updates_per_epoch | |
| self.total_updates = self.train_updates_per_epoch * self.cfg.optim.epochs | |
| # instantiate model & exponential moving average on the model | |
| self.build_model() | |
| self.logger.info("Model hash: %s", model_hash(self.model)) | |
| assert 'model' in self.stateful.sources, \ | |
| "Please register the model to stateful with self.register_stateful('model') in build_model." | |
| self.profiler = Profiler(self.model, **self.cfg.profiler) | |
| self.initialize_ema() | |
| self.register_stateful('ema') | |
| assert self.ema is None or 'ema' in self.stateful.sources, \ | |
| "Please register the ema to stateful with self.register_stateful('ema') in build_model." | |
| self.deadlock_detect = DeadlockDetect(**self.cfg.deadlock) | |
| # basic statistics on the trained model | |
| model_size = sum(p.numel() for p in self.model.parameters() if p.requires_grad) / 1e6 | |
| # one copy of grad, one copy of momentum, one copy of denominator and model weights. | |
| # and 4 bytes for each float! | |
| mem_usage = model_size * 4 * 4 / 1000 | |
| self.logger.info("Model size: %.2f M params", model_size) | |
| self.logger.info("Base memory usage, with model, grad and optim: %.2f GB", mem_usage) | |
| def autocast(self): | |
| """Convenient autocast (or not) using the solver configuration.""" | |
| return TorchAutocast(enabled=self.cfg.autocast, device_type=self.device, dtype=self.autocast_dtype) | |
| def _get_state_source(self, name) -> flashy.state.StateDictSource: | |
| # Internal utility to get a state source from the solver | |
| return self.stateful.sources[name] | |
| def best_metric_name(self) -> tp.Optional[str]: | |
| """Metric name used to identify the best state. This metric should be stored in the metrics | |
| used on the stage for best state identification (most likely, `valid`). If None, then | |
| no best state is saved. | |
| """ | |
| return None | |
| def register_best_state(self, *args: str): | |
| """Register state sources in `BestStateDictManager` to keep their best states along with their | |
| latest states. The best state will be used at evaluation stages instead of the latest states. | |
| Shortcut around `BestStateDictManager.register` method. You can pass any number of | |
| attribute, included nested attributes and those will be included into the checkpoints | |
| and automatically restored when `BaseSolver.restore` is called. | |
| """ | |
| for name in args: | |
| state_source = self._get_state_source(name) | |
| assert name in self.stateful.sources, "Registered states in best should be registered in stateful first!" | |
| self.best_state.register(name, state_source) | |
| def register_ema(self, *args: str): | |
| """Register state sources for exponential moving average. | |
| The registered sources are used to instantiate a ModuleDictEMA instance. | |
| The ModuleDictEMA keeps a `nn.ModuleDict` module that is updated when self.ema.step() is called | |
| and swapped with the original state sources with self.swap_ema_state() method. | |
| Usage: | |
| self.register_ema('model') | |
| """ | |
| assert self.ema is None, "Cannot register state source to already instantiated EMA." | |
| for name in args: | |
| self._ema_sources[name] = getattr(self, name) | |
| def wrap_with_fsdp(self, model: torch.nn.Module, *args, **kwargs): | |
| model = fsdp.wrap_with_fsdp(self.cfg.fsdp, model, *args, **kwargs) | |
| if isinstance(model, fsdp.FSDP): | |
| self._fsdp_modules.append(model) | |
| return model | |
| def update_best_state_from_stage(self, stage_name: str = 'valid'): | |
| """Update latest best state based on pending metrics of a given stage. This method relies | |
| on the `BestStateDictManager.update` method to update the best state_dict with latest weights | |
| if the registered states happen to match to the best performing setup. | |
| """ | |
| if self.best_metric_name is None: | |
| # when no best metric is defined, the last state is always the best | |
| self._new_best_state = True | |
| self.logger.info("Updating best state with current state.") | |
| else: | |
| assert stage_name in self._pending_metrics, f"Metrics for stage {stage_name} not found." | |
| assert self.best_metric_name in self._pending_metrics[stage_name], \ | |
| f"Best metric not found in {stage_name} metrics. Cannot register best state" | |
| current_score = self._pending_metrics[stage_name][self.best_metric_name] | |
| all_best_metric_scores = [ | |
| past_metrics[stage_name][self.best_metric_name] | |
| for past_metrics in self.history | |
| ] | |
| all_best_metric_scores.append(current_score) | |
| best_score = min(all_best_metric_scores) | |
| self._new_best_state = current_score == best_score | |
| if self._new_best_state: | |
| old_best = min(all_best_metric_scores[:-1] + [float('inf')]) | |
| self.logger.info( | |
| f"New best state with {self.best_metric_name}={current_score:.3f} (was {old_best:.3f})") | |
| if self._new_best_state: | |
| if self.cfg.fsdp.use: | |
| # this will give an empty state dict on all ranks but the rank 0 | |
| # which will have a copy in memory of the full model. | |
| with fsdp.switch_to_full_state_dict(self._fsdp_modules): | |
| for name in self.best_state.states.keys(): | |
| state_source = self._get_state_source(name) | |
| self.best_state.update(name, state_source) | |
| # we save to a different dict. | |
| self.fsdp_best_state.update(self.best_state.state_dict()) | |
| # We cannot efficiently load fsdp_best_state when using FSDP, | |
| # so we have do do a second pass, with the local shards. | |
| for name in self.best_state.states.keys(): | |
| state_source = self._get_state_source(name) | |
| self.best_state.update(name, state_source) | |
| def _load_new_state_dict(self, state_dict: dict) -> dict: | |
| old_states = {} | |
| for name, new_state in state_dict.items(): | |
| state_source = self._get_state_source(name) | |
| old_states[name] = copy_state(state_source.state_dict()) | |
| state_source.load_state_dict(new_state) | |
| return old_states | |
| def swap_best_state(self): | |
| self.logger.debug(f"Swapping to best state for: {', '.join(self.best_state.state_dict().keys())}") | |
| old_states = self._load_new_state_dict(self.best_state.state_dict()) | |
| try: | |
| yield | |
| finally: | |
| self.logger.debug("Swapping back from best to original state") | |
| for name, old_state in old_states.items(): | |
| state_source = self._get_state_source(name) | |
| state_source.load_state_dict(old_state) | |
| def swap_ema_state(self): | |
| if self.ema is None: | |
| yield | |
| else: | |
| ema_state_dict = self.ema.state_dict()['state'] | |
| self.logger.debug(f"Swapping to EMA state for: {', '.join(ema_state_dict.keys())}") | |
| old_states = self._load_new_state_dict(ema_state_dict) | |
| try: | |
| yield | |
| finally: | |
| self.logger.debug("Swapping back from EMA state to original state") | |
| for name, old_state in old_states.items(): | |
| state_source = self._get_state_source(name) | |
| state_source.load_state_dict(old_state) | |
| def is_training(self): | |
| return self.current_stage == 'train' | |
| def log_model_summary(self, model: nn.Module): | |
| """Log model summary, architecture and size of the model.""" | |
| self.logger.info(model) | |
| mb = sum(p.numel() for p in model.parameters()) * 4 / 2 ** 20 | |
| self.logger.info("Size: %.1f MB", mb) | |
| def build_model(self): | |
| """Method to implement to initialize model.""" | |
| ... | |
| def initialize_ema(self): | |
| """Initialize exponential moving average with the registered sources. | |
| EMA object is created if the optim.ema.model.decay value is non-null. | |
| """ | |
| from .builders import get_ema | |
| self.ema = get_ema(self._ema_sources, self.cfg.optim.ema) | |
| if self.ema is None: | |
| self.logger.info('No EMA on the model.') | |
| else: | |
| assert self.cfg.optim.ema.updates > 0 | |
| self.logger.info( | |
| f'Initializing EMA on the model with decay = {self.ema.decay}' | |
| f' every {self.cfg.optim.ema.updates} updates' | |
| ) | |
| def build_dataloaders(self): | |
| """Method to implement to initialize dataloaders.""" | |
| ... | |
| def show(self): | |
| """Method to log any information without running the job.""" | |
| ... | |
| def log_updates(self): | |
| # convenient access to log updates | |
| return self._log_updates | |
| def checkpoint_path(self, **kwargs): | |
| kwargs.setdefault('use_fsdp', self.cfg.fsdp.use) | |
| return self.folder / checkpoint.checkpoint_name(**kwargs) | |
| def epoch_checkpoint_path(self, epoch: int, **kwargs): | |
| kwargs.setdefault('use_fsdp', self.cfg.fsdp.use) | |
| return self.folder / checkpoint.checkpoint_name(str(epoch), **kwargs) | |
| def checkpoint_path_with_name(self, name: str, **kwargs): | |
| kwargs.setdefault('use_fsdp', self.cfg.fsdp.use) | |
| return self.folder / checkpoint.checkpoint_name(name=name, **kwargs) | |
| def save_checkpoints(self): | |
| """Save checkpoint, optionally keeping a copy for a given epoch.""" | |
| is_sharded = self.cfg.fsdp.use | |
| if not flashy.distrib.is_rank_zero() and not is_sharded: | |
| return | |
| self.logger.info("Model hash: %s", model_hash(self.model)) | |
| state = self.state_dict() | |
| epoch = self.epoch - 1 # pushing metrics will increase the epoch in Flashy, so we do -1 here | |
| # save minimal state_dict as new checkpoint every X epoch | |
| if self.cfg.checkpoint.save_every: | |
| if epoch % self.cfg.checkpoint.save_every == 0: | |
| minimal_state = state | |
| if self.cfg.checkpoint.keep_every_states is not None and len(self.cfg.checkpoint.keep_every_states) > 0: | |
| minimal_state = { | |
| name: source for name, source in state.items() | |
| if name in self.cfg.checkpoint.keep_every_states | |
| } | |
| epoch_checkpoint_path = self.epoch_checkpoint_path(epoch) | |
| checkpoint.save_checkpoint(minimal_state, epoch_checkpoint_path, is_sharded) | |
| # save checkpoint as latest checkpoint | |
| if self.cfg.checkpoint.save_last: | |
| last_checkpoint_path = self.checkpoint_path() | |
| checkpoint.save_checkpoint(state, last_checkpoint_path, is_sharded) | |
| # flush any stale checkpoint to reduce disk footprint | |
| checkpoint.flush_stale_checkpoints(self.checkpoint_path()) | |
| def load_from_pretrained(self, name: str) -> dict: | |
| raise NotImplementedError("Solver does not provide a way to load pretrained models.") | |
| def load_checkpoints(self, load_best: bool = False, ignore_state_keys: tp.List[str] = []) -> tp.Optional[dict]: | |
| """Load last checkpoint or the one specified in continue_from. | |
| Args: | |
| load_best (bool): Whether to load from best state dict or not. | |
| Best state dict is always used when not loading the current xp. | |
| ignore_state_keys (list of str): List of sources to ignore when loading the state, e.g. `optimizer`. | |
| Returns: | |
| state (dict, optional): The loaded state dictionary. | |
| """ | |
| # load checkpoints from xp folder or cfg.continue_from | |
| is_sharded = self.cfg.fsdp.use | |
| load_from_path: tp.Optional[Path] = None | |
| checkpoint_source: tp.Optional[checkpoint.CheckpointSource] = None | |
| if load_best: | |
| self.logger.info("Trying to load state_dict from best state.") | |
| state: tp.Optional[dict] = None | |
| rank0_checkpoint_path = self.checkpoint_path(use_fsdp=False) | |
| current_checkpoint_path = self.checkpoint_path() | |
| _pretrained_prefix = '//pretrained/' | |
| continue_pretrained = (self.cfg.continue_from or '').startswith(_pretrained_prefix) | |
| if rank0_checkpoint_path.exists(): | |
| self.logger.info(f"Loading existing checkpoint: {current_checkpoint_path}") | |
| load_from_path = current_checkpoint_path | |
| checkpoint.check_sharded_checkpoint(current_checkpoint_path, rank0_checkpoint_path) | |
| checkpoint_source = checkpoint.CheckpointSource.CURRENT_XP | |
| elif self.cfg.continue_from and not continue_pretrained: | |
| self.logger.info(f"Continuing from provided checkpoint: {self.cfg.continue_from}") | |
| # we're always continuing from consolidated checkpoints: self.cfg.use_fsdp and not continue_best | |
| load_from_path = checkpoint.resolve_checkpoint_path(self.cfg.continue_from, use_fsdp=False) | |
| if load_from_path is None: | |
| self.logger.error('Could not resolve the continue_from checkpoint %s', self.cfg.continue_from) | |
| raise RuntimeError(f'Could not resolve continue_from checkpoint {self.cfg.continue_from}') | |
| checkpoint_source = checkpoint.CheckpointSource.OTHER | |
| if load_from_path is not None: | |
| state = checkpoint.load_checkpoint(load_from_path, is_sharded) | |
| elif continue_pretrained: | |
| self.logger.info("Loading a pretrained model. Ignoring 'load_best' and 'ignore_state_keys' params.") | |
| state = self.load_from_pretrained(self.cfg.continue_from[len(_pretrained_prefix):]) | |
| checkpoint_source = checkpoint.CheckpointSource.PRETRAINED | |
| load_best = True | |
| # checkpoints are not from the current xp, we only retrieve the best state | |
| if checkpoint_source is not None and checkpoint_source != checkpoint.CheckpointSource.CURRENT_XP: | |
| assert state is not None | |
| self.logger.info("Checkpoint source is not the current xp: Load state_dict from best state.") | |
| load_best = True | |
| state = {key: state[key] for key in self._continue_best_source_keys if key in state} | |
| # loaded checkpoints are FSDP checkpoints: we're reading the best state | |
| # from FSDP and we drop the regular best_state | |
| if 'fsdp_best_state' in state and state['fsdp_best_state']: | |
| state.pop('best_state', None) | |
| self.logger.info("... Loaded checkpoint has FSDP best state") | |
| # FSDP is enabled in the solver, if the loaded checkpoints do not have FSDP support | |
| # then we're initializing FSDP best state with the regular best state | |
| elif self.cfg.fsdp.use: | |
| if 'fsdp_best_state' not in state or not state['fsdp_best_state']: | |
| # we swap non-FSDP checkpoints best_state to FSDP-compatible best state | |
| state['fsdp_best_state'] = state.pop('best_state') | |
| self.logger.info("... Loaded checkpoint does not have FSDP best state. Use regular best state") | |
| if state is not None: | |
| if load_best: | |
| self.logger.info("Ignoring keys when loading best %r", ignore_state_keys) | |
| for key in set(ignore_state_keys): | |
| if key in state: | |
| state.pop(key) | |
| has_best_state = 'best_state' in state or 'fsdp_best_state' in state | |
| assert has_best_state, ("Trying to load best state but neither 'best_state'", | |
| " or 'fsdp_best_state' found in checkpoints.") | |
| self.load_state_dict(state) | |
| # for FSDP, let's make extra sure nothing bad happened with out of sync | |
| # checkpoints across workers. | |
| epoch = float(self.epoch) | |
| avg_epoch = flashy.distrib.average_metrics({'epoch': epoch})['epoch'] | |
| if avg_epoch != epoch: | |
| raise RuntimeError( | |
| f"Inconsistent loading of checkpoints happened, our epoch is {epoch} " | |
| f"but average of epochs is {avg_epoch}, at least one gpu must have a " | |
| "different epoch number.") | |
| # on load_best, properly reinitialize state_dict, best states and ema | |
| # otherwise we load from the current xp and don't alter anything | |
| if load_best: | |
| self.logger.info("Loading state_dict from best state.") | |
| if not self.cfg.fsdp.use and self.fsdp_best_state: | |
| # loading from an FSDP checkpoint but with FSDP deactivated | |
| self.logger.info("... Loading from FSDP best state dict.") | |
| self.best_state.load_state_dict(self.fsdp_best_state) | |
| # if load_best, we permanently override the regular state_dict with the best state | |
| if self.cfg.fsdp.use: | |
| self.logger.info("FSDP is used, loading from FSDP best state.") | |
| with fsdp.switch_to_full_state_dict(self._fsdp_modules): | |
| # this might be really fragile but okay for now. | |
| self.load_state_dict(self.fsdp_best_state) | |
| else: | |
| # we permanently swap the stateful objects to their best state | |
| self._load_new_state_dict(self.best_state.state_dict()) | |
| # the EMA modules should also be instantiated with best state. | |
| # the easiest way to do so is to reinitialize a new EMA with best state loaded. | |
| if self.ema is not None: | |
| self.logger.info("Re-initializing EMA from best state") | |
| self.initialize_ema() | |
| if self.cfg.fsdp.use: | |
| self.logger.info("Re-initializing best state after using FSDP best state.") | |
| for name in self.best_state.states.keys(): | |
| state_source = self._get_state_source(name) | |
| self.best_state.update(name, state_source) | |
| return state | |
| def restore(self, load_best: bool = False, replay_metrics: bool = False, | |
| ignore_state_keys: tp.List[str] = []) -> bool: | |
| """Restore the status of a solver for a given xp. | |
| Args: | |
| load_best (bool): if `True`, load the best state from the checkpoint. | |
| replay_metrics (bool): if `True`, logs all the metrics from past epochs. | |
| ignore_state_keys (list of str): list of sources to ignore when loading the state, e.g. `optimizer`. | |
| """ | |
| self.logger.info("Restoring weights and history.") | |
| restored_checkpoints = self.load_checkpoints(load_best, ignore_state_keys) | |
| self.logger.info("Model hash: %s", model_hash(self.model)) | |
| if replay_metrics and len(self.history) > 0: | |
| self.logger.info("Replaying past metrics...") | |
| for epoch, stages in enumerate(self.history): | |
| for stage_name, metrics in stages.items(): | |
| # We manually log the metrics summary to the result logger | |
| # as we don't want to add them to the pending metrics | |
| self.result_logger._log_summary(stage_name, metrics, step=epoch + 1, step_name='epoch', | |
| formatter=self.get_formatter(stage_name)) | |
| return restored_checkpoints is not None | |
| def commit(self, save_checkpoints: bool = True): | |
| """Commit metrics to dora and save checkpoints at the end of an epoch.""" | |
| # we override commit to introduce more complex checkpoint saving behaviors | |
| self.history.append(self._pending_metrics) # This will increase self.epoch | |
| if save_checkpoints: | |
| self.save_checkpoints() | |
| self._start_epoch() | |
| if flashy.distrib.is_rank_zero(): | |
| self.xp.link.update_history(self.history) | |
| def run_epoch(self): | |
| """Run a single epoch with all stages. | |
| Metrics for a given stage are stored in _pending_metrics and committed by the solver afterwards. | |
| Children solvers can extend this method with custom behavior, e.g.: | |
| def run_epoch(self): | |
| ... # custom code | |
| super().run_epoch() | |
| ... # custom code | |
| """ | |
| self.run_stage('train', self.train) | |
| with torch.no_grad(): | |
| with self.swap_ema_state(): | |
| self.run_stage('valid', self.valid) | |
| # the best state is updated with EMA states if available | |
| self.update_best_state_from_stage('valid') | |
| with self.swap_best_state(): | |
| if self.should_run_stage('evaluate'): | |
| self.run_stage('evaluate', self.evaluate) | |
| if self.should_run_stage('generate'): | |
| self.run_stage('generate', with_rank_rng()(self.generate)) | |
| def run(self): | |
| """Training loop.""" | |
| assert len(self.state_dict()) > 0 | |
| self.restore(replay_metrics=True) # load checkpoint and replay history | |
| self.log_hyperparams(dict_from_config(self.cfg)) | |
| for epoch in range(self.epoch, self.cfg.optim.epochs + 1): | |
| if self.should_stop_training(): | |
| return | |
| self.run_epoch() | |
| # Commit will send the metrics to Dora and save checkpoints by default. | |
| self.commit() | |
| def should_stop_training(self) -> bool: | |
| """Check whether we should stop training or not.""" | |
| return self.epoch > self.cfg.optim.epochs | |
| def should_run_stage(self, stage_name) -> bool: | |
| """Check whether we want to run the specified stages.""" | |
| stage_every = self.cfg[stage_name].get('every', None) | |
| is_last_epoch = self.epoch == self.cfg.optim.epochs | |
| is_epoch_every = (stage_every and self.epoch % stage_every == 0) | |
| return is_last_epoch or is_epoch_every | |
| def run_step(self, idx: int, batch: tp.Any, metrics: dict): | |
| """Perform one training or valid step on a given batch.""" | |
| ... | |
| def common_train_valid(self, dataset_split: str, **kwargs: tp.Any): | |
| """Common logic for train and valid stages.""" | |
| self.model.train(self.is_training) | |
| loader = self.dataloaders[dataset_split] | |
| # get a different order for distributed training, otherwise this will get ignored | |
| if flashy.distrib.world_size() > 1 \ | |
| and isinstance(loader.sampler, torch.utils.data.distributed.DistributedSampler): | |
| loader.sampler.set_epoch(self.epoch) | |
| updates_per_epoch = self.train_updates_per_epoch if self.is_training else len(loader) | |
| if self.cfg.benchmark_no_load: | |
| self.logger.warning("Fake loading for benchmarking: re-using first batch") | |
| batch = next(iter(loader)) | |
| loader = [batch] * updates_per_epoch # type: ignore | |
| lp = self.log_progress(self.current_stage, loader, total=updates_per_epoch, updates=self.log_updates) | |
| average = flashy.averager() # epoch wise average | |
| instant_average = flashy.averager() # average between two logging | |
| metrics: dict = {} | |
| with self.profiler, self.deadlock_detect: # profiler will only run for the first 20 updates. | |
| for idx, batch in enumerate(lp): | |
| self.deadlock_detect.update('batch') | |
| if idx >= updates_per_epoch: | |
| break | |
| metrics = {} | |
| metrics = self.run_step(idx, batch, metrics) | |
| self.deadlock_detect.update('step') | |
| # run EMA step | |
| if self.ema is not None and self.is_training and (idx + 1) % self.cfg.optim.ema.updates == 0: | |
| self.logger.debug("EMA model step") | |
| self.ema.step() | |
| self.deadlock_detect.update('ema') | |
| self.profiler.step() | |
| instant_metrics = instant_average(metrics) | |
| if lp.update(**instant_metrics): | |
| instant_average = flashy.averager() # reset averager between two logging | |
| metrics = average(metrics) # epoch wise average | |
| self.deadlock_detect.update('end_batch') | |
| metrics = flashy.distrib.average_metrics(metrics, updates_per_epoch) | |
| return metrics | |
| def train(self): | |
| """Train stage.""" | |
| return self.common_train_valid('train') | |
| def valid(self): | |
| """Valid stage.""" | |
| return self.common_train_valid('valid') | |
| def evaluate(self): | |
| """Evaluate stage.""" | |
| ... | |
| def generate(self): | |
| """Generate stage.""" | |
| ... | |
| def run_one_stage(self, stage_name: str): | |
| """Run only the specified stage. | |
| This method is useful to only generate samples from a trained experiment | |
| or rerun the validation or evaluation stages. | |
| """ | |
| fn = { | |
| 'generate': with_rank_rng()(self.generate), | |
| 'evaluate': self.evaluate, | |
| 'valid': self.valid, | |
| } | |
| if stage_name not in fn: | |
| raise ValueError(f'Trying to run stage {stage_name} is not supported.') | |
| assert len(self.state_dict()) > 0 | |
| self._start_epoch() | |
| with torch.no_grad(), self.swap_best_state(): | |
| self.run_stage(stage_name, fn[stage_name]) | |
| if not self.cfg.execute_inplace: | |
| self.commit(save_checkpoints=False) | |
| def get_eval_solver_from_sig(sig: str, dtype: tp.Optional[str] = None, | |
| device: tp.Optional[str] = None, autocast: bool = True, | |
| batch_size: tp.Optional[int] = None, | |
| override_cfg: tp.Optional[tp.Union[dict, omegaconf.DictConfig]] = None, | |
| **kwargs): | |
| """Mostly a convenience function around audiocraft.train.get_solver_from_sig, | |
| populating all the proper param, deactivating EMA, FSDP, loading the best state, | |
| basically all you need to get a solver ready to "play" with in single GPU mode | |
| and with minimal memory overhead. | |
| Args: | |
| sig (str): signature to load. | |
| dtype (str or None): potential dtype, as a string, i.e. 'float16'. | |
| device (str or None): potential device, as a string, i.e. 'cuda'. | |
| override_cfg (dict or omegaconf.DictConfig or None): potential device, as a string, i.e. 'cuda'. | |
| """ | |
| from audiocraft import train | |
| our_override_cfg: tp.Dict[str, tp.Any] = {'optim': {'ema': {'use': False}}} | |
| our_override_cfg['autocast'] = autocast | |
| if dtype is not None: | |
| our_override_cfg['dtype'] = dtype | |
| if device is not None: | |
| our_override_cfg['device'] = device | |
| if batch_size is not None: | |
| our_override_cfg['dataset'] = {'batch_size': batch_size} | |
| if override_cfg is None: | |
| override_cfg = {} | |
| override_cfg = omegaconf.OmegaConf.merge( | |
| omegaconf.DictConfig(override_cfg), omegaconf.DictConfig(our_override_cfg)) # type: ignore | |
| solver = train.get_solver_from_sig( | |
| sig, override_cfg=override_cfg, | |
| load_best=True, disable_fsdp=True, | |
| ignore_state_keys=['optimizer', 'ema'], **kwargs) | |
| solver.model.eval() | |
| return solver | |