pr-include-rev-in-flake

#1
by drbh HF staff - opened
build/torch-universal/triton_layer_norm/layers.py CHANGED
@@ -1,24 +1,4 @@
1
- import torch
2
- from torch import nn
3
 
4
- from .layer_norm import rms_norm_fn
5
 
6
-
7
- class LlamaRMSNorm(nn.Module):
8
- weight: torch.Tensor
9
- variance_epsilon: float
10
-
11
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
12
- return rms_norm_fn(
13
- hidden_states,
14
- self.weight,
15
- bias=None,
16
- residual=None,
17
- eps=self.variance_epsilon,
18
- dropout_p=0.0,
19
- prenorm=False,
20
- residual_in_fp32=False,
21
- )
22
-
23
-
24
- __all__ = ["LlamaRMSNorm"]
 
1
+ from .layer_norm import RMSNorm
 
2
 
 
3
 
4
+ __all__ = ["RMSNorm"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
flake.nix CHANGED
@@ -10,8 +10,5 @@
10
  self,
11
  kernel-builder,
12
  }:
13
- kernel-builder.lib.genFlakeOutputs {
14
- path = ./.;
15
- rev = self.shortRev or self.dirtyShortRev or self.lastModifiedDate;
16
- };
17
  }
 
10
  self,
11
  kernel-builder,
12
  }:
13
+ kernel-builder.lib.genFlakeOutputs ./.;
 
 
 
14
  }
torch-ext/triton_layer_norm/layers.py CHANGED
@@ -1,24 +1,4 @@
1
- import torch
2
- from torch import nn
3
 
4
- from .layer_norm import rms_norm_fn
5
 
6
-
7
- class LlamaRMSNorm(nn.Module):
8
- weight: torch.Tensor
9
- variance_epsilon: float
10
-
11
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
12
- return rms_norm_fn(
13
- hidden_states,
14
- self.weight,
15
- bias=None,
16
- residual=None,
17
- eps=self.variance_epsilon,
18
- dropout_p=0.0,
19
- prenorm=False,
20
- residual_in_fp32=False,
21
- )
22
-
23
-
24
- __all__ = ["LlamaRMSNorm"]
 
1
+ from .layer_norm import RMSNorm
 
2
 
 
3
 
4
+ __all__ = ["RMSNorm"]