|
|
|
|
|
|
|
|
|
|
|
|
|
import torch.nn as nn |
|
|
|
|
|
def build_norm(norm_type: str, dim: int, eps: float = 1e-6): |
|
""" |
|
Builds the specified normalization layer based on the norm_type. |
|
|
|
Args: |
|
norm_type (str): The type of normalization layer to build. |
|
Supported types: layernorm, np_layernorm, rmsnorm |
|
dim (int): The dimension of the normalization layer. |
|
eps (float, optional): The epsilon value for numerical stability. Defaults to 1e-6. |
|
|
|
Returns: |
|
The built normalization layer. |
|
|
|
Raises: |
|
NotImplementedError: If an unknown norm_type is provided. |
|
""" |
|
norm_type = norm_type.lower() |
|
|
|
if norm_type == "layernorm": |
|
return nn.LayerNorm(dim, eps=eps, bias=False) |
|
elif norm_type == "np_layernorm": |
|
return nn.LayerNorm(dim, eps=eps, elementwise_affine=False, bias=False) |
|
elif norm_type == "rmsnorm": |
|
return nn.RMSNorm(dim, eps=eps) |
|
else: |
|
raise NotImplementedError(f"Unknown norm_type: '{norm_type}'") |
|
|