Xenova HF Staff commited on
Commit
55791dc
·
verified ·
1 Parent(s): c92ff1c

Upload ONNX export script

Browse files
Files changed (1) hide show
  1. export.py +325 -0
export.py ADDED
@@ -0,0 +1,325 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Tuple
2
+
3
+ import torch
4
+ from torch import nn
5
+ from torch.nn.functional import scaled_dot_product_attention
6
+
7
+ from transformers import (
8
+ PreTrainedModel,
9
+ PretrainedConfig,
10
+ )
11
+ from transformers.modeling_outputs import BaseModelOutput
12
+
13
+ from xformers.ops import SwiGLU
14
+
15
+
16
+ def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
17
+ """
18
+ Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
19
+
20
+ This function calculates a frequency tensor with complex exponentials using the given dimension 'dim'
21
+ and the end index 'end'. The 'theta' parameter scales the frequencies.
22
+ The returned tensor contains complex values in complex64 data type.
23
+
24
+ Adapted from https://github.com/facebookresearch/llama/blob/main/llama/model.py.
25
+
26
+ Args:
27
+ dim (int): Dimension of the frequency tensor.
28
+ end (int): End index for precomputing frequencies.
29
+ theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0.
30
+
31
+ Returns:
32
+ torch.Tensor: Precomputed frequency tensor with complex exponentials.
33
+ """
34
+
35
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
36
+ t = torch.arange(end, device=freqs.device)
37
+ freqs = torch.outer(t, freqs).float()
38
+ return torch.polar(torch.ones_like(freqs), freqs)
39
+
40
+
41
+ def apply_rotary_emb_real(
42
+ xq: torch.Tensor,
43
+ xk: torch.Tensor,
44
+ freqs_cis: Tuple[torch.Tensor, torch.Tensor],
45
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
46
+ """
47
+ Pure-real rotary embeddings.
48
+
49
+ xq, xk: (B, seq, n_heads, dim)
50
+ freqs_cis: (cos, sin), each of shape (B, seq, dim/2)
51
+ """
52
+ cos, sin = freqs_cis
53
+ # make (B, seq, 1, dim/2) so they broadcast to (B, seq, n_heads, dim/2)
54
+ cos = cos.unsqueeze(2)
55
+ sin = sin.unsqueeze(2)
56
+
57
+ # split even/odd dims
58
+ xq_even = xq[..., 0::2]
59
+ xq_odd = xq[..., 1::2]
60
+ xk_even = xk[..., 0::2]
61
+ xk_odd = xk[..., 1::2]
62
+
63
+ # apply the rotation formula:
64
+ q_rot_even = xq_even * cos - xq_odd * sin
65
+ q_rot_odd = xq_even * sin + xq_odd * cos
66
+ k_rot_even = xk_even * cos - xk_odd * sin
67
+ k_rot_odd = xk_even * sin + xk_odd * cos
68
+
69
+ # interleave even/odd back into last dim
70
+ xq_rot = torch.stack([q_rot_even, q_rot_odd], dim=-1).flatten(-2)
71
+ xk_rot = torch.stack([k_rot_even, k_rot_odd], dim=-1).flatten(-2)
72
+
73
+ return xq_rot.type_as(xq), xk_rot.type_as(xk)
74
+
75
+
76
+ class NeoBERTConfig(PretrainedConfig):
77
+ model_type = "neobert"
78
+
79
+ # All config parameters must have a default value.
80
+ def __init__(
81
+ self,
82
+ hidden_size: int = 768,
83
+ num_hidden_layers: int = 28,
84
+ num_attention_heads: int = 12,
85
+ intermediate_size: int = 3072,
86
+ embedding_init_range: float = 0.02,
87
+ decoder_init_range: float = 0.02,
88
+ norm_eps: float = 1e-06,
89
+ vocab_size: int = 30522,
90
+ pad_token_id: int = 0,
91
+ max_length: int = 1024,
92
+ **kwargs,
93
+ ):
94
+ super().__init__(**kwargs)
95
+
96
+ self.hidden_size = hidden_size
97
+ self.num_hidden_layers = num_hidden_layers
98
+ self.num_attention_heads = num_attention_heads
99
+ if hidden_size % num_attention_heads != 0:
100
+ raise ValueError("Hidden size must be divisible by the number of heads.")
101
+ self.dim_head = hidden_size // num_attention_heads
102
+ self.intermediate_size = intermediate_size
103
+ self.embedding_init_range = embedding_init_range
104
+ self.decoder_init_range = decoder_init_range
105
+ self.norm_eps = norm_eps
106
+ self.vocab_size = vocab_size
107
+ self.pad_token_id = pad_token_id
108
+ self.max_length = max_length
109
+ self.kwargs = kwargs
110
+
111
+
112
+ class EncoderBlock(nn.Module):
113
+ """Transformer encoder block."""
114
+
115
+ def __init__(self, config: NeoBERTConfig):
116
+ super().__init__()
117
+
118
+ self.config = config
119
+
120
+ # Attention
121
+ self.qkv = nn.Linear(in_features=config.hidden_size, out_features=config.hidden_size * 3, bias=False)
122
+ self.wo = nn.Linear(in_features=config.hidden_size, out_features=config.hidden_size, bias=False)
123
+
124
+ # Feedforward network
125
+ multiple_of = 8
126
+ intermediate_size = int(2 * config.intermediate_size / 3)
127
+ intermediate_size = multiple_of * ((intermediate_size + multiple_of - 1) // multiple_of)
128
+ self.ffn = SwiGLU(config.hidden_size, intermediate_size, config.hidden_size, bias=False)
129
+
130
+ # Layer norms
131
+ self.attention_norm = nn.RMSNorm(config.hidden_size, config.norm_eps)
132
+ self.ffn_norm = nn.RMSNorm(config.hidden_size, config.norm_eps)
133
+
134
+ def forward(
135
+ self,
136
+ x: torch.Tensor,
137
+ attention_mask: torch.Tensor,
138
+ freqs_cis: Tuple[torch.Tensor, torch.Tensor],
139
+ output_attentions: bool,
140
+ ):
141
+ # Attention
142
+ attn_output, attn_weights = self._att_block(
143
+ self.attention_norm(x), attention_mask, freqs_cis, output_attentions,
144
+ )
145
+
146
+ # Residual
147
+ x = x + attn_output
148
+
149
+ # Feed-forward
150
+ x = x + self.ffn(self.ffn_norm(x))
151
+
152
+ return x, attn_weights
153
+
154
+ def _att_block(
155
+ self,
156
+ x: torch.Tensor,
157
+ attention_mask: torch.Tensor,
158
+ freqs_cis: Tuple[torch.Tensor, torch.Tensor],
159
+ output_attentions: bool,
160
+ ):
161
+ batch_size, seq_len, _ = x.shape
162
+
163
+ xq, xk, xv = self.qkv(x).view(batch_size, seq_len, self.config.num_attention_heads, self.config.dim_head * 3).chunk(3, axis=-1)
164
+
165
+ xq, xk = apply_rotary_emb_real(xq, xk, freqs_cis)
166
+
167
+ # Attn block
168
+ attn_weights = None
169
+
170
+ # Eager attention if attention weights are needed in the output
171
+ if output_attentions:
172
+ attn_weights = xq.permute(0, 2, 1, 3) @ xk.permute(0, 2, 3, 1) / (xq.size(-1) ** 0.5)
173
+ if attention_mask is not None:
174
+ attn_weights = attn_weights * attention_mask
175
+ attn_weights = attn_weights.softmax(-1)
176
+ attn = attn_weights @ xv.permute(0, 2, 1, 3)
177
+ attn = attn.transpose(1, 2)
178
+ # Fall back to SDPA otherwise
179
+ else:
180
+ attn = scaled_dot_product_attention(
181
+ query=xq.transpose(1, 2),
182
+ key=xk.transpose(1, 2),
183
+ value=xv.transpose(1, 2),
184
+ attn_mask=attention_mask.bool(),
185
+ dropout_p=0,
186
+ ).transpose(1, 2)
187
+
188
+ return self.wo(attn.reshape(batch_size, seq_len, self.config.num_attention_heads * self.config.dim_head)), attn_weights
189
+
190
+
191
+ class NeoBERTPreTrainedModel(PreTrainedModel):
192
+ config_class = NeoBERTConfig
193
+ base_model_prefix = "model"
194
+ _supports_cache_class = True
195
+
196
+ def _init_weights(self, module):
197
+ if isinstance(module, nn.Linear):
198
+ module.weight.data.uniform_(-self.config.decoder_init_range, self.config.decoder_init_range)
199
+ elif isinstance(module, nn.Embedding):
200
+ module.weight.data.uniform_(-self.config.embedding_init_range, self.config.embedding_init_range)
201
+
202
+
203
+ class NeoBERT(NeoBERTPreTrainedModel):
204
+ config_class = NeoBERTConfig
205
+
206
+ def __init__(self, config: NeoBERTConfig):
207
+ super().__init__(config)
208
+
209
+ self.config = config
210
+
211
+ self.encoder = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
212
+
213
+ # Ensures freqs_cis is moved to the same devices as the model. Non-persistent buffers are not saved in the state_dict.
214
+ freqs_cis = precompute_freqs_cis(config.hidden_size // config.num_attention_heads, config.max_length)
215
+ self.register_buffer("freqs_cos", freqs_cis.real, persistent=False)
216
+ self.register_buffer("freqs_sin", freqs_cis.imag, persistent=False)
217
+
218
+ self.transformer_encoder = nn.ModuleList()
219
+ for _ in range(config.num_hidden_layers):
220
+ self.transformer_encoder.append(EncoderBlock(config))
221
+
222
+ self.layer_norm = nn.RMSNorm(config.hidden_size, config.norm_eps)
223
+
224
+ # Initialize weights and apply final processing
225
+ self.post_init()
226
+
227
+ def forward(
228
+ self,
229
+ input_ids: Optional[torch.Tensor] = None,
230
+ attention_mask: torch.Tensor = None,
231
+ position_ids: torch.Tensor = None,
232
+ inputs_embeds: Optional[torch.Tensor] = None,
233
+ output_hidden_states: bool = False,
234
+ output_attentions: bool = False,
235
+ **kwargs,
236
+ ):
237
+ # Initialize
238
+ hidden_states, attentions = [], []
239
+
240
+ if (input_ids is None) ^ (inputs_embeds is not None):
241
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
242
+
243
+ # Expand and repeat: (Batch, Length) -> (Batch, Heads, Length, Length)
244
+ if attention_mask is not None:
245
+ attention_mask = attention_mask.unsqueeze(1).unsqueeze(1).repeat(1, self.config.num_attention_heads, attention_mask.size(-1), 1)
246
+
247
+ # RoPE
248
+ freqs_cos = (
249
+ self.freqs_cos[position_ids]
250
+ if position_ids is not None
251
+ else self.freqs_cos[: (input_ids if input_ids is not None else inputs_embeds).shape[1]].unsqueeze(0)
252
+ )
253
+ freqs_sin = (
254
+ self.freqs_sin[position_ids]
255
+ if position_ids is not None
256
+ else self.freqs_sin[: (input_ids if input_ids is not None else inputs_embeds).shape[1]].unsqueeze(0)
257
+ )
258
+
259
+ # Embedding
260
+ x = self.encoder(input_ids) if input_ids is not None else inputs_embeds
261
+
262
+ # Transformer encoder
263
+ for layer in self.transformer_encoder:
264
+ x, attn = layer(x, attention_mask, (freqs_cos, freqs_sin), output_attentions)
265
+ if output_hidden_states:
266
+ hidden_states.append(x)
267
+ if output_attentions:
268
+ attentions.append(attn)
269
+
270
+ # Final normalization layer
271
+ x = self.layer_norm(x)
272
+
273
+ # Return the output of the last hidden layer
274
+ return BaseModelOutput(
275
+ last_hidden_state=x,
276
+ hidden_states=hidden_states if output_hidden_states else None,
277
+ attentions=attentions if output_attentions else None,
278
+ )
279
+
280
+ if __name__ == "__main__":
281
+ from transformers import AutoTokenizer
282
+
283
+ model_name = "chandar-lab/NeoBERT"
284
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
285
+ model = NeoBERT.from_pretrained(model_name)
286
+
287
+ # Tokenize input text
288
+ text = [
289
+ "NeoBERT is the most efficient model of its kind!",
290
+ "This is really cool",
291
+ ]
292
+ inputs = tokenizer(text, padding=True, return_tensors="pt")
293
+
294
+ # Generate embeddings
295
+ with torch.no_grad():
296
+ pytorch_outputs = model(**inputs)
297
+
298
+ # Export to ONNX
299
+ torch.onnx.export(
300
+ model,
301
+ (inputs['input_ids'], inputs['attention_mask']),
302
+ f="model.onnx",
303
+ export_params=True,
304
+ opset_version=20,
305
+ do_constant_folding=True,
306
+ input_names = ['input_ids', 'attention_mask'],
307
+ output_names = ['last_hidden_state'],
308
+ dynamic_axes = {
309
+ 'input_ids': {0: 'batch_size', 1: 'sequence_length'},
310
+ 'attention_mask': {0: 'batch_size', 1: 'sequence_length'},
311
+ 'last_hidden_state': {0: 'batch_size', 1: 'sequence_length'},
312
+ },
313
+ dynamo=True,
314
+ )
315
+
316
+ # Validate
317
+ import onnxruntime as ort
318
+ ort_session = ort.InferenceSession("model.onnx")
319
+ ort_inputs = {
320
+ "input_ids": inputs['input_ids'].numpy(),
321
+ "attention_mask": inputs['attention_mask'].numpy(),
322
+ }
323
+ ort_outputs = ort_session.run(None, ort_inputs)
324
+
325
+ assert (pytorch_outputs.last_hidden_state.numpy() - ort_outputs[0]).max() < 1e-3