FANformer-1B / initialization.py
dongyh's picture
Upload 15 files
55c82d2 verified
from typing import Optional, Union
import torch.nn as nn
__all__ = ["init_normal"]
def init_normal(
module: Union[nn.Linear, nn.Embedding],
std: float,
init_cutoff_factor: Optional[float] = None,
):
# weights
if init_cutoff_factor is not None:
cutoff_value = init_cutoff_factor * std
nn.init.trunc_normal_(module.weight, mean=0.0, std=std, a=-cutoff_value, b=cutoff_value)
else:
nn.init.normal_(module.weight, mean=0.0, std=std)
# biases
if isinstance(module, nn.Linear) and module.bias is not None:
nn.init.zeros_(module.bias)