pr-include-rev-in-flake
#1
by
drbh
HF staff
- opened
build/torch-universal/triton_layer_norm/layers.py
CHANGED
@@ -1,24 +1,4 @@
|
|
1 |
-
import
|
2 |
-
from torch import nn
|
3 |
|
4 |
-
from .layer_norm import rms_norm_fn
|
5 |
|
6 |
-
|
7 |
-
class LlamaRMSNorm(nn.Module):
|
8 |
-
weight: torch.Tensor
|
9 |
-
variance_epsilon: float
|
10 |
-
|
11 |
-
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
12 |
-
return rms_norm_fn(
|
13 |
-
hidden_states,
|
14 |
-
self.weight,
|
15 |
-
bias=None,
|
16 |
-
residual=None,
|
17 |
-
eps=self.variance_epsilon,
|
18 |
-
dropout_p=0.0,
|
19 |
-
prenorm=False,
|
20 |
-
residual_in_fp32=False,
|
21 |
-
)
|
22 |
-
|
23 |
-
|
24 |
-
__all__ = ["LlamaRMSNorm"]
|
|
|
1 |
+
from .layer_norm import RMSNorm
|
|
|
2 |
|
|
|
3 |
|
4 |
+
__all__ = ["RMSNorm"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
flake.nix
CHANGED
@@ -10,8 +10,5 @@
|
|
10 |
self,
|
11 |
kernel-builder,
|
12 |
}:
|
13 |
-
kernel-builder.lib.genFlakeOutputs
|
14 |
-
path = ./.;
|
15 |
-
rev = self.shortRev or self.dirtyShortRev or self.lastModifiedDate;
|
16 |
-
};
|
17 |
}
|
|
|
10 |
self,
|
11 |
kernel-builder,
|
12 |
}:
|
13 |
+
kernel-builder.lib.genFlakeOutputs ./.;
|
|
|
|
|
|
|
14 |
}
|
torch-ext/triton_layer_norm/layers.py
CHANGED
@@ -1,24 +1,4 @@
|
|
1 |
-
import
|
2 |
-
from torch import nn
|
3 |
|
4 |
-
from .layer_norm import rms_norm_fn
|
5 |
|
6 |
-
|
7 |
-
class LlamaRMSNorm(nn.Module):
|
8 |
-
weight: torch.Tensor
|
9 |
-
variance_epsilon: float
|
10 |
-
|
11 |
-
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
12 |
-
return rms_norm_fn(
|
13 |
-
hidden_states,
|
14 |
-
self.weight,
|
15 |
-
bias=None,
|
16 |
-
residual=None,
|
17 |
-
eps=self.variance_epsilon,
|
18 |
-
dropout_p=0.0,
|
19 |
-
prenorm=False,
|
20 |
-
residual_in_fp32=False,
|
21 |
-
)
|
22 |
-
|
23 |
-
|
24 |
-
__all__ = ["LlamaRMSNorm"]
|
|
|
1 |
+
from .layer_norm import RMSNorm
|
|
|
2 |
|
|
|
3 |
|
4 |
+
__all__ = ["RMSNorm"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|