hvlgo commited on
Commit
d809d25
1 Parent(s): 51e990c

Update modeling_timer.py

Browse files
Files changed (1) hide show
  1. modeling_timer.py +2 -2
modeling_timer.py CHANGED
@@ -472,9 +472,9 @@ class TimerForPrediction(TimerPreTrainedModel, TSGenerationMixin):
472
  else:
473
  output_token_len = h
474
  lm_head = self.lm_heads[self.output_token_len_map[output_token_len]]
475
- predictions = lm_head(hidden_states)
476
  if output_token_len > max_output_length:
477
- predictions = predictions[:, :, :max_output_length]
478
  if revin:
479
  predictions = predictions * std + mean
480
  if not return_dict:
 
472
  else:
473
  output_token_len = h
474
  lm_head = self.lm_heads[self.output_token_len_map[output_token_len]]
475
+ predictions = lm_head(hidden_states)[:, -1, :]
476
  if output_token_len > max_output_length:
477
+ predictions = predictions[:, :max_output_length]
478
  if revin:
479
  predictions = predictions * std + mean
480
  if not return_dict: