gagan3012 commited on
Commit
b1fc84a
·
verified ·
1 Parent(s): 9eaa1e0

Upload folder using huggingface_hub

Browse files
Files changed (2) hide show
  1. flake.lock +168 -0
  2. torch-ext/batch_invariant/__init__.py +168 -0
flake.lock ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "nodes": {
3
+ "flake-compat": {
4
+ "locked": {
5
+ "lastModified": 1747046372,
6
+ "narHash": "sha256-CIVLLkVgvHYbgI2UpXvIIBJ12HWgX+fjA8Xf8PUmqCY=",
7
+ "owner": "edolstra",
8
+ "repo": "flake-compat",
9
+ "rev": "9100a0f413b0c601e0533d1d94ffd501ce2e7885",
10
+ "type": "github"
11
+ },
12
+ "original": {
13
+ "owner": "edolstra",
14
+ "repo": "flake-compat",
15
+ "type": "github"
16
+ }
17
+ },
18
+ "flake-compat_2": {
19
+ "locked": {
20
+ "lastModified": 1747046372,
21
+ "narHash": "sha256-CIVLLkVgvHYbgI2UpXvIIBJ12HWgX+fjA8Xf8PUmqCY=",
22
+ "owner": "edolstra",
23
+ "repo": "flake-compat",
24
+ "rev": "9100a0f413b0c601e0533d1d94ffd501ce2e7885",
25
+ "type": "github"
26
+ },
27
+ "original": {
28
+ "owner": "edolstra",
29
+ "repo": "flake-compat",
30
+ "type": "github"
31
+ }
32
+ },
33
+ "flake-utils": {
34
+ "inputs": {
35
+ "systems": "systems"
36
+ },
37
+ "locked": {
38
+ "lastModified": 1731533236,
39
+ "narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=",
40
+ "owner": "numtide",
41
+ "repo": "flake-utils",
42
+ "rev": "11707dc2f618dd54ca8739b309ec4fc024de578b",
43
+ "type": "github"
44
+ },
45
+ "original": {
46
+ "owner": "numtide",
47
+ "repo": "flake-utils",
48
+ "type": "github"
49
+ }
50
+ },
51
+ "flake-utils_2": {
52
+ "inputs": {
53
+ "systems": "systems_2"
54
+ },
55
+ "locked": {
56
+ "lastModified": 1731533236,
57
+ "narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=",
58
+ "owner": "numtide",
59
+ "repo": "flake-utils",
60
+ "rev": "11707dc2f618dd54ca8739b309ec4fc024de578b",
61
+ "type": "github"
62
+ },
63
+ "original": {
64
+ "owner": "numtide",
65
+ "repo": "flake-utils",
66
+ "type": "github"
67
+ }
68
+ },
69
+ "hf-nix": {
70
+ "inputs": {
71
+ "flake-compat": "flake-compat_2",
72
+ "flake-utils": "flake-utils_2",
73
+ "nixpkgs": "nixpkgs"
74
+ },
75
+ "locked": {
76
+ "lastModified": 1756316789,
77
+ "narHash": "sha256-DJvw0l+PXeFq963L3sbqAQKjIwGPae+yWpZHraFES28=",
78
+ "owner": "huggingface",
79
+ "repo": "hf-nix",
80
+ "rev": "57ea72ac74c89331005c47bb082b28cef653bed8",
81
+ "type": "github"
82
+ },
83
+ "original": {
84
+ "owner": "huggingface",
85
+ "repo": "hf-nix",
86
+ "type": "github"
87
+ }
88
+ },
89
+ "kernel-builder": {
90
+ "inputs": {
91
+ "flake-compat": "flake-compat",
92
+ "flake-utils": "flake-utils",
93
+ "hf-nix": "hf-nix",
94
+ "nixpkgs": [
95
+ "kernel-builder",
96
+ "hf-nix",
97
+ "nixpkgs"
98
+ ]
99
+ },
100
+ "locked": {
101
+ "lastModified": 1757060761,
102
+ "narHash": "sha256-aKGP9jgV6N8aRF7jR3OnYSBmOa6C6u4ULRpvcThgFck=",
103
+ "owner": "huggingface",
104
+ "repo": "kernel-builder",
105
+ "rev": "08fcbf386981dc0fb7e47679d3ba86d77a33721b",
106
+ "type": "github"
107
+ },
108
+ "original": {
109
+ "owner": "huggingface",
110
+ "repo": "kernel-builder",
111
+ "type": "github"
112
+ }
113
+ },
114
+ "nixpkgs": {
115
+ "locked": {
116
+ "lastModified": 1755963616,
117
+ "narHash": "sha256-6yD0ww/S8n+U2uPYcJZ3DRURP8Kx036GRpR2uPNZroE=",
118
+ "owner": "nixos",
119
+ "repo": "nixpkgs",
120
+ "rev": "73e96df7cff5783f45e21342a75a1540c4eddce4",
121
+ "type": "github"
122
+ },
123
+ "original": {
124
+ "owner": "nixos",
125
+ "ref": "nixos-unstable-small",
126
+ "repo": "nixpkgs",
127
+ "type": "github"
128
+ }
129
+ },
130
+ "root": {
131
+ "inputs": {
132
+ "kernel-builder": "kernel-builder"
133
+ }
134
+ },
135
+ "systems": {
136
+ "locked": {
137
+ "lastModified": 1681028828,
138
+ "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
139
+ "owner": "nix-systems",
140
+ "repo": "default",
141
+ "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
142
+ "type": "github"
143
+ },
144
+ "original": {
145
+ "owner": "nix-systems",
146
+ "repo": "default",
147
+ "type": "github"
148
+ }
149
+ },
150
+ "systems_2": {
151
+ "locked": {
152
+ "lastModified": 1681028828,
153
+ "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
154
+ "owner": "nix-systems",
155
+ "repo": "default",
156
+ "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
157
+ "type": "github"
158
+ },
159
+ "original": {
160
+ "owner": "nix-systems",
161
+ "repo": "default",
162
+ "type": "github"
163
+ }
164
+ }
165
+ },
166
+ "root": "root",
167
+ "version": 7
168
+ }
torch-ext/batch_invariant/__init__.py CHANGED
@@ -1,4 +1,6 @@
1
  import torch
 
 
