| from collections.abc import Iterable | |
| import torch.nn as nn | |
| from torch.utils.checkpoint import checkpoint, checkpoint_sequential | |
| def set_grad_checkpoint(model, use_fp32_attention=False, gc_step=1): | |
| assert isinstance(model, nn.Module) | |
| def set_attr(module): | |
| module.grad_checkpointing = True | |
| module.fp32_attention = use_fp32_attention | |
| module.grad_checkpointing_step = gc_step | |
| model.apply(set_attr) | |
| def auto_grad_checkpoint(module, *args, **kwargs): | |
| if getattr(module, "grad_checkpointing", False): | |
| if not isinstance(module, Iterable): | |
| return checkpoint(module, *args, **kwargs) | |
| gc_step = module[0].grad_checkpointing_step | |
| return checkpoint_sequential(module, gc_step, *args, **kwargs) | |
| return module(*args, **kwargs) | |
