Chenghao-Qiu commited on
Commit
2134860
·
verified ·
1 Parent(s): 47180e3

Upload folder using huggingface_hub

Browse files
Files changed (49) hide show
  1. checkpoint-1000/config.json +37 -0
  2. checkpoint-1000/configuration_hyena.py +88 -0
  3. checkpoint-1000/modeling_hyena.py +574 -0
  4. checkpoint-1000/optimizer.pt +3 -0
  5. checkpoint-1000/pytorch_model.bin +3 -0
  6. checkpoint-1000/rng_state.pth +3 -0
  7. checkpoint-1000/scaler.pt +3 -0
  8. checkpoint-1000/scheduler.pt +3 -0
  9. checkpoint-1000/special_tokens_map.json +51 -0
  10. checkpoint-1000/tokenization_hyena.py +117 -0
  11. checkpoint-1000/tokenizer_config.json +72 -0
  12. checkpoint-1000/trainer_state.json +141 -0
  13. checkpoint-1000/training_args.bin +3 -0
  14. checkpoint-600/config.json +37 -0
  15. checkpoint-600/configuration_hyena.py +88 -0
  16. checkpoint-600/modeling_hyena.py +574 -0
  17. checkpoint-600/optimizer.pt +3 -0
  18. checkpoint-600/pytorch_model.bin +3 -0
  19. checkpoint-600/rng_state.pth +3 -0
  20. checkpoint-600/scaler.pt +3 -0
  21. checkpoint-600/scheduler.pt +3 -0
  22. checkpoint-600/special_tokens_map.json +51 -0
  23. checkpoint-600/tokenization_hyena.py +117 -0
  24. checkpoint-600/tokenizer_config.json +72 -0
  25. checkpoint-600/trainer_state.json +91 -0
  26. checkpoint-600/training_args.bin +3 -0
  27. checkpoint-800/config.json +37 -0
  28. checkpoint-800/configuration_hyena.py +88 -0
  29. checkpoint-800/modeling_hyena.py +574 -0
  30. checkpoint-800/optimizer.pt +3 -0
  31. checkpoint-800/pytorch_model.bin +3 -0
  32. checkpoint-800/rng_state.pth +3 -0
  33. checkpoint-800/scaler.pt +3 -0
  34. checkpoint-800/scheduler.pt +3 -0
  35. checkpoint-800/special_tokens_map.json +51 -0
  36. checkpoint-800/tokenization_hyena.py +117 -0
  37. checkpoint-800/tokenizer_config.json +72 -0
  38. checkpoint-800/trainer_state.json +116 -0
  39. checkpoint-800/training_args.bin +3 -0
  40. config.json +37 -0
  41. configuration_hyena.py +88 -0
  42. modeling_hyena.py +574 -0
  43. optimizer_state_dict.pth +3 -0
  44. pytorch_model.bin +3 -0
  45. special_tokens_map.json +51 -0
  46. tokenization_hyena.py +117 -0
  47. tokenizer_config.json +72 -0
  48. trainer_state.json +156 -0
  49. training_args.bin +3 -0