2
  from ._ops import ops
3
 
4
 
@@ -118,3 +120,169 @@ def mean_batch_invariant(input, dim, keepdim=False, dtype: torch.dtype = None):
118
  for d in dim:
119
  n_elems *= input.shape[d]
120
  return torch.sum(input, dim=dim, keepdim=keepdim, dtype=torch.float32) / n_elems
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import torch
2
+ import torch.nn as nn
3
+ import math
4
  from ._ops import ops
5
 
6
 
 
120
  for d in dim:
121
  n_elems *= input.shape[d]
122
  return torch.sum(input, dim=dim, keepdim=keepdim, dtype=torch.float32) / n_elems
123
+
124
+ class BatchInvariantAttention(nn.Module):
125
+ """
126
+ Batch invariant multi-head attention implementation.
127
+ Compatible with transformers library integration.
128
+ """
129
+
130
+ def __init__(self, config):
131
+ super().__init__()
132
+ self.config = config
133
+ self.hidden_size = config.hidden_size
134
+ self.num_heads = config.num_attention_heads
135
+ self.head_dim = self.hidden_size // self.num_heads
136
+ self.max_position_embeddings = getattr(config, "max_position_embeddings", 2048)
137
+
138
+ if (self.head_dim * self.num_heads) != self.hidden_size:
139
+ raise ValueError(
140
+ f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
141
+ f" and `num_heads`: {self.num_heads})."
142
+ )
143
+
144
+ # Linear projections
145
+ self.q_proj = nn.Linear(
146
+ self.hidden_size, self.num_heads * self.head_dim, bias=False
147
+ )
148
+ self.k_proj = nn.Linear(
149
+ self.hidden_size, self.num_heads * self.head_dim, bias=False
150
+ )
151
+ self.v_proj = nn.Linear(
152
+ self.hidden_size, self.num_heads * self.head_dim, bias=False
153
+ )
154
+ self.o_proj = nn.Linear(
155
+ self.num_heads * self.head_dim, self.hidden_size, bias=False
156
+ )
157
+
158
+ def forward(
159
+ self,
160
+ hidden_states: torch.Tensor,
161
+ attention_mask: torch.Tensor = None,
162
+ position_ids: torch.Tensor = None,
163
+ past_key_value=None,
164
+ output_attentions: bool = False,
165
+ use_cache: bool = False,
166
+ cache_position: torch.Tensor = None,
167
+ **kwargs,
168
+ ):
169
+ batch_size, seq_len, _ = hidden_states.size()
170
+
171
+ # Project to Q, K, V using batch invariant matrix multiplication
172
+ query_states = self._batch_invariant_linear(hidden_states, self.q_proj.weight)
173
+ key_states = self._batch_invariant_linear(hidden_states, self.k_proj.weight)
174
+ value_states = self._batch_invariant_linear(hidden_states, self.v_proj.weight)
175
+
176
+ # Reshape for multi-head attention
177
+ query_states = query_states.view(
178
+ batch_size, seq_len, self.num_heads, self.head_dim
179
+ ).transpose(1, 2)
180
+ key_states = key_states.view(
181
+ batch_size, seq_len, self.num_heads, self.head_dim
182
+ ).transpose(1, 2)
183
+ value_states = value_states.view(
184
+ batch_size, seq_len, self.num_heads, self.head_dim
185
+ ).transpose(1, 2)
186
+
187
+ # Compute attention scores
188
+ attn_weights = torch.matmul(
189
+ query_states, key_states.transpose(2, 3)
190
+ ) / math.sqrt(self.head_dim)
191
+
192
+ # Apply attention mask if provided
193
+ if attention_mask is not None:
194
+ attn_weights = attn_weights + attention_mask
195
+
196
+ # Apply softmax using batch invariant log_softmax
197
+ attn_weights_log = log_softmax(attn_weights, dim=-1)
198
+ attn_weights = torch.exp(attn_weights_log)
199
+
200
+ # Apply attention to values
201
+ attn_output = torch.matmul(attn_weights, value_states)
202
+
203
+ # Reshape and apply output projection
204
+ attn_output = attn_output.transpose(1, 2).contiguous()
205
+ attn_output = attn_output.reshape(batch_size, seq_len, self.hidden_size)
206
+ attn_output = self._batch_invariant_linear(attn_output, self.o_proj.weight)
207
+
208
+ outputs = (attn_output,)
209
+ if output_attentions:
210
+ outputs += (attn_weights,)
211
+ if use_cache:
212
+ outputs += (past_key_value,)
213
+
214
+ return outputs
215
+
216
+ def _batch_invariant_linear(
217
+ self, input_tensor: torch.Tensor, weight: torch.Tensor
218
+ ) -> torch.Tensor:
219
+ """Apply linear transformation using batch invariant matrix multiplication"""
220
+ original_shape = input_tensor.shape
221
+ input_2d = input_tensor.view(-1, original_shape[-1])
222
+ output_2d = matmul_persistent(input_2d, weight.t())
223
+ return output_2d.view(*original_shape[:-1], -1)
224
+
225
+
226
+ class BatchInvariantMLP(nn.Module):
227
+ """
228
+ Batch invariant MLP implementation.
229
+ """
230
+
231
+ def __init__(self, config):
232
+ super().__init__()
233
+ self.config = config
234
+ self.hidden_size = config.hidden_size
235
+ self.intermediate_size = config.intermediate_size
236
+
237
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
238
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
239
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
240
+ self.act_fn = (
241
+ nn.SiLU()
242
+ ) # or whatever activation function is specified in config
243
+
244
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
245
+ # Use batch invariant matrix multiplication for projections
246
+ gate = self._batch_invariant_linear(x, self.gate_proj.weight)
247
+ up = self._batch_invariant_linear(x, self.up_proj.weight)
248
+
249
+ # Apply activation
250
+ intermediate = self.act_fn(gate) * up
251
+
252
+ # Down projection
253
+ output = self._batch_invariant_linear(intermediate, self.down_proj.weight)
254
+ return output
255
+
256
+ def _batch_invariant_linear(
257
+ self, input_tensor: torch.Tensor, weight: torch.Tensor
258
+ ) -> torch.Tensor:
259
+ """Apply linear transformation using batch invariant matrix multiplication"""
260
+ original_shape = input_tensor.shape
261
+ input_2d = input_tensor.view(-1, original_shape[-1])
262
+ output_2d = matmul_persistent(input_2d, weight.t())
263
+ return output_2d.view(*original_shape[:-1], -1)
264
+
265
+
266
+ class BatchInvariantRMSNorm(nn.Module):
267
+ """
268
+ Batch invariant RMS normalization implementation.
269
+ """
270
+
271
+ def __init__(self, hidden_size, eps=1e-6):
272
+ super().__init__()
273
+ self.weight = nn.Parameter(torch.ones(hidden_size))
274
+ self.variance_epsilon = eps
275
+
276
+ def forward(self, hidden_states):
277
+ input_dtype = hidden_states.dtype
278
+ hidden_states = hidden_states.to(torch.float32)
279
+
280
+ # Compute mean square using batch invariant mean
281
+ variance = mean_dim(hidden_states.pow(2), dim=-1, keepdim=True)
282
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
283
+
284
+ return self.weight * hidden_states.to(input_dtype)
285
+
286
+
287
+ # Export the layer classes
288
+ __all__ += ["BatchInvariantAttention", "BatchInvariantMLP", "BatchInvariantRMSNorm"]