Spaces:
Running
on
Zero
Running
on
Zero
| import torch | |
| import comfy.model_management | |
| import numbers | |
| import logging | |
| RMSNorm = None | |
| try: | |
| rms_norm_torch = torch.nn.functional.rms_norm | |
| RMSNorm = torch.nn.RMSNorm | |
| except: | |
| rms_norm_torch = None | |
| logging.warning("Please update pytorch to use native RMSNorm") | |
| def rms_norm(x, weight=None, eps=1e-6): | |
| if rms_norm_torch is not None and not (torch.jit.is_tracing() or torch.jit.is_scripting()): | |
| if weight is None: | |
| return rms_norm_torch(x, (x.shape[-1],), eps=eps) | |
| else: | |
| return rms_norm_torch(x, weight.shape, weight=comfy.model_management.cast_to(weight, dtype=x.dtype, device=x.device), eps=eps) | |
| else: | |
| r = x * torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + eps) | |
| if weight is None: | |
| return r | |
| else: | |
| return r * comfy.model_management.cast_to(weight, dtype=x.dtype, device=x.device) | |
| if RMSNorm is None: | |
| class RMSNorm(torch.nn.Module): | |
| def __init__( | |
| self, | |
| normalized_shape, | |
| eps=1e-6, | |
| elementwise_affine=True, | |
| device=None, | |
| dtype=None, | |
| ): | |
| factory_kwargs = {"device": device, "dtype": dtype} | |
| super().__init__() | |
| if isinstance(normalized_shape, numbers.Integral): | |
| # mypy error: incompatible types in assignment | |
| normalized_shape = (normalized_shape,) # type: ignore[assignment] | |
| self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type] | |
| self.eps = eps | |
| self.elementwise_affine = elementwise_affine | |
| if self.elementwise_affine: | |
| self.weight = torch.nn.Parameter( | |
| torch.empty(self.normalized_shape, **factory_kwargs) | |
| ) | |
| else: | |
| self.register_parameter("weight", None) | |
| self.bias = None | |
| def forward(self, x): | |
| return rms_norm(x, self.weight, self.eps) | |