checkpoint-1000/config.json ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "LongSafari/hyenadna-small-32k-seqlen-hf",
3
+ "activation_freq": 10,
4
+ "architectures": [
5
+ "HyenaDNAForSequenceClassification"
6
+ ],
7
+ "auto_map": {
8
+ "AutoConfig": "configuration_hyena.HyenaConfig",
9
+ "AutoModel": "modeling_hyena.HyenaDNAModel",
10
+ "AutoModelForCausalLM": "modeling_hyena.HyenaDNAForCausalLM",
11
+ "AutoModelForSequenceClassification": "modeling_hyena.HyenaDNAForSequenceClassification"
12
+ },
13
+ "d_inner": 1024,
14
+ "d_model": 256,
15
+ "emb_dim": 5,
16
+ "embed_dropout": 0.1,
17
+ "filter_order": 64,
18
+ "hyena_dropout": 0.0,
19
+ "hyena_filter_dropout": 0.0,
20
+ "hyena_order": 2,
21
+ "initializer_range": 0.02,
22
+ "layer_norm_epsilon": 1e-05,
23
+ "max_seq_len": 32770,
24
+ "model_type": "hyenadna",
25
+ "n_layer": 4,
26
+ "num_inner_mlps": 2,
27
+ "pad_token_id": 4,
28
+ "pad_vocab_size_multiple": 8,
29
+ "problem_type": "single_label_classification",
30
+ "short_filter_order": 3,
31
+ "tie_word_embeddings": false,
32
+ "torch_dtype": "float32",
33
+ "train_freq": true,
34
+ "transformers_version": "4.26.1",
35
+ "use_bias": true,
36
+ "vocab_size": 12
37
+ }
checkpoint-1000/configuration_hyena.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+ import json
3
+
4
+
5
+ class HyenaConfig(PretrainedConfig):
6
+ model_type = "hyenadna"
7
+ def __init__(
8
+ self,
9
+ vocab_size=12,
10
+ d_model=256,
11
+ d_inner=None,
12
+ use_bias=True,
13
+ train_freq=True,
14
+ max_seq_len=1024,
15
+ emb_dim=3,
16
+ n_layer=12,
17
+ num_inner_mlps=2,
18
+ hyena_order=2,
19
+ short_filter_order=3,
20
+ filter_order=64,
21
+ activation_freq=1,
22
+ embed_dropout=0.1,
23
+ hyena_dropout=0.0,
24
+ hyena_filter_dropout=0.0,
25
+ layer_norm_epsilon=1e-5,
26
+ initializer_range=0.02,
27
+ pad_vocab_size_multiple=8,
28
+ **kwargs,
29
+ ):
30
+ self.vocab_size = vocab_size
31
+ self.d_model = d_model
32
+ if d_inner is None:
33
+ self.d_inner = 4 * d_model
34
+ else:
35
+ self.d_inner = d_inner
36
+ self.use_bias = use_bias
37
+ self.train_freq = train_freq
38
+ self.max_seq_len = max_seq_len
39
+ self.emb_dim = emb_dim
40
+ self.n_layer = n_layer
41
+ self.hyena_order = hyena_order
42
+ self.filter_order = filter_order
43
+ self.short_filter_order = short_filter_order
44
+ self.activation_freq = activation_freq
45
+ self.num_inner_mlps = num_inner_mlps
46
+ self.embed_dropout = embed_dropout
47
+ self.hyena_dropout = hyena_dropout
48
+ self.hyena_filter_dropout = hyena_filter_dropout
49
+ self.layer_norm_epsilon = layer_norm_epsilon
50
+ self.initializer_range = initializer_range
51
+ self.pad_vocab_size_multiple = pad_vocab_size_multiple
52
+ super().__init__(**kwargs)
53
+
54
+ @classmethod
55
+ def from_original_config(cls, config_path, **kwargs):
56
+ with open(config_path, "r") as f:
57
+ config = json.load(f)
58
+
59
+ vocab_size = config["vocab_size"]
60
+ d_model = config["d_model"]
61
+ d_inner = config["d_inner"]
62
+ max_seq_len = config["layer"]["l_max"]
63
+ emb_dim = config["layer"]["emb_dim"]
64
+ filter_order = config["layer"]["filter_order"]
65
+ if "local_order" in config["layer"]:
66
+ short_filter_order = config["layer"]["local_order"]
67
+ elif "short_filter_order" in config["layer"]:
68
+ short_filter_order = config["layer"]["short_filter_order"]
69
+ else:
70
+ short_filter_order = 3
71
+ n_layer = config["n_layer"]
72
+ activation_freq = config["layer"]["w"]
73
+ embed_dropout = config["embed_dropout"]
74
+ pad_vocab_size_multiple = config["pad_vocab_size_multiple"]
75
+ return cls(vocab_size=vocab_size,
76
+ d_model=d_model,
77
+ d_inner=d_inner,
78
+ max_seq_len=max_seq_len,
79
+ emb_dim=emb_dim,
80
+ filter_order=filter_order,
81
+ short_filter_order=short_filter_order,
82
+ n_layer=n_layer,
83
+ activation_freq=activation_freq,
84
+ embed_dropout=embed_dropout,
85
+ pad_vocab_size_multiple=pad_vocab_size_multiple,
86
+ tie_word_embeddings=False,
87
+ **kwargs
88
+ )
checkpoint-1000/modeling_hyena.py ADDED
@@ -0,0 +1,574 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """HyenaDNA custom code port to Hugging Face Hub"""
3
+
4
+ import math
5
+ import torch
6
+ import torch.nn as nn
7
+ from torch.nn import functional as F
8
+ from .configuration_hyena import HyenaConfig
9
+ from transformers import PreTrainedModel
10
+ from typing import Optional, Tuple, Union
11
+ from transformers.modeling_outputs import CausalLMOutput, SequenceClassifierOutput, BaseModelOutputWithNoAttention
12
+
13
+
14
+ def fftconv(u, k, D):
15
+ """
16
+ We apply a convolution through the fourier domain (from the Convolution Theorem)
17
+
18
+ """
19
+ seqlen = u.shape[-1]
20
+ fft_size = 2 * seqlen
21
+
22
+ k_f = torch.fft.rfft(k.to(torch.float32), n=fft_size) / fft_size
23
+ u_f = torch.fft.rfft(u.to(dtype=torch.float32), n=fft_size)
24
+
25
+ if len(u.shape) > 3: k_f = k_f.unsqueeze(1)
26
+ y = torch.fft.irfft(u_f * k_f, n=fft_size, norm='forward')[..., :seqlen]
27
+
28
+ out = y + u * D.unsqueeze(-1)
29
+ return out.to(dtype=u.dtype)
30
+
31
+
32
+ @torch.jit.script
33
+ def mul_sum(q, y):
34
+ return (q * y).sum(dim=1)
35
+
36
+
37
+ class HyenaSin(nn.Module):
38
+ """The Sin activation function for the Hyena Filter function."""
39
+ def __init__(self, config):
40
+ super().__init__()
41
+ self.freq = nn.Parameter(config.activation_freq * torch.ones(1, config.filter_order)) if config.train_freq else config.activation_freq * torch.ones(1, config.filter_order)
42
+
43
+ def forward(self, x):
44
+ return torch.sin(self.freq * x)
45
+
46
+
47
+ class HyenaPositionalEmbedding(nn.Module):
48
+ def __init__(self, config):
49
+ """Complex exponential positional embeddings for Hyena filters."""
50
+ super().__init__()
51
+
52
+ self.seq_len = config.max_seq_len
53
+ # The time embedding fed to the filteres is normalized so that t_f = 1
54
+ t = torch.linspace(0, 1, self.seq_len)[None, :, None] # 1, L, 1
55
+
56
+ if config.emb_dim > 1:
57
+ bands = (config.emb_dim - 1) // 2
58
+ # To compute the right embeddings we use the "proper" linspace
59
+ t_rescaled = torch.linspace(0, self.seq_len - 1, self.seq_len)[None, :, None]
60
+ w = 2 * math.pi * t_rescaled / self.seq_len # 1, L, 1
61
+
62
+ f = torch.linspace(1e-4, bands - 1, bands)[None, None]
63
+
64
+ z = torch.cat([t, torch.cos(-f * w), torch.sin(-f * w)], dim=-1)
65
+
66
+ self.register_buffer("z", z)
67
+ self.register_buffer("t", t)
68
+
69
+ def forward(self, L):
70
+ return self.z[:, :L], self.t[:, :L]
71
+
72
+
73
+ class HyenaExponentialModulation(nn.Module):
74
+ """The window function applied to the output of the (MLP) filter function."""
75
+ def __init__(
76
+ self,
77
+ d_model,
78
+ fast_decay_pct=0.3,
79
+ slow_decay_pct=1.5,
80
+ target=1e-2,
81
+ modulate: bool=True,
82
+ shift: float = 0.05,
83
+ **kwargs
84
+ ):
85
+ super().__init__()
86
+ self.modulate = modulate
87
+ self.shift = shift
88
+ max_decay = math.log(target) / fast_decay_pct
89
+ min_decay = math.log(target) / slow_decay_pct
90
+ deltas = torch.linspace(min_decay, max_decay, d_model)[None, None]
91
+ self.register_buffer("deltas", deltas)
92
+
93
+ def forward(self, t, x):
94
+ if self.modulate:
95
+ decay = torch.exp(-t * self.deltas.abs())
96
+ x = x * (decay + self.shift)
97
+ return x
98
+
99
+
100
+ class HyenaFilter(nn.Module):
101
+ def __init__(
102
+ self,
103
+ config,
104
+ **kwargs
105
+ ):
106
+ """
107
+ Implicit long filter with modulation.
108
+
109
+ Args:
110
+ d_model: number of channels in the input
111
+ emb_dim: dimension of the positional encoding (`emb_dim` - 1) // 2 is the number of bands
112
+ order: width of the FFN
113
+ num_inner_mlps: number of inner linear layers inside filter MLP
114
+
115
+ Note:
116
+ filter_dropout is not implemented
117
+ """
118
+ super().__init__()
119
+
120
+ self.d_model = config.d_model * (config.hyena_order - 1)
121
+ self.use_bias = config.use_bias
122
+ self.bias = nn.Parameter(torch.randn(self.d_model))
123
+ self.dropout = nn.Dropout(config.hyena_filter_dropout)
124
+
125
+ act = HyenaSin(config)
126
+ self.emb_dim = config.emb_dim
127
+ assert self.emb_dim % 2 != 0 and self.emb_dim >= 3, "emb_dim must be odd and greater or equal to 3 (time, sine and cosine)"
128
+ self.seq_len = config.max_seq_len
129
+
130
+ self.pos_emb = HyenaPositionalEmbedding(config)
131
+
132
+ self.implicit_filter = nn.Sequential(
133
+ nn.Linear(self.emb_dim, config.filter_order),
134
+ act,
135
+ )
136
+ for i in range(config.num_inner_mlps):
137
+ self.implicit_filter.append(nn.Linear(config.filter_order, config.filter_order))
138
+ self.implicit_filter.append(act)
139
+
140
+ self.implicit_filter.append(nn.Linear(config.filter_order, config.d_model, bias=False))
141
+
142
+ self.modulation = HyenaExponentialModulation(config.d_model)
143
+
144
+ self.normalized = False
145
+
146
+ def filter(self, L, *args, **kwargs):
147
+ z, t = self.pos_emb(L)
148
+ h = self.implicit_filter(z.to(dtype=self.implicit_filter[0].weight.dtype))
149
+ h = self.modulation(t, h)
150
+ return h
151
+
152
+ def forward(self, x, L, k=None, bias=None, *args, **kwargs):
153
+ if k is None: k = self.filter(L)
154
+
155
+ # Ensure compatibility with filters that return a tuple
156
+ k = k[0] if type(k) is tuple else k
157
+
158
+ y = fftconv(x, k, bias)
159
+ return y
160
+
161
+
162
+ class HyenaOperator(nn.Module):
163
+ def __init__(
164
+ self,
165
+ config,
166
+ **filter_args,
167
+ ):
168
+ r"""
169
+ Hyena operator described in the paper https://arxiv.org/pdf/2302.10866.pdf
170
+
171
+ Args:
172
+ d_model (int): Dimension of the input and output embeddings (width of the layer)
173
+ l_max: (int): Maximum input sequence length. Defaults to None
174
+ order: (int): Depth of the Hyena recurrence. Defaults to 2
175
+ dropout: (float): Dropout probability. Defaults to 0.0
176
+ filter_dropout: (float): Dropout probability for the filter. Defaults to 0.0
177
+ """
178
+ super().__init__()
179
+
180
+ self.d_model = config.d_model
181
+ self.l_max = config.max_seq_len
182
+ self.order = config.hyena_order
183
+ inner_width = config.d_model * (self.order + 1)
184
+ self.dropout = nn.Dropout(config.hyena_dropout)
185
+ self.in_proj = nn.Linear(self.d_model, inner_width)
186
+ self.out_proj = nn.Linear(self.d_model, self.d_model)
187
+
188
+ self.short_filter = nn.Conv1d(
189
+ inner_width,
190
+ inner_width,
191
+ config.short_filter_order,
192
+ padding=2,
193
+ groups=inner_width
194
+ )
195
+ self.filter_fn = HyenaFilter(config)
196
+
197
+ def forward(self, u):
198
+ l = u.size(-2)
199
+ l_filter = min(l, self.l_max)
200
+ u = self.in_proj(u).transpose(1, 2)
201
+
202
+ uc = self.short_filter(u)[...,:l_filter]
203
+ *x, v = uc.split(self.d_model, dim=1)
204
+
205
+ k = self.filter_fn.filter(l_filter)[0]
206
+ k = k.transpose(0, 1).reshape(self.order - 1, self.d_model, l_filter)
207
+ bias = self.filter_fn.bias.reshape(self.order - 1, self.d_model)
208
+
209
+ for o, x_i in enumerate(reversed(x[1:])):
210
+ v = self.dropout(v * x_i)
211
+ v = self.filter_fn(v, l_filter, k=k[o], bias=bias[o])
212
+
213
+ y = (v * x[0]).transpose(1, 2)
214
+
215
+ y = self.out_proj(y)
216
+ return y
217
+
218
+ class HyenaMlp(nn.Module):
219
+
220
+ def __init__(self, config):
221
+ """
222
+ From https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/modules/mlp.py
223
+ """
224
+ super().__init__()
225
+ in_features = config.d_model
226
+ hidden_features = config.d_inner
227
+ self.fc1 = nn.Linear(in_features, hidden_features)
228
+ self.fc2 = nn.Linear(hidden_features, config.d_model)
229
+
230
+ def forward(self, x):
231
+ y = self.fc1(x)
232
+ y = F.gelu(y, approximate="tanh")
233
+ y = self.fc2(y)
234
+ return y
235
+
236
+ class HyenaBlock(nn.Module):
237
+
238
+ def __init__(self, config):
239
+ """
240
+ From https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/modules/block.py
241
+ For prenorm=True, this Block has a slightly different structure compared to a regular
242
+ prenorm Transformer block.
243
+ The standard block is: LN -> MHA -> Dropout -> Add -> LN -> MLP -> Dropout -> Add.
244
+ [Ref: https://arxiv.org/abs/2002.04745]
245
+ Here we have: Dropout -> Add -> LN -> MHA -> Dropout -> Add -> LN -> MLP, returning both
246
+ the hidden_states (output of the MLP) and the residual.
247
+ This is for performance reasons, as we can fuse the dropout, add and LayerNorm.
248
+ The residual needs to be provided (except for the very first block).
249
+ For prenorm=False, this Block has the same structure as a regular postnorm Transformer
250
+ block: MHA -> Dropout -> Add -> LN -> MLP -> Dropout -> Add -> LN.
251
+ return_residual: whether each of the sub-layers (mixer and mlp) will return the residual.
252
+ This is for performance reason: for post-norm architecture, returning the input allows us
253
+ to fuse the backward of nn.Linear with the residual connection.
254
+ """
255
+ super().__init__()
256
+ self.mixer = HyenaOperator(config)
257
+ self.norm1 = nn.LayerNorm(config.d_model)
258
+ self.mlp = HyenaMlp(config)
259
+ self.norm2 = nn.LayerNorm(config.d_model)
260
+
261
+ def forward(self, hidden_states):
262
+ r"""Pass the input through the encoder layer.
263
+ Args:
264
+ hidden_states: the sequence to the encoder layer (required).
265
+ residual: if postnorm, residual=None, If prenorm, hidden_states = Attn/MLP(LN(residual))
266
+ mixer_subset: for cross-attention only. If not None, will take a subset of x
267
+ before applying the query projection. Useful for e.g., ViT where we only care
268
+ about the CLS token in the last layer.
269
+ """
270
+ residual = hidden_states
271
+ residual = residual.to(torch.float32)
272
+ hyena_normed = self.norm1(residual.to(dtype=self.norm1.weight.dtype))
273
+ hidden_states = self.mixer(hyena_normed)
274
+ # Tested above here and all is equivalent. That means the mixer is fine!!!
275
+ residual = hidden_states + residual
276
+ hidden_states = self.norm2(residual.to(dtype=self.norm2.weight.dtype))
277
+ residual = residual.to(torch.float32)
278
+
279
+ hidden_states = self.mlp(hidden_states)
280
+ return hidden_states + residual
281
+
282
+
283
+ # https://github.com/huggingface/transformers/blob/c28d04e9e252a1a099944e325685f14d242ecdcd/src/transformers/models/gpt2/modeling_gpt2.py#L454
284
+
285
+
286
+ class HyenaEmbeddings(nn.Module):
287
+
288
+ def __init__(self, config, padding_idx=None):
289
+ """
290
+ If max_position_embeddings <= 0, there's no position embeddings
291
+ If word_embe_proj_dim is not None (e.g., OPT-350m), we embed to that dimension
292
+ the project up to embed_dim
293
+ """
294
+ super().__init__()
295
+ vocab_size = config.vocab_size
296
+ if vocab_size % config.pad_vocab_size_multiple != 0:
297
+ vocab_size += config.pad_vocab_size_multiple - (vocab_size % config.pad_vocab_size_multiple)
298
+ self.word_embeddings = nn.Embedding(vocab_size, config.d_model, padding_idx=padding_idx)
299
+
300
+ def forward(self, input_ids):
301
+ """
302
+ input_ids: (batch, seqlen)
303
+ """
304
+ embeddings = self.word_embeddings(input_ids)
305
+ return embeddings
306
+
307
+ class HyenaLMBackbone(nn.Module):
308
+
309
+ def __init__(self, config) -> None:
310
+ super().__init__()
311
+ # note max_position_embeddings is 0 for Hyena, and therefore isn't used
312
+ self.embeddings = HyenaEmbeddings(config)
313
+ self.dropout = nn.Dropout(config.embed_dropout)
314
+
315
+ self.layers = nn.ModuleList([HyenaBlock(config) for i in range(config.n_layer)])
316
+
317
+ self.ln_f = nn.LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
318
+ self.gradient_checkpointing = False
319
+
320
+ def forward(self, input_ids, inputs_embeds=None, output_hidden_states=False):
321
+ all_hidden_states = []
322
+ if inputs_embeds is not None:
323
+ hidden_states = inputs_embeds
324
+ else:
325
+ hidden_states = self.embeddings(input_ids)
326
+ if output_hidden_states:
327
+ all_hidden_states.append(hidden_states)
328
+
329
+ for layer in self.layers:
330
+ if self.gradient_checkpointing and self.training:
331
+ hidden_states = self._gradient_checkpointing_func(layer.__call__, hidden_states)
332
+ else:
333
+ hidden_states = layer(hidden_states)
334
+ if output_hidden_states:
335
+ all_hidden_states.append(hidden_states)
336
+
337
+ hidden_states = self.ln_f(hidden_states.to(dtype=self.ln_f.weight.dtype))
338
+ if output_hidden_states:
339
+ all_hidden_states.append(hidden_states)
340
+
341
+ return hidden_states, all_hidden_states
342
+
343
+
344
+ class HyenaDNAPreTrainedModel(PreTrainedModel):
345
+ config_class = HyenaConfig
346
+ base_model_prefix = "hyena"
347
+ supports_gradient_checkpointing = True
348
+ _no_split_modules = ["HyenaBlock"]
349
+ _skip_keys_device_placement = "past_key_values"
350
+ _keys_to_ignore_on_load_missing = [r"freq"] # Shared tensors that safetensors merges
351
+
352
+ def _init_weights(self, module, initializer_range=0.02):
353
+ if isinstance(module, nn.Linear):
354
+ nn.init.normal_(module.weight, std=initializer_range)
355
+ if module.bias is not None:
356
+ nn.init.zeros_(module.bias)
357
+ elif isinstance(module, nn.Embedding):
358
+ nn.init.normal_(module.weight, std=initializer_range)
359
+ # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
360
+ # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
361
+ # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
362
+ # > -- GPT-2 :: https://openai.com/blog/better-language-models/
363
+ #
364
+ # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
365
+ for name, p in self.named_parameters():
366
+ if name in ["out_proj.weight", "fc2.weight"]:
367
+ # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
368
+ nn.init.normal_(p, mean=0.0, std=initializer_range / math.sqrt(2 * self.config.num_layers))
369
+ # If using GLU activation for now, we scale the std by 2
370
+ elif name in ["output_linear.0.weight"]:
371
+ # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
372
+ nn.init.normal_(p, mean=0.0, std=initializer_range / math.sqrt(2 * self.config.num_layers))
373
+
374
+
375
+ class HyenaDNAModel(HyenaDNAPreTrainedModel):
376
+ def __init__(self, config, **kwargs) -> None:
377
+ super().__init__(config, **kwargs)
378
+
379
+ self.backbone = HyenaLMBackbone(config)
380
+ self.config = config
381
+
382
+ # Initialize weights and apply final processing
383
+ self.post_init()
384
+
385
+ def forward(self, input_ids, inputs_embeds=None, output_hidden_states=None, return_dict=None):
386
+ output_hidden_states = (
387
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
388
+ )
389
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
390
+
391
+ hidden_states, all_hidden_states = self.backbone(input_ids, inputs_embeds=inputs_embeds, output_hidden_states=output_hidden_states)
392
+ if return_dict:
393
+ return BaseModelOutputWithNoAttention(last_hidden_state=hidden_states,
394
+ hidden_states=all_hidden_states if output_hidden_states else None)
395
+ elif output_hidden_states:
396
+ return hidden_states, all_hidden_states
397
+ else:
398
+ return hidden_states
399
+
400
+
401
+ class HyenaDNAForCausalLM(HyenaDNAPreTrainedModel):
402
+
403
+ def __init__(self, config, **kwargs):
404
+ super().__init__(config, **kwargs)
405
+ self.hyena = HyenaDNAModel(config)
406
+ vocab_size = config.vocab_size
407
+ if vocab_size % config.pad_vocab_size_multiple != 0:
408
+ vocab_size += config.pad_vocab_size_multiple - (vocab_size % config.pad_vocab_size_multiple)
409
+ self.vocab_size = vocab_size
410
+ self.lm_head = nn.Linear(config.d_model, vocab_size, bias=False)
411
+
412
+ # Initialize weights and apply final processing
413
+ self.post_init()
414
+
415
+ def get_input_embeddings(self):
416
+ return self.hyena.backbone.embeddings.word_embeddings
417
+
418
+ def set_input_embeddings(self, value):
419
+ self.hyena.backbone.embeddings.word_embeddings = value
420
+
421
+ def get_output_embeddings(self):
422
+ return self.lm_head
423
+
424
+ def set_output_embeddings(self, new_embeddings):
425
+ self.lm_head = new_embeddings
426
+
427
+ def set_decoder(self, decoder):
428
+ self.hyena = decoder
429
+
430
+ def get_decoder(self):
431
+ return self.hyena
432
+
433
+ def forward(
434
+ self,
435
+ input_ids: torch.LongTensor = None,
436
+ inputs_embeds: Optional[torch.FloatTensor] = None,
437
+ labels: Optional[torch.LongTensor] = None,
438
+ output_hidden_states: Optional[bool] = None,
439
+ return_dict: Optional[bool] = None,
440
+ ) -> Union[Tuple, CausalLMOutput]:
441
+
442
+ output_hidden_states = (
443
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
444
+ )
445
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
446
+
447
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
448
+ outputs = self.hyena(
449
+ input_ids=input_ids,
450
+ inputs_embeds=inputs_embeds,
451
+ output_hidden_states=output_hidden_states,
452
+ return_dict=return_dict,
453
+ )
454
+
455
+ hidden_states = outputs[0]
456
+ logits = self.lm_head(hidden_states)
457
+ logits = logits.float()
458
+
459
+ loss = None
460
+ if labels is not None:
461
+ # Shift so that tokens < n predict n
462
+ shift_logits = logits[..., :-1, :].contiguous()
463
+ shift_labels = labels[..., 1:].contiguous()
464
+ # Flatten the tokens
465
+ loss_fct = nn.CrossEntropyLoss()
466
+ shift_logits = shift_logits.view(-1, self.vocab_size)
467
+ shift_labels = shift_labels.view(-1)
468
+ # Enable model parallelism
469
+ shift_labels = shift_labels.to(shift_logits.device)
470
+ loss = loss_fct(shift_logits, shift_labels)
471
+
472
+ if not return_dict:
473
+ output = (logits,) + outputs[1:]
474
+ return (loss,) + output if loss is not None else output
475
+
476
+ return CausalLMOutput(
477
+ loss=loss,
478
+ logits=logits,
479
+ hidden_states=outputs.hidden_states,
480
+ )
481
+
482
+
483
+ class HyenaDNAForSequenceClassification(HyenaDNAPreTrainedModel):
484
+ def __init__(self, config, **kwargs):
485
+ super().__init__(config, **kwargs)
486
+ self.num_labels = kwargs.get("num_labels", config.num_labels)
487
+ self.hyena = HyenaDNAModel(config)
488
+ self.score = nn.Linear(config.d_model, self.num_labels, bias=False)
489
+
490
+ # Initialize weights and apply final processing
491
+ self.post_init()
492
+
493
+ def get_input_embeddings(self):
494
+ return self.hyena.backbone.embeddings.word_embeddings
495
+
496
+ def set_input_embeddings(self, value):
497
+ self.hyena.backbone.embeddings.word_embeddings = value
498
+
499
+ def forward(
500
+ self,
501
+ input_ids: torch.LongTensor = None,
502
+ inputs_embeds: Optional[torch.FloatTensor] = None,
503
+ labels: Optional[torch.LongTensor] = None,
504
+ output_hidden_states: Optional[bool] = None,
505
+ return_dict: Optional[bool] = None,
506
+ ) -> Union[Tuple, SequenceClassifierOutput]:
507
+ r"""
508
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
509
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
510
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
511
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
512
+ """
513
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
514
+
515
+ transformer_outputs = self.hyena(
516
+ input_ids,
517
+ inputs_embeds=inputs_embeds,
518
+ output_hidden_states=output_hidden_states,
519
+ return_dict=return_dict,
520
+ )
521
+ hidden_states = transformer_outputs[0]
522
+ logits = self.score(hidden_states)
523
+
524
+ if input_ids is not None:
525
+ batch_size = input_ids.shape[0]
526
+ else:
527
+ batch_size = inputs_embeds.shape[0]
528
+
529
+ if self.config.pad_token_id is None and batch_size != 1:
530
+ raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
531
+ if self.config.pad_token_id is None:
532
+ sequence_lengths = -1
533
+ else:
534
+ if input_ids is not None:
535
+ sequence_lengths = (torch.eq(input_ids, self.config.pad_token_id).long().argmax(-1) - 1).to(
536
+ logits.device
537
+ )
538
+ else:
539
+ sequence_lengths = -1
540
+
541
+ pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
542
+
543
+ loss = None
544
+ if labels is not None:
545
+ labels = labels.to(logits.device)
546
+ if self.config.problem_type is None:
547
+ if self.num_labels == 1:
548
+ self.config.problem_type = "regression"
549
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
550
+ self.config.problem_type = "single_label_classification"
551
+ else:
552
+ self.config.problem_type = "multi_label_classification"
553
+
554
+ if self.config.problem_type == "regression":
555
+ loss_fct = nn.MSELoss()
556
+ if self.num_labels == 1:
557
+ loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
558
+ else:
559
+ loss = loss_fct(pooled_logits, labels)
560
+ elif self.config.problem_type == "single_label_classification":
561
+ loss_fct = nn.CrossEntropyLoss()
562
+ loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
563
+ elif self.config.problem_type == "multi_label_classification":
564
+ loss_fct = nn.BCEWithLogitsLoss()
565
+ loss = loss_fct(pooled_logits, labels)
566
+ if not return_dict:
567
+ output = (pooled_logits,) + transformer_outputs[1:]
568
+ return ((loss,) + output) if loss is not None else output
569
+
570
+ return SequenceClassifierOutput(
571
+ loss=loss,
572
+ logits=pooled_logits,
573
+ hidden_states=transformer_outputs.hidden_states,
574
+ )
checkpoint-1000/optimizer.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2ef7e54ba1b263d9091a3b54061ee508313c2e29a1b4fd6c4a456699aab93ff5
3
+ size 26304517
checkpoint-1000/pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a980d905b63d647cd0ba855443b706a1ccbbe69153ae04b1954150e780ece36d
3
+ size 16300157
checkpoint-1000/rng_state.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d5773d772a88e414bc446a704f98e2303c1d836815da2d0e7d245a164163118d
3
+ size 14575
checkpoint-1000/scaler.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:46b7af93ae41ac3bdc5a501bc85a1111baad8d4df217a9776b7435044a9d320a
3
+ size 557
checkpoint-1000/scheduler.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6ebe2cff767413b62b8cb3e5d55cf4f3e637c678e963eb9059d55891b8da69e4
3
+ size 627
checkpoint-1000/special_tokens_map.json ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": {
3
+ "content": "[BOS]",
4
+ "lstrip": false,
5
+ "normalized": false,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "cls_token": {
10
+ "content": "[CLS]",
11
+ "lstrip": false,
12
+ "normalized": false,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "eos_token": {
17
+ "content": "[SEP]",
18
+ "lstrip": false,
19
+ "normalized": false,
20
+ "rstrip": false,
21
+ "single_word": false
22
+ },
23
+ "mask_token": {
24
+ "content": "[MASK]",
25
+ "lstrip": false,
26
+ "normalized": false,
27
+ "rstrip": false,
28
+ "single_word": false
29
+ },
30
+ "pad_token": {
31
+ "content": "[PAD]",
32
+ "lstrip": false,
33
+ "normalized": false,
34
+ "rstrip": false,
35
+ "single_word": false
36
+ },
37
+ "sep_token": {
38
+ "content": "[SEP]",
39
+ "lstrip": false,
40
+ "normalized": false,
41
+ "rstrip": false,
42
+ "single_word": false
43
+ },
44
+ "unk_token": {
45
+ "content": "[UNK]",
46
+ "lstrip": false,
47
+ "normalized": false,
48
+ "rstrip": false,
49
+ "single_word": false
50
+ }
51
+ }
checkpoint-1000/tokenization_hyena.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PreTrainedTokenizer, AddedToken
2
+ from typing import List, Optional, Union, Dict, Sequence, Tuple
3
+ from pathlib import Path
4
+ import json
5
+ import os
6
+
7
+
8
+ class HyenaDNATokenizer(PreTrainedTokenizer):
9
+ model_input_names = ["input_ids"]
10
+
11
+ def __init__(self,
12
+ model_max_length: int,
13
+ bos_token="[BOS]",
14
+ eos_token="[SEP]",
15
+ sep_token="[SEP]",
16
+ cls_token="[CLS]",
17
+ pad_token="[PAD]",
18
+ mask_token="[MASK]",
19
+ unk_token="[UNK]",
20
+ **kwargs):
21
+ """Character tokenizer for Hugging Face transformers.
22
+ Args:
23
+ characters (Sequence[str]): List of desired characters. Any character which
24
+ is not included in this list will be replaced by a special token called
25
+ [UNK] with id=6. Following are list of all of the special tokens with
26
+ their corresponding ids:
27
+ "[CLS]": 0
28
+ "[SEP]": 1
29
+ "[BOS]": 2
30
+ "[MASK]": 3
31
+ "[PAD]": 4
32
+ "[RESERVED]": 5
33
+ "[UNK]": 6
34
+ an id (starting at 7) will be assigned to each character.
35
+ model_max_length (int): Model maximum sequence length.
36
+ """
37
+ self.characters = ('A', 'C', 'G', 'T', 'N')
38
+ self.model_max_length = model_max_length
39
+
40
+ self._vocab_str_to_int = {
41
+ "[CLS]": 0,
42
+ "[SEP]": 1,
43
+ "[BOS]": 2,
44
+ "[MASK]": 3,
45
+ "[PAD]": 4,
46
+ "[RESERVED]": 5,
47
+ "[UNK]": 6,
48
+ **{ch: i + 7 for i, ch in enumerate(self.characters)},
49
+ }
50
+ self._vocab_int_to_str = {v: k for k, v in self._vocab_str_to_int.items()}
51
+ add_prefix_space = kwargs.pop("add_prefix_space", False)
52
+ padding_side = kwargs.pop("padding_side", "left")
53
+
54
+ super().__init__(
55
+ bos_token=bos_token,
56
+ eos_token=eos_token,
57
+ sep_token=sep_token,
58
+ cls_token=cls_token,
59
+ pad_token=pad_token,
60
+ mask_token=mask_token,
61
+ unk_token=unk_token,
62
+ add_prefix_space=add_prefix_space,
63
+ model_max_length=model_max_length,
64
+ padding_side=padding_side,
65
+ **kwargs,
66
+ )
67
+
68
+ @property
69
+ def vocab_size(self) -> int:
70
+ return len(self._vocab_str_to_int)
71
+
72
+ def _tokenize(self, text: str) -> List[str]:
73
+ return list(text)
74
+
75
+ def _convert_token_to_id(self, token: str) -> int:
76
+ return self._vocab_str_to_int.get(token, self._vocab_str_to_int["[UNK]"])
77
+
78
+ def _convert_id_to_token(self, index: int) -> str:
79
+ return self._vocab_int_to_str[index]
80
+
81
+ def convert_tokens_to_string(self, tokens):
82
+ return "".join(tokens)
83
+
84
+ def get_special_tokens_mask(
85
+ self,
86
+ token_ids_0: List[int],
87
+ token_ids_1: Optional[List[int]] = None,
88
+ already_has_special_tokens: bool = False,
89
+ ) -> List[int]:
90
+ if already_has_special_tokens:
91
+ return super().get_special_tokens_mask(
92
+ token_ids_0=token_ids_0,
93
+ token_ids_1=token_ids_1,
94
+ already_has_special_tokens=True,
95
+ )
96
+
97
+ result = ([0] * len(token_ids_0)) + [1]
98
+ if token_ids_1 is not None:
99
+ result += ([0] * len(token_ids_1)) + [1]
100
+ return result
101
+
102
+ def build_inputs_with_special_tokens(
103
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
104
+ ) -> List[int]:
105
+ sep = [self.sep_token_id]
106
+ # cls = [self.cls_token_id]
107
+ result = token_ids_0 + sep
108
+ if token_ids_1 is not None:
109
+ result += token_ids_1 + sep
110
+ return result
111
+
112
+ def get_vocab(self) -> Dict[str, int]:
113
+ return self._vocab_str_to_int
114
+
115
+ # HyenaDNA has a fixed vocabulary with no vocab file
116
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple:
117
+ return ()
checkpoint-1000/tokenizer_config.json ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_prefix_space": false,
3
+ "added_tokens_decoder": {
4
+ "0": {
5
+ "content": "[CLS]",
6
+ "lstrip": false,
7
+ "normalized": false,
8
+ "rstrip": false,
9
+ "single_word": false,
10
+ "special": true
11
+ },
12
+ "1": {
13
+ "content": "[SEP]",
14
+ "lstrip": false,
15
+ "normalized": false,
16
+ "rstrip": false,
17
+ "single_word": false,
18
+ "special": true
19
+ },
20
+ "2": {
21
+ "content": "[BOS]",
22
+ "lstrip": false,
23
+ "normalized": false,
24
+ "rstrip": false,
25
+ "single_word": false,
26
+ "special": true
27
+ },
28
+ "3": {
29
+ "content": "[MASK]",
30
+ "lstrip": false,
31
+ "normalized": false,
32
+ "rstrip": false,
33
+ "single_word": false,
34
+ "special": true
35
+ },
36
+ "4": {
37
+ "content": "[PAD]",
38
+ "lstrip": false,
39
+ "normalized": false,
40
+ "rstrip": false,
41
+ "single_word": false,
42
+ "special": true
43
+ },
44
+ "6": {
45
+ "content": "[UNK]",
46
+ "lstrip": false,
47
+ "normalized": false,
48
+ "rstrip": false,
49
+ "single_word": false,
50
+ "special": true
51
+ }
52
+ },
53
+ "auto_map": {
54
+ "AutoTokenizer": [
55
+ "tokenization_hyena.HyenaDNATokenizer",
56
+ null
57
+ ]
58
+ },
59
+ "bos_token": "[BOS]",
60
+ "clean_up_tokenization_spaces": true,
61
+ "cls_token": "[CLS]",
62
+ "eos_token": "[SEP]",
63
+ "mask_token": "[MASK]",
64
+ "model_max_length": 256,
65
+ "name_or_path": "LongSafari/hyenadna-small-32k-seqlen-hf",
66
+ "pad_token": "[PAD]",
67
+ "padding_side": "right",
68
+ "sep_token": "[SEP]",
69
+ "special_tokens_map_file": "/home/hlv8980/.cache/huggingface/hub/models--LongSafari--hyenadna-small-32k-seqlen-hf/snapshots/8fe770c78eb13fe33bf81501612faeddf4d6f331/special_tokens_map.json",
70
+ "tokenizer_class": "HyenaDNATokenizer",
71
+ "unk_token": "[UNK]"
72
+ }
checkpoint-1000/trainer_state.json ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "best_metric": 0.39216598868370056,
3
+ "best_model_checkpoint": "/scratch/hlv8980/Attack_Benchmark/models/hyena/tf4/origin/checkpoint-600",
4
+ "epoch": 3.3670033670033668,
5
+ "global_step": 1000,
6
+ "is_hyper_param_search": false,
7
+ "is_local_process_zero": true,
8
+ "is_world_process_zero": true,
9
+ "log_history": [
10
+ {
11
+ "epoch": 0.34,
12
+ "learning_rate": 2.8760984182776802e-05,
13
+ "loss": 0.5992,
14
+ "step": 100
15
+ },
16
+ {
17
+ "epoch": 0.67,
18
+ "learning_rate": 2.615114235500879e-05,
19
+ "loss": 0.4813,
20
+ "step": 200
21
+ },
22
+ {
23
+ "epoch": 0.67,
24
+ "eval_accuracy": 0.774,
25
+ "eval_f1": 0.7713328260834371,
26
+ "eval_loss": 0.48207539319992065,
27
+ "eval_matthews_correlation": 0.5579338694412199,
28
+ "eval_precision": 0.785067107786007,
29
+ "eval_recall": 0.772997299729973,
30
+ "eval_runtime": 0.1057,
31
+ "eval_samples_per_second": 9462.679,
32
+ "eval_steps_per_second": 151.403,
33
+ "step": 200
34
+ },
35
+ {
36
+ "epoch": 1.01,
37
+ "learning_rate": 2.3514938488576452e-05,
38
+ "loss": 0.4431,
39
+ "step": 300
40
+ },
41
+ {
42
+ "epoch": 1.35,
43
+ "learning_rate": 2.087873462214411e-05,
44
+ "loss": 0.377,
45
+ "step": 400
46
+ },
47
+ {
48
+ "epoch": 1.35,
49
+ "eval_accuracy": 0.816,
50
+ "eval_f1": 0.8159933757615274,
51
+ "eval_loss": 0.427643358707428,
52
+ "eval_matthews_correlation": 0.6320128653971173,
53
+ "eval_precision": 0.815991263965056,
54
+ "eval_recall": 0.816021602160216,
55
+ "eval_runtime": 0.1039,
56
+ "eval_samples_per_second": 9625.637,
57
+ "eval_steps_per_second": 154.01,
58
+ "step": 400
59
+ },
60
+ {
61
+ "epoch": 1.68,
62
+ "learning_rate": 1.82688927943761e-05,
63
+ "loss": 0.3443,
64
+ "step": 500
65
+ },
66
+ {
67
+ "epoch": 2.02,
68
+ "learning_rate": 1.563268892794376e-05,
69
+ "loss": 0.33,
70
+ "step": 600
71
+ },
72
+ {
73
+ "epoch": 2.02,
74
+ "eval_accuracy": 0.824,
75
+ "eval_f1": 0.8239746523499383,
76
+ "eval_loss": 0.39216598868370056,
77
+ "eval_matthews_correlation": 0.6479558982194922,
78
+ "eval_precision": 0.8239935027265344,
79
+ "eval_recall": 0.8239623962396239,
80
+ "eval_runtime": 0.1031,
81
+ "eval_samples_per_second": 9696.512,
82
+ "eval_steps_per_second": 155.144,
83
+ "step": 600
84
+ },
85
+ {
86
+ "epoch": 2.36,
87
+ "learning_rate": 1.2996485061511423e-05,
88
+ "loss": 0.227,
89
+ "step": 700
90
+ },
91
+ {
92
+ "epoch": 2.69,
93
+ "learning_rate": 1.0360281195079087e-05,
94
+ "loss": 0.2219,
95
+ "step": 800
96
+ },
97
+ {
98
+ "epoch": 2.69,
99
+ "eval_accuracy": 0.838,
100
+ "eval_f1": 0.8379766686402841,
101
+ "eval_loss": 0.4026987850666046,
102
+ "eval_matthews_correlation": 0.6767651028795362,
103
+ "eval_precision": 0.8385613769517563,
104
+ "eval_recall": 0.8382038203820381,
105
+ "eval_runtime": 0.1026,
106
+ "eval_samples_per_second": 9746.534,
107
+ "eval_steps_per_second": 155.945,
108
+ "step": 800
109
+ },
110
+ {
111
+ "epoch": 3.03,
112
+ "learning_rate": 7.724077328646749e-06,
113
+ "loss": 0.2121,
114
+ "step": 900
115
+ },
116
+ {
117
+ "epoch": 3.37,
118
+ "learning_rate": 5.087873462214412e-06,
119
+ "loss": 0.1388,
120
+ "step": 1000
121
+ },
122
+ {
123
+ "epoch": 3.37,
124
+ "eval_accuracy": 0.857,
125
+ "eval_f1": 0.8566558306493891,
126
+ "eval_loss": 0.393052339553833,
127
+ "eval_matthews_correlation": 0.7159331394438886,
128
+ "eval_precision": 0.859342750257998,
129
+ "eval_recall": 0.8565956595659566,
130
+ "eval_runtime": 0.1036,
131
+ "eval_samples_per_second": 9656.574,
132
+ "eval_steps_per_second": 154.505,
133
+ "step": 1000
134
+ }
135
+ ],
136
+ "max_steps": 1188,
137
+ "num_train_epochs": 4,
138
+ "total_flos": 128187317035008.0,
139
+ "trial_name": null,
140
+ "trial_params": null
141
+ }
checkpoint-1000/training_args.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1f4e3a85efd6ca1a2228fc0bf6f5ca43150a2981352327acbf61aa5be7e43d49
3
+ size 3707
checkpoint-600/config.json ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "LongSafari/hyenadna-small-32k-seqlen-hf",
3
+ "activation_freq": 10,
4
+ "architectures": [
5
+ "HyenaDNAForSequenceClassification"
6
+ ],
7
+ "auto_map": {
8
+ "AutoConfig": "configuration_hyena.HyenaConfig",
9
+ "AutoModel": "modeling_hyena.HyenaDNAModel",
10
+ "AutoModelForCausalLM": "modeling_hyena.HyenaDNAForCausalLM",
11
+ "AutoModelForSequenceClassification": "modeling_hyena.HyenaDNAForSequenceClassification"
12
+ },
13
+ "d_inner": 1024,
14
+ "d_model": 256,
15
+ "emb_dim": 5,
16
+ "embed_dropout": 0.1,
17
+ "filter_order": 64,
18
+ "hyena_dropout": 0.0,
19
+ "hyena_filter_dropout": 0.0,
20
+ "hyena_order": 2,
21
+ "initializer_range": 0.02,
22
+ "layer_norm_epsilon": 1e-05,
23
+ "max_seq_len": 32770,
24
+ "model_type": "hyenadna",
25
+ "n_layer": 4,
26
+ "num_inner_mlps": 2,
27
+ "pad_token_id": 4,
28
+ "pad_vocab_size_multiple": 8,
29
+ "problem_type": "single_label_classification",
30
+ "short_filter_order": 3,
31
+ "tie_word_embeddings": false,
32
+ "torch_dtype": "float32",
33
+ "train_freq": true,
34
+ "transformers_version": "4.26.1",
35
+ "use_bias": true,
36
+ "vocab_size": 12
37
+ }
checkpoint-600/configuration_hyena.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+ import json
3
+
4
+
5
+ class HyenaConfig(PretrainedConfig):
6
+ model_type = "hyenadna"
7
+ def __init__(
8
+ self,
9
+ vocab_size=12,
10
+ d_model=256,
11
+ d_inner=None,
12
+ use_bias=True,
13
+ train_freq=True,
14
+ max_seq_len=1024,
15
+ emb_dim=3,
16
+ n_layer=12,
17
+ num_inner_mlps=2,
18
+ hyena_order=2,
19
+ short_filter_order=3,
20
+ filter_order=64,
21
+ activation_freq=1,
22
+ embed_dropout=0.1,
23
+ hyena_dropout=0.0,
24
+ hyena_filter_dropout=0.0,
25
+ layer_norm_epsilon=1e-5,
26
+ initializer_range=0.02,
27
+ pad_vocab_size_multiple=8,
28
+ **kwargs,
29
+ ):
30
+ self.vocab_size = vocab_size
31
+ self.d_model = d_model
32
+ if d_inner is None:
33
+ self.d_inner = 4 * d_model
34
+ else:
35
+ self.d_inner = d_inner
36
+ self.use_bias = use_bias
37
+ self.train_freq = train_freq
38
+ self.max_seq_len = max_seq_len
39
+ self.emb_dim = emb_dim
40
+ self.n_layer = n_layer
41
+ self.hyena_order = hyena_order
42
+ self.filter_order = filter_order
43
+ self.short_filter_order = short_filter_order
44
+ self.activation_freq = activation_freq
45
+ self.num_inner_mlps = num_inner_mlps
46
+ self.embed_dropout = embed_dropout
47
+ self.hyena_dropout = hyena_dropout
48
+ self.hyena_filter_dropout = hyena_filter_dropout
49
+ self.layer_norm_epsilon = layer_norm_epsilon
50
+ self.initializer_range = initializer_range
51
+ self.pad_vocab_size_multiple = pad_vocab_size_multiple
52
+ super().__init__(**kwargs)
53
+
54
+ @classmethod
55
+ def from_original_config(cls, config_path, **kwargs):
56
+ with open(config_path, "r") as f:
57
+ config = json.load(f)
58
+
59
+ vocab_size = config["vocab_size"]
60
+ d_model = config["d_model"]
61
+ d_inner = config["d_inner"]
62
+ max_seq_len = config["layer"]["l_max"]
63
+ emb_dim = config["layer"]["emb_dim"]
64
+ filter_order = config["layer"]["filter_order"]
65
+ if "local_order" in config["layer"]:
66
+ short_filter_order = config["layer"]["local_order"]
67
+ elif "short_filter_order" in config["layer"]:
68
+ short_filter_order = config["layer"]["short_filter_order"]
69
+ else:
70
+ short_filter_order = 3
71
+ n_layer = config["n_layer"]
72
+ activation_freq = config["layer"]["w"]
73
+ embed_dropout = config["embed_dropout"]
74
+ pad_vocab_size_multiple = config["pad_vocab_size_multiple"]
75
+ return cls(vocab_size=vocab_size,
76
+ d_model=d_model,
77
+ d_inner=d_inner,
78
+ max_seq_len=max_seq_len,
79
+ emb_dim=emb_dim,
80
+ filter_order=filter_order,
81
+ short_filter_order=short_filter_order,
82
+ n_layer=n_layer,
83
+ activation_freq=activation_freq,
84
+ embed_dropout=embed_dropout,
85
+ pad_vocab_size_multiple=pad_vocab_size_multiple,
86
+ tie_word_embeddings=False,
87
+ **kwargs
88
+ )
checkpoint-600/modeling_hyena.py ADDED
@@ -0,0 +1,574 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """HyenaDNA custom code port to Hugging Face Hub"""
3
+
4
+ import math
5
+ import torch
6
+ import torch.nn as nn
7
+ from torch.nn import functional as F
8
+ from .configuration_hyena import HyenaConfig
9
+ from transformers import PreTrainedModel
10
+ from typing import Optional, Tuple, Union
11
+ from transformers.modeling_outputs import CausalLMOutput, SequenceClassifierOutput, BaseModelOutputWithNoAttention
12
+
13
+
14
+ def fftconv(u, k, D):
15
+ """
16
+ We apply a convolution through the fourier domain (from the Convolution Theorem)
17
+
18
+ """
19
+ seqlen = u.shape[-1]
20
+ fft_size = 2 * seqlen
21
+
22
+ k_f = torch.fft.rfft(k.to(torch.float32), n=fft_size) / fft_size
23
+ u_f = torch.fft.rfft(u.to(dtype=torch.float32), n=fft_size)
24
+
25
+ if len(u.shape) > 3: k_f = k_f.unsqueeze(1)
26
+ y = torch.fft.irfft(u_f * k_f, n=fft_size, norm='forward')[..., :seqlen]
27
+
28
+ out = y + u * D.unsqueeze(-1)
29
+ return out.to(dtype=u.dtype)
30
+
31
+
32
+ @torch.jit.script
33
+ def mul_sum(q, y):
34
+ return (q * y).sum(dim=1)
35
+
36
+
37
+ class HyenaSin(nn.Module):
38
+ """The Sin activation function for the Hyena Filter function."""
39
+ def __init__(self, config):
40
+ super().__init__()
41
+ self.freq = nn.Parameter(config.activation_freq * torch.ones(1, config.filter_order)) if config.train_freq else config.activation_freq * torch.ones(1, config.filter_order)
42
+
43
+ def forward(self, x):
44
+ return torch.sin(self.freq * x)
45
+
46
+
47
+ class HyenaPositionalEmbedding(nn.Module):
48
+ def __init__(self, config):
49
+ """Complex exponential positional embeddings for Hyena filters."""
50
+ super().__init__()
51
+
52
+ self.seq_len = config.max_seq_len
53
+ # The time embedding fed to the filteres is normalized so that t_f = 1
54
+ t = torch.linspace(0, 1, self.seq_len)[None, :, None] # 1, L, 1
55
+
56
+ if config.emb_dim > 1:
57
+ bands = (config.emb_dim - 1) // 2
58
+ # To compute the right embeddings we use the "proper" linspace
59
+ t_rescaled = torch.linspace(0, self.seq_len - 1, self.seq_len)[None, :, None]
60
+ w = 2 * math.pi * t_rescaled / self.seq_len # 1, L, 1
61
+
62
+ f = torch.linspace(1e-4, bands - 1, bands)[None, None]
63
+
64
+ z = torch.cat([t, torch.cos(-f * w), torch.sin(-f * w)], dim=-1)
65
+
66
+ self.register_buffer("z", z)
67
+ self.register_buffer("t", t)
68
+
69
+ def forward(self, L):
70
+ return self.z[:, :L], self.t[:, :L]
71
+
72
+
73
+ class HyenaExponentialModulation(nn.Module):
74
+ """The window function applied to the output of the (MLP) filter function."""
75
+ def __init__(
76
+ self,
77
+ d_model,
78
+ fast_decay_pct=0.3,
79
+ slow_decay_pct=1.5,
80
+ target=1e-2,
81
+ modulate: bool=True,
82
+ shift: float = 0.05,
83
+ **kwargs
84
+ ):
85
+ super().__init__()
86
+ self.modulate = modulate
87
+ self.shift = shift
88
+ max_decay = math.log(target) / fast_decay_pct
89
+ min_decay = math.log(target) / slow_decay_pct
90
+ deltas = torch.linspace(min_decay, max_decay, d_model)[None, None]
91
+ self.register_buffer("deltas", deltas)
92
+
93
+ def forward(self, t, x):
94
+ if self.modulate:
95
+ decay = torch.exp(-t * self.deltas.abs())
96
+ x = x * (decay + self.shift)
97
+ return x
98
+
99
+
100
+ class HyenaFilter(nn.Module):
101
+ def __init__(
102
+ self,
103
+ config,
104
+ **kwargs
105
+ ):
106
+ """
107
+ Implicit long filter with modulation.
108
+
109
+ Args:
110
+ d_model: number of channels in the input
111
+ emb_dim: dimension of the positional encoding (`emb_dim` - 1) // 2 is the number of bands
112
+ order: width of the FFN
113
+ num_inner_mlps: number of inner linear layers inside filter MLP
114
+
115
+ Note:
116
+ filter_dropout is not implemented
117
+ """
118
+ super().__init__()
119
+
120
+ self.d_model = config.d_model * (config.hyena_order - 1)
121
+ self.use_bias = config.use_bias
122
+ self.bias = nn.Parameter(torch.randn(self.d_model))
123
+ self.dropout = nn.Dropout(config.hyena_filter_dropout)
124
+
125
+ act = HyenaSin(config)
126
+ self.emb_dim = config.emb_dim
127
+ assert self.emb_dim % 2 != 0 and self.emb_dim >= 3, "emb_dim must be odd and greater or equal to 3 (time, sine and cosine)"
128
+ self.seq_len = config.max_seq_len
129
+
130
+ self.pos_emb = HyenaPositionalEmbedding(config)
131
+
132
+ self.implicit_filter = nn.Sequential(
133
+ nn.Linear(self.emb_dim, config.filter_order),
134
+ act,
135
+ )
136
+ for i in range(config.num_inner_mlps):
137
+ self.implicit_filter.append(nn.Linear(config.filter_order, config.filter_order))
138
+ self.implicit_filter.append(act)
139
+
140
+ self.implicit_filter.append(nn.Linear(config.filter_order, config.d_model, bias=False))
141
+
142
+ self.modulation = HyenaExponentialModulation(config.d_model)
143
+
144
+ self.normalized = False
145
+
146
+ def filter(self, L, *args, **kwargs):
147
+ z, t = self.pos_emb(L)
148
+ h = self.implicit_filter(z.to(dtype=self.implicit_filter[0].weight.dtype))
149
+ h = self.modulation(t, h)
150
+ return h
151
+
152
+ def forward(self, x, L, k=None, bias=None, *args, **kwargs):
153
+ if k is None: k = self.filter(L)
154
+
155
+ # Ensure compatibility with filters that return a tuple
156
+ k = k[0] if type(k) is tuple else k
157
+
158
+ y = fftconv(x, k, bias)
159
+ return y
160
+
161
+
162
+ class HyenaOperator(nn.Module):
163
+ def __init__(
164
+ self,
165
+ config,
166
+ **filter_args,
167
+ ):
168
+ r"""
169
+ Hyena operator described in the paper https://arxiv.org/pdf/2302.10866.pdf
170
+
171
+ Args:
172
+ d_model (int): Dimension of the input and output embeddings (width of the layer)
173
+ l_max: (int): Maximum input sequence length. Defaults to None
174
+ order: (int): Depth of the Hyena recurrence. Defaults to 2
175
+ dropout: (float): Dropout probability. Defaults to 0.0
176
+ filter_dropout: (float): Dropout probability for the filter. Defaults to 0.0
177
+ """
178
+ super().__init__()
179
+
180
+ self.d_model = config.d_model
181
+ self.l_max = config.max_seq_len
182
+ self.order = config.hyena_order
183
+ inner_width = config.d_model * (self.order + 1)
184
+ self.dropout = nn.Dropout(config.hyena_dropout)
185
+ self.in_proj = nn.Linear(self.d_model, inner_width)
186
+ self.out_proj = nn.Linear(self.d_model, self.d_model)
187
+
188
+ self.short_filter = nn.Conv1d(
189
+ inner_width,
190
+ inner_width,
191
+ config.short_filter_order,
192
+ padding=2,
193
+ groups=inner_width
194
+ )
195
+ self.filter_fn = HyenaFilter(config)
196
+
197
+ def forward(self, u):
198
+ l = u.size(-2)
199
+ l_filter = min(l, self.l_max)
200
+ u = self.in_proj(u).transpose(1, 2)
201
+
202
+ uc = self.short_filter(u)[...,:l_filter]
203
+ *x, v = uc.split(self.d_model, dim=1)
204
+
205
+ k = self.filter_fn.filter(l_filter)[0]
206
+ k = k.transpose(0, 1).reshape(self.order - 1, self.d_model, l_filter)
207
+ bias = self.filter_fn.bias.reshape(self.order - 1, self.d_model)
208
+
209
+ for o, x_i in enumerate(reversed(x[1:])):
210
+ v = self.dropout(v * x_i)
211
+ v = self.filter_fn(v, l_filter, k=k[o], bias=bias[o])
212
+
213
+ y = (v * x[0]).transpose(1, 2)
214
+
215
+ y = self.out_proj(y)
216
+ return y
217
+
218
+ class HyenaMlp(nn.Module):
219
+
220
+ def __init__(self, config):
221
+ """
222
+ From https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/modules/mlp.py
223
+ """
224
+ super().__init__()
225
+ in_features = config.d_model
226
+ hidden_features = config.d_inner
227
+ self.fc1 = nn.Linear(in_features, hidden_features)
228
+ self.fc2 = nn.Linear(hidden_features, config.d_model)
229
+
230
+ def forward(self, x):
231
+ y = self.fc1(x)
232
+ y = F.gelu(y, approximate="tanh")
233
+ y = self.fc2(y)
234
+ return y
235
+
236
+ class HyenaBlock(nn.Module):
237
+
238
+ def __init__(self, config):
239
+ """
240
+ From https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/modules/block.py
241
+ For prenorm=True, this Block has a slightly different structure compared to a regular
242
+ prenorm Transformer block.
243
+ The standard block is: LN -> MHA -> Dropout -> Add -> LN -> MLP -> Dropout -> Add.
244
+ [Ref: https://arxiv.org/abs/2002.04745]
245
+ Here we have: Dropout -> Add -> LN -> MHA -> Dropout -> Add -> LN -> MLP, returning both
246
+ the hidden_states (output of the MLP) and the residual.
247
+ This is for performance reasons, as we can fuse the dropout, add and LayerNorm.
248
+ The residual needs to be provided (except for the very first block).
249
+ For prenorm=False, this Block has the same structure as a regular postnorm Transformer
250
+ block: MHA -> Dropout -> Add -> LN -> MLP -> Dropout -> Add -> LN.
251
+ return_residual: whether each of the sub-layers (mixer and mlp) will return the residual.
252
+ This is for performance reason: for post-norm architecture, returning the input allows us
253
+ to fuse the backward of nn.Linear with the residual connection.
254
+ """
255
+ super().__init__()
256
+ self.mixer = HyenaOperator(config)
257
+ self.norm1 = nn.LayerNorm(config.d_model)
258
+ self.mlp = HyenaMlp(config)
259
+ self.norm2 = nn.LayerNorm(config.d_model)
260
+
261
+ def forward(self, hidden_states):
262
+ r"""Pass the input through the encoder layer.
263
+ Args:
264
+ hidden_states: the sequence to the encoder layer (required).
265
+ residual: if postnorm, residual=None, If prenorm, hidden_states = Attn/MLP(LN(residual))
266
+ mixer_subset: for cross-attention only. If not None, will take a subset of x
267
+ before applying the query projection. Useful for e.g., ViT where we only care
268
+ about the CLS token in the last layer.
269
+ """
270
+ residual = hidden_states
271
+ residual = residual.to(torch.float32)
272
+ hyena_normed = self.norm1(residual.to(dtype=self.norm1.weight.dtype))
273
+ hidden_states = self.mixer(hyena_normed)
274
+ # Tested above here and all is equivalent. That means the mixer is fine!!!
275
+ residual = hidden_states + residual
276
+ hidden_states = self.norm2(residual.to(dtype=self.norm2.weight.dtype))
277
+ residual = residual.to(torch.float32)
278
+
279
+ hidden_states = self.mlp(hidden_states)
280
+ return hidden_states + residual
281
+
282
+
283
+ # https://github.com/huggingface/transformers/blob/c28d04e9e252a1a099944e325685f14d242ecdcd/src/transformers/models/gpt2/modeling_gpt2.py#L454
284
+
285
+
286
+ class HyenaEmbeddings(nn.Module):
287
+
288
+ def __init__(self, config, padding_idx=None):
289
+ """
290
+ If max_position_embeddings <= 0, there's no position embeddings
291
+ If word_embe_proj_dim is not None (e.g., OPT-350m), we embed to that dimension
292
+ the project up to embed_dim
293
+ """
294
+ super().__init__()
295
+ vocab_size = config.vocab_size
296
+ if vocab_size % config.pad_vocab_size_multiple != 0:
297
+ vocab_size += config.pad_vocab_size_multiple - (vocab_size % config.pad_vocab_size_multiple)
298
+ self.word_embeddings = nn.Embedding(vocab_size, config.d_model, padding_idx=padding_idx)
299
+
300
+ def forward(self, input_ids):
301
+ """
302
+ input_ids: (batch, seqlen)
303
+ """
304
+ embeddings = self.word_embeddings(input_ids)
305
+ return embeddings
306
+
307
+ class HyenaLMBackbone(nn.Module):
308
+
309
+ def __init__(self, config) -> None:
310
+ super().__init__()
311
+ # note max_position_embeddings is 0 for Hyena, and therefore isn't used
312
+ self.embeddings = HyenaEmbeddings(config)
313
+ self.dropout = nn.Dropout(config.embed_dropout)
314
+
315
+ self.layers = nn.ModuleList([HyenaBlock(config) for i in range(config.n_layer)])
316
+
317
+ self.ln_f = nn.LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
318
+ self.gradient_checkpointing = False
319
+
320
+ def forward(self, input_ids, inputs_embeds=None, output_hidden_states=False):
321
+ all_hidden_states = []
322
+ if inputs_embeds is not None:
323
+ hidden_states = inputs_embeds
324
+ else:
325
+ hidden_states = self.embeddings(input_ids)
326
+ if output_hidden_states:
327
+ all_hidden_states.append(hidden_states)
328
+
329
+ for layer in self.layers:
330
+ if self.gradient_checkpointing and self.training:
331
+ hidden_states = self._gradient_checkpointing_func(layer.__call__, hidden_states)
332
+ else:
333
+ hidden_states = layer(hidden_states)
334
+ if output_hidden_states:
335
+ all_hidden_states.append(hidden_states)
336
+
337
+ hidden_states = self.ln_f(hidden_states.to(dtype=self.ln_f.weight.dtype))
338
+ if output_hidden_states:
339
+ all_hidden_states.append(hidden_states)
340
+
341
+ return hidden_states, all_hidden_states
342
+
343
+
344
+ class HyenaDNAPreTrainedModel(PreTrainedModel):
345
+ config_class = HyenaConfig
346
+ base_model_prefix = "hyena"
347
+ supports_gradient_checkpointing = True
348
+ _no_split_modules = ["HyenaBlock"]
349
+ _skip_keys_device_placement = "past_key_values"
350
+ _keys_to_ignore_on_load_missing = [r"freq"] # Shared tensors that safetensors merges
351
+
352
+ def _init_weights(self, module, initializer_range=0.02):
353
+ if isinstance(module, nn.Linear):
354
+ nn.init.normal_(module.weight, std=initializer_range)
355
+ if module.bias is not None:
356
+ nn.init.zeros_(module.bias)
357
+ elif isinstance(module, nn.Embedding):
358
+ nn.init.normal_(module.weight, std=initializer_range)
359
+ # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
360
+ # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
361
+ # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
362
+ # > -- GPT-2 :: https://openai.com/blog/better-language-models/
363
+ #
364
+ # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
365
+ for name, p in self.named_parameters():
366
+ if name in ["out_proj.weight", "fc2.weight"]:
367
+ # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
368
+ nn.init.normal_(p, mean=0.0, std=initializer_range / math.sqrt(2 * self.config.num_layers))
369
+ # If using GLU activation for now, we scale the std by 2
370
+ elif name in ["output_linear.0.weight"]:
371
+ # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
372
+ nn.init.normal_(p, mean=0.0, std=initializer_range / math.sqrt(2 * self.config.num_layers))
373
+
374
+
375
+ class HyenaDNAModel(HyenaDNAPreTrainedModel):
376
+ def __init__(self, config, **kwargs) -> None:
377
+ super().__init__(config, **kwargs)
378
+
379
+ self.backbone = HyenaLMBackbone(config)
380
+ self.config = config
381
+
382
+ # Initialize weights and apply final processing
383
+ self.post_init()
384
+
385
+ def forward(self, input_ids, inputs_embeds=None, output_hidden_states=None, return_dict=None):
386
+ output_hidden_states = (
387
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
388
+ )
389
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
390
+
391
+ hidden_states, all_hidden_states = self.backbone(input_ids, inputs_embeds=inputs_embeds, output_hidden_states=output_hidden_states)
392
+ if return_dict:
393
+ return BaseModelOutputWithNoAttention(last_hidden_state=hidden_states,
394
+ hidden_states=all_hidden_states if output_hidden_states else None)
395
+ elif output_hidden_states:
396
+ return hidden_states, all_hidden_states
397
+ else:
398
+ return hidden_states
399
+
400
+
401
+ class HyenaDNAForCausalLM(HyenaDNAPreTrainedModel):
402
+
403
+ def __init__(self, config, **kwargs):
404
+ super().__init__(config, **kwargs)
405
+ self.hyena = HyenaDNAModel(config)
406
+ vocab_size = config.vocab_size
407
+ if vocab_size % config.pad_vocab_size_multiple != 0:
408
+ vocab_size += config.pad_vocab_size_multiple - (vocab_size % config.pad_vocab_size_multiple)
409
+ self.vocab_size = vocab_size
410
+ self.lm_head = nn.Linear(config.d_model, vocab_size, bias=False)
411
+
412
+ # Initialize weights and apply final processing
413
+ self.post_init()
414
+
415
+ def get_input_embeddings(self):
416
+ return self.hyena.backbone.embeddings.word_embeddings
417
+
418
+ def set_input_embeddings(self, value):
419
+ self.hyena.backbone.embeddings.word_embeddings = value
420
+
421
+ def get_output_embeddings(self):
422
+ return self.lm_head
423
+
424
+ def set_output_embeddings(self, new_embeddings):
425
+ self.lm_head = new_embeddings
426
+
427
+ def set_decoder(self, decoder):
428
+ self.hyena = decoder
429
+
430
+ def get_decoder(self):
431
+ return self.hyena
432
+
433
+ def forward(
434
+ self,
435
+ input_ids: torch.LongTensor = None,
436
+ inputs_embeds: Optional[torch.FloatTensor] = None,
437
+ labels: Optional[torch.LongTensor] = None,
438
+ output_hidden_states: Optional[bool] = None,
439
+ return_dict: Optional[bool] = None,
440
+ ) -> Union[Tuple, CausalLMOutput]:
441
+
442
+ output_hidden_states = (
443
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
444
+ )
445
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
446
+
447
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
448
+ outputs = self.hyena(
449
+ input_ids=input_ids,
450
+ inputs_embeds=inputs_embeds,
451
+ output_hidden_states=output_hidden_states,
452
+ return_dict=return_dict,
453
+ )
454
+
455
+ hidden_states = outputs[0]
456
+ logits = self.lm_head(hidden_states)
457
+ logits = logits.float()
458
+
459
+ loss = None
460
+ if labels is not None:
461
+ # Shift so that tokens < n predict n
462
+ shift_logits = logits[..., :-1, :].contiguous()
463
+ shift_labels = labels[..., 1:].contiguous()
464
+ # Flatten the tokens
465
+ loss_fct = nn.CrossEntropyLoss()
466
+ shift_logits = shift_logits.view(-1, self.vocab_size)
467
+ shift_labels = shift_labels.view(-1)
468
+ # Enable model parallelism
469
+ shift_labels = shift_labels.to(shift_logits.device)
470
+ loss = loss_fct(shift_logits, shift_labels)
471
+
472
+ if not return_dict:
473
+ output = (logits,) + outputs[1:]
474
+ return (loss,) + output if loss is not None else output
475
+
476
+ return CausalLMOutput(
477
+ loss=loss,
478
+ logits=logits,
479
+ hidden_states=outputs.hidden_states,
480
+ )
481
+
482
+
483
+ class HyenaDNAForSequenceClassification(HyenaDNAPreTrainedModel):
484
+ def __init__(self, config, **kwargs):
485
+ super().__init__(config, **kwargs)
486
+ self.num_labels = kwargs.get("num_labels", config.num_labels)
487
+ self.hyena = HyenaDNAModel(config)
488
+ self.score = nn.Linear(config.d_model, self.num_labels, bias=False)
489
+
490
+ # Initialize weights and apply final processing
491
+ self.post_init()
492
+
493
+ def get_input_embeddings(self):
494
+ return self.hyena.backbone.embeddings.word_embeddings
495
+
496
+ def set_input_embeddings(self, value):
497
+ self.hyena.backbone.embeddings.word_embeddings = value
498
+
499
+ def forward(
500
+ self,
501
+ input_ids: torch.LongTensor = None,
502
+ inputs_embeds: Optional[torch.FloatTensor] = None,
503
+ labels: Optional[torch.LongTensor] = None,
504
+ output_hidden_states: Optional[bool] = None,
505
+ return_dict: Optional[bool] = None,
506
+ ) -> Union[Tuple, SequenceClassifierOutput]:
507
+ r"""
508
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
509
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
510
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
511
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
512
+ """
513
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
514
+
515
+ transformer_outputs = self.hyena(
516
+ input_ids,
517
+ inputs_embeds=inputs_embeds,
518
+ output_hidden_states=output_hidden_states,
519
+ return_dict=return_dict,
520
+ )
521
+ hidden_states = transformer_outputs[0]
522
+ logits = self.score(hidden_states)
523
+
524
+ if input_ids is not None:
525
+ batch_size = input_ids.shape[0]
526
+ else:
527
+ batch_size = inputs_embeds.shape[0]
528
+
529
+ if self.config.pad_token_id is None and batch_size != 1:
530
+ raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
531
+ if self.config.pad_token_id is None:
532
+ sequence_lengths = -1
533
+ else:
534
+ if input_ids is not None:
535
+ sequence_lengths = (torch.eq(input_ids, self.config.pad_token_id).long().argmax(-1) - 1).to(
536
+ logits.device
537
+ )
538
+ else:
539
+ sequence_lengths = -1
540
+
541
+ pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
542
+
543
+ loss = None
544
+ if labels is not None:
545
+ labels = labels.to(logits.device)
546
+ if self.config.problem_type is None:
547
+ if self.num_labels == 1:
548
+ self.config.problem_type = "regression"
549
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
550
+ self.config.problem_type = "single_label_classification"
551
+ else:
552
+ self.config.problem_type = "multi_label_classification"
553
+
554
+ if self.config.problem_type == "regression":
555
+ loss_fct = nn.MSELoss()
556
+ if self.num_labels == 1:
557
+ loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
558
+ else:
559
+ loss = loss_fct(pooled_logits, labels)
560
+ elif self.config.problem_type == "single_label_classification":
561
+ loss_fct = nn.CrossEntropyLoss()
562
+ loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
563
+ elif self.config.problem_type == "multi_label_classification":
564
+ loss_fct = nn.BCEWithLogitsLoss()
565
+ loss = loss_fct(pooled_logits, labels)
566
+ if not return_dict:
567
+ output = (pooled_logits,) + transformer_outputs[1:]
568
+ return ((loss,) + output) if loss is not None else output
569
+
570
+ return SequenceClassifierOutput(
571
+ loss=loss,
572
+ logits=pooled_logits,
573
+ hidden_states=transformer_outputs.hidden_states,
574
+ )
checkpoint-600/optimizer.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:275d48d3b7464b79d184fe1244342cca6f3d8e813f5d8f03461334b719d4c105
3
+ size 26304517
checkpoint-600/pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2dbaff97807d7961eb4208709c13b67c9b7fb92e5f7418a96202fef0ae7e5dd5
3
+ size 16300157
checkpoint-600/rng_state.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cde37b1d59ec8275aa116b67c6503ca997bd9a90886b106af365cbb8dfef4db9
3
+ size 14575
checkpoint-600/scaler.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d21d167def712bc3f600520b89e654492b844aa18ab3c37e7c0d3d698a2a65b5
3
+ size 557
checkpoint-600/scheduler.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:af037f192434655ed417b8694a7888adca54958a4cd3e44898830064a0a7c1ee
3
+ size 627
checkpoint-600/special_tokens_map.json ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": {
3
+ "content": "[BOS]",
4
+ "lstrip": false,
5
+ "normalized": false,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "cls_token": {
10
+ "content": "[CLS]",
11
+ "lstrip": false,
12
+ "normalized": false,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "eos_token": {
17
+ "content": "[SEP]",
18
+ "lstrip": false,
19
+ "normalized": false,
20
+ "rstrip": false,
21
+ "single_word": false
22
+ },
23
+ "mask_token": {
24
+ "content": "[MASK]",
25
+ "lstrip": false,
26
+ "normalized": false,
27
+ "rstrip": false,
28
+ "single_word": false
29
+ },
30
+ "pad_token": {
31
+ "content": "[PAD]",
32
+ "lstrip": false,
33
+ "normalized": false,
34
+ "rstrip": false,
35
+ "single_word": false
36
+ },
37
+ "sep_token": {
38
+ "content": "[SEP]",
39
+ "lstrip": false,
40
+ "normalized": false,
41
+ "rstrip": false,
42
+ "single_word": false
43
+ },
44
+ "unk_token": {
45
+ "content": "[UNK]",
46
+ "lstrip": false,
47
+ "normalized": false,
48
+ "rstrip": false,
49
+ "single_word": false
50
+ }
51
+ }
checkpoint-600/tokenization_hyena.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PreTrainedTokenizer, AddedToken
2
+ from typing import List, Optional, Union, Dict, Sequence, Tuple
3
+ from pathlib import Path
4
+ import json
5
+ import os
6
+
7
+
8
+ class HyenaDNATokenizer(PreTrainedTokenizer):
9
+ model_input_names = ["input_ids"]
10
+
11
+ def __init__(self,
12
+ model_max_length: int,
13
+ bos_token="[BOS]",
14
+ eos_token="[SEP]",
15
+ sep_token="[SEP]",
16
+ cls_token="[CLS]",
17
+ pad_token="[PAD]",
18
+ mask_token="[MASK]",
19
+ unk_token="[UNK]",
20
+ **kwargs):
21
+ """Character tokenizer for Hugging Face transformers.
22
+ Args:
23
+ characters (Sequence[str]): List of desired characters. Any character which
24
+ is not included in this list will be replaced by a special token called
25
+ [UNK] with id=6. Following are list of all of the special tokens with
26
+ their corresponding ids:
27
+ "[CLS]": 0
28
+ "[SEP]": 1
29
+ "[BOS]": 2
30
+ "[MASK]": 3
31
+ "[PAD]": 4
32
+ "[RESERVED]": 5
33
+ "[UNK]": 6
34
+ an id (starting at 7) will be assigned to each character.
35
+ model_max_length (int): Model maximum sequence length.
36
+ """
37
+ self.characters = ('A', 'C', 'G', 'T', 'N')
38
+ self.model_max_length = model_max_length
39
+
40
+ self._vocab_str_to_int = {
41
+ "[CLS]": 0,
42
+ "[SEP]": 1,
43
+ "[BOS]": 2,
44
+ "[MASK]": 3,
45
+ "[PAD]": 4,
46
+ "[RESERVED]": 5,
47
+ "[UNK]": 6,
48
+ **{ch: i + 7 for i, ch in enumerate(self.characters)},
49
+ }
50
+ self._vocab_int_to_str = {v: k for k, v in self._vocab_str_to_int.items()}
51
+ add_prefix_space = kwargs.pop("add_prefix_space", False)
52
+ padding_side = kwargs.pop("padding_side", "left")
53
+
54
+ super().__init__(
55
+ bos_token=bos_token,
56
+ eos_token=eos_token,
57
+ sep_token=sep_token,
58
+ cls_token=cls_token,
59
+ pad_token=pad_token,
60
+ mask_token=mask_token,
61
+ unk_token=unk_token,
62
+ add_prefix_space=add_prefix_space,
63
+ model_max_length=model_max_length,
64
+ padding_side=padding_side,
65
+ **kwargs,
66
+ )
67
+
68
+ @property
69
+ def vocab_size(self) -> int:
70
+ return len(self._vocab_str_to_int)
71
+
72
+ def _tokenize(self, text: str) -> List[str]:
73
+ return list(text)
74
+
75
+ def _convert_token_to_id(self, token: str) -> int:
76
+ return self._vocab_str_to_int.get(token, self._vocab_str_to_int["[UNK]"])
77
+
78
+ def _convert_id_to_token(self, index: int) -> str:
79
+ return self._vocab_int_to_str[index]
80
+
81
+ def convert_tokens_to_string(self, tokens):
82
+ return "".join(tokens)
83
+
84
+ def get_special_tokens_mask(
85
+ self,
86
+ token_ids_0: List[int],
87
+ token_ids_1: Optional[List[int]] = None,
88
+ already_has_special_tokens: bool = False,
89
+ ) -> List[int]:
90
+ if already_has_special_tokens:
91
+ return super().get_special_tokens_mask(
92
+ token_ids_0=token_ids_0,
93
+ token_ids_1=token_ids_1,
94
+ already_has_special_tokens=True,
95
+ )
96
+
97
+ result = ([0] * len(token_ids_0)) + [1]
98
+ if token_ids_1 is not None:
99
+ result += ([0] * len(token_ids_1)) + [1]
100
+ return result
101
+
102
+ def build_inputs_with_special_tokens(
103
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
104
+ ) -> List[int]:
105
+ sep = [self.sep_token_id]
106
+ # cls = [self.cls_token_id]
107
+ result = token_ids_0 + sep
108
+ if token_ids_1 is not None:
109
+ result += token_ids_1 + sep
110
+ return result
111
+
112
+ def get_vocab(self) -> Dict[str, int]:
113
+ return self._vocab_str_to_int
114
+
115
+ # HyenaDNA has a fixed vocabulary with no vocab file
116
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple:
117
+ return ()
checkpoint-600/tokenizer_config.json ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_prefix_space": false,
3
+ "added_tokens_decoder": {
4
+ "0": {
5
+ "content": "[CLS]",
6
+ "lstrip": false,
7
+ "normalized": false,
8
+ "rstrip": false,
9
+ "single_word": false,
10
+ "special": true
11
+ },
12
+ "1": {
13
+ "content": "[SEP]",
14
+ "lstrip": false,
15
+ "normalized": false,
16
+ "rstrip": false,
17
+ "single_word": false,
18
+ "special": true
19
+ },
20
+ "2": {
21
+ "content": "[BOS]",
22
+ "lstrip": false,
23
+ "normalized": false,
24
+ "rstrip": false,
25
+ "single_word": false,
26
+ "special": true
27
+ },
28
+ "3": {
29
+ "content": "[MASK]",
30
+ "lstrip": false,
31
+ "normalized": false,
32
+ "rstrip": false,
33
+ "single_word": false,
34
+ "special": true
35
+ },
36
+ "4": {
37
+ "content": "[PAD]",
38
+ "lstrip": false,
39
+ "normalized": false,
40
+ "rstrip": false,
41
+ "single_word": false,
42
+ "special": true
43
+ },
44
+ "6": {
45
+ "content": "[UNK]",
46
+ "lstrip": false,
47
+ "normalized": false,
48
+ "rstrip": false,
49
+ "single_word": false,
50
+ "special": true
51
+ }
52
+ },
53
+ "auto_map": {
54
+ "AutoTokenizer": [
55
+ "tokenization_hyena.HyenaDNATokenizer",
56
+ null
57
+ ]
58
+ },
59
+ "bos_token": "[BOS]",
60
+ "clean_up_tokenization_spaces": true,
61
+ "cls_token": "[CLS]",
62
+ "eos_token": "[SEP]",
63
+ "mask_token": "[MASK]",
64
+ "model_max_length": 256,
65
+ "name_or_path": "LongSafari/hyenadna-small-32k-seqlen-hf",
66
+ "pad_token": "[PAD]",
67
+ "padding_side": "right",
68
+ "sep_token": "[SEP]",
69
+ "special_tokens_map_file": "/home/hlv8980/.cache/huggingface/hub/models--LongSafari--hyenadna-small-32k-seqlen-hf/snapshots/8fe770c78eb13fe33bf81501612faeddf4d6f331/special_tokens_map.json",
70
+ "tokenizer_class": "HyenaDNATokenizer",
71
+ "unk_token": "[UNK]"
72
+ }
checkpoint-600/trainer_state.json ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "best_metric": 0.39216598868370056,
3
+ "best_model_checkpoint": "/scratch/hlv8980/Attack_Benchmark/models/hyena/tf4/origin/checkpoint-600",
4
+ "epoch": 2.0202020202020203,
5
+ "global_step": 600,
6
+ "is_hyper_param_search": false,
7
+ "is_local_process_zero": true,
8
+ "is_world_process_zero": true,
9
+ "log_history": [
10
+ {
11
+ "epoch": 0.34,
12
+ "learning_rate": 2.8760984182776802e-05,
13
+ "loss": 0.5992,
14
+ "step": 100
15
+ },
16
+ {
17
+ "epoch": 0.67,
18
+ "learning_rate": 2.615114235500879e-05,
19
+ "loss": 0.4813,
20
+ "step": 200
21
+ },
22
+ {
23
+ "epoch": 0.67,
24
+ "eval_accuracy": 0.774,
25
+ "eval_f1": 0.7713328260834371,
26
+ "eval_loss": 0.48207539319992065,
27
+ "eval_matthews_correlation": 0.5579338694412199,
28
+ "eval_precision": 0.785067107786007,
29
+ "eval_recall": 0.772997299729973,
30
+ "eval_runtime": 0.1057,
31
+ "eval_samples_per_second": 9462.679,
32
+ "eval_steps_per_second": 151.403,
33
+ "step": 200
34
+ },
35
+ {
36
+ "epoch": 1.01,
37
+ "learning_rate": 2.3514938488576452e-05,
38
+ "loss": 0.4431,
39
+ "step": 300
40
+ },
41
+ {
42
+ "epoch": 1.35,
43
+ "learning_rate": 2.087873462214411e-05,
44
+ "loss": 0.377,
45
+ "step": 400
46
+ },
47
+ {
48
+ "epoch": 1.35,
49
+ "eval_accuracy": 0.816,
50
+ "eval_f1": 0.8159933757615274,
51
+ "eval_loss": 0.427643358707428,
52
+ "eval_matthews_correlation": 0.6320128653971173,
53
+ "eval_precision": 0.815991263965056,
54
+ "eval_recall": 0.816021602160216,
55
+ "eval_runtime": 0.1039,
56
+ "eval_samples_per_second": 9625.637,
57
+ "eval_steps_per_second": 154.01,
58
+ "step": 400
59
+ },
60
+ {
61
+ "epoch": 1.68,
62
+ "learning_rate": 1.82688927943761e-05,
63
+ "loss": 0.3443,
64
+ "step": 500
65
+ },
66
+ {
67
+ "epoch": 2.02,
68
+ "learning_rate": 1.563268892794376e-05,
69
+ "loss": 0.33,
70
+ "step": 600
71
+ },
72
+ {
73
+ "epoch": 2.02,
74
+ "eval_accuracy": 0.824,
75
+ "eval_f1": 0.8239746523499383,
76
+ "eval_loss": 0.39216598868370056,
77
+ "eval_matthews_correlation": 0.6479558982194922,
78
+ "eval_precision": 0.8239935027265344,
79
+ "eval_recall": 0.8239623962396239,
80
+ "eval_runtime": 0.1031,
81
+ "eval_samples_per_second": 9696.512,
82
+ "eval_steps_per_second": 155.144,
83
+ "step": 600
84
+ }
85
+ ],
86
+ "max_steps": 1188,
87
+ "num_train_epochs": 4,
88
+ "total_flos": 76909184335872.0,
89
+ "trial_name": null,
90
+ "trial_params": null
91
+ }
checkpoint-600/training_args.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1f4e3a85efd6ca1a2228fc0bf6f5ca43150a2981352327acbf61aa5be7e43d49
3
+ size 3707
checkpoint-800/config.json ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "LongSafari/hyenadna-small-32k-seqlen-hf",
3
+ "activation_freq": 10,
4
+ "architectures": [
5
+ "HyenaDNAForSequenceClassification"
6
+ ],
7
+ "auto_map": {
8
+ "AutoConfig": "configuration_hyena.HyenaConfig",
9
+ "AutoModel": "modeling_hyena.HyenaDNAModel",
10
+ "AutoModelForCausalLM": "modeling_hyena.HyenaDNAForCausalLM",
11
+ "AutoModelForSequenceClassification": "modeling_hyena.HyenaDNAForSequenceClassification"
12
+ },
13
+ "d_inner": 1024,
14
+ "d_model": 256,
15
+ "emb_dim": 5,
16
+ "embed_dropout": 0.1,
17
+ "filter_order": 64,
18
+ "hyena_dropout": 0.0,
19
+ "hyena_filter_dropout": 0.0,
20
+ "hyena_order": 2,
21
+ "initializer_range": 0.02,
22
+ "layer_norm_epsilon": 1e-05,
23
+ "max_seq_len": 32770,
24
+ "model_type": "hyenadna",
25
+ "n_layer": 4,
26
+ "num_inner_mlps": 2,
27
+ "pad_token_id": 4,
28
+ "pad_vocab_size_multiple": 8,
29
+ "problem_type": "single_label_classification",
30
+ "short_filter_order": 3,
31
+ "tie_word_embeddings": false,
32
+ "torch_dtype": "float32",
33
+ "train_freq": true,
34
+ "transformers_version": "4.26.1",
35
+ "use_bias": true,
36
+ "vocab_size": 12
37
+ }
checkpoint-800/configuration_hyena.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+ import json
3
+
4
+
5
+ class HyenaConfig(PretrainedConfig):
6
+ model_type = "hyenadna"
7
+ def __init__(
8
+ self,
9
+ vocab_size=12,
10
+ d_model=256,
11
+ d_inner=None,
12
+ use_bias=True,
13
+ train_freq=True,
14
+ max_seq_len=1024,
15
+ emb_dim=3,
16
+ n_layer=12,
17
+ num_inner_mlps=2,
18
+ hyena_order=2,
19
+ short_filter_order=3,
20
+ filter_order=64,
21
+ activation_freq=1,
22
+ embed_dropout=0.1,
23
+ hyena_dropout=0.0,
24
+ hyena_filter_dropout=0.0,
25
+ layer_norm_epsilon=1e-5,
26
+ initializer_range=0.02,
27
+ pad_vocab_size_multiple=8,
28
+ **kwargs,
29
+ ):
30
+ self.vocab_size = vocab_size
31
+ self.d_model = d_model
32
+ if d_inner is None:
33
+ self.d_inner = 4 * d_model
34
+ else:
35
+ self.d_inner = d_inner
36
+ self.use_bias = use_bias
37
+ self.train_freq = train_freq
38
+ self.max_seq_len = max_seq_len
39
+ self.emb_dim = emb_dim
40
+ self.n_layer = n_layer
41
+ self.hyena_order = hyena_order
42
+ self.filter_order = filter_order
43
+ self.short_filter_order = short_filter_order
44
+ self.activation_freq = activation_freq
45
+ self.num_inner_mlps = num_inner_mlps
46
+ self.embed_dropout = embed_dropout
47
+ self.hyena_dropout = hyena_dropout
48
+ self.hyena_filter_dropout = hyena_filter_dropout
49
+ self.layer_norm_epsilon = layer_norm_epsilon
50
+ self.initializer_range = initializer_range
51
+ self.pad_vocab_size_multiple = pad_vocab_size_multiple
52
+ super().__init__(**kwargs)
53
+
54
+ @classmethod
55
+ def from_original_config(cls, config_path, **kwargs):
56
+ with open(config_path, "r") as f:
57
+ config = json.load(f)
58
+
59
+ vocab_size = config["vocab_size"]
60
+ d_model = config["d_model"]
61
+ d_inner = config["d_inner"]
62
+ max_seq_len = config["layer"]["l_max"]
63
+ emb_dim = config["layer"]["emb_dim"]
64
+ filter_order = config["layer"]["filter_order"]
65
+ if "local_order" in config["layer"]:
66
+ short_filter_order = config["layer"]["local_order"]
67
+ elif "short_filter_order" in config["layer"]:
68
+ short_filter_order = config["layer"]["short_filter_order"]
69
+ else:
70
+ short_filter_order = 3
71
+ n_layer = config["n_layer"]
72
+ activation_freq = config["layer"]["w"]
73
+ embed_dropout = config["embed_dropout"]
74
+ pad_vocab_size_multiple = config["pad_vocab_size_multiple"]
75
+ return cls(vocab_size=vocab_size,
76
+ d_model=d_model,
77
+ d_inner=d_inner,
78
+ max_seq_len=max_seq_len,
79
+ emb_dim=emb_dim,
80
+ filter_order=filter_order,
81
+ short_filter_order=short_filter_order,
82
+ n_layer=n_layer,
83
+ activation_freq=activation_freq,
84
+ embed_dropout=embed_dropout,
85
+ pad_vocab_size_multiple=pad_vocab_size_multiple,
86
+ tie_word_embeddings=False,
87
+ **kwargs
88
+ )
checkpoint-800/modeling_hyena.py ADDED
@@ -0,0 +1,574 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """HyenaDNA custom code port to Hugging Face Hub"""
3
+
4
+ import math
5
+ import torch
6
+ import torch.nn as nn
7
+ from torch.nn import functional as F
8
+ from .configuration_hyena import HyenaConfig
9
+ from transformers import PreTrainedModel
10
+ from typing import Optional, Tuple, Union
11
+ from transformers.modeling_outputs import CausalLMOutput, SequenceClassifierOutput, BaseModelOutputWithNoAttention
12
+
13
+
14
+ def fftconv(u, k, D):
15
+ """
16
+ We apply a convolution through the fourier domain (from the Convolution Theorem)
17
+
18
+ """
19
+ seqlen = u.shape[-1]
20
+ fft_size = 2 * seqlen
21
+
22
+ k_f = torch.fft.rfft(k.to(torch.float32), n=fft_size) / fft_size
23
+ u_f = torch.fft.rfft(u.to(dtype=torch.float32), n=fft_size)
24
+
25
+ if len(u.shape) > 3: k_f = k_f.unsqueeze(1)
26
+ y = torch.fft.irfft(u_f * k_f, n=fft_size, norm='forward')[..., :seqlen]
27
+
28
+ out = y + u * D.unsqueeze(-1)
29
+ return out.to(dtype=u.dtype)
30
+
31
+
32
+ @torch.jit.script
33
+ def mul_sum(q, y):
34
+ return (q * y).sum(dim=1)
35
+
36
+
37
+ class HyenaSin(nn.Module):
38
+ """The Sin activation function for the Hyena Filter function."""
39
+ def __init__(self, config):
40
+ super().__init__()
41
+ self.freq = nn.Parameter(config.activation_freq * torch.ones(1, config.filter_order)) if config.train_freq else config.activation_freq * torch.ones(1, config.filter_order)
42
+
43
+ def forward(self, x):
44
+ return torch.sin(self.freq * x)
45
+
46
+
47
+ class HyenaPositionalEmbedding(nn.Module):
48
+ def __init__(self, config):
49
+ """Complex exponential positional embeddings for Hyena filters."""
50
+ super().__init__()
51
+
52
+ self.seq_len = config.max_seq_len
53
+ # The time embedding fed to the filteres is normalized so that t_f = 1
54
+ t = torch.linspace(0, 1, self.seq_len)[None, :, None] # 1, L, 1
55
+
56
+ if config.emb_dim > 1:
57
+ bands = (config.emb_dim - 1) // 2
58
+ # To compute the right embeddings we use the "proper" linspace
59
+ t_rescaled = torch.linspace(0, self.seq_len - 1, self.seq_len)[None, :, None]
60
+ w = 2 * math.pi * t_rescaled / self.seq_len # 1, L, 1
61
+
62
+ f = torch.linspace(1e-4, bands - 1, bands)[None, None]
63
+
64
+ z = torch.cat([t, torch.cos(-f * w), torch.sin(-f * w)], dim=-1)
65
+
66
+ self.register_buffer("z", z)
67
+ self.register_buffer("t", t)
68
+
69
+ def forward(self, L):
70
+ return self.z[:, :L], self.t[:, :L]
71
+
72
+
73
+ class HyenaExponentialModulation(nn.Module):
74
+ """The window function applied to the output of the (MLP) filter function."""
75
+ def __init__(
76
+ self,
77
+ d_model,
78
+ fast_decay_pct=0.3,
79
+ slow_decay_pct=1.5,
80
+ target=1e-2,
81
+ modulate: bool=True,
82
+ shift: float = 0.05,
83
+ **kwargs
84
+ ):
85
+ super().__init__()
86
+ self.modulate = modulate
87
+ self.shift = shift
88
+ max_decay = math.log(target) / fast_decay_pct
89
+ min_decay = math.log(target) / slow_decay_pct
90
+ deltas = torch.linspace(min_decay, max_decay, d_model)[None, None]
91
+ self.register_buffer("deltas", deltas)
92
+
93
+ def forward(self, t, x):
94
+ if self.modulate:
95
+ decay = torch.exp(-t * self.deltas.abs())
96
+ x = x * (decay + self.shift)
97
+ return x
98
+
99
+
100
+ class HyenaFilter(nn.Module):
101
+ def __init__(
102
+ self,
103
+ config,
104
+ **kwargs
105
+ ):
106
+ """
107
+ Implicit long filter with modulation.
108
+
109
+ Args:
110
+ d_model: number of channels in the input
111
+ emb_dim: dimension of the positional encoding (`emb_dim` - 1) // 2 is the number of bands
112
+ order: width of the FFN
113
+ num_inner_mlps: number of inner linear layers inside filter MLP
114
+
115
+ Note:
116
+ filter_dropout is not implemented
117
+ """
118
+ super().__init__()
119
+
120
+ self.d_model = config.d_model * (config.hyena_order - 1)
121
+ self.use_bias = config.use_bias
122
+ self.bias = nn.Parameter(torch.randn(self.d_model))
123
+ self.dropout = nn.Dropout(config.hyena_filter_dropout)
124
+
125
+ act = HyenaSin(config)
126
+ self.emb_dim = config.emb_dim
127
+ assert self.emb_dim % 2 != 0 and self.emb_dim >= 3, "emb_dim must be odd and greater or equal to 3 (time, sine and cosine)"
128
+ self.seq_len = config.max_seq_len
129
+
130
+ self.pos_emb = HyenaPositionalEmbedding(config)
131
+
132
+ self.implicit_filter = nn.Sequential(
133
+ nn.Linear(self.emb_dim, config.filter_order),
134
+ act,
135
+ )
136
+ for i in range(config.num_inner_mlps):
137
+ self.implicit_filter.append(nn.Linear(config.filter_order, config.filter_order))
138
+ self.implicit_filter.append(act)
139
+
140
+ self.implicit_filter.append(nn.Linear(config.filter_order, config.d_model, bias=False))
141
+
142
+ self.modulation = HyenaExponentialModulation(config.d_model)
143
+
144
+ self.normalized = False
145
+
146
+ def filter(self, L, *args, **kwargs):
147
+ z, t = self.pos_emb(L)
148
+ h = self.implicit_filter(z.to(dtype=self.implicit_filter[0].weight.dtype))
149
+ h = self.modulation(t, h)
150
+ return h
151
+
152
+ def forward(self, x, L, k=None, bias=None, *args, **kwargs):
153
+ if k is None: k = self.filter(L)
154
+
155
+ # Ensure compatibility with filters that return a tuple
156
+ k = k[0] if type(k) is tuple else k
157
+
158
+ y = fftconv(x, k, bias)
159
+ return y
160
+
161
+
162
+ class HyenaOperator(nn.Module):
163
+ def __init__(
164
+ self,
165
+ config,
166
+ **filter_args,
167
+ ):
168
+ r"""
169
+ Hyena operator described in the paper https://arxiv.org/pdf/2302.10866.pdf
170
+
171
+ Args:
172
+ d_model (int): Dimension of the input and output embeddings (width of the layer)
173
+ l_max: (int): Maximum input sequence length. Defaults to None
174
+ order: (int): Depth of the Hyena recurrence. Defaults to 2
175
+ dropout: (float): Dropout probability. Defaults to 0.0
176
+ filter_dropout: (float): Dropout probability for the filter. Defaults to 0.0
177
+ """
178
+ super().__init__()
179
+
180
+ self.d_model = config.d_model
181
+ self.l_max = config.max_seq_len
182
+ self.order = config.hyena_order
183
+ inner_width = config.d_model * (self.order + 1)
184
+ self.dropout = nn.Dropout(config.hyena_dropout)
185
+ self.in_proj = nn.Linear(self.d_model, inner_width)
186
+ self.out_proj = nn.Linear(self.d_model, self.d_model)
187
+
188
+ self.short_filter = nn.Conv1d(
189
+ inner_width,
190
+ inner_width,
191
+ config.short_filter_order,
192
+ padding=2,
193
+ groups=inner_width
194
+ )
195
+ self.filter_fn = HyenaFilter(config)
196
+
197
+ def forward(self, u):
198
+ l = u.size(-2)
199
+ l_filter = min(l, self.l_max)
200
+ u = self.in_proj(u).transpose(1, 2)
201
+
202
+ uc = self.short_filter(u)[...,:l_filter]
203
+ *x, v = uc.split(self.d_model, dim=1)
204
+
205
+ k = self.filter_fn.filter(l_filter)[0]
206
+ k = k.transpose(0, 1).reshape(self.order - 1, self.d_model, l_filter)
207
+ bias = self.filter_fn.bias.reshape(self.order - 1, self.d_model)
208
+
209
+ for o, x_i in enumerate(reversed(x[1:])):
210
+ v = self.dropout(v * x_i)
211
+ v = self.filter_fn(v, l_filter, k=k[o], bias=bias[o])
212
+
213
+ y = (v * x[0]).transpose(1, 2)
214
+
215
+ y = self.out_proj(y)
216
+ return y
217
+
218
+ class HyenaMlp(nn.Module):
219
+
220
+ def __init__(self, config):
221
+ """
222
+ From https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/modules/mlp.py
223
+ """
224
+ super().__init__()
225
+ in_features = config.d_model
226
+ hidden_features = config.d_inner
227
+ self.fc1 = nn.Linear(in_features, hidden_features)
228
+ self.fc2 = nn.Linear(hidden_features, config.d_model)
229
+
230
+ def forward(self, x):
231
+ y = self.fc1(x)
232
+ y = F.gelu(y, approximate="tanh")
233
+ y = self.fc2(y)
234
+ return y
235
+
236
+ class HyenaBlock(nn.Module):
237
+
238
+ def __init__(self, config):
239
+ """
240
+ From https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/modules/block.py
241
+ For prenorm=True, this Block has a slightly different structure compared to a regular
242
+ prenorm Transformer block.
243
+ The standard block is: LN -> MHA -> Dropout -> Add -> LN -> MLP -> Dropout -> Add.
244
+ [Ref: https://arxiv.org/abs/2002.04745]
245
+ Here we have: Dropout -> Add -> LN -> MHA -> Dropout -> Add -> LN -> MLP, returning both
246
+ the hidden_states (output of the MLP) and the residual.
247
+ This is for performance reasons, as we can fuse the dropout, add and LayerNorm.
248
+ The residual needs to be provided (except for the very first block).
249
+ For prenorm=False, this Block has the same structure as a regular postnorm Transformer
250
+ block: MHA -> Dropout -> Add -> LN -> MLP -> Dropout -> Add -> LN.
251
+ return_residual: whether each of the sub-layers (mixer and mlp) will return the residual.
252
+ This is for performance reason: for post-norm architecture, returning the input allows us
253
+ to fuse the backward of nn.Linear with the residual connection.
254
+ """
255
+ super().__init__()
256
+ self.mixer = HyenaOperator(config)
257
+ self.norm1 = nn.LayerNorm(config.d_model)
258
+ self.mlp = HyenaMlp(config)
259
+ self.norm2 = nn.LayerNorm(config.d_model)
260
+
261
+ def forward(self, hidden_states):
262
+ r"""Pass the input through the encoder layer.
263
+ Args:
264
+ hidden_states: the sequence to the encoder layer (required).
265
+ residual: if postnorm, residual=None, If prenorm, hidden_states = Attn/MLP(LN(residual))
266
+ mixer_subset: for cross-attention only. If not None, will take a subset of x
267
+ before applying the query projection. Useful for e.g., ViT where we only care
268
+ about the CLS token in the last layer.
269
+ """
270
+ residual = hidden_states
271
+ residual = residual.to(torch.float32)
272
+ hyena_normed = self.norm1(residual.to(dtype=self.norm1.weight.dtype))
273
+ hidden_states = self.mixer(hyena_normed)
274
+ # Tested above here and all is equivalent. That means the mixer is fine!!!
275
+ residual = hidden_states + residual
276
+ hidden_states = self.norm2(residual.to(dtype=self.norm2.weight.dtype))
277
+ residual = residual.to(torch.float32)
278
+
279
+ hidden_states = self.mlp(hidden_states)
280
+ return hidden_states + residual
281
+
282
+
283
+ # https://github.com/huggingface/transformers/blob/c28d04e9e252a1a099944e325685f14d242ecdcd/src/transformers/models/gpt2/modeling_gpt2.py#L454
284
+
285
+
286
+ class HyenaEmbeddings(nn.Module):
287
+
288
+ def __init__(self, config, padding_idx=None):
289
+ """
290
+ If max_position_embeddings <= 0, there's no position embeddings
291
+ If word_embe_proj_dim is not None (e.g., OPT-350m), we embed to that dimension
292
+ the project up to embed_dim
293
+ """
294
+ super().__init__()
295
+ vocab_size = config.vocab_size
296
+ if vocab_size % config.pad_vocab_size_multiple != 0:
297
+ vocab_size += config.pad_vocab_size_multiple - (vocab_size % config.pad_vocab_size_multiple)
298
+ self.word_embeddings = nn.Embedding(vocab_size, config.d_model, padding_idx=padding_idx)
299
+
300
+ def forward(self, input_ids):
301
+ """
302
+ input_ids: (batch, seqlen)
303
+ """
304
+ embeddings = self.word_embeddings(input_ids)
305
+ return embeddings
306
+
307
+ class HyenaLMBackbone(nn.Module):
308
+
309
+ def __init__(self, config) -> None:
310
+ super().__init__()
311
+ # note max_position_embeddings is 0 for Hyena, and therefore isn't used
312
+ self.embeddings = HyenaEmbeddings(config)
313
+ self.dropout = nn.Dropout(config.embed_dropout)
314
+
315
+ self.layers = nn.ModuleList([HyenaBlock(config) for i in range(config.n_layer)])
316
+
317
+ self.ln_f = nn.LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
318
+ self.gradient_checkpointing = False
319
+
320
+ def forward(self, input_ids, inputs_embeds=None, output_hidden_states=False):
321
+ all_hidden_states = []
322
+ if inputs_embeds is not None:
323
+ hidden_states = inputs_embeds
324
+ else:
325
+ hidden_states = self.embeddings(input_ids)
326
+ if output_hidden_states:
327
+ all_hidden_states.append(hidden_states)
328
+
329
+ for layer in self.layers:
330
+ if self.gradient_checkpointing and self.training:
331
+ hidden_states = self._gradient_checkpointing_func(layer.__call__, hidden_states)
332
+ else:
333
+ hidden_states = layer(hidden_states)
334
+ if output_hidden_states:
335
+ all_hidden_states.append(hidden_states)
336
+
337
+ hidden_states = self.ln_f(hidden_states.to(dtype=self.ln_f.weight.dtype))
338
+ if output_hidden_states:
339
+ all_hidden_states.append(hidden_states)
340
+
341
+ return hidden_states, all_hidden_states
342
+
343
+
344
+ class HyenaDNAPreTrainedModel(PreTrainedModel):
345
+ config_class = HyenaConfig
346
+ base_model_prefix = "hyena"
347
+ supports_gradient_checkpointing = True
348
+ _no_split_modules = ["HyenaBlock"]
349
+ _skip_keys_device_placement = "past_key_values"
350
+ _keys_to_ignore_on_load_missing = [r"freq"] # Shared tensors that safetensors merges
351
+
352
+ def _init_weights(self, module, initializer_range=0.02):
353
+ if isinstance(module, nn.Linear):
354
+ nn.init.normal_(module.weight, std=initializer_range)
355
+ if module.bias is not None:
356
+ nn.init.zeros_(module.bias)
357
+ elif isinstance(module, nn.Embedding):
358
+ nn.init.normal_(module.weight, std=initializer_range)
359
+ # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
360
+ # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
361
+ # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
362
+ # > -- GPT-2 :: https://openai.com/blog/better-language-models/
363
+ #
364
+ # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
365
+ for name, p in self.named_parameters():
366
+ if name in ["out_proj.weight", "fc2.weight"]:
367
+ # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
368
+ nn.init.normal_(p, mean=0.0, std=initializer_range / math.sqrt(2 * self.config.num_layers))
369
+ # If using GLU activation for now, we scale the std by 2
370
+ elif name in ["output_linear.0.weight"]:
371
+ # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
372
+ nn.init.normal_(p, mean=0.0, std=initializer_range / math.sqrt(2 * self.config.num_layers))
373
+
374
+
375
+ class HyenaDNAModel(HyenaDNAPreTrainedModel):
376
+ def __init__(self, config, **kwargs) -> None:
377
+ super().__init__(config, **kwargs)
378
+
379
+ self.backbone = HyenaLMBackbone(config)
380
+ self.config = config
381
+
382
+ # Initialize weights and apply final processing
383
+ self.post_init()
384
+
385
+ def forward(self, input_ids, inputs_embeds=None, output_hidden_states=None, return_dict=None):
386
+ output_hidden_states = (
387
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
388
+ )
389
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
390
+
391
+ hidden_states, all_hidden_states = self.backbone(input_ids, inputs_embeds=inputs_embeds, output_hidden_states=output_hidden_states)
392
+ if return_dict:
393
+ return BaseModelOutputWithNoAttention(last_hidden_state=hidden_states,
394
+ hidden_states=all_hidden_states if output_hidden_states else None)
395
+ elif output_hidden_states:
396
+ return hidden_states, all_hidden_states
397
+ else:
398
+ return hidden_states
399
+
400
+
401
+ class HyenaDNAForCausalLM(HyenaDNAPreTrainedModel):
402
+
403
+ def __init__(self, config, **kwargs):
404
+ super().__init__(config, **kwargs)
405
+ self.hyena = HyenaDNAModel(config)
406
+ vocab_size = config.vocab_size
407
+ if vocab_size % config.pad_vocab_size_multiple != 0:
408
+ vocab_size += config.pad_vocab_size_multiple - (vocab_size % config.pad_vocab_size_multiple)
409
+ self.vocab_size = vocab_size
410
+ self.lm_head = nn.Linear(config.d_model, vocab_size, bias=False)
411
+
412
+ # Initialize weights and apply final processing
413
+ self.post_init()
414
+
415
+ def get_input_embeddings(self):
416
+ return self.hyena.backbone.embeddings.word_embeddings
417
+
418
+ def set_input_embeddings(self, value):
419
+ self.hyena.backbone.embeddings.word_embeddings = value
420
+
421
+ def get_output_embeddings(self):
422
+ return self.lm_head
423
+
424
+ def set_output_embeddings(self, new_embeddings):
425
+ self.lm_head = new_embeddings
426
+
427
+ def set_decoder(self, decoder):
428
+ self.hyena = decoder
429
+
430
+ def get_decoder(self):
431
+ return self.hyena
432
+
433
+ def forward(
434
+ self,
435
+ input_ids: torch.LongTensor = None,
436
+ inputs_embeds: Optional[torch.FloatTensor] = None,
437
+ labels: Optional[torch.LongTensor] = None,
438
+ output_hidden_states: Optional[bool] = None,
439
+ return_dict: Optional[bool] = None,
440
+ ) -> Union[Tuple, CausalLMOutput]:
441
+
442
+ output_hidden_states = (
443
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
444
+ )
445
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
446
+
447
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
448
+ outputs = self.hyena(
449
+ input_ids=input_ids,
450
+ inputs_embeds=inputs_embeds,
451
+ output_hidden_states=output_hidden_states,
452
+ return_dict=return_dict,
453
+ )
454
+
455
+ hidden_states = outputs[0]
456
+ logits = self.lm_head(hidden_states)
457
+ logits = logits.float()
458
+
459
+ loss = None
460
+ if labels is not None:
461
+ # Shift so that tokens < n predict n
462
+ shift_logits = logits[..., :-1, :].contiguous()
463
+ shift_labels = labels[..., 1:].contiguous()
464
+ # Flatten the tokens
465
+ loss_fct = nn.CrossEntropyLoss()
466
+ shift_logits = shift_logits.view(-1, self.vocab_size)
467
+ shift_labels = shift_labels.view(-1)
468
+ # Enable model parallelism
469
+ shift_labels = shift_labels.to(shift_logits.device)
470
+ loss = loss_fct(shift_logits, shift_labels)
471
+
472
+ if not return_dict:
473
+ output = (logits,) + outputs[1:]
474
+ return (loss,) + output if loss is not None else output
475
+
476
+ return CausalLMOutput(
477
+ loss=loss,
478
+ logits=logits,
479
+ hidden_states=outputs.hidden_states,
480
+ )
481
+
482
+
483
+ class HyenaDNAForSequenceClassification(HyenaDNAPreTrainedModel):
484
+ def __init__(self, config, **kwargs):
485
+ super().__init__(config, **kwargs)
486
+ self.num_labels = kwargs.get("num_labels", config.num_labels)
487
+ self.hyena = HyenaDNAModel(config)
488
+ self.score = nn.Linear(config.d_model, self.num_labels, bias=False)
489
+
490
+ # Initialize weights and apply final processing
491
+ self.post_init()
492
+
493
+ def get_input_embeddings(self):
494
+ return self.hyena.backbone.embeddings.word_embeddings
495
+
496
+ def set_input_embeddings(self, value):
497
+ self.hyena.backbone.embeddings.word_embeddings = value
498
+
499
+ def forward(
500
+ self,
501
+ input_ids: torch.LongTensor = None,
502
+ inputs_embeds: Optional[torch.FloatTensor] = None,
503
+ labels: Optional[torch.LongTensor] = None,
504
+ output_hidden_states: Optional[bool] = None,
505
+ return_dict: Optional[bool] = None,
506
+ ) -> Union[Tuple, SequenceClassifierOutput]:
507
+ r"""
508
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
509
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
510
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
511
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
512
+ """
513
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
514
+
515
+ transformer_outputs = self.hyena(
516
+ input_ids,
517
+ inputs_embeds=inputs_embeds,
518
+ output_hidden_states=output_hidden_states,
519
+ return_dict=return_dict,
520
+ )
521
+ hidden_states = transformer_outputs[0]
522
+ logits = self.score(hidden_states)
523
+
524
+ if input_ids is not None:
525
+ batch_size = input_ids.shape[0]
526
+ else:
527
+ batch_size = inputs_embeds.shape[0]
528
+
529
+ if self.config.pad_token_id is None and batch_size != 1:
530
+ raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
531
+ if self.config.pad_token_id is None:
532
+ sequence_lengths = -1
533
+ else:
534
+ if input_ids is not None:
535
+ sequence_lengths = (torch.eq(input_ids, self.config.pad_token_id).long().argmax(-1) - 1).to(
536
+ logits.device
537
+ )
538
+ else:
539
+ sequence_lengths = -1
540
+
541
+ pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
542
+
543
+ loss = None
544
+ if labels is not None:
545
+ labels = labels.to(logits.device)
546
+ if self.config.problem_type is None:
547
+ if self.num_labels == 1:
548
+ self.config.problem_type = "regression"
549
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
550
+ self.config.problem_type = "single_label_classification"
551
+ else:
552
+ self.config.problem_type = "multi_label_classification"
553
+
554
+ if self.config.problem_type == "regression":
555
+ loss_fct = nn.MSELoss()
556
+ if self.num_labels == 1:
557
+ loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
558
+ else:
559
+ loss = loss_fct(pooled_logits, labels)
560
+ elif self.config.problem_type == "single_label_classification":
561
+ loss_fct = nn.CrossEntropyLoss()
562
+ loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
563
+ elif self.config.problem_type == "multi_label_classification":
564
+ loss_fct = nn.BCEWithLogitsLoss()
565
+ loss = loss_fct(pooled_logits, labels)
566
+ if not return_dict:
567
+ output = (pooled_logits,) + transformer_outputs[1:]
568
+ return ((loss,) + output) if loss is not None else output
569
+
570
+ return SequenceClassifierOutput(
571
+ loss=loss,
572
+ logits=pooled_logits,
573
+ hidden_states=transformer_outputs.hidden_states,
574
+ )
checkpoint-800/optimizer.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8763978a2444824f5a38c3580aba497e275fdda9db740c2f56f67645d5be8636
3
+ size 26304517
checkpoint-800/pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:299e2c7579e7fa27e18571ea2c3b4590da8d13bf1e459e0bb5e600e6d1482acd
3
+ size 16300157
checkpoint-800/rng_state.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9ae46c9fc2e6261b72f159e1aea97f31830fdd07cc60689547b223b43e934178
3
+ size 14575
checkpoint-800/scaler.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:639595adfe39430588607a0328b8b87bd7572031124a42ca29c480181cdf81a1
3
+ size 557
checkpoint-800/scheduler.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:dac5d5b84a5c940b5d42dcc0f70b744b0e74ac1348849e3191a87cec9e5c4661
3
+ size 627
checkpoint-800/special_tokens_map.json ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": {
3
+ "content": "[BOS]",
4
+ "lstrip": false,
5
+ "normalized": false,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "cls_token": {
10
+ "content": "[CLS]",
11
+ "lstrip": false,
12
+ "normalized": false,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "eos_token": {
17
+ "content": "[SEP]",
18
+ "lstrip": false,
19
+ "normalized": false,
20
+ "rstrip": false,
21
+ "single_word": false
22
+ },
23
+ "mask_token": {
24
+ "content": "[MASK]",
25
+ "lstrip": false,
26
+ "normalized": false,
27
+ "rstrip": false,
28
+ "single_word": false
29
+ },
30
+ "pad_token": {
31
+ "content": "[PAD]",
32
+ "lstrip": false,
33
+ "normalized": false,
34
+ "rstrip": false,
35
+ "single_word": false
36
+ },
37
+ "sep_token": {
38
+ "content": "[SEP]",
39
+ "lstrip": false,
40
+ "normalized": false,
41
+ "rstrip": false,
42
+ "single_word": false
43
+ },
44
+ "unk_token": {
45
+ "content": "[UNK]",
46
+ "lstrip": false,
47
+ "normalized": false,
48
+ "rstrip": false,
49
+ "single_word": false
50
+ }
51
+ }
checkpoint-800/tokenization_hyena.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PreTrainedTokenizer, AddedToken
2
+ from typing import List, Optional, Union, Dict, Sequence, Tuple
3
+ from pathlib import Path
4
+ import json
5
+ import os
6
+
7
+
8
+ class HyenaDNATokenizer(PreTrainedTokenizer):
9
+ model_input_names = ["input_ids"]
10
+
11
+ def __init__(self,
12
+ model_max_length: int,
13
+ bos_token="[BOS]",
14
+ eos_token="[SEP]",
15
+ sep_token="[SEP]",
16
+ cls_token="[CLS]",
17
+ pad_token="[PAD]",
18
+ mask_token="[MASK]",
19
+ unk_token="[UNK]",
20
+ **kwargs):
21
+ """Character tokenizer for Hugging Face transformers.
22
+ Args:
23
+ characters (Sequence[str]): List of desired characters. Any character which
24
+ is not included in this list will be replaced by a special token called
25
+ [UNK] with id=6. Following are list of all of the special tokens with
26
+ their corresponding ids:
27
+ "[CLS]": 0
28
+ "[SEP]": 1
29
+ "[BOS]": 2
30
+ "[MASK]": 3
31
+ "[PAD]": 4
32
+ "[RESERVED]": 5
33
+ "[UNK]": 6
34
+ an id (starting at 7) will be assigned to each character.
35
+ model_max_length (int): Model maximum sequence length.
36
+ """
37
+ self.characters = ('A', 'C', 'G', 'T', 'N')
38
+ self.model_max_length = model_max_length
39
+
40
+ self._vocab_str_to_int = {
41
+ "[CLS]": 0,
42
+ "[SEP]": 1,
43
+ "[BOS]": 2,
44
+ "[MASK]": 3,
45
+ "[PAD]": 4,
46
+ "[RESERVED]": 5,
47
+ "[UNK]": 6,
48
+ **{ch: i + 7 for i, ch in enumerate(self.characters)},
49
+ }
50
+ self._vocab_int_to_str = {v: k for k, v in self._vocab_str_to_int.items()}
51
+ add_prefix_space = kwargs.pop("add_prefix_space", False)
52
+ padding_side = kwargs.pop("padding_side", "left")
53
+
54
+ super().__init__(
55
+ bos_token=bos_token,
56
+ eos_token=eos_token,
57
+ sep_token=sep_token,
58
+ cls_token=cls_token,
59
+ pad_token=pad_token,
60
+ mask_token=mask_token,
61
+ unk_token=unk_token,
62
+ add_prefix_space=add_prefix_space,
63
+ model_max_length=model_max_length,
64
+ padding_side=padding_side,
65
+ **kwargs,
66
+ )
67
+
68
+ @property
69
+ def vocab_size(self) -> int:
70
+ return len(self._vocab_str_to_int)
71
+
72
+ def _tokenize(self, text: str) -> List[str]:
73
+ return list(text)
74
+
75
+ def _convert_token_to_id(self, token: str) -> int:
76
+ return self._vocab_str_to_int.get(token, self._vocab_str_to_int["[UNK]"])
77
+
78
+ def _convert_id_to_token(self, index: int) -> str:
79
+ return self._vocab_int_to_str[index]
80
+
81
+ def convert_tokens_to_string(self, tokens):
82
+ return "".join(tokens)
83
+
84
+ def get_special_tokens_mask(
85
+ self,
86
+ token_ids_0: List[int],
87
+ token_ids_1: Optional[List[int]] = None,
88
+ already_has_special_tokens: bool = False,
89
+ ) -> List[int]:
90
+ if already_has_special_tokens:
91
+ return super().get_special_tokens_mask(
92
+ token_ids_0=token_ids_0,
93
+ token_ids_1=token_ids_1,
94
+ already_has_special_tokens=True,
95
+ )
96
+
97
+ result = ([0] * len(token_ids_0)) + [1]
98
+ if token_ids_1 is not None:
99
+ result += ([0] * len(token_ids_1)) + [1]
100
+ return result
101
+
102
+ def build_inputs_with_special_tokens(
103
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
104
+ ) -> List[int]:
105
+ sep = [self.sep_token_id]
106
+ # cls = [self.cls_token_id]
107
+ result = token_ids_0 + sep
108
+ if token_ids_1 is not None:
109
+ result += token_ids_1 + sep
110
+ return result
111
+
112
+ def get_vocab(self) -> Dict[str, int]:
113
+ return self._vocab_str_to_int
114
+
115
+ # HyenaDNA has a fixed vocabulary with no vocab file
116
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple:
117
+ return ()
checkpoint-800/tokenizer_config.json ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_prefix_space": false,
3
+ "added_tokens_decoder": {
4
+ "0": {
5
+ "content": "[CLS]",
6
+ "lstrip": false,
7
+ "normalized": false,
8
+ "rstrip": false,
9
+ "single_word": false,
10
+ "special": true
11
+ },
12
+ "1": {
13
+ "content": "[SEP]",
14
+ "lstrip": false,
15
+ "normalized": false,
16
+ "rstrip": false,
17
+ "single_word": false,
18
+ "special": true
19
+ },
20
+ "2": {
21
+ "content": "[BOS]",
22
+ "lstrip": false,
23
+ "normalized": false,
24
+ "rstrip": false,
25
+ "single_word": false,
26
+ "special": true
27
+ },
28
+ "3": {
29
+ "content": "[MASK]",
30
+ "lstrip": false,
31
+ "normalized": false,
32
+ "rstrip": false,
33
+ "single_word": false,
34
+ "special": true
35
+ },
36
+ "4": {
37
+ "content": "[PAD]",
38
+ "lstrip": false,
39
+ "normalized": false,
40
+ "rstrip": false,
41
+ "single_word": false,
42
+ "special": true
43
+ },
44
+ "6": {
45
+ "content": "[UNK]",
46
+ "lstrip": false,
47
+ "normalized": false,
48
+ "rstrip": false,
49
+ "single_word": false,
50
+ "special": true
51
+ }
52
+ },
53
+ "auto_map": {
54
+ "AutoTokenizer": [
55
+ "tokenization_hyena.HyenaDNATokenizer",
56
+ null
57
+ ]
58
+ },
59
+ "bos_token": "[BOS]",
60
+ "clean_up_tokenization_spaces": true,
61
+ "cls_token": "[CLS]",
62
+ "eos_token": "[SEP]",
63
+ "mask_token": "[MASK]",
64
+ "model_max_length": 256,
65
+ "name_or_path": "LongSafari/hyenadna-small-32k-seqlen-hf",
66
+ "pad_token": "[PAD]",
67
+ "padding_side": "right",
68
+ "sep_token": "[SEP]",
69
+ "special_tokens_map_file": "/home/hlv8980/.cache/huggingface/hub/models--LongSafari--hyenadna-small-32k-seqlen-hf/snapshots/8fe770c78eb13fe33bf81501612faeddf4d6f331/special_tokens_map.json",
70
+ "tokenizer_class": "HyenaDNATokenizer",
71
+ "unk_token": "[UNK]"
72
+ }
checkpoint-800/trainer_state.json ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "best_metric": 0.39216598868370056,
3
+ "best_model_checkpoint": "/scratch/hlv8980/Attack_Benchmark/models/hyena/tf4/origin/checkpoint-600",
4
+ "epoch": 2.6936026936026938,
5
+ "global_step": 800,
6
+ "is_hyper_param_search": false,
7
+ "is_local_process_zero": true,
8
+ "is_world_process_zero": true,
9
+ "log_history": [
10
+ {
11
+ "epoch": 0.34,
12
+ "learning_rate": 2.8760984182776802e-05,
13
+ "loss": 0.5992,
14
+ "step": 100
15
+ },
16
+ {
17
+ "epoch": 0.67,
18
+ "learning_rate": 2.615114235500879e-05,
19
+ "loss": 0.4813,
20
+ "step": 200
21
+ },
22
+ {
23
+ "epoch": 0.67,
24
+ "eval_accuracy": 0.774,
25
+ "eval_f1": 0.7713328260834371,
26
+ "eval_loss": 0.48207539319992065,
27
+ "eval_matthews_correlation": 0.5579338694412199,
28
+ "eval_precision": 0.785067107786007,
29
+ "eval_recall": 0.772997299729973,
30
+ "eval_runtime": 0.1057,
31
+ "eval_samples_per_second": 9462.679,
32
+ "eval_steps_per_second": 151.403,
33
+ "step": 200
34
+ },
35
+ {
36
+ "epoch": 1.01,
37
+ "learning_rate": 2.3514938488576452e-05,
38
+ "loss": 0.4431,
39
+ "step": 300
40
+ },
41
+ {
42
+ "epoch": 1.35,
43
+ "learning_rate": 2.087873462214411e-05,
44
+ "loss": 0.377,
45
+ "step": 400
46
+ },
47
+ {
48
+ "epoch": 1.35,
49
+ "eval_accuracy": 0.816,
50
+ "eval_f1": 0.8159933757615274,
51
+ "eval_loss": 0.427643358707428,
52
+ "eval_matthews_correlation": 0.6320128653971173,
53
+ "eval_precision": 0.815991263965056,
54
+ "eval_recall": 0.816021602160216,
55
+ "eval_runtime": 0.1039,
56
+ "eval_samples_per_second": 9625.637,
57
+ "eval_steps_per_second": 154.01,
58
+ "step": 400
59
+ },
60
+ {
61
+ "epoch": 1.68,
62
+ "learning_rate": 1.82688927943761e-05,
63
+ "loss": 0.3443,
64
+ "step": 500
65
+ },
66
+ {
67
+ "epoch": 2.02,
68
+ "learning_rate": 1.563268892794376e-05,
69
+ "loss": 0.33,
70
+ "step": 600
71
+ },
72
+ {
73
+ "epoch": 2.02,
74
+ "eval_accuracy": 0.824,
75
+ "eval_f1": 0.8239746523499383,
76
+ "eval_loss": 0.39216598868370056,
77
+ "eval_matthews_correlation": 0.6479558982194922,
78
+ "eval_precision": 0.8239935027265344,
79
+ "eval_recall": 0.8239623962396239,
80
+ "eval_runtime": 0.1031,
81
+ "eval_samples_per_second": 9696.512,
82
+ "eval_steps_per_second": 155.144,
83
+ "step": 600
84
+ },
85
+ {
86
+ "epoch": 2.36,
87
+ "learning_rate": 1.2996485061511423e-05,
88
+ "loss": 0.227,
89
+ "step": 700
90
+ },
91
+ {
92
+ "epoch": 2.69,
93
+ "learning_rate": 1.0360281195079087e-05,
94
+ "loss": 0.2219,
95
+ "step": 800
96
+ },
97
+ {
98
+ "epoch": 2.69,
99
+ "eval_accuracy": 0.838,
100
+ "eval_f1": 0.8379766686402841,
101
+ "eval_loss": 0.4026987850666046,
102
+ "eval_matthews_correlation": 0.6767651028795362,
103
+ "eval_precision": 0.8385613769517563,
104
+ "eval_recall": 0.8382038203820381,
105
+ "eval_runtime": 0.1026,
106
+ "eval_samples_per_second": 9746.534,
107
+ "eval_steps_per_second": 155.945,
108
+ "step": 800
109
+ }
110
+ ],
111
+ "max_steps": 1188,
112
+ "num_train_epochs": 4,
113
+ "total_flos": 102556265398272.0,
114
+ "trial_name": null,
115
+ "trial_params": null
116
+ }
checkpoint-800/training_args.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1f4e3a85efd6ca1a2228fc0bf6f5ca43150a2981352327acbf61aa5be7e43d49
3
+ size 3707
config.json ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "LongSafari/hyenadna-small-32k-seqlen-hf",
3
+ "activation_freq": 10,
4
+ "architectures": [
5
+ "HyenaDNAForSequenceClassification"
6
+ ],
7
+ "auto_map": {
8
+ "AutoConfig": "configuration_hyena.HyenaConfig",
9
+ "AutoModel": "modeling_hyena.HyenaDNAModel",
10
+ "AutoModelForCausalLM": "modeling_hyena.HyenaDNAForCausalLM",
11
+ "AutoModelForSequenceClassification": "modeling_hyena.HyenaDNAForSequenceClassification"
12
+ },
13
+ "d_inner": 1024,
14
+ "d_model": 256,
15
+ "emb_dim": 5,
16
+ "embed_dropout": 0.1,
17
+ "filter_order": 64,
18
+ "hyena_dropout": 0.0,
19
+ "hyena_filter_dropout": 0.0,
20
+ "hyena_order": 2,
21
+ "initializer_range": 0.02,
22
+ "layer_norm_epsilon": 1e-05,
23
+ "max_seq_len": 32770,
24
+ "model_type": "hyenadna",
25
+ "n_layer": 4,
26
+ "num_inner_mlps": 2,
27
+ "pad_token_id": 4,
28
+ "pad_vocab_size_multiple": 8,
29
+ "problem_type": "single_label_classification",
30
+ "short_filter_order": 3,
31
+ "tie_word_embeddings": false,
32
+ "torch_dtype": "float32",
33
+ "train_freq": true,
34
+ "transformers_version": "4.26.1",
35
+ "use_bias": true,
36
+ "vocab_size": 12
37
+ }
configuration_hyena.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+ import json
3
+
4
+
5
+ class HyenaConfig(PretrainedConfig):
6
+ model_type = "hyenadna"
7
+ def __init__(
8
+ self,
9
+ vocab_size=12,
10
+ d_model=256,
11
+ d_inner=None,
12
+ use_bias=True,
13
+ train_freq=True,
14
+ max_seq_len=1024,
15
+ emb_dim=3,
16
+ n_layer=12,
17
+ num_inner_mlps=2,
18
+ hyena_order=2,
19
+ short_filter_order=3,
20
+ filter_order=64,
21
+ activation_freq=1,
22
+ embed_dropout=0.1,
23
+ hyena_dropout=0.0,
24
+ hyena_filter_dropout=0.0,
25
+ layer_norm_epsilon=1e-5,
26
+ initializer_range=0.02,
27
+ pad_vocab_size_multiple=8,
28
+ **kwargs,
29
+ ):
30
+ self.vocab_size = vocab_size
31
+ self.d_model = d_model
32
+ if d_inner is None:
33
+ self.d_inner = 4 * d_model
34
+ else:
35
+ self.d_inner = d_inner
36
+ self.use_bias = use_bias
37
+ self.train_freq = train_freq
38
+ self.max_seq_len = max_seq_len
39
+ self.emb_dim = emb_dim
40
+ self.n_layer = n_layer
41
+ self.hyena_order = hyena_order
42
+ self.filter_order = filter_order
43
+ self.short_filter_order = short_filter_order
44
+ self.activation_freq = activation_freq
45
+ self.num_inner_mlps = num_inner_mlps
46
+ self.embed_dropout = embed_dropout
47
+ self.hyena_dropout = hyena_dropout
48
+ self.hyena_filter_dropout = hyena_filter_dropout
49
+ self.layer_norm_epsilon = layer_norm_epsilon
50
+ self.initializer_range = initializer_range
51
+ self.pad_vocab_size_multiple = pad_vocab_size_multiple
52
+ super().__init__(**kwargs)
53
+
54
+ @classmethod
55
+ def from_original_config(cls, config_path, **kwargs):
56
+ with open(config_path, "r") as f:
57
+ config = json.load(f)
58
+
59
+ vocab_size = config["vocab_size"]
60
+ d_model = config["d_model"]
61
+ d_inner = config["d_inner"]
62
+ max_seq_len = config["layer"]["l_max"]
63
+ emb_dim = config["layer"]["emb_dim"]
64
+ filter_order = config["layer"]["filter_order"]
65
+ if "local_order" in config["layer"]:
66
+ short_filter_order = config["layer"]["local_order"]
67
+ elif "short_filter_order" in config["layer"]:
68
+ short_filter_order = config["layer"]["short_filter_order"]
69
+ else:
70
+ short_filter_order = 3
71
+ n_layer = config["n_layer"]
72
+ activation_freq = config["layer"]["w"]
73
+ embed_dropout = config["embed_dropout"]
74
+ pad_vocab_size_multiple = config["pad_vocab_size_multiple"]
75
+ return cls(vocab_size=vocab_size,
76
+ d_model=d_model,
77
+ d_inner=d_inner,
78
+ max_seq_len=max_seq_len,
79
+ emb_dim=emb_dim,
80
+ filter_order=filter_order,
81
+ short_filter_order=short_filter_order,
82
+ n_layer=n_layer,
83
+ activation_freq=activation_freq,
84
+ embed_dropout=embed_dropout,
85
+ pad_vocab_size_multiple=pad_vocab_size_multiple,
86
+ tie_word_embeddings=False,
87
+ **kwargs
88
+ )
modeling_hyena.py ADDED
@@ -0,0 +1,574 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """HyenaDNA custom code port to Hugging Face Hub"""
3
+
4
+ import math
5
+ import torch
6
+ import torch.nn as nn
7
+ from torch.nn import functional as F
8
+ from .configuration_hyena import HyenaConfig
9
+ from transformers import PreTrainedModel
10
+ from typing import Optional, Tuple, Union
11
+ from transformers.modeling_outputs import CausalLMOutput, SequenceClassifierOutput, BaseModelOutputWithNoAttention
12
+
13
+
14
+ def fftconv(u, k, D):
15
+ """
16
+ We apply a convolution through the fourier domain (from the Convolution Theorem)
17
+
18
+ """
19
+ seqlen = u.shape[-1]
20
+ fft_size = 2 * seqlen
21
+
22
+ k_f = torch.fft.rfft(k.to(torch.float32), n=fft_size) / fft_size
23
+ u_f = torch.fft.rfft(u.to(dtype=torch.float32), n=fft_size)
24
+
25
+ if len(u.shape) > 3: k_f = k_f.unsqueeze(1)
26
+ y = torch.fft.irfft(u_f * k_f, n=fft_size, norm='forward')[..., :seqlen]
27
+
28
+ out = y + u * D.unsqueeze(-1)
29
+ return out.to(dtype=u.dtype)
30
+
31
+
32
+ @torch.jit.script
33
+ def mul_sum(q, y):
34
+ return (q * y).sum(dim=1)
35
+
36
+
37
+ class HyenaSin(nn.Module):
38
+ """The Sin activation function for the Hyena Filter function."""
39
+ def __init__(self, config):
40
+ super().__init__()
41
+ self.freq = nn.Parameter(config.activation_freq * torch.ones(1, config.filter_order)) if config.train_freq else config.activation_freq * torch.ones(1, config.filter_order)
42
+
43
+ def forward(self, x):
44
+ return torch.sin(self.freq * x)
45
+
46
+
47
+ class HyenaPositionalEmbedding(nn.Module):
48
+ def __init__(self, config):
49
+ """Complex exponential positional embeddings for Hyena filters."""
50
+ super().__init__()
51
+
52
+ self.seq_len = config.max_seq_len
53
+ # The time embedding fed to the filteres is normalized so that t_f = 1
54
+ t = torch.linspace(0, 1, self.seq_len)[None, :, None] # 1, L, 1
55
+
56
+ if config.emb_dim > 1:
57
+ bands = (config.emb_dim - 1) // 2
58
+ # To compute the right embeddings we use the "proper" linspace
59
+ t_rescaled = torch.linspace(0, self.seq_len - 1, self.seq_len)[None, :, None]
60
+ w = 2 * math.pi * t_rescaled / self.seq_len # 1, L, 1
61
+
62
+ f = torch.linspace(1e-4, bands - 1, bands)[None, None]
63
+
64
+ z = torch.cat([t, torch.cos(-f * w), torch.sin(-f * w)], dim=-1)
65
+
66
+ self.register_buffer("z", z)
67
+ self.register_buffer("t", t)
68
+
69
+ def forward(self, L):
70
+ return self.z[:, :L], self.t[:, :L]
71
+
72
+
73
+ class HyenaExponentialModulation(nn.Module):
74
+ """The window function applied to the output of the (MLP) filter function."""
75
+ def __init__(
76
+ self,
77
+ d_model,
78
+ fast_decay_pct=0.3,
79
+ slow_decay_pct=1.5,
80
+ target=1e-2,
81
+ modulate: bool=True,
82
+ shift: float = 0.05,
83
+ **kwargs
84
+ ):
85
+ super().__init__()
86
+ self.modulate = modulate
87
+ self.shift = shift
88
+ max_decay = math.log(target) / fast_decay_pct
89
+ min_decay = math.log(target) / slow_decay_pct
90
+ deltas = torch.linspace(min_decay, max_decay, d_model)[None, None]
91
+ self.register_buffer("deltas", deltas)
92
+
93
+ def forward(self, t, x):
94
+ if self.modulate:
95
+ decay = torch.exp(-t * self.deltas.abs())
96
+ x = x * (decay + self.shift)
97
+ return x
98
+
99
+
100
+ class HyenaFilter(nn.Module):
101
+ def __init__(
102
+ self,
103
+ config,
104
+ **kwargs
105
+ ):
106
+ """
107
+ Implicit long filter with modulation.
108
+
109
+ Args:
110
+ d_model: number of channels in the input
111
+ emb_dim: dimension of the positional encoding (`emb_dim` - 1) // 2 is the number of bands
112
+ order: width of the FFN
113
+ num_inner_mlps: number of inner linear layers inside filter MLP
114
+
115
+ Note:
116
+ filter_dropout is not implemented
117
+ """
118
+ super().__init__()
119
+
120
+ self.d_model = config.d_model * (config.hyena_order - 1)
121
+ self.use_bias = config.use_bias
122
+ self.bias = nn.Parameter(torch.randn(self.d_model))
123
+ self.dropout = nn.Dropout(config.hyena_filter_dropout)
124
+
125
+ act = HyenaSin(config)
126
+ self.emb_dim = config.emb_dim
127
+ assert self.emb_dim % 2 != 0 and self.emb_dim >= 3, "emb_dim must be odd and greater or equal to 3 (time, sine and cosine)"
128
+ self.seq_len = config.max_seq_len
129
+
130
+ self.pos_emb = HyenaPositionalEmbedding(config)
131
+
132
+ self.implicit_filter = nn.Sequential(
133
+ nn.Linear(self.emb_dim, config.filter_order),
134
+ act,
135
+ )
136
+ for i in range(config.num_inner_mlps):
137
+ self.implicit_filter.append(nn.Linear(config.filter_order, config.filter_order))
138
+ self.implicit_filter.append(act)
139
+
140
+ self.implicit_filter.append(nn.Linear(config.filter_order, config.d_model, bias=False))
141
+
142
+ self.modulation = HyenaExponentialModulation(config.d_model)
143
+
144
+ self.normalized = False
145
+
146
+ def filter(self, L, *args, **kwargs):
147
+ z, t = self.pos_emb(L)
148
+ h = self.implicit_filter(z.to(dtype=self.implicit_filter[0].weight.dtype))
149
+ h = self.modulation(t, h)
150
+ return h
151
+
152
+ def forward(self, x, L, k=None, bias=None, *args, **kwargs):
153
+ if k is None: k = self.filter(L)
154
+
155
+ # Ensure compatibility with filters that return a tuple
156
+ k = k[0] if type(k) is tuple else k
157
+
158
+ y = fftconv(x, k, bias)
159
+ return y
160
+
161
+
162
+ class HyenaOperator(nn.Module):
163
+ def __init__(
164
+ self,
165
+ config,
166
+ **filter_args,
167
+ ):
168
+ r"""
169
+ Hyena operator described in the paper https://arxiv.org/pdf/2302.10866.pdf
170
+
171
+ Args:
172
+ d_model (int): Dimension of the input and output embeddings (width of the layer)
173
+ l_max: (int): Maximum input sequence length. Defaults to None
174
+ order: (int): Depth of the Hyena recurrence. Defaults to 2
175
+ dropout: (float): Dropout probability. Defaults to 0.0
176
+ filter_dropout: (float): Dropout probability for the filter. Defaults to 0.0
177
+ """
178
+ super().__init__()
179
+
180
+ self.d_model = config.d_model
181
+ self.l_max = config.max_seq_len
182
+ self.order = config.hyena_order
183
+ inner_width = config.d_model * (self.order + 1)
184
+ self.dropout = nn.Dropout(config.hyena_dropout)
185
+ self.in_proj = nn.Linear(self.d_model, inner_width)
186
+ self.out_proj = nn.Linear(self.d_model, self.d_model)
187
+
188
+ self.short_filter = nn.Conv1d(
189
+ inner_width,
190
+ inner_width,
191
+ config.short_filter_order,
192
+ padding=2,
193
+ groups=inner_width
194
+ )
195
+ self.filter_fn = HyenaFilter(config)
196
+
197
+ def forward(self, u):
198
+ l = u.size(-2)
199
+ l_filter = min(l, self.l_max)
200
+ u = self.in_proj(u).transpose(1, 2)
201
+
202
+ uc = self.short_filter(u)[...,:l_filter]
203
+ *x, v = uc.split(self.d_model, dim=1)
204
+
205
+ k = self.filter_fn.filter(l_filter)[0]
206
+ k = k.transpose(0, 1).reshape(self.order - 1, self.d_model, l_filter)
207
+ bias = self.filter_fn.bias.reshape(self.order - 1, self.d_model)
208
+
209
+ for o, x_i in enumerate(reversed(x[1:])):
210
+ v = self.dropout(v * x_i)
211
+ v = self.filter_fn(v, l_filter, k=k[o], bias=bias[o])
212
+
213
+ y = (v * x[0]).transpose(1, 2)
214
+
215
+ y = self.out_proj(y)
216
+ return y
217
+
218
+ class HyenaMlp(nn.Module):
219
+
220
+ def __init__(self, config):
221
+ """
222
+ From https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/modules/mlp.py
223
+ """
224
+ super().__init__()
225
+ in_features = config.d_model
226
+ hidden_features = config.d_inner
227
+ self.fc1 = nn.Linear(in_features, hidden_features)
228
+ self.fc2 = nn.Linear(hidden_features, config.d_model)
229
+
230
+ def forward(self, x):
231
+ y = self.fc1(x)
232
+ y = F.gelu(y, approximate="tanh")
233
+ y = self.fc2(y)
234
+ return y
235
+
236
+ class HyenaBlock(nn.Module):
237
+
238
+ def __init__(self, config):
239
+ """
240
+ From https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/modules/block.py
241
+ For prenorm=True, this Block has a slightly different structure compared to a regular
242
+ prenorm Transformer block.
243
+ The standard block is: LN -> MHA -> Dropout -> Add -> LN -> MLP -> Dropout -> Add.
244
+ [Ref: https://arxiv.org/abs/2002.04745]
245
+ Here we have: Dropout -> Add -> LN -> MHA -> Dropout -> Add -> LN -> MLP, returning both
246
+ the hidden_states (output of the MLP) and the residual.
247
+ This is for performance reasons, as we can fuse the dropout, add and LayerNorm.
248
+ The residual needs to be provided (except for the very first block).
249
+ For prenorm=False, this Block has the same structure as a regular postnorm Transformer
250
+ block: MHA -> Dropout -> Add -> LN -> MLP -> Dropout -> Add -> LN.
251
+ return_residual: whether each of the sub-layers (mixer and mlp) will return the residual.
252
+ This is for performance reason: for post-norm architecture, returning the input allows us
253
+ to fuse the backward of nn.Linear with the residual connection.
254
+ """
255
+ super().__init__()
256
+ self.mixer = HyenaOperator(config)
257
+ self.norm1 = nn.LayerNorm(config.d_model)
258
+ self.mlp = HyenaMlp(config)
259
+ self.norm2 = nn.LayerNorm(config.d_model)
260
+
261
+ def forward(self, hidden_states):
262
+ r"""Pass the input through the encoder layer.
263
+ Args:
264
+ hidden_states: the sequence to the encoder layer (required).
265
+ residual: if postnorm, residual=None, If prenorm, hidden_states = Attn/MLP(LN(residual))
266
+ mixer_subset: for cross-attention only. If not None, will take a subset of x
267
+ before applying the query projection. Useful for e.g., ViT where we only care
268
+ about the CLS token in the last layer.
269
+ """
270
+ residual = hidden_states
271
+ residual = residual.to(torch.float32)
272
+ hyena_normed = self.norm1(residual.to(dtype=self.norm1.weight.dtype))
273
+ hidden_states = self.mixer(hyena_normed)
274
+ # Tested above here and all is equivalent. That means the mixer is fine!!!
275
+ residual = hidden_states + residual
276
+ hidden_states = self.norm2(residual.to(dtype=self.norm2.weight.dtype))
277
+ residual = residual.to(torch.float32)
278
+
279
+ hidden_states = self.mlp(hidden_states)
280
+ return hidden_states + residual
281
+
282
+
283
+ # https://github.com/huggingface/transformers/blob/c28d04e9e252a1a099944e325685f14d242ecdcd/src/transformers/models/gpt2/modeling_gpt2.py#L454
284
+
285
+
286
+ class HyenaEmbeddings(nn.Module):
287
+
288
+ def __init__(self, config, padding_idx=None):
289
+ """
290
+ If max_position_embeddings <= 0, there's no position embeddings
291
+ If word_embe_proj_dim is not None (e.g., OPT-350m), we embed to that dimension
292
+ the project up to embed_dim
293
+ """
294
+ super().__init__()
295
+ vocab_size = config.vocab_size
296
+ if vocab_size % config.pad_vocab_size_multiple != 0:
297
+ vocab_size += config.pad_vocab_size_multiple - (vocab_size % config.pad_vocab_size_multiple)
298
+ self.word_embeddings = nn.Embedding(vocab_size, config.d_model, padding_idx=padding_idx)
299
+
300
+ def forward(self, input_ids):
301
+ """
302
+ input_ids: (batch, seqlen)
303
+ """
304
+ embeddings = self.word_embeddings(input_ids)
305
+ return embeddings
306
+
307
+ class HyenaLMBackbone(nn.Module):
308
+
309
+ def __init__(self, config) -> None:
310
+ super().__init__()
311
+ # note max_position_embeddings is 0 for Hyena, and therefore isn't used
312
+ self.embeddings = HyenaEmbeddings(config)
313
+ self.dropout = nn.Dropout(config.embed_dropout)
314
+
315
+ self.layers = nn.ModuleList([HyenaBlock(config) for i in range(config.n_layer)])
316
+
317
+ self.ln_f = nn.LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
318
+ self.gradient_checkpointing = False
319
+
320
+ def forward(self, input_ids, inputs_embeds=None, output_hidden_states=False):
321
+ all_hidden_states = []
322
+ if inputs_embeds is not None:
323
+ hidden_states = inputs_embeds
324
+ else:
325
+ hidden_states = self.embeddings(input_ids)
326
+ if output_hidden_states:
327
+ all_hidden_states.append(hidden_states)
328
+
329
+ for layer in self.layers:
330
+ if self.gradient_checkpointing and self.training:
331
+ hidden_states = self._gradient_checkpointing_func(layer.__call__, hidden_states)
332
+ else:
333
+ hidden_states = layer(hidden_states)
334
+ if output_hidden_states:
335
+ all_hidden_states.append(hidden_states)
336
+
337
+ hidden_states = self.ln_f(hidden_states.to(dtype=self.ln_f.weight.dtype))
338
+ if output_hidden_states:
339
+ all_hidden_states.append(hidden_states)
340
+
341
+ return hidden_states, all_hidden_states
342
+
343
+
344
+ class HyenaDNAPreTrainedModel(PreTrainedModel):
345
+ config_class = HyenaConfig
346
+ base_model_prefix = "hyena"
347
+ supports_gradient_checkpointing = True
348
+ _no_split_modules = ["HyenaBlock"]
349
+ _skip_keys_device_placement = "past_key_values"
350
+ _keys_to_ignore_on_load_missing = [r"freq"] # Shared tensors that safetensors merges
351
+
352
+ def _init_weights(self, module, initializer_range=0.02):
353
+ if isinstance(module, nn.Linear):
354
+ nn.init.normal_(module.weight, std=initializer_range)
355
+ if module.bias is not None:
356
+ nn.init.zeros_(module.bias)
357
+ elif isinstance(module, nn.Embedding):
358
+ nn.init.normal_(module.weight, std=initializer_range)
359
+ # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
360
+ # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
361
+ # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
362
+ # > -- GPT-2 :: https://openai.com/blog/better-language-models/
363
+ #
364
+ # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
365
+ for name, p in self.named_parameters():
366
+ if name in ["out_proj.weight", "fc2.weight"]:
367
+ # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
368
+ nn.init.normal_(p, mean=0.0, std=initializer_range / math.sqrt(2 * self.config.num_layers))
369
+ # If using GLU activation for now, we scale the std by 2
370
+ elif name in ["output_linear.0.weight"]:
371
+ # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
372
+ nn.init.normal_(p, mean=0.0, std=initializer_range / math.sqrt(2 * self.config.num_layers))
373
+
374
+
375
+ class HyenaDNAModel(HyenaDNAPreTrainedModel):
376
+ def __init__(self, config, **kwargs) -> None:
377
+ super().__init__(config, **kwargs)
378
+
379
+ self.backbone = HyenaLMBackbone(config)
380
+ self.config = config
381
+
382
+ # Initialize weights and apply final processing
383
+ self.post_init()
384
+
385
+ def forward(self, input_ids, inputs_embeds=None, output_hidden_states=None, return_dict=None):
386
+ output_hidden_states = (
387
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
388
+ )
389
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
390
+
391
+ hidden_states, all_hidden_states = self.backbone(input_ids, inputs_embeds=inputs_embeds, output_hidden_states=output_hidden_states)
392
+ if return_dict:
393
+ return BaseModelOutputWithNoAttention(last_hidden_state=hidden_states,
394
+ hidden_states=all_hidden_states if output_hidden_states else None)
395
+ elif output_hidden_states:
396
+ return hidden_states, all_hidden_states
397
+ else:
398
+ return hidden_states
399
+
400
+
401
+ class HyenaDNAForCausalLM(HyenaDNAPreTrainedModel):
402
+
403
+ def __init__(self, config, **kwargs):
404
+ super().__init__(config, **kwargs)
405
+ self.hyena = HyenaDNAModel(config)
406
+ vocab_size = config.vocab_size
407
+ if vocab_size % config.pad_vocab_size_multiple != 0:
408
+ vocab_size += config.pad_vocab_size_multiple - (vocab_size % config.pad_vocab_size_multiple)
409
+ self.vocab_size = vocab_size
410
+ self.lm_head = nn.Linear(config.d_model, vocab_size, bias=False)
411
+
412
+ # Initialize weights and apply final processing
413
+ self.post_init()
414
+
415
+ def get_input_embeddings(self):
416
+ return self.hyena.backbone.embeddings.word_embeddings
417
+
418
+ def set_input_embeddings(self, value):
419
+ self.hyena.backbone.embeddings.word_embeddings = value
420
+
421
+ def get_output_embeddings(self):
422
+ return self.lm_head
423
+
424
+ def set_output_embeddings(self, new_embeddings):
425
+ self.lm_head = new_embeddings
426
+
427
+ def set_decoder(self, decoder):
428
+ self.hyena = decoder
429
+
430
+ def get_decoder(self):
431
+ return self.hyena
432
+
433
+ def forward(
434
+ self,
435
+ input_ids: torch.LongTensor = None,
436
+ inputs_embeds: Optional[torch.FloatTensor] = None,
437
+ labels: Optional[torch.LongTensor] = None,
438
+ output_hidden_states: Optional[bool] = None,
439
+ return_dict: Optional[bool] = None,
440
+ ) -> Union[Tuple, CausalLMOutput]:
441
+
442
+ output_hidden_states = (
443
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
444
+ )
445
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
446
+
447
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
448
+ outputs = self.hyena(
449
+ input_ids=input_ids,
450
+ inputs_embeds=inputs_embeds,
451
+ output_hidden_states=output_hidden_states,
452
+ return_dict=return_dict,
453
+ )
454
+
455
+ hidden_states = outputs[0]
456
+ logits = self.lm_head(hidden_states)
457
+ logits = logits.float()
458
+
459
+ loss = None
460
+ if labels is not None:
461
+ # Shift so that tokens < n predict n
462
+ shift_logits = logits[..., :-1, :].contiguous()
463
+ shift_labels = labels[..., 1:].contiguous()
464
+ # Flatten the tokens
465
+ loss_fct = nn.CrossEntropyLoss()
466
+ shift_logits = shift_logits.view(-1, self.vocab_size)
467
+ shift_labels = shift_labels.view(-1)
468
+ # Enable model parallelism
469
+ shift_labels = shift_labels.to(shift_logits.device)
470
+ loss = loss_fct(shift_logits, shift_labels)
471
+
472
+ if not return_dict:
473
+ output = (logits,) + outputs[1:]
474
+ return (loss,) + output if loss is not None else output
475
+
476
+ return CausalLMOutput(
477
+ loss=loss,
478
+ logits=logits,
479
+ hidden_states=outputs.hidden_states,
480
+ )
481
+
482
+
483
+ class HyenaDNAForSequenceClassification(HyenaDNAPreTrainedModel):
484
+ def __init__(self, config, **kwargs):
485
+ super().__init__(config, **kwargs)
486
+ self.num_labels = kwargs.get("num_labels", config.num_labels)
487
+ self.hyena = HyenaDNAModel(config)
488
+ self.score = nn.Linear(config.d_model, self.num_labels, bias=False)
489
+
490
+ # Initialize weights and apply final processing
491
+ self.post_init()
492
+
493
+ def get_input_embeddings(self):
494
+ return self.hyena.backbone.embeddings.word_embeddings
495
+
496
+ def set_input_embeddings(self, value):
497
+ self.hyena.backbone.embeddings.word_embeddings = value
498
+
499
+ def forward(
500
+ self,
501
+ input_ids: torch.LongTensor = None,
502
+ inputs_embeds: Optional[torch.FloatTensor] = None,
503
+ labels: Optional[torch.LongTensor] = None,
504
+ output_hidden_states: Optional[bool] = None,
505
+ return_dict: Optional[bool] = None,
506
+ ) -> Union[Tuple, SequenceClassifierOutput]:
507
+ r"""
508
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
509
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
510
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
511
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
512
+ """
513
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
514
+
515
+ transformer_outputs = self.hyena(
516
+ input_ids,
517
+ inputs_embeds=inputs_embeds,
518
+ output_hidden_states=output_hidden_states,
519
+ return_dict=return_dict,
520
+ )
521
+ hidden_states = transformer_outputs[0]
522
+ logits = self.score(hidden_states)
523
+
524
+ if input_ids is not None:
525
+ batch_size = input_ids.shape[0]
526
+ else:
527
+ batch_size = inputs_embeds.shape[0]
528
+
529
+ if self.config.pad_token_id is None and batch_size != 1:
530
+ raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
531
+ if self.config.pad_token_id is None:
532
+ sequence_lengths = -1
533
+ else:
534
+ if input_ids is not None:
535
+ sequence_lengths = (torch.eq(input_ids, self.config.pad_token_id).long().argmax(-1) - 1).to(
536
+ logits.device
537
+ )
538
+ else:
539
+ sequence_lengths = -1
540
+
541
+ pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
542
+
543
+ loss = None
544
+ if labels is not None:
545
+ labels = labels.to(logits.device)
546
+ if self.config.problem_type is None:
547
+ if self.num_labels == 1:
548
+ self.config.problem_type = "regression"
549
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
550
+ self.config.problem_type = "single_label_classification"
551
+ else:
552
+ self.config.problem_type = "multi_label_classification"
553
+
554
+ if self.config.problem_type == "regression":
555
+ loss_fct = nn.MSELoss()
556
+ if self.num_labels == 1:
557
+ loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
558
+ else:
559
+ loss = loss_fct(pooled_logits, labels)
560
+ elif self.config.problem_type == "single_label_classification":
561
+ loss_fct = nn.CrossEntropyLoss()
562
+ loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
563
+ elif self.config.problem_type == "multi_label_classification":
564
+ loss_fct = nn.BCEWithLogitsLoss()
565
+ loss = loss_fct(pooled_logits, labels)
566
+ if not return_dict:
567
+ output = (pooled_logits,) + transformer_outputs[1:]
568
+ return ((loss,) + output) if loss is not None else output
569
+
570
+ return SequenceClassifierOutput(
571
+ loss=loss,
572
+ logits=pooled_logits,
573
+ hidden_states=transformer_outputs.hidden_states,
574
+ )
optimizer_state_dict.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cbc2fac38ec5a3bae03043becad4825c1a4afe885a12bdd597f5d263a40d8b55
3
+ size 26307771
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2dbaff97807d7961eb4208709c13b67c9b7fb92e5f7418a96202fef0ae7e5dd5
3
+ size 16300157
special_tokens_map.json ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": {
3
+ "content": "[BOS]",
4
+ "lstrip": false,
5
+ "normalized": false,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "cls_token": {
10
+ "content": "[CLS]",
11
+ "lstrip": false,
12
+ "normalized": false,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "eos_token": {
17
+ "content": "[SEP]",
18
+ "lstrip": false,
19
+ "normalized": false,
20
+ "rstrip": false,
21
+ "single_word": false
22
+ },
23
+ "mask_token": {
24
+ "content": "[MASK]",
25
+ "lstrip": false,
26
+ "normalized": false,
27
+ "rstrip": false,
28
+ "single_word": false
29
+ },
30
+ "pad_token": {
31
+ "content": "[PAD]",
32
+ "lstrip": false,
33
+ "normalized": false,
34
+ "rstrip": false,
35
+ "single_word": false
36
+ },
37
+ "sep_token": {
38
+ "content": "[SEP]",
39
+ "lstrip": false,
40
+ "normalized": false,
41
+ "rstrip": false,
42
+ "single_word": false
43
+ },
44
+ "unk_token": {
45
+ "content": "[UNK]",
46
+ "lstrip": false,
47
+ "normalized": false,
48
+ "rstrip": false,
49
+ "single_word": false
50
+ }
51
+ }
tokenization_hyena.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PreTrainedTokenizer, AddedToken
2
+ from typing import List, Optional, Union, Dict, Sequence, Tuple
3
+ from pathlib import Path
4
+ import json
5
+ import os
6
+
7
+
8
+ class HyenaDNATokenizer(PreTrainedTokenizer):
9
+ model_input_names = ["input_ids"]
10
+
11
+ def __init__(self,
12
+ model_max_length: int,
13
+ bos_token="[BOS]",
14
+ eos_token="[SEP]",
15
+ sep_token="[SEP]",
16
+ cls_token="[CLS]",
17
+ pad_token="[PAD]",
18
+ mask_token="[MASK]",
19
+ unk_token="[UNK]",
20
+ **kwargs):
21
+ """Character tokenizer for Hugging Face transformers.
22
+ Args:
23
+ characters (Sequence[str]): List of desired characters. Any character which
24
+ is not included in this list will be replaced by a special token called
25
+ [UNK] with id=6. Following are list of all of the special tokens with
26
+ their corresponding ids:
27
+ "[CLS]": 0
28
+ "[SEP]": 1
29
+ "[BOS]": 2
30
+ "[MASK]": 3
31
+ "[PAD]": 4
32
+ "[RESERVED]": 5
33
+ "[UNK]": 6
34
+ an id (starting at 7) will be assigned to each character.
35
+ model_max_length (int): Model maximum sequence length.
36
+ """
37
+ self.characters = ('A', 'C', 'G', 'T', 'N')
38
+ self.model_max_length = model_max_length
39
+
40
+ self._vocab_str_to_int = {
41
+ "[CLS]": 0,
42
+ "[SEP]": 1,
43
+ "[BOS]": 2,
44
+ "[MASK]": 3,
45
+ "[PAD]": 4,
46
+ "[RESERVED]": 5,
47
+ "[UNK]": 6,
48
+ **{ch: i + 7 for i, ch in enumerate(self.characters)},
49
+ }
50
+ self._vocab_int_to_str = {v: k for k, v in self._vocab_str_to_int.items()}
51
+ add_prefix_space = kwargs.pop("add_prefix_space", False)
52
+ padding_side = kwargs.pop("padding_side", "left")
53
+
54
+ super().__init__(
55
+ bos_token=bos_token,
56
+ eos_token=eos_token,
57
+ sep_token=sep_token,
58
+ cls_token=cls_token,
59
+ pad_token=pad_token,
60
+ mask_token=mask_token,
61
+ unk_token=unk_token,
62
+ add_prefix_space=add_prefix_space,
63
+ model_max_length=model_max_length,
64
+ padding_side=padding_side,
65
+ **kwargs,
66
+ )
67
+
68
+ @property
69
+ def vocab_size(self) -> int:
70
+ return len(self._vocab_str_to_int)
71
+
72
+ def _tokenize(self, text: str) -> List[str]:
73
+ return list(text)
74
+
75
+ def _convert_token_to_id(self, token: str) -> int:
76
+ return self._vocab_str_to_int.get(token, self._vocab_str_to_int["[UNK]"])
77
+
78
+ def _convert_id_to_token(self, index: int) -> str:
79
+ return self._vocab_int_to_str[index]
80
+
81
+ def convert_tokens_to_string(self, tokens):
82
+ return "".join(tokens)
83
+
84
+ def get_special_tokens_mask(
85
+ self,
86
+ token_ids_0: List[int],
87
+ token_ids_1: Optional[List[int]] = None,
88
+ already_has_special_tokens: bool = False,
89
+ ) -> List[int]:
90
+ if already_has_special_tokens:
91
+ return super().get_special_tokens_mask(
92
+ token_ids_0=token_ids_0,
93
+ token_ids_1=token_ids_1,
94
+ already_has_special_tokens=True,
95
+ )
96
+
97
+ result = ([0] * len(token_ids_0)) + [1]
98
+ if token_ids_1 is not None:
99
+ result += ([0] * len(token_ids_1)) + [1]
100
+ return result
101
+
102
+ def build_inputs_with_special_tokens(
103
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
104
+ ) -> List[int]:
105
+ sep = [self.sep_token_id]
106
+ # cls = [self.cls_token_id]
107
+ result = token_ids_0 + sep
108
+ if token_ids_1 is not None:
109
+ result += token_ids_1 + sep
110
+ return result
111
+
112
+ def get_vocab(self) -> Dict[str, int]:
113
+ return self._vocab_str_to_int
114
+
115
+ # HyenaDNA has a fixed vocabulary with no vocab file
116
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple:
117
+ return ()
tokenizer_config.json ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_prefix_space": false,
3
+ "added_tokens_decoder": {
4
+ "0": {
5
+ "content": "[CLS]",
6
+ "lstrip": false,
7
+ "normalized": false,
8
+ "rstrip": false,
9
+ "single_word": false,
10
+ "special": true
11
+ },
12
+ "1": {
13
+ "content": "[SEP]",
14
+ "lstrip": false,
15
+ "normalized": false,
16
+ "rstrip": false,
17
+ "single_word": false,
18
+ "special": true
19
+ },
20
+ "2": {
21
+ "content": "[BOS]",
22
+ "lstrip": false,
23
+ "normalized": false,
24
+ "rstrip": false,
25
+ "single_word": false,
26
+ "special": true
27
+ },
28
+ "3": {
29
+ "content": "[MASK]",
30
+ "lstrip": false,
31
+ "normalized": false,
32
+ "rstrip": false,
33
+ "single_word": false,
34
+ "special": true
35
+ },
36
+ "4": {
37
+ "content": "[PAD]",
38
+ "lstrip": false,
39
+ "normalized": false,
40
+ "rstrip": false,
41
+ "single_word": false,
42
+ "special": true
43
+ },
44
+ "6": {
45
+ "content": "[UNK]",
46
+ "lstrip": false,
47
+ "normalized": false,
48
+ "rstrip": false,
49
+ "single_word": false,
50
+ "special": true
51
+ }
52
+ },
53
+ "auto_map": {
54
+ "AutoTokenizer": [
55
+ "tokenization_hyena.HyenaDNATokenizer",
56
+ null
57
+ ]
58
+ },
59
+ "bos_token": "[BOS]",
60
+ "clean_up_tokenization_spaces": true,
61
+ "cls_token": "[CLS]",
62
+ "eos_token": "[SEP]",
63
+ "mask_token": "[MASK]",
64
+ "model_max_length": 256,
65
+ "name_or_path": "LongSafari/hyenadna-small-32k-seqlen-hf",
66
+ "pad_token": "[PAD]",
67
+ "padding_side": "right",
68
+ "sep_token": "[SEP]",
69
+ "special_tokens_map_file": "/home/hlv8980/.cache/huggingface/hub/models--LongSafari--hyenadna-small-32k-seqlen-hf/snapshots/8fe770c78eb13fe33bf81501612faeddf4d6f331/special_tokens_map.json",
70
+ "tokenizer_class": "HyenaDNATokenizer",
71
+ "unk_token": "[UNK]"
72
+ }
trainer_state.json ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "best_metric": 0.39216598868370056,
3
+ "best_model_checkpoint": "/scratch/hlv8980/Attack_Benchmark/models/hyena/tf4/origin/checkpoint-600",
4
+ "epoch": 4.0,
5
+ "global_step": 1188,
6
+ "is_hyper_param_search": false,
7
+ "is_local_process_zero": true,
8
+ "is_world_process_zero": true,
9
+ "log_history": [
10
+ {
11
+ "epoch": 0.34,
12
+ "learning_rate": 2.8760984182776802e-05,
13
+ "loss": 0.5992,
14
+ "step": 100
15
+ },
16
+ {
17
+ "epoch": 0.67,
18
+ "learning_rate": 2.615114235500879e-05,
19
+ "loss": 0.4813,
20
+ "step": 200
21
+ },
22
+ {
23
+ "epoch": 0.67,
24
+ "eval_accuracy": 0.774,
25
+ "eval_f1": 0.7713328260834371,
26
+ "eval_loss": 0.48207539319992065,
27
+ "eval_matthews_correlation": 0.5579338694412199,
28
+ "eval_precision": 0.785067107786007,
29
+ "eval_recall": 0.772997299729973,
30
+ "eval_runtime": 0.1057,
31
+ "eval_samples_per_second": 9462.679,
32
+ "eval_steps_per_second": 151.403,
33
+ "step": 200
34
+ },
35
+ {
36
+ "epoch": 1.01,
37
+ "learning_rate": 2.3514938488576452e-05,
38
+ "loss": 0.4431,
39
+ "step": 300
40
+ },
41
+ {
42
+ "epoch": 1.35,
43
+ "learning_rate": 2.087873462214411e-05,
44
+ "loss": 0.377,
45
+ "step": 400
46
+ },
47
+ {
48
+ "epoch": 1.35,
49
+ "eval_accuracy": 0.816,
50
+ "eval_f1": 0.8159933757615274,
51
+ "eval_loss": 0.427643358707428,
52
+ "eval_matthews_correlation": 0.6320128653971173,
53
+ "eval_precision": 0.815991263965056,
54
+ "eval_recall": 0.816021602160216,
55
+ "eval_runtime": 0.1039,
56
+ "eval_samples_per_second": 9625.637,
57
+ "eval_steps_per_second": 154.01,
58
+ "step": 400
59
+ },
60
+ {
61
+ "epoch": 1.68,
62
+ "learning_rate": 1.82688927943761e-05,
63
+ "loss": 0.3443,
64
+ "step": 500
65
+ },
66
+ {
67
+ "epoch": 2.02,
68
+ "learning_rate": 1.563268892794376e-05,
69
+ "loss": 0.33,
70
+ "step": 600
71
+ },
72
+ {
73
+ "epoch": 2.02,
74
+ "eval_accuracy": 0.824,
75
+ "eval_f1": 0.8239746523499383,
76
+ "eval_loss": 0.39216598868370056,
77
+ "eval_matthews_correlation": 0.6479558982194922,
78
+ "eval_precision": 0.8239935027265344,
79
+ "eval_recall": 0.8239623962396239,
80
+ "eval_runtime": 0.1031,
81
+ "eval_samples_per_second": 9696.512,
82
+ "eval_steps_per_second": 155.144,
83
+ "step": 600
84
+ },
85
+ {
86
+ "epoch": 2.36,
87
+ "learning_rate": 1.2996485061511423e-05,
88
+ "loss": 0.227,
89
+ "step": 700
90
+ },
91
+ {
92
+ "epoch": 2.69,
93
+ "learning_rate": 1.0360281195079087e-05,
94
+ "loss": 0.2219,
95
+ "step": 800
96
+ },
97
+ {
98
+ "epoch": 2.69,
99
+ "eval_accuracy": 0.838,
100
+ "eval_f1": 0.8379766686402841,
101
+ "eval_loss": 0.4026987850666046,
102
+ "eval_matthews_correlation": 0.6767651028795362,
103
+ "eval_precision": 0.8385613769517563,
104
+ "eval_recall": 0.8382038203820381,
105
+ "eval_runtime": 0.1026,
106
+ "eval_samples_per_second": 9746.534,
107
+ "eval_steps_per_second": 155.945,
108
+ "step": 800
109
+ },
110
+ {
111
+ "epoch": 3.03,
112
+ "learning_rate": 7.724077328646749e-06,
113
+ "loss": 0.2121,
114
+ "step": 900
115
+ },
116
+ {
117
+ "epoch": 3.37,
118
+ "learning_rate": 5.087873462214412e-06,
119
+ "loss": 0.1388,
120
+ "step": 1000
121
+ },
122
+ {
123
+ "epoch": 3.37,
124
+ "eval_accuracy": 0.857,
125
+ "eval_f1": 0.8566558306493891,
126
+ "eval_loss": 0.393052339553833,
127
+ "eval_matthews_correlation": 0.7159331394438886,
128
+ "eval_precision": 0.859342750257998,
129
+ "eval_recall": 0.8565956595659566,
130
+ "eval_runtime": 0.1036,
131
+ "eval_samples_per_second": 9656.574,
132
+ "eval_steps_per_second": 154.505,
133
+ "step": 1000
134
+ },
135
+ {
136
+ "epoch": 3.7,
137
+ "learning_rate": 2.4516695957820737e-06,
138
+ "loss": 0.1394,
139
+ "step": 1100
140
+ },
141
+ {
142
+ "epoch": 4.0,
143
+ "step": 1188,
144
+ "total_flos": 152279543808000.0,
145
+ "train_loss": 0.30630172623528373,
146
+ "train_runtime": 38.7189,
147
+ "train_samples_per_second": 1962.866,
148
+ "train_steps_per_second": 30.683
149
+ }
150
+ ],
151
+ "max_steps": 1188,
152
+ "num_train_epochs": 4,
153
+ "total_flos": 152279543808000.0,
154
+ "trial_name": null,
155
+ "trial_params": null
156
+ }
training_args.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1f4e3a85efd6ca1a2228fc0bf6f5ca43150a2981352327acbf61aa5be7e43d49
3
+ size 3707