from typing import Optional, Union import torch.nn as nn __all__ = ["init_normal"] def init_normal( module: Union[nn.Linear, nn.Embedding], std: float, init_cutoff_factor: Optional[float] = None, ): # weights if init_cutoff_factor is not None: cutoff_value = init_cutoff_factor * std nn.init.trunc_normal_(module.weight, mean=0.0, std=std, a=-cutoff_value, b=cutoff_value) else: nn.init.normal_(module.weight, mean=0.0, std=std) # biases if isinstance(module, nn.Linear) and module.bias is not None: nn.init.zeros_(module.bias)