medmekk HF Staff commited on
Commit
8f291c8
·
verified ·
1 Parent(s): 3de0c84

Upload custom kernels

Browse files
build/torch-universal/liger_kernels/_ops.py CHANGED
@@ -1,8 +1,8 @@
1
  import torch
2
- ops = torch.ops._liger_kernels_20250507091553
3
 
4
  def add_op_namespace_prefix(op_name: str):
5
  """
6
  Prefix op by namespace.
7
  """
8
- return f"_liger_kernels_20250507091553::{op_name}"
 
1
  import torch
2
+ ops = torch.ops._liger_kernels_20250507091832
3
 
4
  def add_op_namespace_prefix(op_name: str):
5
  """
6
  Prefix op by namespace.
7
  """
8
+ return f"_liger_kernels_20250507091832::{op_name}"
build/torch-universal/liger_kernels/layers.py CHANGED
@@ -16,9 +16,6 @@ class LigerRMSNorm(torch.nn.Module):
16
 
17
  weight: torch.Tensor
18
  variance_epsilon: float
19
- offset: float = 0
20
- casting_mode: str = "llama"
21
- in_place: bool = True
22
 
23
  def forward(self, hidden_states):
24
  """
@@ -34,9 +31,9 @@ class LigerRMSNorm(torch.nn.Module):
34
  hidden_states,
35
  self.weight,
36
  self.variance_epsilon,
37
- self.offset,
38
- self.casting_mode,
39
- self.in_place
40
  )
41
 
42
  __all__ = ["LigerRMSNorm"]
 
16
 
17
  weight: torch.Tensor
18
  variance_epsilon: float
 
 
 
19
 
20
  def forward(self, hidden_states):
21
  """
 
31
  hidden_states,
32
  self.weight,
33
  self.variance_epsilon,
34
+ 0,
35
+ "llama",
36
+ True
37
  )
38
 
39
  __all__ = ["LigerRMSNorm"]
torch-ext/liger_kernels/layers.py CHANGED
@@ -16,9 +16,6 @@ class LigerRMSNorm(torch.nn.Module):
16
 
17
  weight: torch.Tensor
18
  variance_epsilon: float
19
- offset: float = 0
20
- casting_mode: str = "llama"
21
- in_place: bool = True
22
 
23
  def forward(self, hidden_states):
24
  """
@@ -34,9 +31,9 @@ class LigerRMSNorm(torch.nn.Module):
34
  hidden_states,
35
  self.weight,
36
  self.variance_epsilon,
37
- self.offset,
38
- self.casting_mode,
39
- self.in_place
40
  )
41
 
42
  __all__ = ["LigerRMSNorm"]
 
16
 
17
  weight: torch.Tensor
18
  variance_epsilon: float
 
 
 
19
 
20
  def forward(self, hidden_states):
21
  """
 
31
  hidden_states,
32
  self.weight,
33
  self.variance_epsilon,
34
+ 0,
35
+ "llama",
36
+ True
37
  )
38
 
39
  __all__ = ["LigerRMSNorm"]