Fix BF16 training
#19
by
alexanderchemeris
- opened
- modeling_qwen2.py +1 -1
modeling_qwen2.py
CHANGED
@@ -131,7 +131,7 @@ class TimeSeriesEmbedding(nn.Module):
|
|
131 |
x = x.reshape(batch_size, -1, self.num_features)
|
132 |
|
133 |
mask = x[:, :, -1].long()
|
134 |
-
valid_lengths = mask.sum(dim=1)
|
135 |
|
136 |
patch_cnt = (valid_lengths + self.patch_size - 1) // self.patch_size # 向上取整
|
137 |
|
|
|
131 |
x = x.reshape(batch_size, -1, self.num_features)
|
132 |
|
133 |
mask = x[:, :, -1].long()
|
134 |
+
valid_lengths = (mask > 0.5).long().sum(dim=1) # Shape: (batch_size)
|
135 |
|
136 |
patch_cnt = (valid_lengths + self.patch_size - 1) // self.patch_size # 向上取整
|
137 |
|