Stanislas commited on
Commit
1be049c
1 Parent(s): 5055277

Fix precision error

Browse files
Files changed (1) hide show
  1. modeling_chatglm.py +9 -7
modeling_chatglm.py CHANGED
@@ -3,9 +3,7 @@
3
  import math
4
  import copy
5
  import warnings
6
- import re
7
  import sys
8
-
9
  import torch
10
  import torch.utils.checkpoint
11
  import torch.nn.functional as F
@@ -183,9 +181,14 @@ class RMSNorm(torch.nn.Module):
183
  self.eps = eps
184
 
185
  def forward(self, hidden_states: torch.Tensor):
186
- input_dtype = hidden_states.dtype
187
- variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
188
- hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
 
 
 
 
 
189
 
190
  return (self.weight * hidden_states).to(input_dtype)
191
 
@@ -517,8 +520,7 @@ class GLMBlock(torch.nn.Module):
517
 
518
  LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm
519
  # Layernorm on the input data.
520
- self.input_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device,
521
- dtype=config.torch_dtype)
522
 
523
  # Self attention.
524
  self.self_attention = SelfAttention(config, layer_number, device=device)
 
3
  import math
4
  import copy
5
  import warnings
 
6
  import sys
 
7
  import torch
8
  import torch.utils.checkpoint
9
  import torch.nn.functional as F
 
181
  self.eps = eps
182
 
183
  def forward(self, hidden_states: torch.Tensor):
184
+ if hidden_states == torch.bfloat16:
185
+ norm_x = torch.mean(hidden_states * hidden_states, dim=-1, keepdim=True)
186
+ x_normed = hidden_states * torch.rsqrt(norm_x + self.eps)
187
+ return self.weight * x_normed
188
+ else:
189
+ input_dtype = hidden_states.dtype
190
+ variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
191
+ hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
192
 
193
  return (self.weight * hidden_states).to(input_dtype)
194
 
 
520
 
521
  LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm
522
  # Layernorm on the input data.
523
+ self.input_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device, dtype=config.torch_dtype)
 
524
 
525
  # Self attention.
526
  self.self_attention = SelfAttention(config, layer_number, device=device)