import torch import torch.nn as nn class RMSNorm(nn.Module): def __init__(self, dimensions: int, eps: float, device: torch.device, dtype: torch.dtype = torch.bfloat16, norm_in_fp32: bool = False): super().__init__() self.eps = eps self.weight = torch.nn.Parameter(torch.ones(dimensions, dtype=dtype).to(device)) self.norm_in_fp32 = norm_in_fp32 def forward(self, x: torch.Tensor) -> torch.Tensor: original_dtype = x.dtype if self.norm_in_fp32: x = x.float() out = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) if out.dtype != original_dtype: out = out.to(original_dtype) return out * self.weight