Update modeling_timer.py
Browse files- modeling_timer.py +2 -2
modeling_timer.py
CHANGED
@@ -472,9 +472,9 @@ class TimerForPrediction(TimerPreTrainedModel, TSGenerationMixin):
|
|
472 |
else:
|
473 |
output_token_len = h
|
474 |
lm_head = self.lm_heads[self.output_token_len_map[output_token_len]]
|
475 |
-
predictions = lm_head(hidden_states)
|
476 |
if output_token_len > max_output_length:
|
477 |
-
predictions = predictions[:,
|
478 |
if revin:
|
479 |
predictions = predictions * std + mean
|
480 |
if not return_dict:
|
|
|
472 |
else:
|
473 |
output_token_len = h
|
474 |
lm_head = self.lm_heads[self.output_token_len_map[output_token_len]]
|
475 |
+
predictions = lm_head(hidden_states)[:, -1, :]
|
476 |
if output_token_len > max_output_length:
|
477 |
+
predictions = predictions[:, :max_output_length]
|
478 |
if revin:
|
479 |
predictions = predictions * std + mean
|
480 |
if not return_dict:
|