|
from typing import Optional, Union |
|
|
|
import torch.nn as nn |
|
|
|
__all__ = ["init_normal"] |
|
|
|
|
|
def init_normal( |
|
module: Union[nn.Linear, nn.Embedding], |
|
std: float, |
|
init_cutoff_factor: Optional[float] = None, |
|
): |
|
|
|
if init_cutoff_factor is not None: |
|
cutoff_value = init_cutoff_factor * std |
|
nn.init.trunc_normal_(module.weight, mean=0.0, std=std, a=-cutoff_value, b=cutoff_value) |
|
else: |
|
nn.init.normal_(module.weight, mean=0.0, std=std) |
|
|
|
|
|
if isinstance(module, nn.Linear) and module.bias is not None: |
|
nn.init.zeros_(module.bias) |
|
|