hvlgo commited on
Commit
d935394
1 Parent(s): fbd9db4

Upload TimerForPrediction

Browse files
config.json ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "checkpoints/timer_base",
3
+ "architectures": [
4
+ "TimerForPrediction"
5
+ ],
6
+ "attention_dropout": 0.0,
7
+ "auto_map": {
8
+ "AutoConfig": "configuration_timer.TimerConfig",
9
+ "AutoModelForCausalLM": "modeling_timer.TimerForPrediction"
10
+ },
11
+ "hidden_act": "silu",
12
+ "hidden_size": 1024,
13
+ "initializer_range": 0.02,
14
+ "input_token_len": 96,
15
+ "intermediate_size": 2048,
16
+ "max_position_embeddings": 10000,
17
+ "model_type": "timer",
18
+ "num_attention_heads": 8,
19
+ "num_hidden_layers": 8,
20
+ "output_token_lens": [
21
+ 96
22
+ ],
23
+ "rope_theta": 10000,
24
+ "torch_dtype": "float32",
25
+ "transformers_version": "4.40.1",
26
+ "use_cache": true
27
+ }
configuration_timer.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+ from transformers import PretrainedConfig
3
+
4
+
5
+ class TimerConfig(PretrainedConfig):
6
+ model_type = "timer"
7
+ keys_to_ignore_at_inference = ["past_key_values"]
8
+
9
+ def __init__(
10
+ self,
11
+ input_token_len: int = 1,
12
+ hidden_size: int = 1024,
13
+ intermediate_size: int = 2048,
14
+ output_token_lens: List[int] = [1, 8, 32, 64],
15
+ num_hidden_layers: int = 8,
16
+ num_attention_heads: int = 8,
17
+ hidden_act: str = "silu",
18
+ use_cache: bool = True,
19
+ rope_theta: int = 10000,
20
+ attention_dropout: float = 0.0,
21
+ initializer_range: float = 0.02,
22
+ max_position_embeddings: int = 10000,
23
+ **kwargs,
24
+ ):
25
+ self.input_token_len = input_token_len
26
+ self.hidden_size = hidden_size
27
+ self.intermediate_size = intermediate_size
28
+ self.num_hidden_layers = num_hidden_layers
29
+ self.num_attention_heads = num_attention_heads
30
+ self.hidden_act = hidden_act
31
+ self.output_token_lens = output_token_lens
32
+ self.use_cache = use_cache
33
+ self.rope_theta = rope_theta
34
+ self.attention_dropout = attention_dropout
35
+ self.initializer_range = initializer_range
36
+ self.max_position_embeddings = max_position_embeddings
37
+
38
+ super().__init__(
39
+ **kwargs,
40
+ )
generation_config.json ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "transformers_version": "4.40.1"
4
+ }
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9c3d18f12ffe1ea7d4fa70eb3304b26e3841164a6a265fbae4f7a05cd213aa3d
3
+ size 336580760
modeling_timer.py ADDED
@@ -0,0 +1,565 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Tuple, List, Union
2
+ import torch
3
+ from torch import nn
4
+ import torch.nn.functional as F
5
+ from transformers import PreTrainedModel, Cache, DynamicCache
6
+ from transformers.activations import ACT2FN
7
+ from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
8
+ from transformers.modeling_outputs import MoeModelOutputWithPast, MoeCausalLMOutputWithPast
9
+ from .configuration_timer import TimerConfig
10
+ from .ts_generation_mixin import TSGenerationMixin
11
+
12
+
13
+ def rotate_half(x):
14
+ x1 = x[..., : x.shape[-1] // 2]
15
+ x2 = x[..., x.shape[-1] // 2:]
16
+ return torch.cat((-x2, x1), dim=-1)
17
+
18
+
19
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
20
+ cos = cos[position_ids].unsqueeze(unsqueeze_dim)
21
+ sin = sin[position_ids].unsqueeze(unsqueeze_dim)
22
+ q_embed = (q * cos) + (rotate_half(q) * sin)
23
+ k_embed = (k * cos) + (rotate_half(k) * sin)
24
+ return q_embed, k_embed
25
+
26
+
27
+ class TimerPatchEmbedding(nn.Module):
28
+ def __init__(self, config: TimerConfig):
29
+ super().__init__()
30
+ self.input_token_len = config.input_token_len
31
+ self.emb = nn.Linear(config.input_token_len,
32
+ config.hidden_size, bias=False)
33
+
34
+ def forward(self, hidden_state: torch.Tensor):
35
+ hidden_state = hidden_state.unfold(
36
+ dimension=-1, size=self.input_token_len, step=self.input_token_len)
37
+ return self.emb(hidden_state)
38
+
39
+
40
+ class TimerPointEmbedding(nn.Module):
41
+ def __init__(self, config: TimerConfig):
42
+ super().__init__()
43
+ self.emb_layer = nn.Linear(
44
+ config.input_token_len, config.hidden_size, bias=False)
45
+ self.gate_layer = nn.Linear(
46
+ config.input_token_len, config.hidden_size, bias=False)
47
+ self.act_fn = ACT2FN[config.hidden_act]
48
+
49
+ def forward(self, x):
50
+ emb = self.act_fn(self.gate_layer(x)) * self.emb_layer(x)
51
+ return emb
52
+
53
+
54
+ class TimeMoeRotaryEmbedding(torch.nn.Module):
55
+ def __init__(self, dim, max_position_embeddings=10000, base=10000, device=None):
56
+ super().__init__()
57
+ self.dim = dim
58
+ self.max_position_embeddings = max_position_embeddings
59
+ self.base = base
60
+ inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim,
61
+ 2, dtype=torch.int64).float().to(device) / self.dim))
62
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
63
+
64
+ # Build here to make `torch.jit.trace` work.
65
+ self._set_cos_sin_cache(
66
+ seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
67
+ )
68
+
69
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
70
+ self.max_seq_len_cached = seq_len
71
+ t = torch.arange(self.max_seq_len_cached, device=device,
72
+ dtype=torch.int64).type_as(self.inv_freq)
73
+
74
+ freqs = torch.outer(t, self.inv_freq)
75
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
76
+ emb = torch.cat((freqs, freqs), dim=-1)
77
+ self.register_buffer(
78
+ "cos_cached", emb.cos().to(dtype), persistent=False)
79
+ self.register_buffer(
80
+ "sin_cached", emb.sin().to(dtype), persistent=False)
81
+
82
+ def forward(self, x, seq_len=None):
83
+ # x: [bs, num_attention_heads, seq_len, head_size]
84
+ if seq_len > self.max_seq_len_cached:
85
+ self._set_cos_sin_cache(
86
+ seq_len=seq_len, device=x.device, dtype=x.dtype)
87
+
88
+ return (
89
+ self.cos_cached[:seq_len].to(dtype=x.dtype),
90
+ self.sin_cached[:seq_len].to(dtype=x.dtype),
91
+ )
92
+
93
+
94
+ class TimerAttention(nn.Module):
95
+ def __init__(self, config: TimerConfig, layer_idx: Optional[int] = None):
96
+ super().__init__()
97
+ self.layer_idx = layer_idx
98
+ self.hidden_size = config.hidden_size
99
+ self.num_heads = config.num_attention_heads
100
+ self.head_dim = self.hidden_size // self.num_heads
101
+ self.attention_dropout = config.attention_dropout
102
+ self.q_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=True)
103
+ self.k_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=True)
104
+ self.v_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=True)
105
+ self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
106
+ self.rotary_emb = TimeMoeRotaryEmbedding(
107
+ self.head_dim, max_position_embeddings=config.max_position_embeddings)
108
+
109
+ def forward(
110
+ self,
111
+ hidden_states: torch.Tensor,
112
+ attention_mask: Optional[torch.Tensor] = None,
113
+ position_ids: Optional[torch.LongTensor] = None,
114
+ past_key_value: Optional[Cache] = None,
115
+ output_attentions: bool = False,
116
+ **kwargs,
117
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
118
+ bsz, q_len, _ = hidden_states.size()
119
+
120
+ query_states = self.q_proj(hidden_states)
121
+ key_states = self.k_proj(hidden_states)
122
+ value_states = self.v_proj(hidden_states)
123
+
124
+ query_states = query_states.view(
125
+ bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
126
+ key_states = key_states.view(
127
+ bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
128
+ value_states = value_states.view(
129
+ bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
130
+
131
+ kv_seq_len = key_states.shape[-2]
132
+ if past_key_value is not None:
133
+ kv_seq_len += past_key_value.get_usable_length(
134
+ kv_seq_len, self.layer_idx)
135
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
136
+ query_states, key_states = apply_rotary_pos_emb(
137
+ query_states, key_states, cos, sin, position_ids)
138
+
139
+ if past_key_value is not None:
140
+ key_states, value_states = past_key_value.update(
141
+ key_states, value_states, self.layer_idx)
142
+
143
+ attn_output = F.scaled_dot_product_attention(
144
+ query_states, key_states, value_states, attention_mask, dropout_p=self.attention_dropout)
145
+
146
+ attn_output = attn_output.transpose(1, 2).contiguous()
147
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
148
+ attn_output = self.o_proj(attn_output)
149
+
150
+ if not output_attentions:
151
+ attn_weights = None
152
+
153
+ return attn_output, attn_weights, past_key_value
154
+
155
+
156
+ class TimerMLP(nn.Module):
157
+ def __init__(self, hidden_size: int, intermediate_size: int, hidden_act: str):
158
+ super().__init__()
159
+ self.hidden_size = hidden_size
160
+ self.intermediate_size = intermediate_size
161
+ self.gate_proj = nn.Linear(
162
+ self.hidden_size, self.intermediate_size, bias=False)
163
+ self.up_proj = nn.Linear(
164
+ self.hidden_size, self.intermediate_size, bias=False)
165
+ self.down_proj = nn.Linear(
166
+ self.intermediate_size, self.hidden_size, bias=False)
167
+ self.act_fn = ACT2FN[hidden_act]
168
+
169
+ def forward(self, hidden_state):
170
+ return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state))
171
+
172
+
173
+ class TimerDecoderLayer(nn.Module):
174
+ def __init__(self, config: TimerConfig, layer_idx: int):
175
+ super().__init__()
176
+ self.self_attn = TimerAttention(config, layer_idx)
177
+
178
+ self.ffn_layer = TimerMLP(
179
+ hidden_size=config.hidden_size,
180
+ intermediate_size=config.intermediate_size,
181
+ hidden_act=config.hidden_act,
182
+ )
183
+ self.norm1 = torch.nn.LayerNorm(config.hidden_size)
184
+ self.norm2 = torch.nn.LayerNorm(config.hidden_size)
185
+
186
+ def forward(
187
+ self,
188
+ hidden_states: torch.Tensor,
189
+ attention_mask: Optional[torch.Tensor] = None,
190
+ position_ids: Optional[torch.LongTensor] = None,
191
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
192
+ output_attentions: Optional[bool] = False,
193
+ use_cache: Optional[bool] = False,
194
+ **kwargs,
195
+ ) -> Tuple[torch.FloatTensor, torch.FloatTensor, Optional[torch.FloatTensor], Optional[torch.FloatTensor]]:
196
+ residual = hidden_states
197
+
198
+ # Self Attention
199
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
200
+ hidden_states=hidden_states,
201
+ attention_mask=attention_mask,
202
+ position_ids=position_ids,
203
+ past_key_value=past_key_value,
204
+ output_attentions=output_attentions,
205
+ use_cache=use_cache,
206
+ )
207
+ hidden_states = residual + hidden_states
208
+ hidden_states = self.norm1(hidden_states)
209
+
210
+ # Fully Connected
211
+ residual = hidden_states
212
+ hidden_states = self.ffn_layer(hidden_states)
213
+ hidden_states = residual + hidden_states
214
+ hidden_states = self.norm2(hidden_states)
215
+
216
+ if not output_attentions:
217
+ self_attn_weights = None
218
+
219
+ if not use_cache:
220
+ present_key_value = None
221
+ return hidden_states, self_attn_weights, present_key_value
222
+
223
+
224
+ class TimerPreTrainedModel(PreTrainedModel):
225
+ config_class = TimerConfig
226
+ base_model_prefix = "model"
227
+ supports_gradient_checkpointing = True
228
+ _no_split_modules = ["TimeMoeDecoderLayer"]
229
+ _skip_keys_device_placement = "past_key_values"
230
+ _supports_flash_attn_2 = True
231
+ _supports_sdpa = False
232
+ _supports_cache_class = True
233
+
234
+ def _init_weights(self, module):
235
+ std = self.config.initializer_range
236
+ if isinstance(module, torch.nn.Linear):
237
+ module.weight.data.normal_(mean=0.0, std=std)
238
+ if module.bias is not None:
239
+ module.bias.data.zero_()
240
+ elif isinstance(module, torch.nn.Embedding):
241
+ module.weight.data.normal_(mean=0.0, std=std)
242
+ if module.padding_idx is not None:
243
+ module.weight.data[module.padding_idx].zero_()
244
+
245
+
246
+ class TimerModel(TimerPreTrainedModel):
247
+ def __init__(self, config: TimerConfig):
248
+ super().__init__(config)
249
+ self.embed_layer = TimerPatchEmbedding(config)
250
+ self.layers = nn.ModuleList(
251
+ [TimerDecoderLayer(config, layer_idx)
252
+ for layer_idx in range(config.num_hidden_layers)]
253
+ )
254
+ self.norm = torch.nn.LayerNorm(config.hidden_size)
255
+ self.gradient_checkpointing = False
256
+
257
+ def forward(
258
+ self,
259
+ input_ids: torch.FloatTensor = None,
260
+ attention_mask: Optional[torch.Tensor] = None,
261
+ position_ids: Optional[torch.LongTensor] = None,
262
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
263
+ inputs_embeds: Optional[torch.FloatTensor] = None,
264
+ use_cache: Optional[bool] = None,
265
+ output_attentions: Optional[bool] = None,
266
+ output_hidden_states: Optional[bool] = None,
267
+ return_dict: Optional[bool] = None,
268
+ ) -> Union[Tuple, MoeModelOutputWithPast]:
269
+ # input_ids is the input of time series, its shape is [batch_size, seq_len]
270
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
271
+ output_hidden_states = (
272
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
273
+ )
274
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
275
+
276
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
277
+
278
+ # retrieve input_ids and inputs_embeds
279
+ if input_ids is not None and inputs_embeds is not None:
280
+ raise ValueError(
281
+ "You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
282
+ elif input_ids is not None:
283
+ batch_size, seq_length = input_ids.shape
284
+ elif inputs_embeds is not None:
285
+ batch_size, seq_length, _ = inputs_embeds.shape
286
+ else:
287
+ raise ValueError(
288
+ "You have to specify either decoder_input_ids or decoder_inputs_embeds")
289
+
290
+ if inputs_embeds is None:
291
+ inputs_embeds = self.embed_layer(input_ids)
292
+ seq_length = inputs_embeds.shape[1]
293
+
294
+ if self.gradient_checkpointing and self.training:
295
+ if use_cache:
296
+ use_cache = False
297
+
298
+ past_key_values_length = 0
299
+
300
+ if use_cache:
301
+ use_legacy_cache = not isinstance(past_key_values, Cache)
302
+ if use_legacy_cache:
303
+ past_key_values = DynamicCache.from_legacy_cache(
304
+ past_key_values)
305
+ past_key_values_length = past_key_values.get_usable_length(
306
+ seq_length)
307
+
308
+ if position_ids is None:
309
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
310
+ position_ids = torch.arange(
311
+ past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
312
+ )
313
+ # position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
314
+ position_ids = position_ids.view(-1, seq_length)
315
+ else:
316
+ position_ids = position_ids.view(-1, seq_length).long()
317
+
318
+ # 4d mask is passed through the layers
319
+ attention_mask = _prepare_4d_causal_attention_mask(
320
+ attention_mask,
321
+ (batch_size, seq_length),
322
+ inputs_embeds,
323
+ past_key_values_length,
324
+ sliding_window=None,
325
+ )
326
+
327
+ hidden_states = inputs_embeds
328
+
329
+ # decoder layers
330
+ all_hidden_states = () if output_hidden_states else None
331
+ all_self_attns = () if output_attentions else None
332
+ next_decoder_cache = None
333
+
334
+ for decoder_layer in self.layers:
335
+ if output_hidden_states:
336
+ all_hidden_states += (hidden_states,)
337
+
338
+ if self.gradient_checkpointing and self.training:
339
+ layer_outputs = self._gradient_checkpointing_func(
340
+ decoder_layer.__call__,
341
+ hidden_states,
342
+ attention_mask,
343
+ position_ids,
344
+ past_key_values,
345
+ output_attentions,
346
+ use_cache,
347
+ )
348
+ else:
349
+ layer_outputs = decoder_layer(
350
+ hidden_states,
351
+ attention_mask=attention_mask,
352
+ position_ids=position_ids,
353
+ past_key_value=past_key_values,
354
+ output_attentions=output_attentions,
355
+ use_cache=use_cache,
356
+ )
357
+
358
+ hidden_states = layer_outputs[0]
359
+
360
+ if output_attentions:
361
+ all_self_attns += (layer_outputs[1],)
362
+
363
+ if use_cache:
364
+ next_decoder_cache = layer_outputs[2]
365
+
366
+ hidden_states = self.norm(hidden_states)
367
+ # add hidden states from the last decoder layer
368
+ if output_hidden_states:
369
+ all_hidden_states += (hidden_states,)
370
+
371
+ next_cache = None
372
+ if use_cache:
373
+ next_cache = next_decoder_cache.to_legacy_cache(
374
+ ) if use_legacy_cache else next_decoder_cache
375
+
376
+ if not return_dict:
377
+ return tuple(
378
+ v
379
+ for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
380
+ if v is not None
381
+ )
382
+ return MoeModelOutputWithPast(
383
+ last_hidden_state=hidden_states,
384
+ past_key_values=next_cache,
385
+ hidden_states=all_hidden_states,
386
+ attentions=all_self_attns,
387
+ )
388
+
389
+
390
+ class TimerForPrediction(TimerPreTrainedModel, TSGenerationMixin):
391
+ def __init__(self, config: TimerConfig):
392
+ super().__init__(config)
393
+ self.config = config
394
+ self.model = TimerModel(self.config)
395
+ lm_head_list = []
396
+ self.output_token_len_map = {}
397
+ for i, output_token_len in enumerate(self.config.output_token_lens):
398
+ lm_head_list.append(
399
+ nn.Linear(self.config.hidden_size, output_token_len, bias=False))
400
+ self.output_token_len_map[output_token_len] = i
401
+ self.lm_heads = nn.ModuleList(lm_head_list)
402
+ self.loss_function = torch.nn.MSELoss(reduction='none')
403
+ self.post_init()
404
+
405
+ def set_decoder(self, decoder):
406
+ self.model = decoder
407
+
408
+ def get_decoder(self):
409
+ return self.model
410
+
411
+ def forward(
412
+ self,
413
+ input_ids: torch.FloatTensor = None,
414
+ attention_mask: Optional[torch.Tensor] = None,
415
+ position_ids: Optional[torch.LongTensor] = None,
416
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
417
+ inputs_embeds: Optional[torch.FloatTensor] = None,
418
+ labels: Optional[torch.FloatTensor] = None,
419
+ loss_masks: Optional[torch.FloatTensor] = None,
420
+ use_cache: Optional[bool] = None,
421
+ output_attentions: Optional[bool] = None,
422
+ output_hidden_states: Optional[bool] = None,
423
+ return_dict: Optional[bool] = None,
424
+ max_output_length: Optional[int] = None,
425
+ ) -> Union[Tuple, MoeCausalLMOutputWithPast]:
426
+
427
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
428
+ output_hidden_states = (
429
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
430
+ )
431
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
432
+
433
+ outputs = self.model(
434
+ input_ids=input_ids,
435
+ attention_mask=attention_mask,
436
+ position_ids=position_ids,
437
+ past_key_values=past_key_values,
438
+ inputs_embeds=inputs_embeds,
439
+ use_cache=use_cache,
440
+ output_attentions=output_attentions,
441
+ output_hidden_states=output_hidden_states,
442
+ return_dict=return_dict,
443
+ )
444
+
445
+ hidden_states = outputs[0] if not return_dict else outputs.last_hidden_state
446
+ predictions = None
447
+
448
+ loss = None
449
+ if labels is not None:
450
+ ar_loss = 0.0
451
+ for lm_head, output_token_len in zip(self.lm_heads, self.config.output_token_lens):
452
+ one_predictions = lm_head(hidden_states)
453
+ one_loss = self.calc_ar_loss(
454
+ one_predictions, labels, loss_masks, output_token_len)
455
+ ar_loss += one_loss
456
+ if predictions is None:
457
+ predictions = one_predictions
458
+ loss = ar_loss / len(self.config.output_token_lens)
459
+ else:
460
+ if max_output_length is None:
461
+ output_token_len = self.config.output_token_lens[0]
462
+ max_output_length = output_token_len
463
+ else:
464
+ output_token_len = self.config.output_token_lens[0]
465
+ for h in self.config.output_token_lens[1:]:
466
+ if h > max_output_length:
467
+ break
468
+ else:
469
+ output_token_len = h
470
+ lm_head = self.lm_heads[self.output_token_len_map[output_token_len]]
471
+ predictions = lm_head(hidden_states)
472
+ if output_token_len > max_output_length:
473
+ predictions = predictions[:, :, :max_output_length]
474
+ if not return_dict:
475
+ output = (predictions,) + outputs[1:]
476
+ return (loss) + output if loss is not None else output
477
+
478
+ return MoeCausalLMOutputWithPast(
479
+ loss=loss,
480
+ logits=predictions,
481
+ past_key_values=outputs.past_key_values,
482
+ hidden_states=outputs.hidden_states,
483
+ attentions=outputs.attentions,
484
+ )
485
+
486
+ def calc_ar_loss(self, predictions, labels, loss_masks, output_token_len):
487
+ seq_len = predictions.shape[1] * self.config.input_token_len
488
+ labels = labels[:, :seq_len -
489
+ self.config.input_token_len + output_token_len]
490
+ shift_labels = labels.unfold(
491
+ dimension=-1, size=output_token_len, step=self.config.input_token_len)
492
+
493
+ # Calculate loss with mask
494
+ losses = self.loss_function(predictions, shift_labels).mean(dim=-1)
495
+ if loss_masks is not None:
496
+ losses = losses * loss_masks
497
+ loss = losses.sum() / loss_masks.sum()
498
+ else:
499
+ loss = torch.mean(losses)
500
+
501
+ return loss
502
+
503
+ def prepare_inputs_for_generation(
504
+ self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
505
+ ):
506
+ # Omit tokens covered by past_key_values
507
+ if past_key_values is not None:
508
+ if isinstance(past_key_values, Cache):
509
+ cache_length = past_key_values.get_seq_length()
510
+ if isinstance(past_key_values, DynamicCache):
511
+ past_length = past_key_values.seen_tokens
512
+ else:
513
+ past_length = cache_length
514
+
515
+ max_cache_length = past_key_values.get_max_length()
516
+ else:
517
+ cache_length = past_length = past_key_values[0][0].shape[2]
518
+ max_cache_length = None
519
+
520
+ # Keep only the unprocessed tokens:
521
+ # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
522
+ # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
523
+ # input)
524
+ if attention_mask is not None and attention_mask.shape[1] > (input_ids.shape[1] // self.config.input_token_len):
525
+ input_ids = input_ids[:, -
526
+ (attention_mask.shape[1] - past_length):]
527
+ # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
528
+ # input_ids based on the past_length.
529
+ elif past_length < (input_ids.shape[1] // self.config.input_token_len):
530
+ input_ids = input_ids[:, past_length *
531
+ self.config.input_token_len:]
532
+ # 3 - Otherwise (past_length >= (input_ids.shape[1] // self.config.input_token_len)), let's assume input_ids only has unprocessed tokens.
533
+
534
+ # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
535
+ if (
536
+ max_cache_length is not None
537
+ and attention_mask is not None
538
+ and cache_length + (input_ids.shape[1] // self.config.input_token_len) > max_cache_length
539
+ ):
540
+ attention_mask = attention_mask[:, -max_cache_length:]
541
+
542
+ position_ids = kwargs.get("position_ids", None)
543
+ if attention_mask is not None and position_ids is None:
544
+ # create position_ids on the fly for batch generation
545
+ position_ids = attention_mask.long().cumsum(-1) - 1
546
+ position_ids.masked_fill_(attention_mask == 0, 1)
547
+ if past_key_values:
548
+ position_ids = position_ids[:, -
549
+ (input_ids.shape[1] // self.config.input_token_len):]
550
+
551
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
552
+ if inputs_embeds is not None and past_key_values is None:
553
+ model_inputs = {"inputs_embeds": inputs_embeds}
554
+ else:
555
+ model_inputs = {"input_ids": input_ids}
556
+
557
+ model_inputs.update(
558
+ {
559
+ "position_ids": position_ids,
560
+ "past_key_values": past_key_values,
561
+ "use_cache": kwargs.get("use_cache"),
562
+ "attention_mask": attention_mask,
563
+ }
564
+ )
565
+ return model_inputs
ts_generation_mixin.py ADDED
@@ -0,0 +1,251 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+ from typing import Any, Dict, List, Optional, Union
3
+ import torch
4
+ from transformers import GenerationMixin, LogitsProcessorList, StoppingCriteriaList
5
+ from transformers.generation import validate_stopping_criteria, EosTokenCriteria
6
+ from transformers.generation.utils import GenerateNonBeamOutput, GenerateEncoderDecoderOutput, GenerateDecoderOnlyOutput
7
+ from transformers.utils import ModelOutput
8
+
9
+
10
+ class TSGenerationMixin(GenerationMixin):
11
+
12
+ def _greedy_search(
13
+ self,
14
+ input_ids: torch.Tensor,
15
+ logits_processor: Optional[LogitsProcessorList] = None,
16
+ stopping_criteria: Optional[StoppingCriteriaList] = None,
17
+ max_length: Optional[int] = None,
18
+ pad_token_id: Optional[int] = None,
19
+ eos_token_id: Optional[Union[int, List[int]]] = None,
20
+ output_attentions: Optional[bool] = None,
21
+ output_hidden_states: Optional[bool] = None,
22
+ output_scores: Optional[bool] = None,
23
+ output_logits: Optional[bool] = None,
24
+ return_dict_in_generate: Optional[bool] = None,
25
+ synced_gpus: bool = False,
26
+ streamer: Optional["BaseStreamer"] = None,
27
+ **model_kwargs,
28
+ ) -> Union[GenerateNonBeamOutput, torch.Tensor]:
29
+ input_ids_origin_device = input_ids.device
30
+ input_ids = input_ids.to(self.device)
31
+ if len(input_ids.shape) == 2:
32
+ batch_size, cur_len = input_ids.shape
33
+ else:
34
+ raise ValueError('Input shape must be: [batch_size, seq_len]')
35
+ # init values
36
+ logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
37
+ stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
38
+ if max_length is not None:
39
+ warnings.warn(
40
+ "`max_length` is deprecated in this function, use"
41
+ " `stopping_criteria=StoppingCriteriaList([MaxLengthCriteria(max_length=max_length)])` instead.",
42
+ UserWarning,
43
+ )
44
+ stopping_criteria = validate_stopping_criteria(
45
+ stopping_criteria, max_length)
46
+ pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
47
+ if eos_token_id is not None:
48
+ stopping_criteria.append(
49
+ EosTokenCriteria(eos_token_id=eos_token_id))
50
+ else:
51
+ # remove when the method is totally private
52
+ # need to get `eos_token_id` and add stopping criteria, so that generation does not go forever
53
+ eos_token_id = [
54
+ criteria.eos_token_id.tolist() for criteria in stopping_criteria if hasattr(criteria, "eos_token_id")
55
+ ]
56
+ eos_token_id = eos_token_id[0] if eos_token_id else None
57
+ if eos_token_id is None and self.generation_config.eos_token_id is not None:
58
+ eos_token_id = self.generation_config.eos_token_id
59
+ stopping_criteria.append(
60
+ EosTokenCriteria(eos_token_id=eos_token_id))
61
+
62
+ if isinstance(eos_token_id, int):
63
+ eos_token_id = [eos_token_id]
64
+ output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
65
+ output_attentions = (
66
+ output_attentions if output_attentions is not None else self.generation_config.output_attentions
67
+ )
68
+ output_hidden_states = (
69
+ output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states
70
+ )
71
+ return_dict_in_generate = (
72
+ return_dict_in_generate
73
+ if return_dict_in_generate is not None
74
+ else self.generation_config.return_dict_in_generate
75
+ )
76
+
77
+ # init attention / hidden states / scores tuples
78
+ raw_logits = () if (return_dict_in_generate and output_logits) else None
79
+ scores = () if (return_dict_in_generate and output_scores) else None
80
+ decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
81
+ cross_attentions = () if (return_dict_in_generate and output_attentions) else None
82
+ decoder_hidden_states = () if (
83
+ return_dict_in_generate and output_hidden_states) else None
84
+
85
+ # if model is an encoder-decoder, retrieve encoder attention weights and hidden states
86
+ if return_dict_in_generate and self.config.is_encoder_decoder:
87
+ encoder_attentions = model_kwargs["encoder_outputs"].get(
88
+ "attentions") if output_attentions else None
89
+ encoder_hidden_states = (
90
+ model_kwargs["encoder_outputs"].get(
91
+ "hidden_states") if output_hidden_states else None
92
+ )
93
+
94
+ # keep track of which sequences are already finished
95
+ if "inputs_embeds" in model_kwargs:
96
+ cur_len = model_kwargs["inputs_embeds"].shape[1]
97
+ this_peer_finished = False
98
+ unfinished_sequences = torch.ones(
99
+ batch_size, dtype=torch.long, device=input_ids.device)
100
+ model_kwargs["cache_position"] = torch.arange(
101
+ cur_len, device=input_ids.device)
102
+ true_seq_len = input_ids.shape[1] // self.config.input_token_len
103
+ model_kwargs["attention_mask"] = model_kwargs["attention_mask"][:, -true_seq_len:]
104
+
105
+ max_length = stopping_criteria.max_length
106
+ while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
107
+ # prepare model inputs
108
+ model_inputs = self.prepare_inputs_for_generation(
109
+ input_ids, **model_kwargs)
110
+
111
+ input_length = input_ids.shape[1]
112
+
113
+ # forward pass to get next token
114
+ outputs = self(
115
+ **model_inputs,
116
+ return_dict=True,
117
+ output_attentions=output_attentions,
118
+ output_hidden_states=output_hidden_states,
119
+ max_output_length=max_length - input_length,
120
+ )
121
+
122
+ if synced_gpus and this_peer_finished:
123
+ continue # don't waste resources running the code we don't need
124
+
125
+ next_token_logits = outputs.logits[:, -1, :]
126
+
127
+ # pre-process distribution
128
+ next_tokens_scores = logits_processor(input_ids, next_token_logits)
129
+
130
+ # Store scores, attentions and hidden_states when required
131
+ if return_dict_in_generate:
132
+ if output_scores:
133
+ scores += (next_tokens_scores,)
134
+ if output_logits:
135
+ raw_logits += (next_token_logits,)
136
+ if output_attentions:
137
+ decoder_attentions += (
138
+ (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (
139
+ outputs.attentions,)
140
+ )
141
+ if self.config.is_encoder_decoder:
142
+ cross_attentions += (outputs.cross_attentions,)
143
+
144
+ if output_hidden_states:
145
+ decoder_hidden_states += (
146
+ (outputs.decoder_hidden_states,)
147
+ if self.config.is_encoder_decoder
148
+ else (outputs.hidden_states,)
149
+ )
150
+
151
+ # argmax
152
+ # next_tokens = torch.argmax(next_tokens_scores, dim=-1)
153
+ next_tokens = next_tokens_scores
154
+
155
+ # finished sentences should have their next token be a padding token
156
+ if eos_token_id is not None:
157
+ if pad_token_id is None:
158
+ raise ValueError(
159
+ "If `eos_token_id` is defined, make sure that `pad_token_id` is defined.")
160
+ next_tokens = next_tokens * unfinished_sequences + \
161
+ pad_token_id * (1 - unfinished_sequences)
162
+
163
+ # update generated ids, model inputs, and length for next step
164
+ horizon_length = next_tokens.shape[1] // self.config.input_token_len
165
+
166
+ input_ids = torch.cat([input_ids, next_tokens], dim=-1)
167
+ if streamer is not None:
168
+ streamer.put(next_tokens.cpu())
169
+ model_kwargs = self._update_model_kwargs_for_generation(
170
+ outputs,
171
+ model_kwargs,
172
+ horizon_length=horizon_length,
173
+ is_encoder_decoder=self.config.is_encoder_decoder,
174
+ )
175
+ unfinished_sequences = unfinished_sequences & ~stopping_criteria(
176
+ input_ids, scores)
177
+ this_peer_finished = unfinished_sequences.max() == 0
178
+
179
+ if input_ids.shape[1] > max_length:
180
+ input_ids = input_ids[:, :max_length]
181
+
182
+ if streamer is not None:
183
+ streamer.end()
184
+
185
+ if return_dict_in_generate:
186
+ if self.config.is_encoder_decoder:
187
+ return GenerateEncoderDecoderOutput(
188
+ sequences=input_ids,
189
+ scores=scores,
190
+ logits=raw_logits,
191
+ encoder_attentions=encoder_attentions,
192
+ encoder_hidden_states=encoder_hidden_states,
193
+ decoder_attentions=decoder_attentions,
194
+ cross_attentions=cross_attentions,
195
+ decoder_hidden_states=decoder_hidden_states,
196
+ past_key_values=model_kwargs.get("past_key_values"),
197
+ )
198
+ else:
199
+ return GenerateDecoderOnlyOutput(
200
+ sequences=input_ids,
201
+ scores=scores,
202
+ logits=raw_logits,
203
+ attentions=decoder_attentions,
204
+ hidden_states=decoder_hidden_states,
205
+ past_key_values=model_kwargs.get("past_key_values"),
206
+ )
207
+ else:
208
+ return input_ids
209
+
210
+ def _update_model_kwargs_for_generation(
211
+ self,
212
+ outputs: ModelOutput,
213
+ model_kwargs: Dict[str, Any],
214
+ horizon_length: int = 1,
215
+ is_encoder_decoder: bool = False,
216
+ standardize_cache_format: bool = False,
217
+ ) -> Dict[str, Any]:
218
+ # update past_key_values
219
+ model_kwargs["past_key_values"] = self._extract_past_from_model_output(
220
+ outputs, standardize_cache_format=standardize_cache_format
221
+ )
222
+ if getattr(outputs, "state", None) is not None:
223
+ model_kwargs["state"] = outputs.state
224
+
225
+ # update token_type_ids with last value
226
+ if "token_type_ids" in model_kwargs:
227
+ token_type_ids = model_kwargs["token_type_ids"]
228
+ model_kwargs["token_type_ids"] = torch.cat(
229
+ [token_type_ids, token_type_ids[:, -1].unsqueeze(-1)], dim=-1)
230
+
231
+ if not is_encoder_decoder:
232
+ # update attention mask
233
+ if "attention_mask" in model_kwargs:
234
+ attention_mask = model_kwargs["attention_mask"]
235
+ model_kwargs["attention_mask"] = torch.cat(
236
+ [attention_mask, attention_mask.new_ones((attention_mask.shape[0], horizon_length))], dim=-1
237
+ )
238
+ else:
239
+ # update decoder attention mask
240
+ if "decoder_attention_mask" in model_kwargs:
241
+ decoder_attention_mask = model_kwargs["decoder_attention_mask"]
242
+ model_kwargs["decoder_attention_mask"] = torch.cat(
243
+ [decoder_attention_mask, decoder_attention_mask.new_ones(
244
+ (decoder_attention_mask.shape[0], horizon_length))],
245
+ dim=-1,
246
+ )
247
+
248
+ if "cache_position" in model_kwargs and model_kwargs["cache_position"] is not None:
249
+ model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + horizon_length
250
+
251
+ return model_kwargs