dongyh commited on
Commit
55c82d2
·
verified ·
1 Parent(s): 3dedebd

Upload 15 files

Browse files
Files changed (14) hide show
  1. beam_search.py +1078 -0
  2. checkpoint.py +2023 -0
  3. config.json +2 -13
  4. config.py +1371 -0
  5. exceptions.py +50 -0
  6. initialization.py +22 -0
  7. model.py +1959 -0
  8. modeling_fan.py +271 -0
  9. optim.py +1040 -0
  10. safetensors_util.py +81 -0
  11. torch_util.py +158 -0
  12. train.py +1384 -0
  13. util.py +929 -0
  14. version.py +11 -0
beam_search.py ADDED
@@ -0,0 +1,1078 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This is a self-contained and flexible beam search implementation adapted from
3
+ AllenNLP's beam search: https://github.com/allenai/allennlp/blob/main/allennlp/nn/beam_search.py
4
+ """
5
+
6
+ import copy
7
+ import warnings
8
+ from abc import abstractmethod
9
+ from inspect import signature
10
+ from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar, cast
11
+
12
+ import torch
13
+
14
+ __all__ = [
15
+ "Sampler",
16
+ "DeterministicSampler",
17
+ "MultinomialSampler",
18
+ "TopKSampler",
19
+ "TopPSampler",
20
+ "GumbelSampler",
21
+ "FinalSequenceScorer",
22
+ "SequenceLogProbabilityScorer",
23
+ "LengthNormalizedSequenceLogProbabilityScorer",
24
+ "Constraint",
25
+ "RepeatedNGramBlockingConstraint",
26
+ "BeamSearch",
27
+ ]
28
+
29
+ StateType = Dict[str, torch.Tensor]
30
+ StepFunctionTypeWithTimestep = Callable[[torch.Tensor, StateType, int], Tuple[torch.Tensor, StateType]]
31
+ StepFunctionTypeNoTimestep = Callable[[torch.Tensor, StateType], Tuple[torch.Tensor, StateType]]
32
+
33
+ StepFunctionType = TypeVar("StepFunctionType", StepFunctionTypeWithTimestep, StepFunctionTypeNoTimestep)
34
+ """
35
+ The type of step function that can be passed to [`BeamSearch.search`](#search).
36
+
37
+ This can either be [`StepFunctionTypeWithTimestep`](#stepfunctiontypewithtimestep)
38
+ or [`StepFunctionTypeNoTimestep`](#stepfunctiontypenotimestep).
39
+ """
40
+
41
+ ConstraintStateType = List[List[Dict[str, Any]]]
42
+
43
+
44
+ class Sampler:
45
+ """
46
+ An abstract class that can be used to sample candidates (either nodes or beams)
47
+ within `BeamSearch`.
48
+
49
+ A `Sampler` just has three methods, `init_state()`, `sample_nodes()` and `sample_beams()`.
50
+
51
+ `init_state()` takes three arguments:
52
+
53
+ - a tensor of starting log probs with shape `(batch_size,, num_classes)`,
54
+ - the batch size, an int,
55
+ - and the number of classes, also an int.
56
+
57
+ It returns a state dictionary with any state tensors needed for subsequent
58
+ calls to `sample_nodes()` and `sample_beams()`.
59
+
60
+ By default this method just returns an empty dictionary.
61
+
62
+ Both `sample_nodes()` and `sample_beams()` should take three arguments:
63
+
64
+ - tensor of normalized log probabilities with shape `(batch_size, num_examples)`,
65
+ - an integer representing the number of samples to take for each example in the batch,
66
+ - and a state dictionary which could contain any tensors needed for the `Sampler` to keep
67
+ track of state.
68
+
69
+ For `sample_nodes()`, `num_examples = num_classes`, but for `sample_beams`,
70
+ `num_examples = beam_size * per_node_beam_size`.
71
+
72
+ The return value should be a tuple containing:
73
+
74
+ - a tensor of log probabilities of the sampled examples with shape `(batch_size, num_samples)`,
75
+ - a tensor of indices of the sampled examples with shape `(batch_size, num_samples)`,
76
+ - and the updated state dictionary.
77
+
78
+ A default implementation of `sample_beams` is provided, which just deterministically
79
+ picks the `k` examples with highest log probability.
80
+ """
81
+
82
+ def init_state(
83
+ self, start_class_log_probabilities: torch.Tensor, batch_size: int, num_classes: int
84
+ ) -> StateType:
85
+ del start_class_log_probabilities, batch_size, num_classes
86
+ return {}
87
+
88
+ @abstractmethod
89
+ def sample_nodes(
90
+ self, log_probs: torch.Tensor, per_node_beam_size: int, state: StateType
91
+ ) -> Tuple[torch.Tensor, torch.Tensor, StateType]:
92
+ raise NotImplementedError
93
+
94
+ def sample_beams(
95
+ self, log_probs: torch.Tensor, beam_size: int, state: StateType
96
+ ) -> Tuple[torch.Tensor, torch.Tensor, StateType]:
97
+ del state
98
+ selected_log_probs, selected_indices = torch.topk(log_probs, beam_size, dim=-1)
99
+ return selected_log_probs, selected_indices, {}
100
+
101
+
102
+ class DeterministicSampler(Sampler):
103
+ """
104
+ A `Sampler` that just deterministically returns the `k` nodes or beams with highest
105
+ log probability.
106
+ """
107
+
108
+ def sample_nodes(
109
+ self, log_probs: torch.Tensor, per_node_beam_size: int, state: StateType
110
+ ) -> Tuple[torch.Tensor, torch.Tensor, StateType]:
111
+ del state
112
+ selected_log_probs, selected_indices = torch.topk(log_probs, per_node_beam_size, dim=-1)
113
+ return selected_log_probs, selected_indices, {}
114
+
115
+
116
+ class MultinomialSampler(Sampler):
117
+ """
118
+ A `Sampler` which samples nodes from the given multinomial distribution. Beams are sampled
119
+ in the default, non-deterministic way.
120
+
121
+ :param temperature: A `temperature` below 1.0 produces a sharper probability distribution and a `temperature`
122
+ above 1.0 produces a flatter probability distribution.
123
+ :param with_replacement: Whether to sample with replacement.
124
+
125
+ """
126
+
127
+ def __init__(
128
+ self,
129
+ temperature: float = 1.0,
130
+ with_replacement: bool = False,
131
+ ) -> None:
132
+ self.temperature = temperature
133
+ self.with_replacement = with_replacement
134
+
135
+ def sample_nodes(
136
+ self, log_probs: torch.Tensor, per_node_beam_size: int, state: StateType
137
+ ) -> Tuple[torch.Tensor, torch.Tensor, StateType]:
138
+ if self.temperature != 1.0:
139
+ _probabilities = torch.nn.functional.softmax(log_probs / self.temperature, dim=-1)
140
+ else:
141
+ _probabilities = log_probs.exp()
142
+
143
+ selected_indices = torch.multinomial(_probabilities, per_node_beam_size, replacement=self.with_replacement)
144
+
145
+ return torch.gather(log_probs, 1, selected_indices), selected_indices, state
146
+
147
+
148
+ class TopKSampler(Sampler):
149
+ """
150
+ A `Sampler` which redistributes the probability mass function for nodes among the
151
+ top `k` choices, then samples from that subset after re-normalizing the probabilities.
152
+
153
+ Beams are sampled in the default, deterministic way.
154
+
155
+ :param k: The number of top choices to be selected from.
156
+ :param temperature: A `temperature` below 1.0 produces a sharper probability distribution and a `temperature`
157
+ above 1.0 produces a flatter probability distribution.
158
+ :param with_replacement: If set to `True`, samples will be selected with replacement from the top k choices.
159
+ """
160
+
161
+ def __init__(
162
+ self,
163
+ k: int = 1,
164
+ temperature: float = 1.0,
165
+ with_replacement: bool = False,
166
+ ):
167
+ self.k = k
168
+ self.temperature = temperature or 1.0
169
+ self.with_replacement = with_replacement
170
+
171
+ def sample_nodes(
172
+ self, log_probs: torch.Tensor, per_node_beam_size: int, state: StateType
173
+ ) -> Tuple[torch.Tensor, torch.Tensor, StateType]:
174
+ if not per_node_beam_size <= self.k <= log_probs.size()[1]:
175
+ raise ValueError(
176
+ "k must be a postive integer no less than per_node_beam_size and no greater than vocabulary size"
177
+ )
178
+
179
+ # shape (both): (batch_size, k)
180
+ top_k_log_probs, top_k_indices = log_probs.topk(self.k, dim=-1)
181
+
182
+ # Apply temperature if necessary.
183
+ # shape: (batch_size, k)
184
+ if self.temperature != 1.0:
185
+ top_k_log_probs = top_k_log_probs / self.temperature
186
+
187
+ # Re-normalize the subset.
188
+ # shape: (batch_size, k)
189
+ normalized_top_k_probs = torch.nn.functional.softmax(top_k_log_probs, dim=-1)
190
+
191
+ # Sample from the re-normalized subset.
192
+ # NOTE: These indices are not indices into `log_probs`, they are indices into `top_k_log_probs`.
193
+ # shape: (batch_size, per_node_beam_size)
194
+ sampled_indices = torch.multinomial(
195
+ normalized_top_k_probs, per_node_beam_size, replacement=self.with_replacement
196
+ )
197
+
198
+ # Convert `sampled_indices` back to indices in the original `log_probs` tensor.
199
+ # shape: (batch_size, per_node_beam_size)
200
+ indices = top_k_indices.gather(-1, sampled_indices)
201
+
202
+ return log_probs.gather(1, indices), indices, state
203
+
204
+
205
+ class TopPSampler(Sampler):
206
+ """
207
+ A `Sampler` which redistributes the probability mass function for nodes among
208
+ the top choices with a cumulative probability of at least `p`, then samples from that subset
209
+ after re-normalizing the probabilities.
210
+
211
+ Beams are sampled in the default, deterministic way.
212
+
213
+ :param p:
214
+ The cumulative probability cutoff threshold. A higher value of `p` will result in more possible
215
+ examples to sample from. If `with_replacement` is `False` and the number of possible samples is
216
+ insufficient to sample without replacement from when calling `sample_nodes`, then the top
217
+ `per_node_beam_size` examples will be chosen.
218
+ :param temperature:
219
+ A `temperature` below 1.0 produces a sharper probability distribution and a `temperature`
220
+ above 1.0 produces a flatter probability distribution.
221
+ :param with_replacement:
222
+ If set to `True`, samples will be selected with replacement from the top choices.
223
+
224
+ """
225
+
226
+ def __init__(
227
+ self,
228
+ p: float = 0.9,
229
+ temperature: float = 1.0,
230
+ with_replacement: bool = False,
231
+ ):
232
+ if p < 0.0 or p > 1.0:
233
+ raise ValueError("p must be a positive float no greater than 1.0")
234
+ self.p = p
235
+ self.temperature = temperature or 1.0
236
+ self.with_replacement = with_replacement
237
+
238
+ def sample_nodes(
239
+ self, log_probs: torch.Tensor, per_node_beam_size: int, state: StateType
240
+ ) -> Tuple[torch.Tensor, torch.Tensor, StateType]:
241
+ if not per_node_beam_size <= log_probs.size()[1]:
242
+ raise ValueError("per_node_beam_size cannot be greater than vocabulary size")
243
+
244
+ # First apply temperature coefficient:
245
+ if self.temperature != 1.0:
246
+ _log_probs = torch.nn.functional.log_softmax(log_probs / self.temperature, dim=-1)
247
+ else:
248
+ _log_probs = log_probs
249
+
250
+ # Sort the probabilities in descending order to then find cumulative sum
251
+ log_probs_descending, sorting_indices = torch.sort(_log_probs, descending=True)
252
+
253
+ # shape: (batch_size, num_classes)
254
+ probabilities_descending = log_probs_descending.exp()
255
+ probabilities_summed = torch.cumsum(probabilities_descending, dim=-1)
256
+
257
+ # Create a mask for filtering out probabilities that don't make the top `p`.
258
+ # shape: (batch_size, num_classes)
259
+ exclusion_mask = probabilities_summed >= self.p
260
+
261
+ # We want to include the first index where probabilities_summed >= p, so we shift over one.
262
+ exclusion_mask[..., 1:] = exclusion_mask[..., :-1].clone()
263
+ exclusion_mask[..., 0] = False
264
+
265
+ # Make sure there's at least `per_node_beam_size` options to be selected.
266
+ if not self.with_replacement:
267
+ exclusion_mask[..., :per_node_beam_size] = False
268
+
269
+ log_probs_descending[exclusion_mask] = torch.finfo(log_probs.dtype).min
270
+
271
+ # Now re-normalized the included log probs.
272
+ # shape: (batch_size, num_classes)
273
+ filtered_probabilities = torch.nn.functional.softmax(log_probs_descending, dim=-1)
274
+
275
+ # Sample from the re-normalized subset.
276
+ # NOTE: These indices are not indices into `log_probs`, they are indices into `log_probs_descending`.
277
+ # shape: (batch_size, per_node_beam_size)
278
+ sampled_indices = torch.multinomial(
279
+ filtered_probabilities, per_node_beam_size, replacement=self.with_replacement
280
+ )
281
+
282
+ # Convert `sampled_indices` back to indices in the original `log_probs` tensor.
283
+ # shape: (batch_size, per_node_beam_size)
284
+ selected_indices = sorting_indices.gather(-1, sampled_indices)
285
+
286
+ # Return (selected log probabilities, selected classes)
287
+ # shape: (len(log_probs),1) , (len(log_probs), 1)
288
+ return torch.gather(log_probs, 1, selected_indices), selected_indices, state
289
+
290
+
291
+ class GumbelSampler(Sampler):
292
+ """
293
+ A `Sampler` which uses the Gumbel-Top-K trick to sample without replacement. See
294
+ [*Stochastic Beams and Where to Find Them: The Gumbel-Top-k Trick for Sampling
295
+ Sequences Without Replacement*, W Kool, H Van Hoof and M Welling, 2010]
296
+ (https://api.semanticscholar.org/CorpusID:76662039).
297
+
298
+ :param temperature: A `temperature` below 1.0 produces a sharper probability distribution and a `temperature`
299
+ above 1.0 produces a flatter probability distribution.
300
+ """
301
+
302
+ def __init__(self, temperature: float = 1.0):
303
+ self.temperature = temperature
304
+
305
+ def init_state(
306
+ self, start_class_log_probabilities: torch.Tensor, batch_size: int, num_classes: int
307
+ ) -> StateType:
308
+ # shape: (batch_size, num_classes)
309
+ zeros = start_class_log_probabilities.new_zeros((batch_size, num_classes))
310
+
311
+ # shape: (batch_size, num_classes)
312
+ G_phi_S = self.gumbel_with_max(start_class_log_probabilities, zeros)
313
+
314
+ return {"G_phi_S": G_phi_S}
315
+
316
+ def sample_nodes(
317
+ self,
318
+ log_probs: torch.Tensor,
319
+ per_node_beam_size: int,
320
+ state: StateType,
321
+ ) -> Tuple[torch.Tensor, torch.Tensor, StateType]:
322
+ # First apply temperature coefficient:
323
+ # shape: (batch_size * beam_size, num_classes)
324
+ if self.temperature != 1.0:
325
+ _log_probs = torch.nn.functional.log_softmax(log_probs / self.temperature, dim=-1)
326
+ else:
327
+ _log_probs = log_probs
328
+
329
+ # shape: (group_size,)
330
+ phi_S = state["phi_S"]
331
+
332
+ # shape: (group_size, num_classes)
333
+ phi_S = phi_S.unsqueeze(-1).expand_as(_log_probs)
334
+
335
+ # shape: (group_size, num_classes)
336
+ phi_S_new = phi_S + _log_probs
337
+
338
+ # shape: (group_size, 1)
339
+ G_phi_S = state["G_phi_S"].unsqueeze(-1)
340
+
341
+ # shape: (group_size, num_classes)
342
+ G_phi_S_new = self.gumbel_with_max(phi_S_new, G_phi_S)
343
+
344
+ # Replace NaNs with very negative number.
345
+ # shape: (group_size, num_classes)
346
+ # G_phi_S_new[G_phi_S_new.isnan()] = torch.finfo(G_phi_S_new.dtype).min
347
+
348
+ # shape (both): (group_size, per_node_beam_size)
349
+ top_G_phi_S_new, top_indices = torch.topk(G_phi_S_new, per_node_beam_size, dim=-1)
350
+
351
+ # shape: (group_size, per_node_beam_size)
352
+ top_log_probs = log_probs.gather(1, top_indices)
353
+
354
+ return top_log_probs, top_indices, {"G_phi_S": top_G_phi_S_new}
355
+
356
+ def sample_beams(
357
+ self,
358
+ log_probs: torch.Tensor,
359
+ beam_size: int,
360
+ state: StateType,
361
+ ) -> Tuple[torch.Tensor, torch.Tensor, StateType]:
362
+ """
363
+ Returns the beams with the highest perturbed log probabilities.
364
+ """
365
+ # shape (log_probs): (batch_size, beam_size * per_node_beam_size)
366
+
367
+ batch_size = log_probs.size()[0]
368
+
369
+ # shape: (batch_size * beam_size, per_node_beam_size)
370
+ G_phi_S = state["G_phi_S"]
371
+
372
+ # shape: (batch_size, beam_size * per_node_beam_size)
373
+ G_phi_S = G_phi_S.reshape_as(log_probs)
374
+
375
+ # shape (both): (batch_size, beam_size)
376
+ G_phi_S_new, selected_indices = torch.topk(G_phi_S, beam_size, dim=-1)
377
+
378
+ # shape: (batch_size, beam_size)
379
+ selected_log_probs = log_probs.gather(1, selected_indices)
380
+
381
+ # Now sort the selected beams by their true log prob.
382
+ # shape (all): (batch_size, beam_size)
383
+ selected_log_probs, sort_indices = selected_log_probs.sort(dim=-1, descending=True)
384
+ selected_indices = selected_indices.gather(1, sort_indices)
385
+ G_phi_S_new = G_phi_S_new.gather(1, sort_indices)
386
+
387
+ # shape: (batch_size * beam_size,)
388
+ G_phi_S_new = G_phi_S_new.reshape(batch_size * beam_size)
389
+
390
+ # shape: (batch_size * beam_size,)
391
+ phi_S = selected_log_probs.reshape(batch_size * beam_size)
392
+
393
+ return selected_log_probs, selected_indices, {"G_phi_S": G_phi_S_new, "phi_S": phi_S}
394
+
395
+ def gumbel(self, phi) -> torch.Tensor:
396
+ """
397
+ Sample `Gumbel(phi)`.
398
+
399
+ `phi` should have shape `(batch_size, num_classes)`.
400
+ """
401
+ return -torch.log(-torch.log(torch.rand_like(phi))) + phi
402
+
403
+ def gumbel_with_max(self, phi, T) -> torch.Tensor:
404
+ """
405
+ Sample `Gumbel(phi)` conditioned on the maximum value being equal to `T`.
406
+
407
+ `phi` should have shape `(batch_size, num_classes)` and `T` should have
408
+ shape `(batch_size, 1)`.
409
+ """
410
+ # Shape: (batch_size, num_classes)
411
+ G_phi = self.gumbel(phi)
412
+
413
+ # Now we find the maximum from these samples.
414
+ # Shape: (batch_size, )
415
+ Z, _ = G_phi.max(dim=-1)
416
+
417
+ # Shape: (batch_size, num_classes)
418
+ v = T - G_phi + torch.log1p(-torch.exp(G_phi - Z.unsqueeze(-1)))
419
+
420
+ # Shape: (batch_size, num_classes)
421
+ return T - torch.nn.functional.relu(v) - torch.log1p(torch.exp(-v.abs()))
422
+
423
+
424
+ class FinalSequenceScorer:
425
+ """
426
+ An abstract class that can be used to score the final generated sequences found
427
+ by beam search. Given the predicted sequences and the corresponding log probabilities of
428
+ those sequences, the class calculates and returns the final score of the sequences.
429
+
430
+ The default implementation scores the sequences using the sum of the log probabilities of
431
+ the sequence, which is passed as input.
432
+ """
433
+
434
+ @abstractmethod
435
+ def score(self, predictions: torch.Tensor, log_probabilities: torch.Tensor, end_index: int) -> torch.Tensor:
436
+ """
437
+ Score the final predictions found by beam search.
438
+ Returns a tensor of the final sequence scores of shape `(batch_size, beam_size)`.
439
+
440
+ :param predictions: A tensor containing the initial predictions with shape `(batch_size, beam_size, max_steps)`.
441
+ :param log_probabilities: A tensor containing the log probabilities of the sequence, defined as the sum
442
+ of the log probabilities per token, with shape `(batch_size, beam_size)`.
443
+ :param end_index: The index of the end symbol.
444
+
445
+ """
446
+ raise NotImplementedError
447
+
448
+
449
+ class SequenceLogProbabilityScorer(FinalSequenceScorer):
450
+ """
451
+ A :class:`FinalSequenceScorer` which scores the sequences by the sum of the log probabilities
452
+ across the sequence's tokens.
453
+ """
454
+
455
+ def score(self, predictions: torch.Tensor, log_probabilities: torch.Tensor, end_index: int) -> torch.Tensor:
456
+ del predictions, end_index
457
+ # The sum of the sequence log probabilities is the input parameter, so just
458
+ # return it.
459
+ return log_probabilities
460
+
461
+
462
+ class LengthNormalizedSequenceLogProbabilityScorer(FinalSequenceScorer):
463
+ """
464
+ A :class:`FinalSequenceScorer` which scores the sequences by the average log probability of the
465
+ tokens in the sequence. It optionally includes a length penalty which promotes
466
+ or demotes sequences based on their lengths. The final score for a sequence will
467
+ be `(sequence_log_probability) / (sequence_length ** length_penalty)`. The sequence length
468
+ here includes the end token.
469
+
470
+ :param length_penalty: The length penalty to use. A value of 1.0 means no length penalty is used.
471
+ A value > 1.0 favors longer sequences, and < 1.0 favors shorter sequences.
472
+ """
473
+
474
+ def __init__(self, length_penalty: float = 1.0):
475
+ super().__init__()
476
+ self.length_penalty = length_penalty
477
+
478
+ def score(self, predictions: torch.Tensor, log_probabilities: torch.Tensor, end_index: int) -> torch.Tensor:
479
+ # shape: (batch_size, beam_size)
480
+ lengths = (predictions != end_index).long().sum(dim=2)
481
+
482
+ # If the sequence ended during beam search, the `log_probabilities` will include
483
+ # the transition to the end token. Therefore, in such situations, `lengths` is
484
+ # actually off by 1. This corrects for that.
485
+ # shape: (batch_size, beam_size)
486
+ is_end_token = predictions[:, :, -1] == end_index
487
+ lengths += is_end_token.long()
488
+
489
+ # shape: (batch_size, beam_size)
490
+ average_log_probs = log_probabilities / (lengths**self.length_penalty)
491
+ return average_log_probs
492
+
493
+
494
+ class Constraint:
495
+ """
496
+ An abstract class that can be used to enforce constraints on the output predictions
497
+ by manipulating the class log probabilities during beam search.
498
+
499
+ A `Constraint` just has three methods that need to be implemented by subclasses:
500
+ `init_state()`, `apply()` and `_update_state()`.
501
+
502
+ `init_state()` takes one argument:
503
+
504
+ - the batch size, an int
505
+
506
+ It returns a constraint state, which is a nested list of dictionaries, with any state needed for subsequent
507
+ calls to `apply()` and `update_state()`. The length of the outer list should be equal to `batch_size`.
508
+ Each inner list should be of length 1.
509
+
510
+ `apply()` takes two arguments:
511
+
512
+ - the constraint state, which is a nested list of dictionaries. The length of the outer list is `batch_size`
513
+ and the length of each inner list is `beam_size` except on the first time `apply()` is called when it is 1.
514
+ - `class_log_probabilities`, a tensor of shape `(batch_size, beam_size, num_classes)` that contains the
515
+ log probabilities for the classes during search. The first time `apply()` is called, `beam_size = 1`.
516
+
517
+ The `apply()` method should return new `class_log_probabilities` that enforce the constraint
518
+ for this step of beam search. For instance, it may prevent a specific class from being selected by setting
519
+ the corresponding log probability to a negligible value such as `float("-inf")` or
520
+ `torch.finfo(class_log_probabilities.dtype).min`.
521
+
522
+ `_update_state()` takes two arguments:
523
+
524
+ - the copied parent constraint state, which is a nested list of dictionaries. `state[i][j]` contains the
525
+ copied state for the parent of `last_prediction[i, j]`. It is unique to that batch and beam, so it can be
526
+ directly edited in-place without affecting the others.
527
+ - last_prediction, a tensor of shape `(batch_size, beam_size)` containing the predictions from the last
528
+ step of beam search.
529
+
530
+ The `_update_state()` function should return a new constraint state, a nested list of dictionaries of
531
+ length `batch_size` and inner list of length `beam_size`, one for each of the predictions in `last_prediction`.
532
+
533
+ """
534
+
535
+ @abstractmethod
536
+ def init_state(
537
+ self,
538
+ batch_size: int,
539
+ ) -> ConstraintStateType:
540
+ raise NotImplementedError
541
+
542
+ @abstractmethod
543
+ def apply(
544
+ self,
545
+ state: ConstraintStateType,
546
+ class_log_probabilities: torch.Tensor,
547
+ ) -> torch.Tensor:
548
+ raise NotImplementedError
549
+
550
+ @staticmethod
551
+ def _copy_state(
552
+ state: ConstraintStateType,
553
+ batch_size: int,
554
+ beam_size: int,
555
+ last_backpointer: Optional[torch.Tensor] = None,
556
+ ) -> ConstraintStateType:
557
+ """
558
+ Copies the `state` . This method copies the data in `state` using `copy.deepcopy()`. If this
559
+ is not appropriate for your constraint, you will need to implement the copying yourself.
560
+ """
561
+ new_state = []
562
+ for i in range(batch_size):
563
+ batch_state = []
564
+ for j in range(beam_size):
565
+ if last_backpointer is None:
566
+ # This is the first prediction, so the backpointer is 0
567
+ backpointer = 0
568
+ else:
569
+ backpointer = last_backpointer[i, j].item()
570
+ batch_state.append(copy.deepcopy(state[i][backpointer])) # type: ignore
571
+ new_state.append(batch_state)
572
+ return new_state
573
+
574
+ def update_state(
575
+ self,
576
+ state: ConstraintStateType,
577
+ last_prediction: torch.Tensor,
578
+ last_backpointer: Optional[torch.Tensor] = None,
579
+ ) -> ConstraintStateType:
580
+ batch_size, beam_size = last_prediction.size()
581
+ new_state = self._copy_state(state, batch_size, beam_size, last_backpointer)
582
+ return self._update_state(new_state, last_prediction)
583
+
584
+ @abstractmethod
585
+ def _update_state(
586
+ self,
587
+ state: ConstraintStateType,
588
+ last_prediction: torch.Tensor,
589
+ ) -> ConstraintStateType:
590
+ raise NotImplementedError
591
+
592
+
593
+ class RepeatedNGramBlockingConstraint(Constraint):
594
+ def __init__(self, ngram_size: int, **kwargs) -> None:
595
+ super().__init__(**kwargs)
596
+ self.ngram_size = ngram_size
597
+
598
+ def init_state(
599
+ self,
600
+ batch_size: int,
601
+ ) -> ConstraintStateType:
602
+ return [[{"seen_ngrams": {}, "current_prefix": []}] for _ in range(batch_size)]
603
+
604
+ def apply(
605
+ self,
606
+ state: ConstraintStateType,
607
+ class_log_probabilities: torch.Tensor,
608
+ ) -> torch.Tensor:
609
+ for i, batch in enumerate(state):
610
+ for j, beam in enumerate(batch):
611
+ current_prefix = tuple(beam["current_prefix"])
612
+ seen_ngrams = beam["seen_ngrams"]
613
+ try:
614
+ disallowed_indices = seen_ngrams[current_prefix]
615
+ class_log_probabilities[i, j, disallowed_indices] = torch.finfo(
616
+ class_log_probabilities.dtype
617
+ ).min
618
+ except KeyError:
619
+ # We have not seen this prefix before, so there is no index
620
+ # that needs to be blocked
621
+ pass
622
+ return class_log_probabilities
623
+
624
+ def _update_state(
625
+ self,
626
+ state: ConstraintStateType,
627
+ last_prediction: torch.Tensor,
628
+ ) -> ConstraintStateType:
629
+ for i, batch in enumerate(state):
630
+ for j, beam in enumerate(batch):
631
+ prediction = last_prediction[i, j].item()
632
+ prefix = beam["current_prefix"]
633
+ seen_ngrams = beam["seen_ngrams"]
634
+
635
+ if len(prefix) == self.ngram_size - 1:
636
+ # This is a new ngram that we have to remember
637
+ if tuple(prefix) not in seen_ngrams:
638
+ seen_ngrams[tuple(prefix)] = []
639
+ seen_ngrams[tuple(prefix)].append(prediction)
640
+
641
+ # Create the new prefix, removing the oldest index if the prefix
642
+ # is too long
643
+ prefix.append(prediction)
644
+ if len(prefix) == self.ngram_size:
645
+ prefix.pop(0)
646
+ return state
647
+
648
+
649
+ class BeamSearch:
650
+ """
651
+ Implements the beam search algorithm for decoding the most likely sequences.
652
+
653
+ :param end_index: The index of the "stop" or "end" token in the vocabulary. Usually the EOS token ID.
654
+
655
+ :param max_steps: The maximum number of decoding steps to take, i.e. the maximum length
656
+ of the predicted sequences.
657
+
658
+ :param beam_size: The width of the beam used.
659
+
660
+ :param per_node_beam_size: The maximum number of candidates to consider per node, at each step in the search.
661
+ If not given, this just defaults to `beam_size`. Setting this parameter
662
+ to a number smaller than `beam_size` may give better results, as it can introduce
663
+ more diversity into the search. See
664
+ [*Beam Search Strategies for Neural Machine Translation*, Freitag and Al-Onaizan, 2017]
665
+ (https://api.semanticscholar.org/CorpusID:2229477).
666
+
667
+ :param sampler: An optional `Sampler` which is used to pick next candidate nodes and beams.
668
+ If not specified, `DeterministicSampler` will be used, which just takes the
669
+ `per_node_beam_size` most likely nodes and the `beam_size` most likely beams.
670
+
671
+ Using the [`GumbelSampler`](#gumbelsampler), on the other hand, will give you
672
+ [Stochastic Beam Search](https://api.semanticscholar.org/CorpusID:76662039).
673
+
674
+ :param min_steps: The minimum number of decoding steps to take, i.e. the minimum length of
675
+ the predicted sequences. This does not include the start or end tokens. If `None`,
676
+ no minimum is enforced.
677
+
678
+ :param final_sequence_scorer: An optional `FinalSequenceScorer` which is used to score the final generated sequences.
679
+ The output from this module is what is returned by the `search` method. If not
680
+ specified, `SequenceLogProbabilityScorer` will be used, which scores the sequences
681
+ by the sum of the token log probabilities.
682
+
683
+ :param constraints: An optional list of `Constraint`s which should be applied during beam search. If not
684
+ provided, no constraints will be enforced.
685
+
686
+ """
687
+
688
+ def __init__(
689
+ self,
690
+ end_index: int,
691
+ *,
692
+ max_steps: int = 50,
693
+ beam_size: int = 10,
694
+ per_node_beam_size: Optional[int] = None,
695
+ sampler: Optional[Sampler] = None,
696
+ min_steps: Optional[int] = None,
697
+ final_sequence_scorer: Optional[FinalSequenceScorer] = None,
698
+ constraints: Optional[List[Constraint]] = None,
699
+ ) -> None:
700
+ if not max_steps > 0:
701
+ raise ValueError("max_steps must be positive")
702
+ if not beam_size > 0:
703
+ raise ValueError("beam_size must be positive")
704
+ if per_node_beam_size is not None and not per_node_beam_size > 0:
705
+ raise ValueError("per_node_beam_size must be positive")
706
+ if min_steps is not None:
707
+ if not min_steps >= 0:
708
+ raise ValueError("min_steps must be non-negative")
709
+ if not min_steps <= max_steps:
710
+ raise ValueError("min_steps must be less than or equal to max_steps")
711
+
712
+ self._end_index = end_index
713
+ self.max_steps = max_steps
714
+ self.beam_size = beam_size
715
+ self.per_node_beam_size = per_node_beam_size or beam_size
716
+ self.sampler = sampler or DeterministicSampler()
717
+ self.min_steps = min_steps or 0
718
+ self.final_sequence_scorer = final_sequence_scorer or SequenceLogProbabilityScorer()
719
+ self.constraints = constraints or []
720
+
721
+ @staticmethod
722
+ def _reconstruct_sequences(predictions, backpointers):
723
+ # Reconstruct the sequences.
724
+ # shape: [(batch_size, beam_size, 1)]
725
+ reconstructed_predictions = [predictions[-1].unsqueeze(2)]
726
+
727
+ if not backpointers:
728
+ return reconstructed_predictions
729
+
730
+ # shape: (batch_size, beam_size)
731
+ cur_backpointers = backpointers[-1]
732
+
733
+ for timestep in range(len(predictions) - 2, 0, -1):
734
+ # shape: (batch_size, beam_size, 1)
735
+ cur_preds = predictions[timestep].gather(1, cur_backpointers).unsqueeze(2)
736
+
737
+ reconstructed_predictions.append(cur_preds)
738
+
739
+ # shape: (batch_size, beam_size)
740
+ cur_backpointers = backpointers[timestep - 1].gather(1, cur_backpointers)
741
+
742
+ # shape: (batch_size, beam_size, 1)
743
+ final_preds = predictions[0].gather(1, cur_backpointers).unsqueeze(2)
744
+
745
+ reconstructed_predictions.append(final_preds)
746
+
747
+ return reconstructed_predictions
748
+
749
+ def search(
750
+ self,
751
+ start_predictions: torch.Tensor,
752
+ start_state: StateType,
753
+ step: StepFunctionType,
754
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
755
+ """
756
+ Given a starting state and a step function, apply beam search to find the
757
+ most likely target sequences.
758
+
759
+ Returns a tuple of `(predictions, final_scores)`, where `predictions`
760
+ has shape `(batch_size, beam_size, max_steps)` and `final_scores`
761
+ has shape `(batch_size, beam_size)`.
762
+
763
+ .. note::
764
+ If your step function returns `-inf` for some log probabilities
765
+ (like if you're using a masked log-softmax) then some of the "best"
766
+ sequences returned may also have `-inf` log probability. Specifically
767
+ this happens when the beam size is smaller than the number of actions
768
+ with finite log probability (non-zero probability) returned by the step function.
769
+ Therefore if you're using a mask you may want to check the results from `search`
770
+ and potentially discard sequences with non-finite log probability.
771
+
772
+ :param start_predictions: A tensor containing the initial predictions with shape `(batch_size,)`.
773
+ Usually the initial predictions are just the index of the "start" token
774
+ in the target vocabulary.
775
+
776
+ :param start_state: The initial state passed to the `step` function. Each value of the state dict
777
+ should be a tensor of shape `(batch_size, *)`, where `*` means any other
778
+ number of dimensions.
779
+
780
+ :param step: A function that is responsible for computing the next most likely tokens,
781
+ given the current state and the predictions from the last time step.
782
+ The function should accept two or three arguments:
783
+
784
+ - a tensor of shape `(group_size,)` or representing the index of the predicted
785
+ tokens from the last time step,
786
+ - the current state, a `StateType`, and
787
+ - optionally, the timestep, an `int`.
788
+
789
+ The `group_size` will be `batch_size * beam_size`, except in the initial
790
+ step, for which it will just be `batch_size`.
791
+
792
+ The function is expected to return a tuple, where the first element
793
+ is a tensor of shape `(group_size, vocab_size)` containing
794
+ the log probabilities of the tokens for the next step, and the second
795
+ element is the updated state. The tensor in the state should have shape
796
+ `(group_size, *)`, where `*` means any other number of dimensions.
797
+
798
+ """
799
+ step_signature = signature(step)
800
+ if len(step_signature.parameters) < 3:
801
+ # If the step function we're given does not take the time step argument, wrap it
802
+ # in one that does.
803
+ old_step = cast(StepFunctionTypeNoTimestep, step)
804
+
805
+ def new_step(last_predictions: torch.Tensor, state: Dict[str, torch.Tensor], time_step: int):
806
+ del time_step
807
+ return old_step(last_predictions, state)
808
+
809
+ return self._search(start_predictions, start_state, new_step)
810
+ else:
811
+ return self._search(start_predictions, start_state, cast(StepFunctionTypeWithTimestep, step))
812
+
813
+ def _search(
814
+ self,
815
+ start_predictions: torch.Tensor,
816
+ start_state: StateType,
817
+ step: StepFunctionTypeWithTimestep,
818
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
819
+ batch_size = start_predictions.size()[0]
820
+
821
+ # List of (batch_size, beam_size) tensors. One for each time step. Does not
822
+ # include the start symbols, which are implicit.
823
+ predictions: List[torch.Tensor] = []
824
+
825
+ # List of (batch_size, beam_size) tensors. One for each time step. None for
826
+ # the first. Stores the index n for the parent prediction, i.e.
827
+ # predictions[t-1][i][n], that it came from.
828
+ backpointers: List[torch.Tensor] = []
829
+
830
+ constraint_states = [constraint.init_state(batch_size) for constraint in self.constraints]
831
+
832
+ # Calculate the first timestep. This is done outside the main loop
833
+ # because we are going from a single decoder input (the output from the
834
+ # encoder) to the top `beam_size` decoder outputs. On the other hand,
835
+ # within the main loop we are going from the `beam_size` elements of the
836
+ # beam to `beam_size`^2 candidates from which we will select the top
837
+ # `beam_size` elements for the next iteration.
838
+ # shape: (batch_size, num_classes)
839
+ start_class_log_probabilities, state = step(start_predictions, start_state, 0)
840
+
841
+ num_classes = start_class_log_probabilities.size()[1]
842
+
843
+ # Make sure `per_node_beam_size` is not larger than `num_classes`.
844
+ if self.per_node_beam_size > num_classes:
845
+ raise ValueError(
846
+ f"Vocab size ({num_classes:d}) too small "
847
+ f"relative to per_node_beam_size ({self.per_node_beam_size:d}).\n"
848
+ f"Please decrease beam_size or per_node_beam_size."
849
+ )
850
+
851
+ sampler_state = self.sampler.init_state(start_class_log_probabilities, batch_size, num_classes)
852
+
853
+ # Apply all constraints.
854
+ if self.constraints:
855
+ # shape: (batch_size, 1, num_classes)
856
+ expanded_start_class_log_probabilities = start_class_log_probabilities.unsqueeze(1)
857
+ for constraint, constraint_state in zip(self.constraints, constraint_states):
858
+ expanded_start_class_log_probabilities = constraint.apply(
859
+ constraint_state, expanded_start_class_log_probabilities
860
+ )
861
+ start_class_log_probabilities = expanded_start_class_log_probabilities.squeeze(1)
862
+
863
+ # Prevent selecting the end symbol if there is any min_steps constraint
864
+ if self.min_steps >= 1:
865
+ start_class_log_probabilities[:, self._end_index] = torch.finfo(
866
+ start_class_log_probabilities.dtype
867
+ ).min
868
+
869
+ # Get the initial predicted classed and their log probabilities.
870
+ # shape: (batch_size, beam_size), (batch_size, beam_size)
871
+ (
872
+ start_top_log_probabilities,
873
+ start_predicted_classes,
874
+ sampler_state,
875
+ ) = self.sampler.sample_beams(start_class_log_probabilities, self.beam_size, sampler_state)
876
+
877
+ if self.beam_size == 1 and (start_predicted_classes == self._end_index).all():
878
+ warnings.warn(
879
+ "Empty sequences predicted. You may want to increase the beam size or ensure "
880
+ "your step function is working properly.",
881
+ RuntimeWarning,
882
+ )
883
+ return start_predicted_classes.unsqueeze(-1), start_top_log_probabilities
884
+
885
+ # The log probabilities for the last time step.
886
+ # shape: (batch_size, beam_size)
887
+ last_log_probabilities = start_top_log_probabilities
888
+
889
+ # shape: [(batch_size, beam_size)]
890
+ predictions.append(start_predicted_classes)
891
+
892
+ # Log probability tensor that mandates that the end token is selected.
893
+ # shape: (batch_size * beam_size, num_classes)
894
+ log_probs_after_end = start_class_log_probabilities.new_full(
895
+ (batch_size * self.beam_size, num_classes),
896
+ torch.finfo(start_class_log_probabilities.dtype).min,
897
+ )
898
+ log_probs_after_end[:, self._end_index] = 0.0
899
+
900
+ # Set the same state for each element in the beam.
901
+ self._update_initial_state(state, batch_size)
902
+
903
+ for i, constraint in enumerate(self.constraints):
904
+ constraint_states[i] = constraint.update_state(constraint_states[i], start_predicted_classes)
905
+
906
+ for timestep in range(self.max_steps - 1):
907
+ # shape: (batch_size * beam_size,)
908
+ last_predictions = predictions[-1].reshape(batch_size * self.beam_size)
909
+
910
+ # If every predicted token from the last step is `self._end_index`,
911
+ # then we can stop early.
912
+ if (last_predictions == self._end_index).all():
913
+ break
914
+ # Take a step. This get the predicted log probs of the next classes
915
+ # and updates the state.
916
+ # shape: (batch_size * beam_size, num_classes)
917
+ class_log_probabilities, state = step(last_predictions, state, timestep + 1)
918
+
919
+ # Apply all constraints.
920
+ if self.constraints:
921
+ # shape: (batch_size, beam_size, num_classes)
922
+ reshaped_class_log_probabilities = class_log_probabilities.view(batch_size, self.beam_size, -1)
923
+ for constraint, constraint_state in zip(self.constraints, constraint_states):
924
+ reshaped_class_log_probabilities = constraint.apply(
925
+ constraint_state, reshaped_class_log_probabilities
926
+ )
927
+ # shape: (batch_size * beam_size, num_classes)
928
+ class_log_probabilities = reshaped_class_log_probabilities.view(batch_size * self.beam_size, -1)
929
+
930
+ # The `timestep`-th iteration of the for loop is generating the `timestep + 2`-th token
931
+ # of the sequence (because `timestep` is 0-indexed and we generated the first token
932
+ # before the for loop). Here we block the end index if the search is not allowed to
933
+ # terminate on this iteration.
934
+ if timestep + 2 <= self.min_steps:
935
+ class_log_probabilities[:, self._end_index] = torch.finfo(class_log_probabilities.dtype).min
936
+
937
+ # shape: (batch_size * beam_size, num_classes)
938
+ last_predictions_expanded = last_predictions.unsqueeze(-1).expand(
939
+ batch_size * self.beam_size, num_classes
940
+ )
941
+
942
+ # Here we are finding any beams where we predicted the end token in
943
+ # the previous timestep and replacing the distribution with a
944
+ # one-hot distribution, forcing the beam to predict the end token
945
+ # this timestep as well.
946
+ # shape: (batch_size * beam_size, num_classes)
947
+ cleaned_log_probabilities = torch.where(
948
+ last_predictions_expanded == self._end_index,
949
+ log_probs_after_end,
950
+ class_log_probabilities,
951
+ )
952
+
953
+ # shape (both): (batch_size * beam_size, per_node_beam_size)
954
+ top_log_probabilities, predicted_classes, sampler_state = self.sampler.sample_nodes(
955
+ cleaned_log_probabilities, self.per_node_beam_size, sampler_state
956
+ )
957
+
958
+ # Here we expand the last log probabilities to (batch_size * beam_size, per_node_beam_size)
959
+ # so that we can add them to the current log probs for this timestep.
960
+ # This lets us maintain the log probability of each element on the beam.
961
+ # shape: (batch_size * beam_size, per_node_beam_size)
962
+ expanded_last_log_probabilities = (
963
+ last_log_probabilities.unsqueeze(2)
964
+ .expand(batch_size, self.beam_size, self.per_node_beam_size)
965
+ .reshape(batch_size * self.beam_size, self.per_node_beam_size)
966
+ )
967
+
968
+ # shape: (batch_size * beam_size, per_node_beam_size)
969
+ summed_top_log_probabilities = top_log_probabilities + expanded_last_log_probabilities
970
+
971
+ # shape: (batch_size, beam_size * per_node_beam_size)
972
+ reshaped_summed = summed_top_log_probabilities.reshape(
973
+ batch_size, self.beam_size * self.per_node_beam_size
974
+ )
975
+
976
+ # shape: (batch_size, beam_size * per_node_beam_size)
977
+ reshaped_predicted_classes = predicted_classes.reshape(
978
+ batch_size, self.beam_size * self.per_node_beam_size
979
+ )
980
+
981
+ # Keep only the top `beam_size` beam indices.
982
+ # shape (both): (batch_size, beam_size)
983
+ (
984
+ restricted_beam_log_probs,
985
+ restricted_beam_indices,
986
+ sampler_state,
987
+ ) = self.sampler.sample_beams(reshaped_summed, self.beam_size, sampler_state)
988
+
989
+ # Use the beam indices to extract the corresponding classes.
990
+ # shape: (batch_size, beam_size)
991
+ restricted_predicted_classes = reshaped_predicted_classes.gather(1, restricted_beam_indices)
992
+
993
+ predictions.append(restricted_predicted_classes)
994
+
995
+ # shape: (batch_size, beam_size)
996
+ last_log_probabilities = restricted_beam_log_probs
997
+
998
+ # The beam indices come from a `beam_size * per_node_beam_size` dimension where the
999
+ # indices with a common ancestor are grouped together. Hence
1000
+ # dividing by per_node_beam_size gives the ancestor. (Note that this is integer
1001
+ # division as the tensor is a LongTensor.)
1002
+ # shape: (batch_size, beam_size)
1003
+ backpointer = torch.divide(restricted_beam_indices, self.per_node_beam_size, rounding_mode="trunc")
1004
+ backpointers.append(backpointer)
1005
+
1006
+ # Keep only the pieces of the state tensors corresponding to the
1007
+ # ancestors created this iteration.
1008
+ self._update_state(state, backpointer)
1009
+
1010
+ for i, constraint in enumerate(self.constraints):
1011
+ constraint_states[i] = constraint.update_state(
1012
+ constraint_states[i], restricted_predicted_classes, last_backpointer=backpointer
1013
+ )
1014
+
1015
+ # Warn about "-inf" log probabilities if not using any constraints (negligible
1016
+ # log probabilities are expected when using constraints).
1017
+ if not self.constraints and (
1018
+ not torch.isfinite(last_log_probabilities).all()
1019
+ or (last_log_probabilities == torch.finfo(last_log_probabilities.dtype).min).any()
1020
+ ):
1021
+ warnings.warn(
1022
+ "Negligible log probabilities encountered ('-inf' or equivalent). "
1023
+ "Some final sequences may not make sense. "
1024
+ "This can happen when the beam size is larger than the number of valid (non-zero "
1025
+ "probability) transitions that the step function produces.",
1026
+ RuntimeWarning,
1027
+ )
1028
+
1029
+ reconstructed_predictions = self._reconstruct_sequences(predictions, backpointers)
1030
+
1031
+ # shape: (batch_size, beam_size, max_steps)
1032
+ all_predictions = torch.cat(list(reversed(reconstructed_predictions)), 2)
1033
+
1034
+ # Calculate the final sequence scores
1035
+ # shape: (batch_size, beam_size)
1036
+ final_scores = self.final_sequence_scorer.score(all_predictions, last_log_probabilities, self._end_index)
1037
+
1038
+ # Sort the sequences based on the final scores so the best scoring
1039
+ # sequence is at index 0
1040
+ sorted_final_scores, sorted_indices = torch.sort(final_scores, dim=1, descending=True)
1041
+ sorted_all_predictions = torch.gather(
1042
+ all_predictions, 1, sorted_indices.unsqueeze(-1).expand_as(all_predictions)
1043
+ )
1044
+
1045
+ return sorted_all_predictions, sorted_final_scores
1046
+
1047
+ def _update_initial_state(self, state: StateType, batch_size: int):
1048
+ """
1049
+ Expand tensors in a state dictionary from `(batch_size, *)` to `(batch_size * beam_size, *)`.
1050
+ """
1051
+ for key, state_tensor in state.items():
1052
+ if state_tensor is None:
1053
+ continue
1054
+ # shape: (batch_size * beam_size, *)
1055
+ _, *last_dims = state_tensor.size()
1056
+ state[key] = (
1057
+ state_tensor.unsqueeze(1)
1058
+ .expand(batch_size, self.beam_size, *last_dims)
1059
+ .reshape(batch_size * self.beam_size, *last_dims)
1060
+ )
1061
+
1062
+ def _update_state(self, state: StateType, backpointer: torch.Tensor):
1063
+ batch_size = backpointer.size()[0]
1064
+
1065
+ for key, state_tensor in state.items():
1066
+ if state_tensor is None:
1067
+ continue
1068
+ _, *last_dims = state_tensor.size()
1069
+ # shape: (batch_size, beam_size, *)
1070
+ expanded_backpointer = backpointer.view(batch_size, self.beam_size, *([1] * len(last_dims))).expand(
1071
+ batch_size, self.beam_size, *last_dims
1072
+ )
1073
+ # shape: (batch_size * beam_size, *)
1074
+ state[key] = (
1075
+ state_tensor.reshape(batch_size, self.beam_size, *last_dims)
1076
+ .gather(1, expanded_backpointer)
1077
+ .reshape(batch_size * self.beam_size, *last_dims)
1078
+ )
checkpoint.py ADDED
@@ -0,0 +1,2023 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+ import io
3
+ import logging
4
+ import pickle
5
+ import shutil
6
+ import traceback
7
+ from abc import ABCMeta, abstractmethod
8
+ from collections import defaultdict
9
+ from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor, as_completed
10
+ from contextlib import contextmanager
11
+ from copy import deepcopy
12
+ from dataclasses import dataclass, field, replace
13
+ from functools import reduce
14
+ from multiprocessing import shared_memory
15
+ from pathlib import Path
16
+ from typing import Any, Dict, Generator, List, Optional, Set, Tuple, cast
17
+
18
+ import numpy as np
19
+ import torch
20
+ import torch.distributed.checkpoint as dist_cp
21
+ import torch.multiprocessing as mp
22
+ import torch.nn as nn
23
+ from packaging import version
24
+ from torch.distributed import _remote_device
25
+ from torch.distributed._shard._utils import narrow_tensor_by_index
26
+ from torch.distributed._shard.metadata import ShardMetadata
27
+ from torch.distributed._shard.sharded_tensor import ShardedTensor
28
+ from torch.distributed.checkpoint.filesystem import WriteResult, _StorageInfo
29
+ from torch.distributed.checkpoint.metadata import Metadata, MetadataIndex
30
+ from torch.distributed.checkpoint.optimizer import load_sharded_optimizer_state_dict
31
+ from torch.distributed.checkpoint.planner import LoadItemType, ReadItem
32
+ from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
33
+ from torch.distributed.fsdp import StateDictType
34
+ from torch.distributed.fsdp.api import (
35
+ FullOptimStateDictConfig,
36
+ FullStateDictConfig,
37
+ ShardedOptimStateDictConfig,
38
+ ShardedStateDictConfig,
39
+ )
40
+ from torch.futures import Future
41
+ from torch.nn.parallel import DistributedDataParallel as DDP
42
+
43
+ try:
44
+ from torch.distributed.fsdp.flat_param import FlatParamHandle # type: ignore
45
+ except ModuleNotFoundError:
46
+ from torch.distributed.fsdp._flat_param import FlatParamHandle # type: ignore
47
+
48
+ from olmo import util
49
+
50
+ from .aliases import PathOrStr
51
+ from .config import BaseConfig, ShardedCheckpointerType, TrainConfig
52
+ from .exceptions import OLMoCheckpointError
53
+ from .optim import Optimizer, fix_optim_state_dict
54
+ from .safetensors_util import safetensors_file_to_state_dict
55
+ from .torch_util import (
56
+ barrier,
57
+ gc_cuda,
58
+ get_fs_local_rank,
59
+ get_global_rank,
60
+ get_local_rank,
61
+ get_local_world_size,
62
+ get_world_size,
63
+ )
64
+ from .util import (
65
+ _get_s3_client,
66
+ default_thread_count,
67
+ dir_is_empty,
68
+ get_bytes_range,
69
+ get_progress_bar,
70
+ resource_path,
71
+ upload,
72
+ wait_for,
73
+ )
74
+
75
+ __all__ = [
76
+ "save_fsdp_model_and_optim_state",
77
+ "load_fsdp_model_and_optim_state",
78
+ "load_fsdp_optim_state",
79
+ "save_state_dict",
80
+ "load_state_dict",
81
+ "load_model_state",
82
+ "RemoteFileSystemWriter",
83
+ "RemoteFileSystemReader",
84
+ "Checkpointer",
85
+ "FullCheckpointer",
86
+ "TorchNewStyleShardedCheckpointer",
87
+ "TorchLegacyShardedCheckpointer",
88
+ "LocalShardedCheckpointer",
89
+ "build_sharded_checkpointer",
90
+ ]
91
+
92
+
93
+ log = logging.getLogger(__name__)
94
+
95
+ MODEL_AND_OPTIM_FOLDER = "model_and_optim"
96
+
97
+
98
+ def save_fsdp_model_and_optim_state(
99
+ checkpoint_dir: PathOrStr,
100
+ fsdp_model: FSDP,
101
+ optim: Optimizer,
102
+ *,
103
+ upload_to: Optional[str] = None,
104
+ save_overwrite: bool = False,
105
+ ):
106
+ """
107
+ Use this to save a state dict for an FSDP model and its optimizer via :module:`torch.distributed.checkpoint`
108
+ functions. This should be used during distributed training and should be called by all ranks.
109
+
110
+ :param checkpoint_dir: The directory to save to.
111
+ :param fsdp_model: The FSDP model.
112
+ :param optim: The FSDP model's optimizer.
113
+ :param upload_to: Optional, a remote "directory" to upload the checkpoint files to.
114
+ :param save_overwrite: Overwrite existing files.
115
+
116
+ :raises FileExistsError: If a model and optim checkpoint already exists in ``checkpoint_dir`` and ``save_overwrite=False``.
117
+ """
118
+ checkpoint_dir = Path(checkpoint_dir)
119
+ target_dir = checkpoint_dir / MODEL_AND_OPTIM_FOLDER
120
+ if save_overwrite:
121
+ if get_fs_local_rank() == 0:
122
+ shutil.rmtree(target_dir, ignore_errors=True)
123
+ elif not dir_is_empty(target_dir):
124
+ raise FileExistsError(target_dir)
125
+ barrier()
126
+ if get_fs_local_rank() == 0:
127
+ target_dir.mkdir(exist_ok=True, parents=True)
128
+ barrier()
129
+ with FSDP.state_dict_type(
130
+ fsdp_model,
131
+ state_dict_type=StateDictType.SHARDED_STATE_DICT,
132
+ state_dict_config=ShardedStateDictConfig(offload_to_cpu=True),
133
+ optim_state_dict_config=ShardedOptimStateDictConfig(offload_to_cpu=True),
134
+ ):
135
+ model_and_optim_state = {
136
+ "model": fsdp_model.state_dict(),
137
+ "optim": FSDP.optim_state_dict(fsdp_model, optim),
138
+ }
139
+ dist_cp.save_state_dict(
140
+ model_and_optim_state,
141
+ RemoteFileSystemWriter(
142
+ target_dir,
143
+ upload_to=None if upload_to is None else f"{upload_to.rstrip('/')}/{MODEL_AND_OPTIM_FOLDER}",
144
+ save_overwrite=save_overwrite,
145
+ ),
146
+ )
147
+
148
+
149
+ def load_fsdp_model_and_optim_state(
150
+ checkpoint_dir: PathOrStr,
151
+ fsdp_model: FSDP,
152
+ optim: Optimizer,
153
+ *,
154
+ local_cache: Optional[PathOrStr] = None,
155
+ load_optimizer_state: bool = True,
156
+ ):
157
+ """
158
+ Use this to load a state dict for an FSDP model and its optimizer via :module:`torch.distributed.checkpoint`
159
+ functions. This should be used during distributed training and should be called by all ranks.
160
+
161
+ :param checkpoint_dir: The checkpoint directory to load from. This can be a local or remote directory.
162
+ :param fsdp_model: The FSDP model.
163
+ :param optim: The FSDP model's optimizer.
164
+ :param local_cache: A local cache of the checkpoint directory. Use this when the ``checkpoint_dir`` is a
165
+ remote "directory" but there might be a cached version of the same artifacts.
166
+ :param load_optimizer_state: Set to ``False`` to skip loading the optimizer state.
167
+
168
+ :raises FileNotFoundError: If the ``checkpoint_dir`` doesn't contain a model and optimizer checkpoint.
169
+ """
170
+ load_path = str(checkpoint_dir).rstrip("/")
171
+ local_cache = None if local_cache is None else Path(local_cache)
172
+ with FSDP.state_dict_type(
173
+ fsdp_model,
174
+ state_dict_type=StateDictType.SHARDED_STATE_DICT,
175
+ state_dict_config=ShardedStateDictConfig(offload_to_cpu=True),
176
+ optim_state_dict_config=ShardedOptimStateDictConfig(offload_to_cpu=True),
177
+ ):
178
+ # Load the model state dict in place.
179
+ log.info("Loading model state...")
180
+ model_state = {"model": fsdp_model.state_dict()}
181
+ dist_cp.load_state_dict(
182
+ model_state,
183
+ RemoteFileSystemReader(
184
+ f"{load_path}/{MODEL_AND_OPTIM_FOLDER}",
185
+ local_cache=None if local_cache is None else local_cache / MODEL_AND_OPTIM_FOLDER,
186
+ ),
187
+ )
188
+ fsdp_model.load_state_dict(model_state["model"])
189
+
190
+ if not load_optimizer_state:
191
+ return
192
+
193
+ # Load optim state dict in place.
194
+ log.info("Loading sharded optimizer state...")
195
+ optim_state = load_sharded_optimizer_state_dict(
196
+ model_state_dict=model_state["model"],
197
+ optimizer_key="optim",
198
+ storage_reader=RemoteFileSystemReader(
199
+ f"{load_path}/{MODEL_AND_OPTIM_FOLDER}",
200
+ local_cache=None if local_cache is None else local_cache / MODEL_AND_OPTIM_FOLDER,
201
+ ),
202
+ )
203
+ # optim_state["optim"] = {
204
+ # 'state': { fqn: { 'grad_norm_exp_avg': Tensor, 'step': Tensor, 'exp_avg': ShardedTensor, 'exp_avg_sq': ShardedTensor } },
205
+ # 'param_groups': [{ 'param_names': [ fsdp_fqn, ... ], 'params': [ fqn, ... ], ... }],
206
+ # }
207
+ del model_state
208
+
209
+ # Make sure tensors are on CPU! PyTorch puts them on GPU even though we have `offload_to_cpu=True`.
210
+ for state in optim_state["optim"]["state"].values():
211
+ for k in state.keys():
212
+ state[k] = state[k].cpu()
213
+ gc_cuda()
214
+
215
+ load_fsdp_optim_state(fsdp_model, optim, optim_state["optim"])
216
+
217
+
218
+ def load_fsdp_optim_state(fsdp_model: FSDP, optim: Optimizer, optim_state: Dict[str, Any]):
219
+ log.info("Flattening sharded optimizer state...")
220
+ # flattened_osd = {
221
+ # 'state': { id: { 'grad_norm_exp_avg': Tensor, 'step': Tensor, 'exp_avg': Tensor, 'exp_avg_sq': Tensor } },
222
+ # 'param_groups': [{ 'param_names': [ fsdp_fqn, ... ], 'params': [ id, ... ], ... }],
223
+ # }
224
+ # NOTE: Careful! The order of the these arguments has changed from 2.0 to 2.1... ¯\_(ツ)_/¯
225
+ if version.parse(torch.__version__) < version.parse("2.1.0"):
226
+ flattened_osd = FSDP.optim_state_dict_to_load(optim_state, fsdp_model, optim) # type: ignore
227
+ else:
228
+ flattened_osd = FSDP.optim_state_dict_to_load(fsdp_model, optim, optim_state) # type: ignore
229
+
230
+ del optim_state
231
+ gc_cuda()
232
+
233
+ log.info("Loading flattened optimizer state...")
234
+
235
+ # Put optim state on CPU since `Optimizer.load_state_dict()` will create a deepcopy of the whole state dict,
236
+ # which takes up unnecessary GPU memory.
237
+ for state in flattened_osd["state"].values():
238
+ for k in state.keys():
239
+ state[k] = state[k].cpu()
240
+ gc_cuda()
241
+
242
+ optim.load_state_dict(fix_optim_state_dict(optim, flattened_osd))
243
+
244
+
245
+ def save_state_dict(
246
+ checkpoint_dir: PathOrStr,
247
+ fname: str,
248
+ state_dict: Dict[str, Any],
249
+ *,
250
+ upload_to: Optional[str] = None,
251
+ save_overwrite: bool = False,
252
+ synchronize: bool = True,
253
+ ):
254
+ """
255
+ Save a regular state dict to the file ``fname`` within ``checkpoint_dir`` using :func:`torch.save()`.
256
+ This can be used during distributed training or not. If during distributed training the ``fname`` should be unique
257
+ for each rank.
258
+
259
+ :param checkpoint_dir: The directory to save to.
260
+ :param fname: The target file within ``checkpoint_dir`` to save to. This should be a path relative to the ``checkpoint_dir``.
261
+ :param state_dict: The state dict to save.
262
+ :param upload_to: Optional, a remote "directory" to upload the file to.
263
+ :param save_overwrite: Overwrite existing files.
264
+ :param synchronize: If ``False``, don't do any distributed synchronization. Use this when only calling
265
+ this function from a single rank.
266
+
267
+ :raises FileExistsError: If the ``fname`` already exists within ``checkpoint_dir`` and ``save_overwrite=False``.
268
+ """
269
+ checkpoint_dir = Path(checkpoint_dir)
270
+ target_path = checkpoint_dir / fname
271
+ if save_overwrite:
272
+ target_path.unlink(missing_ok=True)
273
+ elif target_path.is_file():
274
+ raise FileExistsError(target_path)
275
+ if synchronize:
276
+ barrier()
277
+ target_path.parent.mkdir(exist_ok=True, parents=True)
278
+ if synchronize:
279
+ barrier()
280
+ torch.save(state_dict, target_path)
281
+ if upload_to is not None:
282
+ upload_target = f"{upload_to.rstrip('/')}/{fname}"
283
+ log.info(f"Uploading {target_path} to {upload_target}...")
284
+ upload(target_path, upload_target, save_overwrite=save_overwrite)
285
+
286
+
287
+ def load_state_dict(
288
+ checkpoint_dir: PathOrStr,
289
+ fname: str,
290
+ *,
291
+ local_cache: Optional[PathOrStr] = None,
292
+ map_location: Optional[str] = None,
293
+ ):
294
+ """
295
+ Load a regular state dict from the file ``fname`` within ``checkpoint_dir`` using :func:`torch.load()`.
296
+ This can be used during distributed training or not.
297
+
298
+ :param checkpoint_dir: A local or remote checkpoint directory.
299
+ :param fname: The target file within the ``checkpoint_dir``. This should be a path relative to the ``checkpoint_dir``.
300
+ :param local_cache: A local cache of the checkpoint directory. Use this when the ``checkpoint_dir`` is a
301
+ remote "directory" but there might be a cached version of the same artifacts.
302
+
303
+ :raises FileNotFoundError: If ``fname`` doesn't exist in the ``checkpoint_dir`` or the local cache.
304
+ """
305
+ if fname.endswith(".pt"):
306
+ # Try safetensors version first.
307
+ try:
308
+ path = resource_path(
309
+ str(checkpoint_dir).rstrip("/"), fname[:-2] + "safetensors", local_cache=local_cache
310
+ )
311
+ return safetensors_file_to_state_dict(path, map_location=map_location)
312
+ except FileNotFoundError:
313
+ pass
314
+
315
+ path = resource_path(str(checkpoint_dir).rstrip("/"), fname, local_cache=local_cache)
316
+ return torch.load(path, map_location=map_location)
317
+
318
+
319
+ def load_model_state(checkpoint_dir: PathOrStr, model: torch.nn.Module):
320
+ """
321
+ Load model state from a distributed FSDP model checkpoint created from :func:`save_fsdp_model_and_optim_state()`.
322
+ Note that ``model`` should not be wrapped with FSDP.
323
+ """
324
+ state_dict = {"model": model.state_dict()}
325
+ dist_cp.load_state_dict(
326
+ state_dict,
327
+ RemoteFileSystemReader(f"{str(checkpoint_dir).rstrip('/')}/{MODEL_AND_OPTIM_FOLDER}"),
328
+ no_dist=True,
329
+ )
330
+ model.load_state_dict(state_dict["model"])
331
+
332
+
333
+ class RemoteFileSystemWriter(dist_cp.FileSystemWriter):
334
+ """
335
+ A subclass of :class:`~torch.distributed.checkpoint.FileSystemWriter` that can upload files
336
+ directly to a cloud bucket when ``upload_to`` is specified.
337
+ """
338
+
339
+ def __init__(
340
+ self,
341
+ path: PathOrStr,
342
+ single_file_per_rank: bool = True,
343
+ sync_files: bool = True,
344
+ thread_count: Optional[int] = None,
345
+ per_thread_copy_ahead: int = 10_000_000,
346
+ upload_to: Optional[str] = None,
347
+ save_overwrite: bool = False,
348
+ ) -> None:
349
+ if thread_count is not None and thread_count <= 0:
350
+ raise ValueError("thread count must be at least 1")
351
+ super().__init__(
352
+ path,
353
+ single_file_per_rank=single_file_per_rank,
354
+ sync_files=sync_files,
355
+ # NOTE: we default to 1 thread here instead of whatever `default_thread_count()`
356
+ # returns because uploading big checkpoint files with multiple threads causes
357
+ # boto3 to fail in weird ways.
358
+ thread_count=thread_count or 1,
359
+ per_thread_copy_ahead=per_thread_copy_ahead,
360
+ )
361
+ self.upload_to = None if upload_to is None else upload_to.rstrip("/")
362
+ self.save_overwrite = save_overwrite
363
+
364
+ def write_data(
365
+ self,
366
+ plan: dist_cp.SavePlan,
367
+ planner: dist_cp.SavePlanner,
368
+ ) -> Future[List[WriteResult]]:
369
+ fut = super().write_data(plan, planner)
370
+ if self.upload_to is not None:
371
+ files_to_upload = set()
372
+ for write_result in fut.wait():
373
+ files_to_upload.add(write_result.storage_data.relative_path)
374
+
375
+ # Create the global S3 client up front to work around a threading issue in boto.
376
+ if self.upload_to.startswith("s3://"):
377
+ _get_s3_client("s3")
378
+ elif self.upload_to.startswith("r2://"):
379
+ _get_s3_client("r2")
380
+ elif self.upload_to.startswith("weka://"):
381
+ _get_s3_client("weka")
382
+
383
+ with ThreadPoolExecutor(max_workers=self.thread_count) as executor:
384
+ futures = []
385
+ for fname in files_to_upload:
386
+ source = self.path / fname
387
+ target = f"{self.upload_to}/{fname}"
388
+ log.info(f"Uploading {source} to {target}...")
389
+ futures.append(executor.submit(upload, source, target, save_overwrite=self.save_overwrite))
390
+ for f in as_completed(futures):
391
+ try:
392
+ f.result()
393
+ except BaseException:
394
+ # NOTE: we might get an error here that can't be pickled, which causes a different failure
395
+ # later when PyTorch tries to reduce that error across ranks. So here we just make
396
+ # sure we're raising a simple error type that can be pickled.
397
+ raise OLMoCheckpointError(f"Original error:\n{traceback.format_exc()}")
398
+ return fut
399
+
400
+ def finish(self, metadata: Metadata, results: List[List[WriteResult]]) -> None:
401
+ super().finish(metadata, results)
402
+ if self.upload_to is not None:
403
+ source = self.path / ".metadata"
404
+ target = f"{self.upload_to}/.metadata"
405
+ log.info(f"Uploading {source} to {target}...")
406
+ upload(source, target, save_overwrite=self.save_overwrite)
407
+
408
+
409
+ class RemoteFileSystemReader(dist_cp.StorageReader):
410
+ """
411
+ A :class:`~torch.distributed.checkpoint.StorageReader` based on :class:`~torch.distributed.checkpoint.FileSystemReader`
412
+ that can read data directly from cloud storage as well as a local directory.
413
+ """
414
+
415
+ def __init__(
416
+ self, path: PathOrStr, *, local_cache: Optional[PathOrStr] = None, thread_count: Optional[int] = None
417
+ ):
418
+ super().__init__()
419
+ if thread_count is not None and thread_count <= 0:
420
+ raise ValueError("thread count must be at least 1")
421
+ self.path = str(path).rstrip("/")
422
+ self.cache = None if local_cache is None else Path(local_cache)
423
+ self.thread_count = thread_count or default_thread_count()
424
+ self.storage_data: Dict[MetadataIndex, _StorageInfo] = dict()
425
+ self._metadata: Optional[Metadata] = None
426
+
427
+ def _get_bytes(self, relative_path: str, offset: int, length: int) -> bytes:
428
+ if self.cache is not None and (path := self.cache / relative_path).is_file():
429
+ return get_bytes_range(path, offset, length)
430
+ else:
431
+ return get_bytes_range(f"{self.path}/{relative_path}", offset, length)
432
+
433
+ def _get_content_for_read(self, read_item: ReadItem) -> Tuple[ReadItem, bytes]:
434
+ sinfo = self.storage_data[read_item.storage_index]
435
+ content = self._get_bytes(sinfo.relative_path, sinfo.offset, sinfo.length)
436
+ return (read_item, content)
437
+
438
+ def read_data(self, plan: dist_cp.LoadPlan, planner: dist_cp.LoadPlanner) -> Future[None]:
439
+ # Create the global S3 client up front to work around a threading issue in boto.
440
+ if isinstance(self.path, str):
441
+ if self.path.startswith("s3://"):
442
+ _get_s3_client("s3")
443
+ elif self.path.startswith("r2://"):
444
+ _get_s3_client("r2")
445
+ elif self.path.startswith("weka://"):
446
+ _get_s3_client("weka")
447
+
448
+ with ThreadPoolExecutor(max_workers=self.thread_count) as executor:
449
+ read_item_content_futures = []
450
+ for read_item in plan.items:
451
+ read_item_content_futures.append(executor.submit(self._get_content_for_read, read_item))
452
+ read_item_content_results = []
453
+ for f in as_completed(read_item_content_futures):
454
+ try:
455
+ read_item_content_results.append(f.result())
456
+ except BaseException:
457
+ # NOTE: we might get an error here that can't be pickled, which causes a different failure
458
+ # later when PyTorch tries to reduce that error across ranks. So here we just make
459
+ # sure we're raising a simple error type that can be pickled.
460
+ raise OLMoCheckpointError(f"Original error:\n{traceback.format_exc()}")
461
+
462
+ # Modified from `FileSystemReader.read_data()`
463
+ for read_item, content in read_item_content_results:
464
+ bytes = io.BytesIO(content)
465
+ bytes.seek(0)
466
+ if read_item.type == LoadItemType.BYTE_IO:
467
+ planner.load_bytes(read_item, bytes)
468
+ else:
469
+ tensor = cast(torch.Tensor, torch.load(bytes, map_location="cpu"))
470
+ tensor = narrow_tensor_by_index(tensor, read_item.storage_offsets, read_item.lengths)
471
+ target_tensor = planner.resolve_tensor(read_item).detach()
472
+
473
+ assert (
474
+ target_tensor.size() == tensor.size()
475
+ ), f"req {read_item.storage_index} mismatch sizes {target_tensor.size()} vs {tensor.size()}"
476
+ target_tensor.copy_(tensor)
477
+ planner.commit_tensor(read_item, target_tensor)
478
+
479
+ fut: Future = Future()
480
+ fut.set_result(None)
481
+ return fut
482
+
483
+ def read_metadata(self) -> Metadata:
484
+ if self._metadata is None:
485
+ with resource_path(self.path, ".metadata", local_cache=self.cache).open("rb") as metadata_file:
486
+ self._metadata = pickle.load(metadata_file)
487
+ return self._metadata
488
+
489
+ def set_up_storage_reader(self, metadata: Metadata, is_coordinator: bool) -> None:
490
+ del is_coordinator
491
+ self.storage_data = metadata.storage_data
492
+ assert self.storage_data is not None
493
+
494
+ def prepare_local_plan(self, plan: dist_cp.LoadPlan) -> dist_cp.LoadPlan:
495
+ return plan
496
+
497
+ def prepare_global_plan(self, global_plan: List[dist_cp.LoadPlan]) -> List[dist_cp.LoadPlan]:
498
+ return global_plan
499
+
500
+
501
+ class Checkpointer(metaclass=ABCMeta):
502
+ def __init__(self, cfg: TrainConfig, thread_count: Optional[int] = None):
503
+ self.cfg = cfg
504
+ self.thread_count = thread_count or default_thread_count()
505
+
506
+ @abstractmethod
507
+ def save_checkpoint(
508
+ self,
509
+ dir: PathOrStr,
510
+ dist_model: nn.Module,
511
+ optim: Optimizer,
512
+ train_state: Dict[str, Any],
513
+ *,
514
+ upload_to: Optional[str] = None,
515
+ ) -> None:
516
+ raise NotImplementedError
517
+
518
+ @abstractmethod
519
+ def restore_checkpoint(
520
+ self,
521
+ load_path: PathOrStr,
522
+ dist_model: nn.Module,
523
+ optim: Optimizer,
524
+ *,
525
+ local_cache: Optional[PathOrStr] = None,
526
+ load_optimizer_state: bool = True,
527
+ ) -> Dict[str, Any]:
528
+ """
529
+ Restores a checkpoint to the model and optimizer. Returns the remaining trainer state.
530
+ """
531
+ raise NotImplementedError
532
+
533
+ def unshard_checkpoint(
534
+ self,
535
+ load_path: PathOrStr,
536
+ *,
537
+ local_cache: Optional[PathOrStr] = None,
538
+ load_optimizer_state: bool = True,
539
+ load_trainer_state: bool = True,
540
+ device: Optional[torch.device] = None,
541
+ ) -> Tuple[Dict[str, torch.Tensor], Optional[Dict[str, Any]], Optional[Dict[str, Any]]]:
542
+ """
543
+ Unshard a checkpoint.
544
+
545
+ Note this is not marked abstract because child classes are not required to implemented this.
546
+ """
547
+ raise NotImplementedError
548
+
549
+ @contextmanager
550
+ def _temporary_wd(self, dir: PathOrStr) -> Generator[Path, None, None]:
551
+ # Make sure checkpoint directory doesn't exist unless it's okay to overwrite it.
552
+ checkpoint_dir = Path(dir)
553
+ if not dir_is_empty(checkpoint_dir):
554
+ if self.cfg.save_overwrite:
555
+ if get_fs_local_rank() == 0:
556
+ shutil.rmtree(checkpoint_dir, ignore_errors=True)
557
+ else:
558
+ raise FileExistsError(checkpoint_dir)
559
+ # No need to mkdir here since we'll directly replace the temporary directory with
560
+ # this directory below.
561
+ barrier()
562
+
563
+ # Prepare temporary directory. We don't have to be as careful here, we can
564
+ # just remove it if it already exists.
565
+ checkpoint_dir_tmp = checkpoint_dir.with_name(checkpoint_dir.name + "-tmp")
566
+ if get_fs_local_rank() == 0:
567
+ shutil.rmtree(checkpoint_dir_tmp, ignore_errors=True)
568
+ checkpoint_dir_tmp.mkdir(exist_ok=True, parents=True)
569
+
570
+ # In the cases where we're using a shared NFS drive between ranks to save checkpoints,
571
+ # creating the temp directory from rank 0 might not be immediately
572
+ # realized in the file systems of the other ranks.
573
+ # So we wait here across all ranks until that tmp checkpoint directory is visible.
574
+ wait_for(lambda: checkpoint_dir_tmp.exists(), "Waiting for checkpoint directory", timeout=10.0)
575
+
576
+ barrier()
577
+
578
+ # Yield temporary directory for `.save_checkpoint()` to use.
579
+ yield checkpoint_dir_tmp
580
+
581
+ barrier()
582
+
583
+ # Finally if all went well replace the temporary directory with the actual
584
+ # checkpoint directory.
585
+ if get_fs_local_rank() == 0:
586
+ # Replace temp directory with target checkpoint directory.
587
+ try:
588
+ checkpoint_dir_tmp.replace(checkpoint_dir)
589
+ except FileNotFoundError:
590
+ # Caught when another (file-system) local rank 0 has already replaced the tmp directory.
591
+ # This can happen when nodes are saving to a common NFS drive but otherwise have distinct
592
+ # file-systems.
593
+ if not checkpoint_dir.exists():
594
+ raise
595
+
596
+ # In the cases where we're using a shared NFS drive between ranks to save checkpoints,
597
+ # replacing the temp directory with the final directory from rank 0 might not be immediately
598
+ # realized in the file systems of the other ranks.
599
+ # So we wait here across all ranks until that final checkpoint directory is visible.
600
+ wait_for(lambda: checkpoint_dir.exists(), "Waiting for checkpoint directory", timeout=10.0)
601
+
602
+ barrier()
603
+
604
+ def _save_config(self, dir: PathOrStr, *, upload_to: Optional[str] = None) -> None:
605
+ if get_global_rank() == 0:
606
+ log.info("Saving config...")
607
+ self.cfg.save(config_path := Path(dir) / "config.yaml")
608
+ if upload_to is not None:
609
+ upload_target = f"{upload_to}/config.yaml"
610
+ log.info(f"Uploading {config_path} to {upload_target}")
611
+ upload(config_path, upload_target, save_overwrite=self.cfg.save_overwrite)
612
+
613
+
614
+ class FullCheckpointer(Checkpointer):
615
+ """
616
+ A :class:`Checkpointer` that saves a single full model and optimizer state dictionary.
617
+ """
618
+
619
+ def save_checkpoint(
620
+ self,
621
+ dir: PathOrStr,
622
+ dist_model: nn.Module,
623
+ optim: Optimizer,
624
+ trainer_state: Dict[str, Any],
625
+ *,
626
+ upload_to: Optional[str] = None,
627
+ ) -> None:
628
+ with self._temporary_wd(dir) as checkpoint_dir:
629
+ if isinstance(dist_model, FSDP):
630
+ with FSDP.state_dict_type(
631
+ dist_model,
632
+ state_dict_type=StateDictType.FULL_STATE_DICT,
633
+ state_dict_config=FullStateDictConfig(rank0_only=True, offload_to_cpu=True),
634
+ optim_state_dict_config=FullOptimStateDictConfig(rank0_only=True, offload_to_cpu=True),
635
+ ):
636
+ # We'll write the model and optimizer state dicts individually to reduce (CPU) memory consumption.
637
+ # First the model state.
638
+ model_state_dict = dist_model.state_dict()
639
+ self._write_model_dict(
640
+ model_state_dict, checkpoint_dir, upload_to, save_overwrite=self.cfg.save_overwrite
641
+ )
642
+
643
+ # Then the optimizer state.
644
+ optim_state_dict = FSDP.optim_state_dict(dist_model, optim)
645
+ self._write_optim_dict(
646
+ optim_state_dict, checkpoint_dir, upload_to, save_overwrite=self.cfg.save_overwrite
647
+ )
648
+ elif isinstance(dist_model, DDP):
649
+ # _write_model_dict and _write_optim_dict only write checkpoints for rank 0
650
+ # First, get the model state dict from DDP wrapped model
651
+ model_state_dict = dist_model.module.state_dict()
652
+ self._write_model_dict(
653
+ model_state_dict, checkpoint_dir, upload_to, save_overwrite=self.cfg.save_overwrite
654
+ )
655
+
656
+ # Then get the optimizer state dict
657
+ optim_state_dict = optim.state_dict()
658
+ self._write_optim_dict(
659
+ optim_state_dict, checkpoint_dir, upload_to, save_overwrite=self.cfg.save_overwrite
660
+ )
661
+ else:
662
+ log.info(
663
+ "`FullCheckpointer.save_checkpoint` only supported for FSDP and DDP distributed strategies!"
664
+ )
665
+
666
+ # Save trainer state.
667
+ if get_global_rank() == 0:
668
+ log.info("Saving trainer state...")
669
+ save_state_dict(
670
+ checkpoint_dir,
671
+ "train.pt",
672
+ trainer_state,
673
+ upload_to=upload_to,
674
+ save_overwrite=self.cfg.save_overwrite,
675
+ synchronize=False,
676
+ )
677
+ # Save config.
678
+ self._save_config(checkpoint_dir, upload_to=upload_to)
679
+
680
+ def restore_checkpoint(
681
+ self,
682
+ load_path: PathOrStr,
683
+ dist_model: nn.Module,
684
+ optim: Optimizer,
685
+ *,
686
+ local_cache: Optional[PathOrStr] = None,
687
+ load_optimizer_state: bool = True,
688
+ ) -> Dict[str, Any]:
689
+ if isinstance(dist_model, FSDP):
690
+ with FSDP.state_dict_type(
691
+ dist_model,
692
+ state_dict_type=StateDictType.FULL_STATE_DICT,
693
+ state_dict_config=FullStateDictConfig(rank0_only=False, offload_to_cpu=True),
694
+ optim_state_dict_config=FullOptimStateDictConfig(rank0_only=False, offload_to_cpu=True),
695
+ ):
696
+ with torch.no_grad():
697
+ # fill everything with NaN, so we can check afterwards that every parameter has been restored
698
+ for module_name, module in dist_model.named_modules():
699
+ if not isinstance(module, FSDP):
700
+ continue
701
+ for param in module.params:
702
+ param.fill_(torch.nan)
703
+
704
+ # restore params from checkpoint
705
+ state_dict_to_load = load_state_dict(
706
+ load_path, "model.pt", local_cache=local_cache, map_location="cpu"
707
+ )
708
+ (
709
+ state_dict_to_load,
710
+ og_keys_to_new,
711
+ ) = dist_model._fsdp_wrapped_module._make_state_dict_compatible(state_dict_to_load)
712
+
713
+ for module_name, module in dist_model.named_modules():
714
+ if not isinstance(module, FSDP):
715
+ continue
716
+ for param in module.params:
717
+ assert param._is_flat_param
718
+ for fqn, spi in zip(param._fqns, param._shard_param_infos):
719
+ if not spi.in_shard:
720
+ continue
721
+ key = f"{module_name}.{fqn}"
722
+ key = key.replace("_fsdp_wrapped_module.", "")
723
+ key = key.lstrip(".")
724
+ t = state_dict_to_load[key]
725
+ t = t.flatten()
726
+ param[spi.offset_in_shard : spi.offset_in_shard + spi.numel_in_shard].copy_(
727
+ t[spi.intra_param_start_idx : spi.intra_param_end_idx + 1]
728
+ )
729
+
730
+ # make sure that every parameter has been restored
731
+ for module_name, module in dist_model.named_modules():
732
+ if not isinstance(module, FSDP):
733
+ continue
734
+ for param in module.params:
735
+ if torch.isnan(param).any():
736
+ raise ValueError(
737
+ f"Module '{module_name}' contains NaNs, this is likely a bug restoring from full checkpoints"
738
+ )
739
+
740
+ # Load optimizer state.
741
+ if load_optimizer_state:
742
+ optim_state_dict_to_load = load_state_dict(
743
+ load_path, "optim.pt", local_cache=local_cache, map_location="cpu"
744
+ )
745
+ optim_state_dict_to_load = self._make_optim_state_dict_compatible(
746
+ optim_state_dict_to_load,
747
+ og_keys_to_new,
748
+ )
749
+ gc.collect()
750
+ torch.cuda.empty_cache()
751
+ barrier()
752
+ for turn in range(get_local_world_size()):
753
+ log.info("Loading optimizer state turn %d ...", turn)
754
+ if turn == get_local_rank():
755
+ load_fsdp_optim_state(dist_model, optim, optim_state_dict_to_load)
756
+ gc.collect()
757
+ torch.cuda.empty_cache()
758
+ barrier()
759
+ del optim_state_dict_to_load
760
+ elif isinstance(dist_model, DDP):
761
+ # Load model state.
762
+ with torch.no_grad():
763
+ state_dict_to_load = load_state_dict(
764
+ load_path, "model.pt", local_cache=local_cache, map_location="cpu"
765
+ )
766
+ dist_model.module.load_state_dict(state_dict_to_load, strict=True)
767
+
768
+ # Load optimizer state.
769
+ if load_optimizer_state:
770
+ optim_state_dict_to_load = load_state_dict(
771
+ load_path, "optim.pt", local_cache=local_cache, map_location="cpu"
772
+ )
773
+ optim.load_state_dict(optim_state_dict_to_load)
774
+
775
+ gc.collect()
776
+ torch.cuda.empty_cache()
777
+ barrier()
778
+ else:
779
+ raise NotImplementedError(
780
+ "`FullCheckpointer.restore_checkpoint` only supported for FSDP and DDP distributed strategies!"
781
+ )
782
+
783
+ # Load other state.
784
+ try:
785
+ trainer_state = load_state_dict(load_path, "train.pt", local_cache=local_cache)
786
+ except FileNotFoundError:
787
+ # for backwards compatibility
788
+ trainer_state = load_state_dict(load_path, "other.pt", local_cache=local_cache)
789
+ barrier()
790
+ return trainer_state
791
+
792
+ def _write_model_dict(self, model_state_dict, checkpoint_dir, upload_to, save_overwrite):
793
+ if get_global_rank() == 0:
794
+ log.info("Saving model state...")
795
+ save_state_dict(
796
+ checkpoint_dir,
797
+ "model.pt",
798
+ model_state_dict,
799
+ upload_to=upload_to,
800
+ save_overwrite=save_overwrite,
801
+ synchronize=False,
802
+ )
803
+
804
+ del model_state_dict
805
+ barrier()
806
+
807
+ def _write_optim_dict(self, optim_state_dict, checkpoint_dir, upload_to, save_overwrite):
808
+ if get_global_rank() == 0:
809
+ log.info("Saving optim state...")
810
+ save_state_dict(
811
+ checkpoint_dir,
812
+ "optim.pt",
813
+ optim_state_dict,
814
+ upload_to=upload_to,
815
+ save_overwrite=save_overwrite,
816
+ synchronize=False,
817
+ )
818
+
819
+ del optim_state_dict
820
+ barrier()
821
+
822
+ def _make_optim_state_dict_compatible(
823
+ self, optim_state_dict: Dict[str, Any], og_keys_to_new: Dict[str, Set[str]]
824
+ ) -> Dict[str, Any]:
825
+ # This state dict comes in two forms: one where the state keys are integers and one where the
826
+ # keys are fully qualified parameter names. The latter case is easier to deal with here so we
827
+ # first transform the integer key form into the FQN key form.
828
+ if isinstance(optim_state_dict["param_groups"][0]["params"][0], int):
829
+ id_to_fqn: Dict[int, str] = {}
830
+ for group in optim_state_dict["param_groups"]:
831
+ new_param_names = []
832
+ for fqn, id in zip(group["param_names"], group["params"]):
833
+ fqn = fqn.replace("_fsdp_wrapped_module.", "")
834
+ id_to_fqn[id] = fqn
835
+ new_param_names.append(fqn)
836
+ group["param_names"] = new_param_names
837
+ group["params"] = new_param_names
838
+ for id in list(optim_state_dict["state"].keys()):
839
+ optim_state_dict["state"][id_to_fqn[id]] = optim_state_dict["state"].pop(id)
840
+ else:
841
+ # Otherwise we still want to clean up the param names to remove the "_fsdp_wrapped_module." prefix.
842
+ for group in optim_state_dict["param_groups"]:
843
+ group["param_names"] = [fqn.replace("_fsdp_wrapped_module.", "") for fqn in group["param_names"]]
844
+ group["params"] = [fqn.replace("_fsdp_wrapped_module.", "") for fqn in group["params"]]
845
+ assert group["param_names"] == group["params"]
846
+ for key in list(optim_state_dict["state"].keys()):
847
+ optim_state_dict["state"][key.replace("_fsdp_wrapped_module.", "")] = optim_state_dict[
848
+ "state"
849
+ ].pop(key)
850
+
851
+ # Now we can transform the state dict by renaming parameters according to `og_keys_to_new`.
852
+ # First fix param names in the state.
853
+ for og_key, new_keys in og_keys_to_new.items():
854
+ og_state = optim_state_dict["state"].pop(og_key, None)
855
+ if og_state is None:
856
+ continue
857
+ for i, new_key in enumerate(new_keys):
858
+ if i == len(new_keys) - 1:
859
+ optim_state_dict["state"][new_key] = og_state
860
+ else:
861
+ optim_state_dict["state"][new_key] = deepcopy(og_state)
862
+ # Now fix param names in the param groups.
863
+ for group in optim_state_dict["param_groups"]:
864
+ og_names = group["params"]
865
+ new_names = []
866
+ for og_key in og_names:
867
+ for new_key in og_keys_to_new[og_key]:
868
+ new_names.append(new_key)
869
+ group["params"] = new_names
870
+ group["param_names"] = new_names
871
+
872
+ return optim_state_dict
873
+
874
+ def load_checkpoint(
875
+ self,
876
+ load_path: PathOrStr,
877
+ *,
878
+ local_cache: Optional[PathOrStr] = None,
879
+ load_optimizer_state: bool = True,
880
+ device: Optional[torch.device] = None,
881
+ ) -> Tuple[Dict[str, torch.Tensor], Optional[Dict[str, Any]]]:
882
+ device = device if device is not None else torch.device("cpu")
883
+ model_state = load_state_dict(load_path, "model.pt", local_cache=local_cache, map_location=device) # type: ignore
884
+ optim_state = None
885
+ if load_optimizer_state:
886
+ optim_state = load_state_dict(load_path, "optim.pt", local_cache=local_cache, map_location=device) # type: ignore
887
+ return model_state, optim_state
888
+
889
+
890
+ class TorchNewStyleShardedCheckpointer(Checkpointer):
891
+ """
892
+ A sharded :class:`Checkpointer` that uses PyTorch's new distributed checkpointing functionality.
893
+ """
894
+
895
+ def save_checkpoint(
896
+ self,
897
+ dir: PathOrStr,
898
+ dist_model: nn.Module,
899
+ optim: Optimizer,
900
+ trainer_state: Dict[str, Any],
901
+ *,
902
+ upload_to: Optional[str] = None,
903
+ ) -> None:
904
+ assert isinstance(
905
+ dist_model, FSDP
906
+ ), f"{self.__class__.__name__} is being called to save a model where `distributed_strategy` is not FSDP."
907
+ with self._temporary_wd(dir) as checkpoint_dir:
908
+ # Save model and optim state.
909
+ save_fsdp_model_and_optim_state(
910
+ checkpoint_dir,
911
+ dist_model,
912
+ optim,
913
+ upload_to=upload_to,
914
+ save_overwrite=self.cfg.save_overwrite,
915
+ )
916
+
917
+ # Save trainer state.
918
+ log.info("Saving trainer state...")
919
+ save_state_dict(
920
+ checkpoint_dir,
921
+ f"train/rank{get_global_rank()}.pt",
922
+ trainer_state,
923
+ upload_to=upload_to,
924
+ save_overwrite=self.cfg.save_overwrite,
925
+ )
926
+
927
+ # Save config.
928
+ self._save_config(checkpoint_dir, upload_to=upload_to)
929
+
930
+ def restore_checkpoint(
931
+ self,
932
+ load_path: PathOrStr,
933
+ dist_model: nn.Module,
934
+ optim: Optimizer,
935
+ *,
936
+ local_cache: Optional[PathOrStr] = None,
937
+ load_optimizer_state: bool = True,
938
+ ) -> Dict[str, Any]:
939
+ # Load model and optimizer state in place.
940
+ log.info("Loading model and optimizer state...")
941
+ assert isinstance(
942
+ dist_model, FSDP
943
+ ), f"{self.__class__.__name__} is being called to load a model where `distributed_strategy` is not FSDP."
944
+
945
+ load_fsdp_model_and_optim_state(
946
+ load_path,
947
+ dist_model,
948
+ optim,
949
+ local_cache=local_cache,
950
+ load_optimizer_state=load_optimizer_state,
951
+ )
952
+
953
+ # Load trainer state dict.
954
+ log.info("Loading trainer state...")
955
+ try:
956
+ trainer_state = load_state_dict(
957
+ load_path, f"train/rank{get_global_rank()}.pt", local_cache=local_cache
958
+ )
959
+ except FileNotFoundError:
960
+ # Fall back to rank 0 train state.
961
+ # This can happen when we're restoring a checkpoint with a different world size.
962
+ trainer_state = load_state_dict(load_path, "train/rank0.pt", local_cache=local_cache)
963
+ barrier()
964
+ return trainer_state
965
+
966
+
967
+ class TorchLegacyShardedCheckpointer(Checkpointer):
968
+ """
969
+ A sharded :class:`Checkpointer` that just uses `torch.save()` with extra logic for handling FSDP model
970
+ and optim state.
971
+
972
+ The world size must be kept consistent when using this checkpointer.
973
+ """
974
+
975
+ def __init__(self, cfg: TrainConfig, thread_count: Optional[int] = None, use_shared_mem_impl: bool = False):
976
+ super().__init__(cfg, thread_count)
977
+ self.use_shared_mem_impl = use_shared_mem_impl
978
+
979
+ def save_checkpoint(
980
+ self,
981
+ dir: PathOrStr,
982
+ dist_model: nn.Module,
983
+ optim: Optimizer,
984
+ trainer_state: Dict[str, Any],
985
+ *,
986
+ upload_to: Optional[str] = None,
987
+ ) -> None:
988
+ assert isinstance(
989
+ dist_model, FSDP
990
+ ), f"{self.__class__.__name__} is being called to save a model where `distributed_strategy` is not FSDP."
991
+ with self._temporary_wd(dir) as checkpoint_dir:
992
+ with FSDP.state_dict_type(
993
+ dist_model,
994
+ state_dict_type=StateDictType.SHARDED_STATE_DICT,
995
+ state_dict_config=ShardedStateDictConfig(offload_to_cpu=True),
996
+ optim_state_dict_config=ShardedOptimStateDictConfig(offload_to_cpu=True),
997
+ ):
998
+ state_dict = {
999
+ "model": dist_model.state_dict(),
1000
+ "optim": FSDP.optim_state_dict(dist_model, optim),
1001
+ **trainer_state,
1002
+ }
1003
+ save_state_dict(
1004
+ checkpoint_dir,
1005
+ f"rank{get_global_rank()}.pt",
1006
+ state_dict,
1007
+ upload_to=upload_to,
1008
+ save_overwrite=self.cfg.save_overwrite,
1009
+ )
1010
+
1011
+ # Save config.
1012
+ self._save_config(checkpoint_dir, upload_to=upload_to)
1013
+
1014
+ def restore_checkpoint(
1015
+ self,
1016
+ load_path: PathOrStr,
1017
+ dist_model: nn.Module,
1018
+ optim: Optimizer,
1019
+ *,
1020
+ local_cache: Optional[PathOrStr] = None,
1021
+ load_optimizer_state: bool = True,
1022
+ ) -> Dict[str, Any]:
1023
+ assert isinstance(
1024
+ dist_model, FSDP
1025
+ ), f"{self.__class__.__name__} is being called to load a model where `distributed_strategy` is not FSDP."
1026
+ with FSDP.state_dict_type(
1027
+ dist_model,
1028
+ state_dict_type=StateDictType.SHARDED_STATE_DICT,
1029
+ state_dict_config=ShardedStateDictConfig(offload_to_cpu=True),
1030
+ optim_state_dict_config=ShardedOptimStateDictConfig(offload_to_cpu=True),
1031
+ ):
1032
+ # Deserialize state dict.
1033
+ state_dict = load_state_dict(
1034
+ load_path, f"rank{get_global_rank()}.pt", local_cache=local_cache, map_location="cpu"
1035
+ )
1036
+
1037
+ # Load model and optimizer state.
1038
+ log.info("Loading model state...")
1039
+ dist_model.load_state_dict(state_dict["model"])
1040
+ del state_dict["model"]
1041
+ if load_optimizer_state:
1042
+ log.info("Loading optimizer state...")
1043
+ load_fsdp_optim_state(dist_model, optim, state_dict["optim"])
1044
+ del state_dict["optim"]
1045
+
1046
+ barrier()
1047
+ return state_dict
1048
+
1049
+ def unshard_checkpoint(
1050
+ self,
1051
+ load_path: PathOrStr,
1052
+ *,
1053
+ local_cache: Optional[PathOrStr] = None,
1054
+ load_optimizer_state: bool = True,
1055
+ load_trainer_state: bool = True,
1056
+ device: Optional[torch.device] = None,
1057
+ ) -> Tuple[Dict[str, torch.Tensor], Optional[Dict[str, Any]], Optional[Dict[str, Any]]]:
1058
+ assert local_cache is None, "this method currently only supports local files"
1059
+ full_state_dict = self._unshard(load_path, device or torch.device("cpu"), skip_keys={"rng"})
1060
+ model_state = full_state_dict.pop("model")
1061
+ optim_state = full_state_dict.pop("optim")
1062
+ return (
1063
+ model_state,
1064
+ optim_state if load_optimizer_state else None,
1065
+ full_state_dict if load_trainer_state else None,
1066
+ )
1067
+
1068
+ def _copy_sharded_tensors_to_shared_mem(self, state: Dict, world_size: int, rank: int, key: Tuple):
1069
+ key = tuple() if key is None else key
1070
+ if isinstance(state, (list, tuple, set)):
1071
+ for i, sub_state in enumerate(state):
1072
+ self._copy_sharded_tensors_to_shared_mem(sub_state, world_size, rank, key + (i,))
1073
+ elif isinstance(state, dict):
1074
+ for name in state.keys():
1075
+ self._copy_sharded_tensors_to_shared_mem(state[name], world_size, rank, key + (name,))
1076
+ elif isinstance(state, ShardedTensor):
1077
+ self._copy_sharded_tensor_to_shared_mem(state, world_size, rank, key)
1078
+ return
1079
+ else:
1080
+ return
1081
+
1082
+ def _get_shard_placement_and_rank_sizes(
1083
+ self, shards_metadata: List[ShardMetadata], world_size: int
1084
+ ) -> Tuple[Dict[ShardMetadata, Tuple[int, int]], List[int]]:
1085
+ def shard_size(shard_md):
1086
+ return reduce((lambda x, y: x * y), shard_md.shard_sizes) # type: ignore[attr-defined]
1087
+
1088
+ rank_sizes = [0 for _ in range(world_size)]
1089
+ shard_placement: Dict[ShardMetadata, Tuple[int, int]] = {}
1090
+ for shard_md in shards_metadata:
1091
+ shard_rank = cast(_remote_device, shard_md.placement).rank()
1092
+ assert shard_rank is not None
1093
+ if shard_rank >= world_size:
1094
+ raise RuntimeError(f"Shard rank {shard_rank} exceeds world size {world_size}")
1095
+
1096
+ shard_placement[shard_md] = (shard_rank, rank_sizes[shard_rank])
1097
+ rank_sizes[shard_rank] += shard_size(shard_md)
1098
+
1099
+ return shard_placement, rank_sizes
1100
+
1101
+ def _copy_sharded_tensor_to_shared_mem(
1102
+ self, sharded_tensor: ShardedTensor, world_size: int, rank: int, key: Tuple
1103
+ ) -> Any:
1104
+ shard0_md = sharded_tensor.metadata()
1105
+ shard_placement, rank_sizes = self._get_shard_placement_and_rank_sizes(
1106
+ shard0_md.shards_metadata, world_size
1107
+ )
1108
+
1109
+ rank_size = rank_sizes[rank]
1110
+ assert rank_size >= 0
1111
+ if rank_size == 0:
1112
+ return
1113
+
1114
+ assert shard0_md.tensor_properties.dtype == torch.float32, "Expected sharded tensor to be fp32"
1115
+ numpy_type = np.float32
1116
+
1117
+ sharded_memory_name = "-".join(key + (str(rank),))
1118
+
1119
+ shm = shared_memory.SharedMemory(
1120
+ create=True, size=rank_size * np.dtype(numpy_type).itemsize, name=sharded_memory_name
1121
+ )
1122
+ np_arr = np.ndarray((rank_size,), dtype=numpy_type, buffer=shm.buf)
1123
+
1124
+ for local_shard in sharded_tensor.local_shards():
1125
+ shard_rank = cast(_remote_device, local_shard.metadata.placement).rank()
1126
+ assert shard_rank == rank
1127
+
1128
+ src = local_shard.tensor.flatten()
1129
+ shard_offset = shard_placement[local_shard.metadata][1]
1130
+
1131
+ np_arr[shard_offset : shard_offset + src.numel()] = src.numpy()
1132
+
1133
+ shm.close()
1134
+
1135
+ def _copy_sharded_data_to_shared_mem(self, world_size: int, shard_filepath: Path):
1136
+ shard_number = int(shard_filepath.name[4:-3])
1137
+ log.info("Starting unsharding shard number %d to shared memory", shard_number)
1138
+
1139
+ with self._patch_sharded_tensor_load():
1140
+ shard = torch.load(shard_filepath, map_location="cpu")
1141
+ log.debug("Done loading shard number %d", shard_number)
1142
+
1143
+ self._copy_sharded_tensors_to_shared_mem(
1144
+ shard, world_size, shard_number, (str(shard_filepath.parent).replace("/", "_"),)
1145
+ )
1146
+ log.info("Done unsharding shard number %d to shared memory", shard_number)
1147
+
1148
+ def _unshard_using_sharded_mem(
1149
+ self, state: Any, world_size: int, device: torch.device, shard_dir: PathOrStr
1150
+ ) -> Any:
1151
+ return self._unshard_state_using_shared_mem(state, world_size, device, (str(shard_dir).replace("/", "_"),))
1152
+
1153
+ def _unshard_state_using_shared_mem(
1154
+ self, state: Any, world_size: int, device: torch.device, key: Tuple
1155
+ ) -> Any:
1156
+ if isinstance(state, (list, tuple, set)):
1157
+ return state.__class__(
1158
+ self._unshard_state_using_shared_mem(sub_state, world_size, device, key + (i,))
1159
+ for i, sub_state in enumerate(state)
1160
+ )
1161
+ elif isinstance(state, dict):
1162
+ return {
1163
+ name: self._unshard_state_using_shared_mem(state[name], world_size, device, key + (name,))
1164
+ for name in state.keys()
1165
+ }
1166
+ elif isinstance(state, ShardedTensor):
1167
+ return self._unshard_tensor_using_shared_mem(state, world_size, device, key)
1168
+ elif isinstance(state, torch.Tensor):
1169
+ return state.to(device=device)
1170
+ else:
1171
+ return state
1172
+
1173
+ def _unshard_tensor_using_shared_mem(
1174
+ self, sharded_tensor: ShardedTensor, world_size: int, device: torch.device, key: Tuple
1175
+ ) -> torch.Tensor:
1176
+ shard0_md = sharded_tensor.metadata()
1177
+
1178
+ def shard_size(shard_md):
1179
+ return reduce((lambda x, y: x * y), shard_md.shard_sizes) # type: ignore[attr-defined]
1180
+
1181
+ shard_placement, rank_sizes = self._get_shard_placement_and_rank_sizes(
1182
+ shard0_md.shards_metadata, world_size
1183
+ )
1184
+
1185
+ assert shard0_md.tensor_properties.dtype == torch.float32, "Expected sharded tensor to be fp32"
1186
+ numpy_type = np.float32
1187
+
1188
+ out = torch.empty(
1189
+ *sharded_tensor.metadata().size, dtype=sharded_tensor.metadata().tensor_properties.dtype, device=device
1190
+ )
1191
+ dims = len(sharded_tensor.metadata().size)
1192
+ for shard_md, (rank, rank_offset) in shard_placement.items():
1193
+ if rank >= world_size:
1194
+ raise RuntimeError(f"Shard rank {rank} exceeds world size {world_size}")
1195
+
1196
+ sharded_memory_name = "-".join(key + (str(rank),))
1197
+ shm = shared_memory.SharedMemory(name=sharded_memory_name)
1198
+
1199
+ rank_size = rank_sizes[rank]
1200
+ assert rank_size >= 0
1201
+ if rank_size == 0:
1202
+ continue
1203
+
1204
+ np_arr = np.ndarray((rank_size,), dtype=numpy_type, buffer=shm.buf)
1205
+
1206
+ tensor = torch.from_numpy(np_arr)[rank_offset : rank_offset + shard_size(shard_md)]
1207
+ tensor = tensor.view(shard_md.shard_sizes)
1208
+
1209
+ out_narrow_view = out
1210
+ for dim in range(dims):
1211
+ out_narrow_view = out_narrow_view.narrow(
1212
+ dim,
1213
+ shard_md.shard_offsets[dim],
1214
+ shard_md.shard_sizes[dim],
1215
+ )
1216
+
1217
+ out_narrow_view.copy_(tensor)
1218
+
1219
+ shm.close()
1220
+ shm.unlink()
1221
+
1222
+ return out
1223
+
1224
+ @contextmanager
1225
+ def _patch_sharded_tensor_load(self):
1226
+ """
1227
+ Monkeypatch for torch's ShardedTensor, so we can unpickle without having torch.distributed set up.
1228
+ """
1229
+
1230
+ def _rebuild_from_type_v2_monkey(func, new_type, args, state):
1231
+ ret = func(*args)
1232
+ if type(ret) is not new_type:
1233
+ ret = ret.as_subclass(new_type)
1234
+
1235
+ # Shortcut the construction of ShardedTensor
1236
+ # This is in the top 5 of my worst hacks.
1237
+ if isinstance(ret, ShardedTensor):
1238
+ ret._local_shards, ret._metadata, _, ret._sharding_spec, ret._init_rrefs = state
1239
+ return ret
1240
+
1241
+ # The rest of this function ought to be in the top 5 of somebody else's worst hacks.
1242
+ # Tensor does define __setstate__ even though it doesn't define
1243
+ # __getstate__. So only use __setstate__ if it is NOT the one defined
1244
+ # on Tensor
1245
+ if getattr(ret.__class__, "__setstate__", torch.Tensor.__setstate__) is not torch.Tensor.__setstate__:
1246
+ ret.__setstate__(state)
1247
+ else:
1248
+ ret = torch._utils._set_obj_state(ret, state)
1249
+ return ret
1250
+
1251
+ original_rebuild_from_type_v2 = torch._tensor._rebuild_from_type_v2
1252
+ try:
1253
+ torch._tensor._rebuild_from_type_v2 = _rebuild_from_type_v2_monkey
1254
+ yield
1255
+ finally:
1256
+ torch._tensor._rebuild_from_type_v2 = original_rebuild_from_type_v2
1257
+
1258
+ def _unshard_using_shared_memory(
1259
+ self, input_dir: PathOrStr, device: torch.device, skip_keys: Optional[Set[str]] = None
1260
+ ):
1261
+ """
1262
+ This unsharding implementation consists of:
1263
+
1264
+ 1. Loading each shard on a separate process and copying their sharded tensors to shared memory.
1265
+ 2. Loading 1 shard on the main process as a base unsharded object.
1266
+ 3. Using the sharded tensors in shared memory to populate the base unsharded object.
1267
+
1268
+ This implementation is an alternative to a prior implementation that instead loaded
1269
+ all shards using threads, because that implementation turned out to
1270
+ be extremely slow (e.g. 6+ hours) sometimes when the world size was 1024.
1271
+ The current implementation is slower than the old one in many scenarios,
1272
+ but is significantly faster in the above mentioned case (e.g. 30 minutes)
1273
+ if there are enough CPUs.
1274
+
1275
+ We keep the other implementation since this once can be more unreliable,
1276
+ likely due to its dependence on a large amount of shared memory.
1277
+ """
1278
+
1279
+ input_dir = Path(input_dir)
1280
+ skip_keys = skip_keys or set()
1281
+
1282
+ shard_filepaths = list(input_dir.glob("rank*.pt"))
1283
+ world_size = len(shard_filepaths)
1284
+ if world_size == 0:
1285
+ raise RuntimeError("No shards found for unsharding")
1286
+
1287
+ log.info("Number of shards: %d", world_size)
1288
+ shard_size_gb = shard_filepaths[0].stat().st_size / (1024 * 1024 * 1024)
1289
+ min_ram_required_estimate_gb = shard_size_gb * world_size
1290
+ log.info(
1291
+ "Shards are %.2fGB each, at least %.2fGB RAM is required", shard_size_gb, min_ram_required_estimate_gb
1292
+ )
1293
+
1294
+ log.info("Copying sharded tensors to shared memory using multiple processes")
1295
+ # Copy sharded data to shared memory using multiple processes, so this process can load
1296
+ # from memory rather than disk. We spawn a new process instead of forking since shared memory
1297
+ # appears to get deleted when forked processes end for some reason.
1298
+ executor = ProcessPoolExecutor(
1299
+ mp_context=mp.get_context("spawn"), initializer=util.prepare_cli_environment
1300
+ )
1301
+ futures = []
1302
+ for shard_filepath in shard_filepaths:
1303
+ shard_rank = int(shard_filepath.name[4:-3])
1304
+
1305
+ if shard_rank >= world_size:
1306
+ raise RuntimeError(
1307
+ f"Shard rank {shard_rank} of file {shard_filepath} exceeds world size {world_size}"
1308
+ )
1309
+
1310
+ futures.append(executor.submit(self._copy_sharded_data_to_shared_mem, world_size, shard_filepath))
1311
+
1312
+ for f in as_completed(futures):
1313
+ f.result()
1314
+ executor.shutdown()
1315
+
1316
+ log.info("Loading a shard on the main process to be unsharded state")
1317
+ with self._patch_sharded_tensor_load():
1318
+ state = torch.load(shard_filepaths[0], map_location="cpu")
1319
+
1320
+ for key in skip_keys:
1321
+ if key in state:
1322
+ del state[key]
1323
+
1324
+ log.info("Unsharding from %d shards ...", world_size)
1325
+ return self._unshard_using_sharded_mem(state, world_size, device, input_dir)
1326
+
1327
+ def _unshard(self, input_dir: PathOrStr, device: torch.device, skip_keys: Optional[Set[str]] = None):
1328
+ if self.use_shared_mem_impl:
1329
+ return self._unshard_using_shared_memory(input_dir, device, skip_keys)
1330
+
1331
+ input_dir = Path(input_dir)
1332
+ skip_keys = skip_keys or set()
1333
+
1334
+ with self._patch_sharded_tensor_load():
1335
+ # We load in threads because it's faster.
1336
+ executor = ThreadPoolExecutor()
1337
+ shards_dict = {}
1338
+ for shard_name in input_dir.glob("rank*.pt"):
1339
+ log.info("Loading %s ...", shard_name)
1340
+ shard_number = int(shard_name.name[4:-3]) # shard names look like "rankXX.pt"
1341
+ shards_dict[shard_number] = executor.submit(torch.load, shard_name, map_location="cpu")
1342
+ shards = [None] * len(shards_dict)
1343
+ for rank, shard_future in shards_dict.items():
1344
+ shard = shard_future.result()
1345
+ for key in skip_keys:
1346
+ if key in shard:
1347
+ del shard[key]
1348
+ shards[rank] = shard
1349
+ assert all(shard is not None for shard in shards)
1350
+ executor.shutdown()
1351
+ del shards_dict
1352
+
1353
+ log.info("Unsharding from %d shards ...", len(shards))
1354
+
1355
+ unsharded_state_dict = self._unshard_object(shards, device=device)
1356
+ # At this point in time we need 2x memory :-(
1357
+ del shards
1358
+
1359
+ return unsharded_state_dict
1360
+
1361
+ def _unshard_object(self, os: List[Any], device: torch.device) -> Any:
1362
+ rank0_item = os[0]
1363
+ assert all(type(o) is type(rank0_item) for o in os)
1364
+ if isinstance(rank0_item, str):
1365
+ assert all(o == rank0_item for o in os)
1366
+ return rank0_item
1367
+ elif isinstance(rank0_item, (list, tuple, set)):
1368
+ assert all(len(o) == len(rank0_item) for o in os)
1369
+ return rank0_item.__class__(self._unshard_object(o, device=device) for o in zip(*os))
1370
+ elif isinstance(rank0_item, dict):
1371
+ assert all(o.keys() == rank0_item.keys() for o in os)
1372
+ return {key: self._unshard_object([o[key] for o in os], device=device) for key in rank0_item.keys()}
1373
+ elif isinstance(rank0_item, ShardedTensor):
1374
+ return self._gather(os, device=device)
1375
+ else:
1376
+ assert all(self._objects_are_equal(o, rank0_item) for o in os)
1377
+ return rank0_item
1378
+
1379
+ def _gather(self, shards: List[ShardedTensor], device: torch.device) -> torch.Tensor:
1380
+ world_size = len(shards)
1381
+ shard0_md = shards[0].metadata()
1382
+ # Make sure all shards agree on the metadata
1383
+ assert all(shard.metadata() == shard0_md for shard in shards)
1384
+ # Make sure the nth shard expects to be the nth shard.
1385
+ assert all(
1386
+ shard_md.placement.rank() == rank # type: ignore
1387
+ for rank, shard_md in enumerate(shard0_md.shards_metadata)
1388
+ )
1389
+
1390
+ def shard_size(shard_md):
1391
+ return reduce((lambda x, y: x * y), shard_md.shard_sizes) # type: ignore[attr-defined]
1392
+
1393
+ rank_sizes = [0 for _ in range(world_size)]
1394
+ max_rank_size = 0
1395
+ shard_placement: Dict[ShardMetadata, Tuple[int, int]] = {}
1396
+ for shard_md in shard0_md.shards_metadata:
1397
+ shard_rank = cast(_remote_device, shard_md.placement).rank()
1398
+ assert shard_rank is not None
1399
+
1400
+ shard_placement[shard_md] = (shard_rank, rank_sizes[shard_rank])
1401
+ rank_sizes[shard_rank] += shard_size(shard_md)
1402
+ max_rank_size = max(max_rank_size, rank_sizes[shard_rank])
1403
+
1404
+ gather_list: List[torch.Tensor] = [torch.empty((max_rank_size,)) for _ in range(world_size)]
1405
+
1406
+ datas = []
1407
+ with torch.no_grad():
1408
+ for shard in shards:
1409
+ data = torch.empty(max_rank_size)
1410
+
1411
+ for local_shard in shard.local_shards():
1412
+ src = local_shard.tensor.flatten()
1413
+ shard_offset = shard_placement[local_shard.metadata][1]
1414
+ data[shard_offset : shard_offset + src.numel()].copy_(src)
1415
+
1416
+ datas.append(data)
1417
+
1418
+ # torch.gather in a nutshell
1419
+ for rank, data in enumerate(datas):
1420
+ gather_list[rank].copy_(data)
1421
+
1422
+ full_size = shard0_md.size
1423
+ out = torch.empty(*full_size, dtype=shard0_md.tensor_properties.dtype, device=device)
1424
+ dims = len(full_size)
1425
+ for shard_md in shard0_md.shards_metadata:
1426
+ rank, rank_offset = shard_placement[shard_md]
1427
+ tensor = gather_list[rank]
1428
+ tensor = tensor[rank_offset : rank_offset + shard_size(shard_md)]
1429
+ tensor = tensor.view(shard_md.shard_sizes)
1430
+
1431
+ out_narrow_view = out
1432
+ for dim in range(dims):
1433
+ out_narrow_view = out_narrow_view.narrow(
1434
+ dim,
1435
+ shard_md.shard_offsets[dim],
1436
+ shard_md.shard_sizes[dim],
1437
+ )
1438
+
1439
+ out_narrow_view.copy_(tensor)
1440
+
1441
+ return out
1442
+
1443
+ def _objects_are_equal(self, a: Any, b: Any) -> bool:
1444
+ if type(a) is not type(b):
1445
+ return False
1446
+ if isinstance(a, np.ndarray):
1447
+ return np.array_equal(a, b)
1448
+ elif isinstance(a, torch.Tensor):
1449
+ return torch.equal(a, b)
1450
+ else:
1451
+ return a == b
1452
+
1453
+
1454
+ @dataclass
1455
+ class _LocalShardedCheckpointerMetadata(BaseConfig):
1456
+ world_size: int = field(default_factory=get_world_size)
1457
+
1458
+
1459
+ @dataclass
1460
+ class _FlatParamShard:
1461
+ full_shape: torch.Size
1462
+ shard_offsets: Tuple[int, int]
1463
+ shard_data: Optional[torch.Tensor]
1464
+
1465
+ def copy_into(self, full_tensor: torch.Tensor) -> None:
1466
+ assert self.shard_data is not None
1467
+ full_tensor_shard_view = full_tensor.view(-1)[self.shard_offsets[0] : self.shard_offsets[1] + 1]
1468
+ assert self.shard_data.shape == full_tensor_shard_view.shape
1469
+ full_tensor_shard_view.copy_(self.shard_data)
1470
+
1471
+
1472
+ class LocalShardedCheckpointer(Checkpointer):
1473
+ """
1474
+ A sharded :class:`Checkpointer` that directly saves the local FSDP flat params data.
1475
+ The optimizer state is saved directly with `torch.save()` without reformatting via FSDP methods.
1476
+
1477
+ The world size must be kept consistent when using this checkpointer. However, you can easily
1478
+ reconstruct a full unsharded model and/or optimizer state dictionary from a single Python process
1479
+ using :meth:`unshard_checkpoint()` (no distributed initialization required).
1480
+ """
1481
+
1482
+ # These correspond to metadata attributes on `torch.distributed.fsdp.flat_param.FlatParameter`.
1483
+ _FLAT_PARAM_METADATA_TO_SAVE = (
1484
+ "_fqns",
1485
+ "_shard_param_offsets",
1486
+ "_shard_indices",
1487
+ "_numels",
1488
+ "_numels_with_padding",
1489
+ "_shapes",
1490
+ "_shard_numel_padded",
1491
+ "_shard_param_infos",
1492
+ )
1493
+
1494
+ def _fsdp_modules(self, fsdp_model: FSDP) -> List[Tuple[str, FSDP]]:
1495
+ """
1496
+ Returns a list of FSDP modules with their FQN.
1497
+ """
1498
+ modules = []
1499
+ for name, module in fsdp_model.named_modules():
1500
+ if isinstance(module, FSDP):
1501
+ modules.append((name, module))
1502
+ return modules
1503
+
1504
+ def _prepare_fsdp_model(self, fsdp_model: FSDP) -> None:
1505
+ from torch.distributed.fsdp._runtime_utils import _lazy_init
1506
+
1507
+ # TODO (epwalsh): I'm not sure if this is necessary, but this is what PyTorch does before saving/loading
1508
+ # an FSDP state dict through the built-in methods.
1509
+ if torch.cuda.is_available():
1510
+ torch.cuda.synchronize()
1511
+ _lazy_init(fsdp_model, fsdp_model)
1512
+
1513
+ def _fsdp_handles(self, fsdp_model: FSDP) -> List[FlatParamHandle]:
1514
+ if version.parse(torch.__version__) < version.parse("2.1.0"):
1515
+ return fsdp_model._handles # type: ignore
1516
+ elif version.parse(torch.__version__) < version.parse("2.3.0"):
1517
+ # Handle could be None if the FSDP wrapper doesn't manage any parameters.
1518
+ if hasattr(fsdp_model, "_handle") and fsdp_model._handle is not None:
1519
+ return [fsdp_model._handle] # type: ignore
1520
+ else:
1521
+ return []
1522
+ else:
1523
+ # Need to verify FSDP internals with newer versions.
1524
+ raise NotImplementedError
1525
+
1526
+ @torch.no_grad()
1527
+ def _get_flat_param_state_to_save(self, fsdp_model: FSDP) -> Dict[str, Any]:
1528
+ self._prepare_fsdp_model(fsdp_model)
1529
+ module_data = []
1530
+ for module_fqn, fsdp_module in self._fsdp_modules(fsdp_model):
1531
+ handle_data = []
1532
+ for handle in self._fsdp_handles(fsdp_module):
1533
+ data: Dict[str, Any] = {}
1534
+ # This is a `FlatParameter` instance.
1535
+ # See `torch.distributed.fsdp.flat_param` for the API.
1536
+ flat_param = handle.flat_param
1537
+ data["flat_param.data"] = flat_param.detach()
1538
+ for key in self._FLAT_PARAM_METADATA_TO_SAVE:
1539
+ if hasattr(flat_param, key):
1540
+ data[f"flat_param.{key}"] = getattr(flat_param, key)
1541
+ handle_data.append(data)
1542
+ module_data.append({"handles": handle_data, "name": module_fqn})
1543
+ return {"modules": module_data}
1544
+
1545
+ @torch.no_grad()
1546
+ def _load_flat_param_state(self, fsdp_model: FSDP, model_state: Dict[str, Any]):
1547
+ """Load the state produced from `self._get_flat_param_state_to_save()`."""
1548
+ self._prepare_fsdp_model(fsdp_model)
1549
+ fsdp_modules = self._fsdp_modules(fsdp_model)
1550
+ assert len(model_state["modules"]) == len(fsdp_modules)
1551
+ for (_, fsdp_module), module_data in zip(fsdp_modules, model_state["modules"]):
1552
+ handles = self._fsdp_handles(fsdp_module)
1553
+ assert len(handles) == len(module_data["handles"])
1554
+ for handle, data in zip(handles, module_data["handles"]):
1555
+ flat_param = handle.flat_param
1556
+ # Make sure metadata matches.
1557
+ for key in self._FLAT_PARAM_METADATA_TO_SAVE:
1558
+ if hasattr(flat_param, key):
1559
+ assert getattr(flat_param, key) == data[f"flat_param.{key}"]
1560
+ # Load the flat sharded data.
1561
+ flat_param.copy_(data["flat_param.data"])
1562
+
1563
+ def _save_metadata(self, dir: PathOrStr, *, upload_to: Optional[str] = None) -> None:
1564
+ if get_fs_local_rank() == 0:
1565
+ log.info("Saving metadata...")
1566
+ metadata = _LocalShardedCheckpointerMetadata()
1567
+ metadata.save(metadata_path := Path(dir) / "metadata.yaml")
1568
+ if upload_to is not None and get_global_rank() == 0:
1569
+ upload_target = f"{upload_to}/metadata.yaml"
1570
+ log.info(f"Uploading {metadata_path} to {upload_target}")
1571
+ upload(metadata_path, upload_target, save_overwrite=self.cfg.save_overwrite)
1572
+
1573
+ def _load_metadata(
1574
+ self, load_path: PathOrStr, *, local_cache: Optional[PathOrStr] = None
1575
+ ) -> _LocalShardedCheckpointerMetadata:
1576
+ metadata_path = resource_path(load_path, "metadata.yaml", local_cache=local_cache)
1577
+ return _LocalShardedCheckpointerMetadata.load(metadata_path)
1578
+
1579
+ def save_checkpoint(
1580
+ self,
1581
+ dir: PathOrStr,
1582
+ dist_model: nn.Module,
1583
+ optim: Optimizer,
1584
+ trainer_state: Dict[str, Any],
1585
+ *,
1586
+ upload_to: Optional[str] = None,
1587
+ ) -> None:
1588
+ assert isinstance(
1589
+ dist_model, FSDP
1590
+ ), f"{self.__class__.__name__} is being called to save a model where `distributed_strategy` is not FSDP."
1591
+
1592
+ with self._temporary_wd(dir) as checkpoint_dir:
1593
+ # Gather local FSDP flat params data to save.
1594
+ # We also save some flat param metadata like the corresponding fully qualified names (fqns)
1595
+ # of each original parameter so we can validate that the sharding is the same when loading
1596
+ # one of these checkpoints.
1597
+ log.info("Saving local FSDP flat params data...")
1598
+ save_state_dict(
1599
+ checkpoint_dir,
1600
+ f"model/rank{get_global_rank()}.pt",
1601
+ self._get_flat_param_state_to_save(dist_model),
1602
+ upload_to=upload_to,
1603
+ save_overwrite=self.cfg.save_overwrite,
1604
+ )
1605
+
1606
+ # Save optimizer state.
1607
+ log.info("Saving local optimizer state...")
1608
+ save_state_dict(
1609
+ checkpoint_dir,
1610
+ f"optim/rank{get_global_rank()}.pt",
1611
+ optim.state_dict(),
1612
+ upload_to=upload_to,
1613
+ save_overwrite=self.cfg.save_overwrite,
1614
+ )
1615
+
1616
+ # Save trainer state.
1617
+ log.info("Saving trainer state...")
1618
+ save_state_dict(
1619
+ checkpoint_dir,
1620
+ f"train/rank{get_global_rank()}.pt",
1621
+ trainer_state,
1622
+ upload_to=upload_to,
1623
+ save_overwrite=self.cfg.save_overwrite,
1624
+ )
1625
+
1626
+ # Save metadata.
1627
+ self._save_metadata(checkpoint_dir, upload_to=upload_to)
1628
+
1629
+ # Save config. We do this last b/c the presence of a config in a remote checkpoint
1630
+ # "directory" indicates that the folder is valid, as a opposed to a partially
1631
+ # uploaded checkpoint directory that failed before completing.
1632
+ self._save_config(checkpoint_dir, upload_to=upload_to)
1633
+
1634
+ def restore_checkpoint(
1635
+ self,
1636
+ load_path: PathOrStr,
1637
+ dist_model: nn.Module,
1638
+ optim: Optimizer,
1639
+ *,
1640
+ local_cache: Optional[PathOrStr] = None,
1641
+ load_optimizer_state: bool = True,
1642
+ ) -> Dict[str, Any]:
1643
+ # Load metadata and make sure checkpoint is compatible.
1644
+ metadata = self._load_metadata(load_path, local_cache=local_cache)
1645
+ assert metadata.world_size == get_world_size()
1646
+
1647
+ # Load local FSDP flat param data.
1648
+ log.info("Loading local FSDP flat params data...")
1649
+ assert isinstance(
1650
+ dist_model, FSDP
1651
+ ), f"{self.__class__.__name__} is being called to load a model where `distributed_strategy` is not FSDP."
1652
+
1653
+ model_state = load_state_dict(
1654
+ load_path, f"model/rank{get_global_rank()}.pt", local_cache=local_cache, map_location="cpu"
1655
+ )
1656
+ self._load_flat_param_state(dist_model, model_state)
1657
+ del model_state
1658
+
1659
+ # Load local optim state.
1660
+ if load_optimizer_state:
1661
+ log.info("Loading local optimizer state...")
1662
+ optim_state = load_state_dict(
1663
+ load_path, f"optim/rank{get_global_rank()}.pt", local_cache=local_cache, map_location="cpu"
1664
+ )
1665
+ # HACK/TODO (epwalsh): When we use adaptive clipping we track the 'grad_norm_exp_avg' for every param
1666
+ # in every rank, and keep this in the optimizer state. But this causes issues when loading the
1667
+ # state since torch sees the state is non-empty for some params which would normally be empty,
1668
+ # and then assumes it should have all of the other state tensors for that param, which is doesn't.
1669
+ # So for now we just remove 'grad_norm_exp_avg' everywhere from the state, which resets that metric.
1670
+ # Not the end of the world but there's probably a better way around this without resetting
1671
+ # the metric.
1672
+ for param_id in list(optim_state["state"].keys()):
1673
+ state = optim_state["state"][param_id]
1674
+ if "grad_norm_exp_avg" in state:
1675
+ del state["grad_norm_exp_avg"]
1676
+ if len(state) == 0:
1677
+ del optim_state["state"][param_id]
1678
+ optim.load_state_dict(optim_state)
1679
+ del optim_state
1680
+
1681
+ # Load local trainer state.
1682
+ log.info("Loading local trainer state...")
1683
+ trainer_state = load_state_dict(load_path, f"train/rank{get_global_rank()}.pt", local_cache=local_cache)
1684
+ barrier()
1685
+ return trainer_state
1686
+
1687
+ def _iter_flat_param_shards(
1688
+ self, model_state: Dict[str, Any]
1689
+ ) -> Generator[Tuple[str, _FlatParamShard], None, None]:
1690
+ for module_data in model_state["modules"]:
1691
+ module_prefix = module_data["name"].replace("_fsdp_wrapped_module.", "")
1692
+ for handle in module_data["handles"]:
1693
+ flat_data = handle["flat_param.data"]
1694
+ if (num_padding := handle["flat_param._shard_numel_padded"]) > 0:
1695
+ # If there's padding in the flat param it should be on the right.
1696
+ assert (flat_data[-num_padding:] == 0).all()
1697
+ # NOTE: this changes depending on the torch version, but we don't do a version
1698
+ # check since we might be trying to unshard an old checkpoint that was stored
1699
+ # with a different torch version than we're currently running with.
1700
+ if "flat_param._shard_indices" in handle:
1701
+ # torch <=2.0.1
1702
+ param_start = handle["flat_param._shard_indices"][0]
1703
+ current_flat_index = 0
1704
+ for relative_fqn, full_shape, (offset_start, offset_end) in zip(
1705
+ handle["flat_param._fqns"][param_start:],
1706
+ handle["flat_param._shapes"][param_start:],
1707
+ handle["flat_param._shard_param_offsets"],
1708
+ ):
1709
+ root_fqn = relative_fqn if not module_prefix else f"{module_prefix}.{relative_fqn}"
1710
+ numel_shard = offset_end - offset_start + 1
1711
+ flat_param_shard = _FlatParamShard(
1712
+ full_shape=full_shape,
1713
+ shard_offsets=(offset_start, offset_end),
1714
+ shard_data=flat_data[current_flat_index : current_flat_index + numel_shard],
1715
+ )
1716
+ current_flat_index += numel_shard
1717
+ yield root_fqn, flat_param_shard
1718
+ else:
1719
+ # torch >=2.1.0
1720
+ for relative_fqn, full_shape, shard_param_info in zip(
1721
+ handle["flat_param._fqns"],
1722
+ handle["flat_param._shapes"],
1723
+ handle["flat_param._shard_param_infos"],
1724
+ ):
1725
+ if not shard_param_info.in_shard:
1726
+ continue
1727
+ root_fqn = relative_fqn if not module_prefix else f"{module_prefix}.{relative_fqn}"
1728
+ flat_param_shard = _FlatParamShard(
1729
+ full_shape=full_shape,
1730
+ shard_offsets=(
1731
+ shard_param_info.intra_param_start_idx,
1732
+ shard_param_info.intra_param_end_idx,
1733
+ ),
1734
+ shard_data=flat_data[
1735
+ shard_param_info.offset_in_shard : shard_param_info.offset_in_shard
1736
+ + shard_param_info.numel_in_shard
1737
+ ],
1738
+ )
1739
+ yield root_fqn, flat_param_shard
1740
+
1741
+ def unshard_checkpoint(
1742
+ self,
1743
+ load_path: PathOrStr,
1744
+ *,
1745
+ local_cache: Optional[PathOrStr] = None,
1746
+ load_optimizer_state: bool = True,
1747
+ load_trainer_state: bool = True,
1748
+ device: Optional[torch.device] = None,
1749
+ ) -> Tuple[Dict[str, torch.Tensor], Optional[Dict[str, Any]], Optional[Dict[str, Any]]]:
1750
+ device = device or torch.device("cpu")
1751
+ metadata = self._load_metadata(load_path, local_cache=local_cache)
1752
+
1753
+ # Gather paths model state, potentially downloading them.
1754
+ log.info("Gathering model state dicts...")
1755
+ model_state_paths = self._gather_state_dict_paths(
1756
+ load_path, "model", metadata.world_size, local_cache=local_cache
1757
+ )
1758
+
1759
+ # Load model state dicts one-by-one, materializing and populating the full parameters as we go.
1760
+ log.info("Materializing full parameters...")
1761
+ full_model_state: Dict[str, torch.Tensor] = {}
1762
+ # We keep a copy of the flat param metadata minus the actual tensors so we can reconstruct
1763
+ # the full optimizer state below without having to reload the model state dicts.
1764
+ flat_params_data: Dict[int, Dict[str, _FlatParamShard]] = defaultdict(dict)
1765
+ for rank, path in enumerate(model_state_paths):
1766
+ log.info(f"Loading shards from rank {rank}...")
1767
+ model_state = torch.load(path, map_location="cpu")
1768
+ for root_fqn, flat_param_shard in self._iter_flat_param_shards(model_state):
1769
+ if root_fqn not in full_model_state:
1770
+ log.info(
1771
+ f"Materializing full parameter '{root_fqn}' with shape {flat_param_shard.full_shape}..."
1772
+ )
1773
+ assert flat_param_shard.shard_data is not None
1774
+ full_model_state[root_fqn] = torch.empty(
1775
+ flat_param_shard.full_shape, dtype=flat_param_shard.shard_data.dtype, device=device
1776
+ )
1777
+ # Fill with NaNs so we can validate that the whole parameter has been populated
1778
+ # afterwards.
1779
+ full_model_state[root_fqn].fill_(torch.nan)
1780
+ # Copy over the local shard to the relevant part of the full parameter.
1781
+ full_param = full_model_state[root_fqn]
1782
+ log.info(f"Loading rank {rank} shard for '{root_fqn}'...")
1783
+ flat_param_shard.copy_into(full_param)
1784
+ flat_params_data[rank][root_fqn] = replace(flat_param_shard, shard_data=None)
1785
+
1786
+ log.info("Validating full parameters...")
1787
+ for key, tensor in full_model_state.items():
1788
+ if torch.isnan(tensor).any():
1789
+ raise ValueError(f"Parameter '{key}' contains NaNs, this is likely a bug with the unsharder")
1790
+
1791
+ trainer_state: Optional[Dict[str, Any]] = None
1792
+ if load_trainer_state:
1793
+ trainer_state = load_state_dict(load_path, "train/rank0.pt", local_cache=local_cache)
1794
+
1795
+ if not load_optimizer_state:
1796
+ return full_model_state, None, trainer_state
1797
+
1798
+ log.info("Gathering optim state dicts...")
1799
+ optim_state_paths = self._gather_state_dict_paths(
1800
+ load_path, "optim", metadata.world_size, local_cache=local_cache
1801
+ )
1802
+
1803
+ log.info("Materializing full optim state...")
1804
+ full_optim_state: Dict[str, Any] = {"state": defaultdict(dict)}
1805
+ fqn_to_id: Dict[str, int] = {}
1806
+ id_to_fqn: Dict[int, str] = {}
1807
+ for rank, path in enumerate(optim_state_paths):
1808
+ log.info(f"Loading sharded optim state from rank {rank}...")
1809
+ optim_state = torch.load(path, map_location="cpu")
1810
+
1811
+ # Initialize param groups.
1812
+ # We assume parameter groups are the same across all ranks.
1813
+ # The only thing that differs across ranks is the state for each local sharded param.
1814
+ if "param_groups" not in full_optim_state:
1815
+ full_optim_state["param_groups"] = optim_state["param_groups"]
1816
+ else:
1817
+ assert full_optim_state["param_groups"] == optim_state["param_groups"]
1818
+
1819
+ # Generate mapping of parameter FQNs to optimizer param IDs and vice-versa.
1820
+ if not fqn_to_id or not id_to_fqn:
1821
+ for group in full_optim_state["param_groups"]:
1822
+ for fqn, id in zip(group["param_names"], group["params"]):
1823
+ fqn = fqn.replace("_fsdp_wrapped_module.", "")
1824
+ fqn_to_id[fqn] = id
1825
+ id_to_fqn[id] = fqn
1826
+
1827
+ # Iterate over local shard state and copy into the full state.
1828
+ for id, shard_state in optim_state["state"].items():
1829
+ fqn = id_to_fqn[id]
1830
+ flat_param_shard = flat_params_data[rank].get(fqn) # type: ignore[assignment]
1831
+ full_state = full_optim_state["state"][id]
1832
+ for key, shard_value in shard_state.items():
1833
+ assert isinstance(shard_value, torch.Tensor)
1834
+ if shard_value.shape == torch.Size([]):
1835
+ # Add singleton tensors directly to full state. These should be the same across
1836
+ # all ranks.
1837
+ assert key in ("step", "grad_norm_exp_avg") # sanity check
1838
+ if key not in full_state:
1839
+ full_state[key] = shard_value.to(device)
1840
+ else:
1841
+ assert full_state[key] == shard_value
1842
+ else:
1843
+ # Otherwise we have a sharded param state.
1844
+ # If the corresponding full param state hasn't been materialized yet, do so now.
1845
+ assert flat_param_shard is not None, f"missing flat_params_data for {fqn} from rank {rank}"
1846
+ if key not in full_state:
1847
+ log.info(
1848
+ f"Materializing full state '{key}' for '{fqn}' with shape {flat_param_shard.full_shape}..."
1849
+ )
1850
+ full_state[key] = torch.empty(
1851
+ flat_param_shard.full_shape, dtype=shard_value.dtype, device=device
1852
+ )
1853
+ full_state_value = full_state[key]
1854
+
1855
+ # Copy over the local shard state to the relevant part of the full parameter state.
1856
+ log.info(f"Loading rank {rank} shard state of '{key}' for '{fqn}'...")
1857
+ replace(flat_param_shard, shard_data=shard_value).copy_into(full_state_value)
1858
+
1859
+ # Lastly, clean up the parameter names in param groups.
1860
+ for group in full_optim_state["param_groups"]:
1861
+ group["param_names"] = [n.replace("_fsdp_wrapped_module.", "") for n in group["param_names"]]
1862
+
1863
+ return full_model_state, full_optim_state, trainer_state
1864
+
1865
+ def _get_state_dict_path(
1866
+ self,
1867
+ load_path: PathOrStr,
1868
+ state_dict_type: str,
1869
+ rank: int,
1870
+ *,
1871
+ local_cache: Optional[PathOrStr] = None,
1872
+ progress=None,
1873
+ ) -> Tuple[int, Path]:
1874
+ fname = f"{state_dict_type}/rank{rank}.pt"
1875
+ return rank, resource_path(str(load_path).rstrip("/"), fname, local_cache=local_cache, progress=progress)
1876
+
1877
+ def _gather_state_dict_paths(
1878
+ self,
1879
+ load_path: PathOrStr,
1880
+ state_dict_type: str,
1881
+ world_size: int,
1882
+ *,
1883
+ local_cache: Optional[PathOrStr] = None,
1884
+ ) -> List[Path]:
1885
+ progress = get_progress_bar()
1886
+ with ThreadPoolExecutor(max_workers=self.thread_count) as executor:
1887
+ futures = []
1888
+ for rank in range(world_size):
1889
+ future = executor.submit(
1890
+ self._get_state_dict_path,
1891
+ load_path,
1892
+ state_dict_type,
1893
+ rank,
1894
+ local_cache=local_cache,
1895
+ progress=progress,
1896
+ )
1897
+ futures.append(future)
1898
+
1899
+ results: Dict[int, Path] = {}
1900
+ for future in as_completed(futures):
1901
+ rank, path = future.result()
1902
+ results[rank] = path
1903
+
1904
+ return [results[rank] for rank in range(world_size)]
1905
+
1906
+
1907
+ class OlmoCoreCheckpointer(Checkpointer):
1908
+ def save_checkpoint(
1909
+ self,
1910
+ dir: PathOrStr,
1911
+ dist_model: nn.Module,
1912
+ optim: Optimizer,
1913
+ trainer_state: Dict[str, Any],
1914
+ *,
1915
+ upload_to: Optional[str] = None,
1916
+ ) -> None:
1917
+ from olmo_core.distributed.checkpoint import ( # type: ignore
1918
+ save_model_and_optim_state,
1919
+ )
1920
+
1921
+ with self._temporary_wd(dir) as checkpoint_dir:
1922
+ log.info("Saving model and optim state...")
1923
+ if get_fs_local_rank() == 0:
1924
+ (checkpoint_dir / "model").mkdir(exist_ok=True, parents=True)
1925
+ (checkpoint_dir / "optim").mkdir(exist_ok=True, parents=True)
1926
+ (checkpoint_dir / "train").mkdir(exist_ok=True, parents=True)
1927
+
1928
+ wait_for(
1929
+ lambda: (checkpoint_dir / "model").exists(), "Waiting for checkpoint model directory", timeout=10.0
1930
+ )
1931
+ wait_for(
1932
+ lambda: (checkpoint_dir / "optim").exists(), "Waiting for checkpoint optim directory", timeout=10.0
1933
+ )
1934
+ wait_for(
1935
+ lambda: (checkpoint_dir / "train").exists(), "Waiting for checkpoint train directory", timeout=10.0
1936
+ )
1937
+
1938
+ local_files_created = save_model_and_optim_state(checkpoint_dir, dist_model, optim)
1939
+ if upload_to is not None:
1940
+ for path in local_files_created:
1941
+ path = Path(path)
1942
+ upload_target = f"{upload_to.rstrip('/')}/{path.relative_to(checkpoint_dir)}"
1943
+ log.info(f"Uploading {path} to {upload_target}...")
1944
+ upload(path, upload_target, save_overwrite=self.cfg.save_overwrite)
1945
+
1946
+ log.info("Saving trainer state...")
1947
+ save_state_dict(
1948
+ checkpoint_dir,
1949
+ f"train/rank{get_global_rank()}.pt",
1950
+ trainer_state,
1951
+ upload_to=upload_to,
1952
+ save_overwrite=self.cfg.save_overwrite,
1953
+ )
1954
+
1955
+ self._save_config(checkpoint_dir, upload_to=upload_to)
1956
+
1957
+ def restore_checkpoint(
1958
+ self,
1959
+ load_path: PathOrStr,
1960
+ dist_model: nn.Module,
1961
+ optim: Optimizer,
1962
+ *,
1963
+ local_cache: Optional[PathOrStr] = None,
1964
+ load_optimizer_state: bool = True,
1965
+ ) -> Dict[str, Any]:
1966
+ from olmo_core.distributed.checkpoint import ( # type: ignore
1967
+ load_model_and_optim_state,
1968
+ )
1969
+
1970
+ log.info("Loading model and optim state...")
1971
+ load_model_and_optim_state(load_path, dist_model, optim if load_optimizer_state else None)
1972
+
1973
+ log.info("Loading trainer state...")
1974
+ try:
1975
+ trainer_state = load_state_dict(
1976
+ load_path, f"train/rank{get_global_rank()}.pt", local_cache=local_cache
1977
+ )
1978
+ except FileNotFoundError:
1979
+ # Fall back to rank 0 train state.
1980
+ # This can happen when we're restoring a checkpoint with a different world size.
1981
+ trainer_state = load_state_dict(load_path, "train/rank0.pt", local_cache=local_cache)
1982
+
1983
+ barrier()
1984
+ return trainer_state
1985
+
1986
+ def unshard_checkpoint(
1987
+ self,
1988
+ load_path: PathOrStr,
1989
+ *,
1990
+ local_cache: Optional[PathOrStr] = None,
1991
+ load_optimizer_state: bool = True,
1992
+ load_trainer_state: bool = True,
1993
+ device: Optional[torch.device] = None,
1994
+ ) -> Tuple[Dict[str, torch.Tensor], Optional[Dict[str, Any]], Optional[Dict[str, Any]]]:
1995
+ from olmo_core.distributed.checkpoint import ( # type: ignore
1996
+ unshard_model_state,
1997
+ unshard_optim_state,
1998
+ )
1999
+
2000
+ model_state = unshard_model_state(load_path, device=device)
2001
+ optim_state: Optional[Dict[str, Any]] = None
2002
+ train_state: Optional[Dict[str, Any]] = None
2003
+ if load_optimizer_state:
2004
+ optim_state = cast(Dict[str, Any], unshard_optim_state(load_path, device=device))
2005
+ if load_trainer_state:
2006
+ train_state = load_state_dict(load_path, "train/rank0.pt", local_cache=local_cache)
2007
+ return model_state, optim_state, train_state
2008
+
2009
+
2010
+ def build_sharded_checkpointer(
2011
+ cfg: TrainConfig, *, name: Optional[ShardedCheckpointerType] = None, use_shared_mem_impl: bool = False
2012
+ ) -> Checkpointer:
2013
+ name = name or cfg.sharded_checkpointer
2014
+ if name == ShardedCheckpointerType.torch_new:
2015
+ return TorchNewStyleShardedCheckpointer(cfg)
2016
+ elif name == ShardedCheckpointerType.torch_legacy:
2017
+ return TorchLegacyShardedCheckpointer(cfg, use_shared_mem_impl=use_shared_mem_impl)
2018
+ elif name == ShardedCheckpointerType.local:
2019
+ return LocalShardedCheckpointer(cfg)
2020
+ elif name == ShardedCheckpointerType.olmo_core:
2021
+ return OlmoCoreCheckpointer(cfg)
2022
+ else:
2023
+ raise NotImplementedError(name)
config.json CHANGED
@@ -1,12 +1,10 @@
1
  {
2
- "_name_or_path": "/nfs100/dongyh/FANformer-1B",
3
  "activation_type": "swiglu",
4
  "alibi": false,
5
  "alibi_bias_max": 8.0,
6
  "architectures": [
7
  "OLMoForCausalLM"
8
  ],
9
- "att_nolinear": false,
10
  "attention_activation": null,
11
  "attention_dropout": 0.0,
12
  "attention_layer_norm": false,
@@ -25,7 +23,6 @@
25
  "embedding_layer_norm": false,
26
  "embedding_size": 50304,
27
  "eos_token_id": 50279,
28
- "ffn_activation": null,
29
  "flash_attention": true,
30
  "include_bias": false,
31
  "init_cutoff_factor": null,
@@ -55,17 +52,9 @@
55
  "rope_theta": 10000,
56
  "scale_emb_init": false,
57
  "scale_logits": false,
58
- "torch_dtype": "float32",
59
- "transformers_version": "4.49.0",
60
- "use_A": false,
61
- "use_ATF": true,
62
  "use_cache": true,
63
- "use_fpn": false,
64
- "use_fpneq": false,
65
- "use_fpnnow": false,
66
- "use_fpnpn": false,
67
- "use_mod": false,
68
- "use_mod_ffn": 0,
69
  "vocab_size": 50280,
70
  "weight_tying": true
71
  }
 
1
  {
 
2
  "activation_type": "swiglu",
3
  "alibi": false,
4
  "alibi_bias_max": 8.0,
5
  "architectures": [
6
  "OLMoForCausalLM"
7
  ],
 
8
  "attention_activation": null,
9
  "attention_dropout": 0.0,
10
  "attention_layer_norm": false,
 
23
  "embedding_layer_norm": false,
24
  "embedding_size": 50304,
25
  "eos_token_id": 50279,
 
26
  "flash_attention": true,
27
  "include_bias": false,
28
  "init_cutoff_factor": null,
 
52
  "rope_theta": 10000,
53
  "scale_emb_init": false,
54
  "scale_logits": false,
55
+ "transformers_version": "4.46.0",
 
 
 
56
  "use_cache": true,
57
+ "use_ATF": true,
 
 
 
 
 
58
  "vocab_size": 50280,
59
  "weight_tying": true
60
  }
config.py ADDED
@@ -0,0 +1,1371 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from copy import deepcopy
4
+ from dataclasses import asdict, dataclass, field
5
+ from glob import glob
6
+ from pathlib import Path
7
+ from typing import (
8
+ Any,
9
+ Dict,
10
+ Iterable,
11
+ List,
12
+ Optional,
13
+ Tuple,
14
+ Type,
15
+ TypeVar,
16
+ Union,
17
+ cast,
18
+ )
19
+
20
+ import numpy as np
21
+ import torch
22
+ from omegaconf import DictConfig, ListConfig
23
+ from omegaconf import OmegaConf as om
24
+ from omegaconf.errors import OmegaConfBaseException
25
+ from torch.distributed.fsdp import MixedPrecision, ShardingStrategy
26
+
27
+ from .aliases import PathOrStr
28
+ from .exceptions import OLMoConfigurationError
29
+ from .util import StrEnum
30
+
31
+ __all__ = [
32
+ "ActivationType",
33
+ "ActivationCheckpointingStrategy",
34
+ "BlockType",
35
+ "LayerNormType",
36
+ "InitFnType",
37
+ "ModelConfig",
38
+ "OptimizerType",
39
+ "OptimizerConfig",
40
+ "SchedulerType",
41
+ "SchedulerConfig",
42
+ "DataConfig",
43
+ "InstanceFilterConfig",
44
+ "EvaluatorConfig",
45
+ "TokenizerConfig",
46
+ "TrainConfig",
47
+ "PaddingDirection",
48
+ "TruncationDirection",
49
+ "SpeedMonitorConfig",
50
+ "WandbConfig",
51
+ "CompilerConfig",
52
+ "WandbConfig",
53
+ "DDPConfig",
54
+ "DistributedStrategy",
55
+ "DDPGradSyncMode",
56
+ "FSDPPrecision",
57
+ "FSDPWrapStrategy",
58
+ "FSDPConfig",
59
+ "SingleGPUConfig",
60
+ "CheckpointType",
61
+ ]
62
+
63
+ C = TypeVar("C", bound="BaseConfig")
64
+ D = TypeVar("D", bound="DictConfig|ListConfig")
65
+
66
+
67
+ class BaseConfig:
68
+ @classmethod
69
+ def _register_resolvers(cls, validate_paths: bool = True):
70
+ # Expands path globs into a list.
71
+ def path_glob(*paths) -> List[str]:
72
+ out = []
73
+ for path in paths:
74
+ matches = sorted(glob(path))
75
+ if not matches and validate_paths:
76
+ raise FileNotFoundError(f"{path} does not match any files or dirs")
77
+ out.extend(matches)
78
+ return out
79
+
80
+ # Chooses the first path in the arguments that exists.
81
+ def path_choose(*paths) -> str:
82
+ from .util import is_url
83
+
84
+ for path in paths:
85
+ if is_url(path) or Path(path).exists():
86
+ return path
87
+ if validate_paths:
88
+ raise FileNotFoundError(", ".join(paths))
89
+ else:
90
+ return ""
91
+
92
+ # Finds the latest checkpoint in a folder.
93
+ def path_last_checkpoint(path) -> str:
94
+ from .util import find_latest_checkpoint
95
+
96
+ latest_checkpoint = find_latest_checkpoint(path)
97
+ if latest_checkpoint is None:
98
+ if validate_paths:
99
+ raise FileNotFoundError(f"Could not find a latest checkpoint at {path}")
100
+ else:
101
+ return ""
102
+ else:
103
+ return str(latest_checkpoint)
104
+
105
+ om.register_new_resolver("path.glob", path_glob, replace=True)
106
+ om.register_new_resolver("path.choose", path_choose, replace=True)
107
+ om.register_new_resolver("path.last_checkpoint", path_last_checkpoint, replace=True)
108
+
109
+ @classmethod
110
+ def update_legacy_settings(cls, config: D) -> D:
111
+ """
112
+ Update the legacy config settings whose schemas have undergone backwards-incompatible changes.
113
+ """
114
+ return config
115
+
116
+ @classmethod
117
+ def new(cls: Type[C], **kwargs) -> C:
118
+ cls._register_resolvers()
119
+ conf = om.structured(cls)
120
+ try:
121
+ if kwargs:
122
+ conf = om.merge(conf, kwargs)
123
+ return cast(C, om.to_object(conf))
124
+ except OmegaConfBaseException as e:
125
+ raise OLMoConfigurationError(str(e))
126
+
127
+ @classmethod
128
+ def load(
129
+ cls: Type[C],
130
+ path: PathOrStr,
131
+ overrides: Optional[List[str]] = None,
132
+ key: Optional[str] = None,
133
+ validate_paths: bool = True,
134
+ ) -> C:
135
+ """Load from a YAML file."""
136
+ cls._register_resolvers(validate_paths=validate_paths)
137
+ schema = om.structured(cls)
138
+ try:
139
+ raw = om.load(str(path))
140
+ if key is not None:
141
+ raw = raw[key] # type: ignore
142
+ raw = cls.update_legacy_settings(raw)
143
+ conf = om.merge(schema, raw)
144
+ if overrides:
145
+ conf = om.merge(conf, om.from_dotlist(overrides))
146
+ return cast(C, om.to_object(conf))
147
+ except OmegaConfBaseException as e:
148
+ raise OLMoConfigurationError(str(e))
149
+
150
+ def save(self, path: PathOrStr) -> None:
151
+ """Save to a YAML file."""
152
+ om.save(config=self, f=str(path))
153
+
154
+ def asdict(self, exclude: Optional[Iterable[str]] = None) -> Dict[str, Any]:
155
+ out = asdict(self) # type: ignore
156
+ if exclude is not None:
157
+ for name in exclude:
158
+ if name in out:
159
+ del out[name]
160
+ return out
161
+
162
+ def update_with(self, **kwargs):
163
+ result = deepcopy(self)
164
+ for key, value in kwargs.items():
165
+ setattr(result, key, value)
166
+ return result
167
+
168
+
169
+ class LayerNormType(StrEnum):
170
+ default = "default"
171
+ """
172
+ The default LayerNorm implementation, equivalent to PyTorch's built-in version.
173
+ """
174
+
175
+ low_precision = "low_precision"
176
+ """
177
+ A low-precision version of the default LayerNorm.
178
+ """
179
+
180
+ rms = "rms"
181
+ """
182
+ An RMSNorm implementation. When using ``torch.compile`` this is
183
+ probably the fastest implementation.
184
+ """
185
+
186
+
187
+ class ActivationType(StrEnum):
188
+ gelu = "gelu"
189
+ relu = "relu"
190
+ swiglu = "swiglu"
191
+
192
+
193
+ class BlockType(StrEnum):
194
+ sequential = "sequential"
195
+
196
+ llama = "llama"
197
+ """
198
+ A block similar to the sequential block with slightly different
199
+ implementations of operations like attention to imitate the behavior of Llama.
200
+ """
201
+
202
+
203
+ class InitFnType(StrEnum):
204
+ mitchell = "mitchell"
205
+ """
206
+ The strategy suggested to us by Mitchell Wortsman from UW.
207
+ This uses a truncated normal distribution with an adaptive standard deviation that depends
208
+ on the size of the weights as well as the depth of the layer.
209
+ """
210
+
211
+ normal = "normal"
212
+ """
213
+ All weights are initialized from the same normal distribution.
214
+ """
215
+
216
+ kaiming_normal = "kaiming_normal"
217
+ """
218
+ All weights are initialized with the Kaiming method from a normal distribution.
219
+ Note this currently won't work with FSDP.
220
+ """
221
+
222
+ fan_in = "fan_in"
223
+ """
224
+ "Fan-in variance scaling", i.e. normal with a standard deviation of ``1/sqrt(d_in)`` where ``d_in``
225
+ is the input dimensionality of the kernel.
226
+ """
227
+
228
+ full_megatron = "full_megatron"
229
+ """
230
+ This is what metaseq calls "full megatron init". It is the init used for Llama 2.
231
+ """
232
+
233
+
234
+ @dataclass
235
+ class ModelConfig(BaseConfig):
236
+ """
237
+ OLMo (model) configuration.
238
+ """
239
+
240
+ # Note that the defaults for these attributes are equivalent to the base GPT2 model.
241
+
242
+ d_model: int = 768
243
+ """
244
+ The hidden size of the model.
245
+ """
246
+
247
+ n_heads: int = 12
248
+ """
249
+ The number of self-attention heads.
250
+ """
251
+
252
+ n_kv_heads: Optional[int] = None
253
+ """
254
+ The number of heads to use for keys and values. Defaults to `n_heads`.
255
+ Set this to ``None`` or ``n_heads`` for normal multi-head attention.
256
+ Set this to 1 for multi-query attention.
257
+ Set it to some in-between value for Llama2-style grouped query attention.
258
+ """
259
+
260
+ clip_qkv: Optional[float] = None
261
+ """
262
+ Clip QKV to this value when set.
263
+ """
264
+
265
+ n_layers: int = 12
266
+ """
267
+ The number of layers/blocks.
268
+ """
269
+
270
+ mlp_ratio: int = 4
271
+ """
272
+ The ratio of the inner MLP dimensionality to ``d_model``.
273
+ This is only used when ``mlp_hidden_size`` is not set.
274
+ """
275
+
276
+ mlp_hidden_size: Optional[int] = None
277
+ """
278
+ Set the exact hidden size for the MLP. Otherwise the inner MLP hidden size will be set to `mlp_ratio * d_model`.
279
+ """
280
+
281
+ activation_type: ActivationType = ActivationType.swiglu
282
+ """
283
+ The activation function to use within the MLP layers.
284
+ """
285
+
286
+ block_type: BlockType = BlockType.sequential
287
+ """
288
+ The transformer block implementation.
289
+ """
290
+
291
+ block_group_size: int = 1
292
+ """
293
+ The number of blocks to group together into a single parent block.
294
+ This has no affect on the number of parameters in the model and is only used to wrap groups
295
+ of blocks together with a single FSDP wrapper during training.
296
+ """
297
+
298
+ alibi: bool = False
299
+ """
300
+ If ``True``, use ALiBi embeddings. Mutually exclusive with ``rope``.
301
+ """
302
+
303
+ alibi_bias_max: float = 8.0
304
+ """
305
+ Maximum absolute value of ALiBi bias.
306
+ """
307
+
308
+ rope: bool = False
309
+ """
310
+ Use rotary positional embeddings (RoPE). Mutually exclusive with ``alibi``.
311
+ """
312
+
313
+ rope_full_precision: bool = True
314
+ """
315
+ If ``True``, apply RoPE embeddings at full precision regardless of the input type. Otherwise,
316
+ apply RoPE at the precision of the input.
317
+ """
318
+
319
+ rope_theta: int = 10_000
320
+ """
321
+ The theta setting for RoPE.
322
+ """
323
+
324
+ flash_attention: bool = False
325
+ """
326
+ If ``True``, use ``FlashAttention``.
327
+ """
328
+
329
+ attention_dropout: float = 0.1
330
+ """
331
+ The dropout probability within the attention modules.
332
+ """
333
+
334
+ multi_query_attention: Optional[bool] = None
335
+ """
336
+ Deprecated. Use n_kv_heads instead.
337
+ """
338
+
339
+ attention_layer_norm: bool = False
340
+ """
341
+ Apply layer norm to the keys and queries within the attention mechanism.
342
+ This can help stabilize training.
343
+ """
344
+
345
+ residual_dropout: float = 0.1
346
+ """
347
+ The dropout probability for the MLP and attention output within each block.
348
+ """
349
+
350
+ embedding_dropout: float = 0.1
351
+ """
352
+ The dropout probability for embeddings.
353
+ """
354
+
355
+ embedding_layer_norm: bool = False
356
+ """
357
+ Apply layer norm directly to the embeddings.
358
+ """
359
+
360
+ layer_norm_type: LayerNormType = LayerNormType.default
361
+ """
362
+ The layernorm implementation to use.
363
+ """
364
+
365
+ layer_norm_with_affine: bool = True
366
+ """
367
+ Whether to include bias and weight parameters for the layer norms.
368
+ This only affects layer norms that are immediately followed by a linear layer in the forward pass,
369
+ so everything except QK-norms. To turn off affines for QK norms as well, set :attr:`attention_layer_norm_with_affine`
370
+ to ``False``.
371
+ """
372
+
373
+ layer_norm_eps: float = 1e-05
374
+
375
+ attention_layer_norm_with_affine: bool = True
376
+ """
377
+ Toggle affine transform for the QK norms.
378
+ """
379
+
380
+ max_sequence_length: int = 1024
381
+ """
382
+ The maximum input sequence length supported by the model.
383
+ """
384
+
385
+ include_bias: bool = True
386
+ """
387
+ Whether or not to include bias parameters in linear layers.
388
+ In PaLM, they got rid of all bias terms because they found that large
389
+ models tend to have near 0 bias terms anyway.
390
+ """
391
+
392
+ bias_for_layer_norm: Optional[bool] = None
393
+ """
394
+ Whether or not to include bias parameters in layer norm.
395
+ This is separate from the include_bias parameter, because of a ROCm crash when biases are disabled in
396
+ layer norm.
397
+ When this is None (the default), it inherits the setting from include_bias.
398
+ """
399
+
400
+ scale_logits: bool = False
401
+ """
402
+ If ``True``, scale the output logits by ``1 / sqrt(d_model)``.
403
+ """
404
+
405
+ vocab_size: int = 50257
406
+ """
407
+ Vocabulary size of the model.
408
+ """
409
+
410
+ embedding_size: Optional[int] = 50304
411
+ """
412
+ The number of embeddings, i.e. the number of tokens. If set to ``None`` it will default
413
+ to ``vocab_size``. If ``vocab_size`` is not a multiple of 128, setting this to the
414
+ next multiple of 128 that's greater than ``vocab_size`` can improve throughput
415
+ substantially.
416
+ """
417
+
418
+ weight_tying: bool = True
419
+ """
420
+ Whether to tie output linear weights to the input embedding.
421
+ """
422
+
423
+ eos_token_id: int = 50256
424
+ """
425
+ The ID of the end-of-sentence special token.
426
+ """
427
+
428
+ pad_token_id: int = 50256
429
+ """
430
+ The ID of the token to use for padding. Defaults to the ID of the EOS token.
431
+ """
432
+
433
+ init_device: Optional[str] = None
434
+ """
435
+ The torch device to use when initializing the model parameters, e.g. "cpu", "cuda:0", "meta".
436
+ """
437
+
438
+ init_fn: InitFnType = InitFnType.normal
439
+ """
440
+ The weight initialization strategy.
441
+ """
442
+
443
+ init_std: float = 0.02
444
+ """
445
+ The standard deviation to use when initializing weights with a "fixed distribution" ``init_fn``, such
446
+ as "normal".
447
+ """
448
+
449
+ init_cutoff_factor: Optional[float] = None
450
+ """
451
+ A positive factor used to scale the cutoff values when initializing weights with a "fixed distribution" ``init_fn``, such
452
+ as "normal". Setting this to None means values are not cutoff.
453
+ """
454
+
455
+ precision: Optional[str] = None
456
+ """
457
+ Precision used to train/evaluate with. You shouldn't set this directly.
458
+ See :data:`TrainConfig.precision` instead.
459
+ """
460
+
461
+ scale_emb_init: bool = False
462
+ """
463
+ If ``True``, embeddings are scaled up by ``sqrt(d_model)`` during initialization.
464
+ Currently this is only used with `full_megatron` init when ``emb_init_std`` is unset.
465
+ """
466
+
467
+ emb_init_std: Optional[float] = None
468
+ """
469
+ Override the standard deviation to use when initializing the embedding weights.
470
+ """
471
+
472
+ norm_after: bool = False
473
+ """
474
+ Apply norm after the attention/feedforward layers rather than before, as introduced in the Swin transformer paper (Liu et al).
475
+ """
476
+
477
+ use_ATF: Optional[bool] = False
478
+
479
+ p_ratio: float = 0.25
480
+
481
+ attention_activation: Optional[str] = None
482
+
483
+ @property
484
+ def effective_n_kv_heads(self) -> int:
485
+ if self.n_kv_heads is None:
486
+ if self.multi_query_attention is True:
487
+ return 1
488
+ else:
489
+ return self.n_heads
490
+ else:
491
+ if self.multi_query_attention is None:
492
+ return self.n_kv_heads
493
+ if self.multi_query_attention:
494
+ n_kv_heads_should_be = 1
495
+ else:
496
+ n_kv_heads_should_be = self.n_heads
497
+ if self.n_kv_heads == n_kv_heads_should_be:
498
+ return n_kv_heads_should_be
499
+ else:
500
+ raise OLMoConfigurationError(
501
+ "You can't set `multi_query_attention` and `n_kv_heads` at the same time."
502
+ )
503
+
504
+
505
+ class OptimizerType(StrEnum):
506
+ lionw = "lionw"
507
+ adamw = "adamw"
508
+
509
+
510
+ @dataclass
511
+ class OptimizerConfig(BaseConfig):
512
+ name: OptimizerType = OptimizerType.lionw
513
+ learning_rate: float = 1.0e-4
514
+ weight_decay: float = 0.01
515
+ betas: Tuple[float, float] = (0.9, 0.95)
516
+ eps: float = 1e-5
517
+
518
+ no_decay_norm_and_bias: Optional[bool] = None
519
+ """
520
+ Deprecated. Use ``decay_norm_and_bias`` and ``decay_embeddings`` instead.
521
+ """
522
+
523
+ selective_updates: bool = False
524
+ """
525
+ If ``True``, optimizer parameter and state updates are skipped when the corresponding gradient is 0.
526
+ """
527
+
528
+ decay_norm_and_bias: bool = False
529
+ decay_embeddings: bool = False
530
+ metrics_log_interval: Optional[int] = None
531
+ """
532
+ The interval with which to collect and log detailed parameter-specific metrics.
533
+ This only applies when logging to W&B, since these metrics won't be logged to the console.
534
+ If not set, defaults to the wandb `log_interval`.
535
+ """
536
+
537
+ record_update_metrics: bool = False
538
+ """
539
+ Whether to record detailed metrics about the optimizer's parameter updates, like the norm and max
540
+ of the update with AdamW.
541
+ """
542
+
543
+ def __post_init__(self):
544
+ self.betas = tuple(self.betas) # type: ignore[assignment]
545
+
546
+ @classmethod
547
+ def update_legacy_settings(cls, config: D) -> D:
548
+ new_config = config.copy()
549
+ if om.is_dict(new_config):
550
+ assert isinstance(new_config, DictConfig)
551
+
552
+ if hasattr(new_config, "name") and new_config.name == "decoupled_lionw":
553
+ new_config.name = "lionw"
554
+ if hasattr(new_config, "eps"):
555
+ del new_config.eps
556
+
557
+ return new_config
558
+
559
+
560
+ class SchedulerType(StrEnum):
561
+ cosine_with_warmup = "cosine_with_warmup"
562
+ linear_with_warmup = "linear_with_warmup"
563
+ inverse_sqrt_with_warmup = "inverse_sqrt_with_warmup"
564
+ max_scheduler = "max_scheduler"
565
+ constant = "constant"
566
+ cosine_linear_envelope = "cosine_linear_envelope"
567
+ constant_with_warmup = "constant_with_warmup"
568
+
569
+
570
+ class SchedulerUnits(StrEnum):
571
+ steps = "steps"
572
+ tokens = "tokens"
573
+
574
+
575
+ @dataclass
576
+ class SchedulerConfig(BaseConfig):
577
+ name: SchedulerType = SchedulerType.cosine_with_warmup
578
+ units: SchedulerUnits = SchedulerUnits.steps
579
+ t_warmup: Union[int, float] = 100
580
+ t_max: Optional[Union[int, float]] = None
581
+ alpha_f: float = 0.1
582
+
583
+ grad_clip_warmup_steps: Optional[Union[int, float]] = None
584
+ """
585
+ The warmup period for which the max grad norm (or norm ratio) will be set to its
586
+ warmup value of `max_grad_norm * grad_clip_warmup_factor`.
587
+ """
588
+
589
+ grad_clip_warmup_factor: Optional[float] = None
590
+ """
591
+ The ratio of the max allowed gradient norm (or norm ratio) for clipping during the warmup period
592
+ vs after the warmup period.
593
+ """
594
+
595
+ warmup_min_lr: Optional[float] = None
596
+ """
597
+ The starting LR during the warmup period. If not set this defaults to 10% of
598
+ the target LR.
599
+ """
600
+
601
+
602
+ class PaddingDirection(StrEnum):
603
+ right = "right"
604
+ left = "left"
605
+
606
+
607
+ @dataclass
608
+ class InstanceFilterConfig(BaseConfig):
609
+ repetition_max_period: int = 13
610
+ repetition_min_period: int = 1
611
+ repetition_max_count: int = 32
612
+
613
+
614
+ @dataclass
615
+ class DataConfig(BaseConfig):
616
+ paths: Optional[List[str]] = None
617
+ memmap_dtype: str = "uint16"
618
+ datasets: Optional[Dict[str, List[str]]] = None
619
+ label_mask_paths: Optional[List[str]] = None
620
+ pad_direction: PaddingDirection = PaddingDirection.right
621
+ generate_attention_mask: bool = False
622
+ generate_doc_lengths: bool = False
623
+ num_workers: int = 0
624
+ drop_last: bool = False
625
+ pin_memory: bool = False
626
+ prefetch_factor: Optional[int] = None
627
+ persistent_workers: bool = False
628
+ timeout: int = 0
629
+ seed: Optional[int] = None
630
+ instance_filter: Optional[InstanceFilterConfig] = None
631
+ custom_dataset: Optional[CustomDatasetConfig] = None
632
+
633
+ @property
634
+ def effective_memmap_dtype(self):
635
+ try:
636
+ # getattr will check this is part of numpy module, while np.dtype will check
637
+ # if this is a valid numpy dtype.
638
+ np.dtype(dtype := getattr(np, self.memmap_dtype))
639
+ except (AttributeError, TypeError) as e:
640
+ raise TypeError(f"Value {self.memmap_dtype} is not a valid numpy type") from e
641
+ return dtype
642
+
643
+
644
+ @dataclass
645
+ class CustomDatasetCollatorConfig(BaseConfig):
646
+ input_id_field: str = "input_ids" #: The field in the dataset items that contains the input token IDs.
647
+ attention_mask_field: Optional[str] = None #: The field in the dataset items that contains the attention mask.
648
+ attention_bias_field: Optional[str] = None #: The field in the dataset items that contains the attention bias.
649
+ label_mask_field: Optional[str] = None #: The field in the dataset items that contains the label mask.
650
+ index_field: Optional[str] = None #: The field in the dataset items that contains the index of the item.
651
+ instance_mask_field: Optional[str] = None #: The field in the dataset items that contains the instance mask.
652
+ doc_lens_field: Optional[str] = None #: The field in the dataset items that contains the document lengths.
653
+ metadata_field: Optional[str] = None #: The field in the dataset items that contains the metadata.
654
+
655
+
656
+ @dataclass
657
+ class CustomDatasetConfig(BaseConfig):
658
+ name: str #: The name of the custom dataset class or function that will be used to load the dataset.
659
+ module: Optional[
660
+ str
661
+ ] = None #: The module where the custom dataset class is defined. If not set, the module will be inferred from the class name.
662
+ args: Optional[Dict[str, Any]] = None #: The arguments to pass to the custom dataset class or function
663
+ collate_fn: Optional[
664
+ str
665
+ ] = None #: The name of the collate function to use for the custom dataset. Assumes the collate function is defined in the same module as the custom dataset class unless specified otherwise using the full object path.
666
+ token_field: Optional[str] = None #: The field in the dataset items that contains the tokenized text.
667
+ collate_config: Optional[CustomDatasetCollatorConfig] = field(
668
+ default_factory=CustomDatasetCollatorConfig
669
+ ) #: The configuration for the collate function to use for the custom dataset.
670
+
671
+
672
+ class EvaluatorType(StrEnum):
673
+ downstream = "downstream"
674
+ lm = "lm"
675
+
676
+
677
+ @dataclass
678
+ class EvaluatorConfig(BaseConfig):
679
+ label: str
680
+ type: EvaluatorType = EvaluatorType.lm
681
+ data: DataConfig = field(default_factory=DataConfig)
682
+ device_eval_batch_size: Optional[int] = None
683
+ subset_num_batches: Optional[int] = None
684
+
685
+
686
+ class TruncationDirection(StrEnum):
687
+ right = "right"
688
+ left = "left"
689
+
690
+
691
+ @dataclass
692
+ class TokenizerConfig(BaseConfig):
693
+ identifier: str = "gpt2"
694
+ truncate_direction: TruncationDirection = TruncationDirection.right
695
+
696
+
697
+ @dataclass
698
+ class WandbConfig(BaseConfig):
699
+ project: Optional[str] = None
700
+ entity: Optional[str] = "ai2-llm"
701
+ group: Optional[str] = None
702
+ name: Optional[str] = None
703
+ tags: Optional[List[str]] = field(default_factory=lambda: ["watching"])
704
+ log_artifacts: bool = False
705
+ rank_zero_only: bool = True
706
+ log_interval: int = 1
707
+
708
+
709
+ @dataclass
710
+ class SpeedMonitorConfig(BaseConfig):
711
+ window_size: int = 100
712
+ gpu_flops_available: Optional[Union[float, int]] = None
713
+
714
+
715
+ @dataclass
716
+ class CompilerConfig(BaseConfig):
717
+ mode: Optional[str] = None
718
+ """
719
+ The mode to compile the model in. At the moment this can be "default",
720
+ "reduce-overhead" (useful for smaller models/batches), or "max-autotune"
721
+ (the fastest for larger models, but takes a long time to compile).
722
+ """
723
+
724
+ fullgraph: bool = False
725
+ """
726
+ Whether it is OK to break model into several subgraphs when compiling.
727
+ Note that this is not compatible with FSDP.
728
+ """
729
+
730
+ backend: str = "inductor"
731
+ """
732
+ The backend to use.
733
+ """
734
+
735
+ dynamic: Optional[bool] = None
736
+ """
737
+ From the torch docs:
738
+
739
+ Use dynamic shape tracing. When this is True, we will up-front attempt to generate a kernel that is as dynamic
740
+ as possible to avoid recompilations when sizes change. This may not always work as some
741
+ operations/optimizations will force specialization; use TORCH_LOGS=dynamic to debug overspecialization. When
742
+ this is False, we will NEVER generate dynamic kernels, we will always specialize. By default (None), we
743
+ automatically detect if dynamism has occurred and compile a more dynamic kernel upon recompile.
744
+ """
745
+
746
+
747
+ class DistributedStrategy(StrEnum):
748
+ ddp = "ddp"
749
+ """
750
+ Wrap OLMo in torch.nn.parallel.DistributedDataParallel to train across ranks.
751
+ """
752
+
753
+ fsdp = "fsdp"
754
+ """
755
+ Wrap OLMo in torch.distributed.fsdp.FullyShardedDataParallel to train across ranks.
756
+ """
757
+
758
+ single = "single"
759
+ """
760
+ Train on a single device, i.e., do not distribute training. For development and debugging.
761
+ """
762
+
763
+
764
+ class DDPGradSyncMode(StrEnum):
765
+ batch = "batch"
766
+ """
767
+ Synchronize gradients after computation at each bucket only at the last micro-batch.
768
+ This is slightly faster than gradient syncs across each micro-batch but will consume more memory.
769
+ Can use this mode only when `find_unused_params` is set to False.
770
+ """
771
+
772
+ micro_batch = "micro_batch"
773
+ """
774
+ Synchronize gradients after computation at each bucket per micro-batch.
775
+ This will be slightly slower than gradient sync at the last micro-batch, but will consume less memory.
776
+ Can use this mode with both option of `find_unused_params` but specifically recommended to use with `find_unused_params`
777
+ set to True, to prevent errors.
778
+ """
779
+
780
+
781
+ @dataclass
782
+ class DDPConfig(BaseConfig):
783
+ grad_sync_mode: DDPGradSyncMode = DDPGradSyncMode.batch
784
+ """
785
+ Gradient sync mode for DDP
786
+
787
+ Note: When `find_unused_params` is set, set `grad_sync_mode` to `micro_batch` as different micro-batches might activate
788
+ different parts of the model, ex- MOEs.
789
+ """
790
+
791
+ find_unused_params: bool = False
792
+ """
793
+ (from torch documentation)
794
+
795
+ This mode allows running backward on a subgraph of the model, and DDP finds out which parameters
796
+ are involved in the backward pass by traversing the autograd graph from the model output and marking
797
+ all unused parameters as ready for reduction. Note that traversing the autograd graph introduces extra overheads,
798
+ so applications should only set find_unused_parameters to True when necessary.
799
+ """
800
+
801
+
802
+ class FSDPWrapStrategy(StrEnum):
803
+ by_block = "by_block"
804
+ """
805
+ Wrap each OLMo block with its own FSDP instance.
806
+ """
807
+
808
+ by_block_and_size = "by_block_and_size"
809
+ """
810
+ Like 'by_block' but `wte` and `ff_out` will be wrapped separately as well.
811
+ """
812
+
813
+ by_block_group = "by_block_group"
814
+ """
815
+ Wrap each block group together into its own FSDP instance.
816
+ This requires :attr:`~ModelConfig.block_group_size` to be bigger than 1.
817
+ """
818
+
819
+ by_block_group_and_size = "by_block_group_and_size"
820
+ """
821
+ Like 'by_block_group' but `wte` and `ff_out` will be wrapped separately as well.
822
+ """
823
+
824
+ size_based = "size_based"
825
+ """
826
+ Used PyTorch's default size-based auto wrap policy.
827
+ """
828
+
829
+ one_in_two = "one_in_two"
830
+ one_in_three = "one_in_three"
831
+ one_in_four = "one_in_four"
832
+ one_in_five = "one_in_five"
833
+
834
+
835
+ class FSDPPrecision(StrEnum):
836
+ pure = "pure"
837
+ """
838
+ Equivalent to :class:`torch.distributed.fsdp.MixedPrecision` with ``param_dtype``, ``reduce_dtype``,
839
+ and ``buffer_dtype`` all set to the autocast precision data type.
840
+ """
841
+
842
+ mixed = "mixed"
843
+ """
844
+ Equivalent to :class:`torch.distributed.fsdp.MixedPrecision` with ``param_dtype``, and ``buffer_dtype``
845
+ set to the autocast precision data type, while ``reduce_dtype`` is set to fp32.
846
+ """
847
+
848
+
849
+ @dataclass
850
+ class FSDPConfig(BaseConfig):
851
+ use_orig_params: bool = True
852
+ """
853
+ This must be ``True`` if using ``compile`` or you want to track the parameter norm during training.
854
+ """
855
+
856
+ sharding_strategy: ShardingStrategy = ShardingStrategy.FULL_SHARD
857
+
858
+ wrapping_strategy: Optional[FSDPWrapStrategy] = None
859
+ """
860
+ The wrapping strategy to use. If ``None``, the default, the model is wrapped with a single top-level
861
+ FSDP instance.
862
+ """
863
+
864
+ precision: Optional[FSDPPrecision] = FSDPPrecision.pure
865
+
866
+ hybrid_sharding_num_model_replicas: Optional[int] = None
867
+ """
868
+ The number of model instances, when using a hybrid sharding strategy.
869
+ If not ``None``, this must divide the total number of nodes. If ``None``, the default,
870
+ a model instance is used per node (as determined by ``get_world_size() // get_local_world_size()``).
871
+ PyTorch's default HSDP behavior matches this default behavior.
872
+ """
873
+
874
+
875
+ @dataclass
876
+ class SingleGPUConfig(BaseConfig):
877
+ device: str = "auto"
878
+ """
879
+ Device to run single-device training.
880
+ """
881
+
882
+ def get_device(self):
883
+ if self.device == "auto":
884
+ if torch.backends.mps.is_available():
885
+ return torch.device("mps")
886
+ elif torch.cuda.is_available():
887
+ return torch.device("cuda")
888
+ else:
889
+ return torch.device("cpu")
890
+ elif self.device == "mps" and not torch.backends.mps.is_available():
891
+ raise OLMoConfigurationError("MPS not available.")
892
+ elif self.device == "cuda" and not torch.cuda.is_available():
893
+ raise OLMoConfigurationError("CUDA not available.")
894
+ else:
895
+ return torch.device(self.device)
896
+
897
+
898
+ class CheckpointType(StrEnum):
899
+ sharded = "sharded"
900
+ unsharded = "unsharded"
901
+ sharded_ephemeral = "sharded_ephemeral"
902
+
903
+
904
+ class ShardedCheckpointerType(StrEnum):
905
+ torch_new = "torch_new"
906
+ torch_legacy = "torch_legacy"
907
+ local = "local"
908
+ olmo_core = "olmo_core"
909
+
910
+
911
+ class ActivationCheckpointingStrategy(StrEnum):
912
+ whole_layer = "whole_layer"
913
+ """
914
+ Checkpoint every transformer layer.
915
+ """
916
+
917
+ one_in_two = "one_in_two"
918
+ """
919
+ Checkpoint one in two transformer layers.
920
+ """
921
+
922
+ one_in_three = "one_in_three"
923
+ """
924
+ Checkpoint one in three transformer layers.
925
+ """
926
+
927
+ one_in_four = "one_in_four"
928
+ """
929
+ Checkpoint one in four transformer layers.
930
+ """
931
+
932
+ one_in_eight = "one_in_eight"
933
+ """
934
+ Checkpoint one in eight transformer layers.
935
+ """
936
+
937
+ two_in_three = "two_in_three"
938
+ """
939
+ Checkpoint two out of every three transformer layers.
940
+ """
941
+
942
+ three_in_four = "three_in_four"
943
+ """
944
+ Checkpoint three out of four of every transformer layers.
945
+ """
946
+
947
+ fine_grained = "fine_grained"
948
+ """
949
+ Focus checkpointing on where it is cheap to recompute and saves most memory.
950
+ """
951
+
952
+
953
+ @dataclass
954
+ class TrainConfig(BaseConfig):
955
+ """
956
+ OLMo training configuration.
957
+ """
958
+
959
+ run_name: Optional[str] = None
960
+ """
961
+ The name of the run.
962
+ """
963
+
964
+ seed: int = 6198
965
+ """
966
+ Used to seed all initial RNG states.
967
+ """
968
+
969
+ epoch: Optional[int] = None
970
+ """
971
+ Increment this when starting a new epoch.
972
+ """
973
+
974
+ dry_run: bool = False
975
+ """
976
+ If ``True``, don't actually train.
977
+ """
978
+
979
+ model: ModelConfig = field(default_factory=ModelConfig)
980
+ """
981
+ OLMo Model configuration.
982
+ """
983
+
984
+ optimizer: OptimizerConfig = field(default_factory=OptimizerConfig)
985
+ """
986
+ Optimizer configuration.
987
+ """
988
+
989
+ scheduler: SchedulerConfig = field(default_factory=SchedulerConfig)
990
+ """
991
+ Learning rate scheduler configuration.
992
+ """
993
+
994
+ data: DataConfig = field(default_factory=DataConfig)
995
+ """
996
+ Training data configuration.
997
+ """
998
+
999
+ restore_dataloader: bool = True
1000
+ """
1001
+ When restarting, restore the data loader to where it left off.
1002
+ If you restarting in order to train on a different dataset, set this to ``False``.
1003
+ """
1004
+
1005
+ fast_forward_batches: Optional[int] = None
1006
+ """
1007
+ When restarting, use this to fast-forward the dataloader beyond the last checkpoint.
1008
+ This can be useful when restarting due to a loss spike in order to skip the data that
1009
+ corresponded to the spike.
1010
+ """
1011
+
1012
+ evaluators: List[EvaluatorConfig] = field(default_factory=list)
1013
+ """
1014
+ Evaluation configurations.
1015
+ """
1016
+
1017
+ eval_interval: int = 1000
1018
+ """
1019
+ How often (in terms of batches) to run evaluations.
1020
+ """
1021
+
1022
+ tokenizer: TokenizerConfig = field(default_factory=TokenizerConfig)
1023
+ """
1024
+ Tokenizer configuration.
1025
+ """
1026
+
1027
+ save_folder: str = "./"
1028
+ """
1029
+ The directory to save checkpoints to.
1030
+ """
1031
+
1032
+ remote_save_folder: Optional[str] = None
1033
+ """
1034
+ A folder in a cloud bucket to upload saved checkpoints to.
1035
+ """
1036
+
1037
+ canceled_check_interval: int = 50
1038
+ """
1039
+ How often (in batches) to check if the run has been canceled or reached its time limit.
1040
+ """
1041
+
1042
+ save_interval: Optional[int] = 1000
1043
+ """
1044
+ How often (in terms of steps) to save sharded training state checkpoints.
1045
+ """
1046
+
1047
+ save_interval_unsharded: Optional[int] = None
1048
+ """
1049
+ How often (if at all) to save unsharded training state checkpoint.
1050
+ For large models it can be costly to save these, so it usually makes sense to save
1051
+ these less often than regular (sharded) training checkpoints.
1052
+ """
1053
+
1054
+ save_interval_ephemeral: Optional[int] = None
1055
+ """
1056
+ How often (if at all) to save ephemeral sharded checkpoints. These checkpoints are the same
1057
+ as those saved every `save_interval` except that at most only the most recent one of these is kept.
1058
+ This is useful when you want to checkpoint often for restarts in case of failures, but don't
1059
+ want to keep the majority of these checkpoints.
1060
+
1061
+ For example, suppose you want to keep your checkpoints at every 1000 steps, but you also want to save
1062
+ a temporary checkpoint every 100 steps in case your job fails. In that case you would
1063
+ set `save_interval=1000` and `save_interval_ephemeral=100`.
1064
+ """
1065
+
1066
+ save_num_checkpoints_to_keep: int = -1
1067
+ """
1068
+ How many sharded checkpoints to keep.
1069
+ """
1070
+
1071
+ save_num_unsharded_checkpoints_to_keep: int = -1
1072
+ """
1073
+ How many unsharded checkpoints to keep.
1074
+ """
1075
+
1076
+ save_overwrite: bool = False
1077
+ """
1078
+ If ``True``, overwrite any conflicting checkpoint files.
1079
+ """
1080
+
1081
+ force_save_unsharded: bool = False
1082
+ """
1083
+ Save an unsharded checkpoint before training (even during a dry run).
1084
+ Use this option with `--load-path={PATH}` and `--dry_run` to convert a sharded
1085
+ checkpoint into an unsharded checkpoint.
1086
+ """
1087
+
1088
+ no_pre_train_checkpoint: bool = False
1089
+ """
1090
+ Skip saving pre-train checkpoint.
1091
+ """
1092
+
1093
+ load_path: Optional[str] = None
1094
+ """
1095
+ The path to a training checkpoint to restore/resume from. If not set, then training begins from scratch.
1096
+
1097
+ Note that you can make use of the "path.last_checkpoint" Omegaconfig YAML resolver here, which takes
1098
+ a local or remote directory and resolves to the latest checkpoint (sharded or unsharded) in that directory.
1099
+ For example,
1100
+
1101
+ ```bash
1102
+ --load_path='${path.last_checkpoint:s3://ai2-llm/checkpoints/7b/v1_5-mix-run-001}'
1103
+ ```
1104
+
1105
+ If `try_load_latest_save` is set and saved checkpoints exist, then `load_path` will be overriden
1106
+ by the latest saved checkpoint.
1107
+ """
1108
+
1109
+ load_path_sharded_checkpointer: Optional[ShardedCheckpointerType] = None
1110
+ """
1111
+ The sharded checkpointer type to use to load the initial checkpoint from ``load_path``.
1112
+ """
1113
+
1114
+ try_load_latest_save: bool = False
1115
+ """
1116
+ If set, then training will be resumed from the latest checkpoint in the local save folder, falling
1117
+ back to the latest checkpoint in the remote save folder if none exists. If there are no checkpoints
1118
+ in the local and remote save folders, then checkpoint loading will fall back to `load_path`.
1119
+ """
1120
+
1121
+ reset_optimizer_state: bool = False
1122
+ """
1123
+ When this is set, we restore the model from a checkpoint (if given), but we leave the optimizer uninitialized.
1124
+ We also set a new learning rate schedule that does a new warmup, such that it intercepts the original learning
1125
+ curve (according to the current learning rate schedule settings), and continues from there.
1126
+ """
1127
+
1128
+ reset_trainer_state: bool = False
1129
+ """
1130
+ When this is set we don't restore the trainer state from a checkpoint.
1131
+ """
1132
+
1133
+ sharded_checkpointer: ShardedCheckpointerType = ShardedCheckpointerType.torch_legacy
1134
+ """
1135
+ The name of the sharded checkpointer to use to save (sharded) checkpoints throughout training.
1136
+ """
1137
+
1138
+ new_style_checkpoints: Optional[bool] = None
1139
+ """
1140
+ Deprecated. Use ``sharded_checkpointer`` instead.
1141
+ """
1142
+
1143
+ max_duration: Union[int, str] = 10000
1144
+ """
1145
+ How long to train for.
1146
+
1147
+ If specified without a unit (the default), the units are assumed to be steps.
1148
+ You can also specify this in terms of tokens, for example: `max_duration="2e12T"` means train until
1149
+ 2 trillion tokens.
1150
+ """
1151
+
1152
+ global_train_batch_size: int = 512
1153
+ """
1154
+ The effective global batch size.
1155
+ """
1156
+
1157
+ device_train_batch_size: Optional[int] = None # calculated automatically
1158
+ """
1159
+ Don't set this manually. This will be set to ``global_train_batch_size // world_size``.
1160
+ """
1161
+
1162
+ device_train_microbatch_size: int = 16
1163
+ """
1164
+ The number of instances passed to the model in a single forward-backward pass. You should set
1165
+ this as large as you can based on available GPU memory.
1166
+ """
1167
+
1168
+ device_eval_batch_size: int = 16
1169
+ """
1170
+ The number of evaluation instances passed to the model in a single forward pass on each device.
1171
+ """
1172
+
1173
+ eval_subset_num_batches: int = -1
1174
+ """
1175
+ The number of batches to use for downstream evaluation from each dataset.
1176
+ """
1177
+
1178
+ eval_on_load: bool = False
1179
+ """
1180
+ When resuming from a checkpoint, run the evaluation loop right away.
1181
+ """
1182
+
1183
+ device_train_grad_accum: Optional[int] = None # calculated automatically
1184
+ """
1185
+ Don't set this manually. This will be set to ``device_train_batch_size // device_train_microbatch_size``.
1186
+ """
1187
+
1188
+ max_grad_norm: Optional[float] = None
1189
+ """
1190
+ Clip gradient norms to this value if set.
1191
+ """
1192
+
1193
+ max_grad_norm_ratio: Optional[float] = None
1194
+ """
1195
+ If set, gradient norms will be clipped to `max_grad_norm_ratio * exp_avg(norm(grad))`.
1196
+ This takes priority over `max_grad_norm` when set.
1197
+ """
1198
+
1199
+ precision: Optional[str] = None
1200
+ """
1201
+ Precision to train with (e.g. "amp_bf16", "amp_fp16", or "fp32").
1202
+ """
1203
+
1204
+ wandb: Optional[WandbConfig] = None
1205
+ """
1206
+ Weights & Biases configuration.
1207
+ """
1208
+
1209
+ speed_monitor: SpeedMonitorConfig = field(default_factory=SpeedMonitorConfig)
1210
+ """
1211
+ Speed monitor configuration.
1212
+ """
1213
+
1214
+ console_log_interval: int = 1
1215
+ """
1216
+ How often to log to the console.
1217
+ """
1218
+
1219
+ gen1_gc_interval: Optional[int] = 1
1220
+ """
1221
+ How often (in steps) to run generation 1 garbage collection.
1222
+ Set to ``None`` to use automatic garbage collection (i.e. we don't mess with it).
1223
+ """
1224
+
1225
+ compile: Optional[CompilerConfig] = None
1226
+ """
1227
+ Settings for compiling the model with ``torch.compile()``.
1228
+ """
1229
+
1230
+ distributed_strategy: Optional[DistributedStrategy] = DistributedStrategy.fsdp
1231
+ """
1232
+ Distributed strategy for OLMo model (eg. single GPU, DDP, FSDP).
1233
+ """
1234
+
1235
+ fsdp: Optional[FSDPConfig] = field(default_factory=FSDPConfig)
1236
+ """
1237
+ Fully sharded data parallel settings.
1238
+ """
1239
+
1240
+ ddp: Optional[DDPConfig] = None
1241
+ """
1242
+ DDP settings.
1243
+ """
1244
+
1245
+ single: SingleGPUConfig = field(default_factory=lambda: SingleGPUConfig(device="auto"))
1246
+ """
1247
+ Single device settings for GPU/CPU/MPS. Defaults to auto-detect the best device.
1248
+ """
1249
+
1250
+ softmax_auxiliary_loss: bool = False
1251
+ """
1252
+ If ``True``, we add the auxiliary loss function from PaLM that encourages the softmax
1253
+ normalizing term to be close to 0.
1254
+ """
1255
+
1256
+ auxiliary_loss_multiplier: Optional[float] = 1e-4
1257
+ """
1258
+ Used with `softmax_auxiliary_loss`. PaLM uses 1e-4, Chameleon uses 1e-5.
1259
+ """
1260
+
1261
+ time_limit: Optional[float] = None
1262
+ """
1263
+ The maximum amount of time to train for before saving a checkpoint and ending early.
1264
+ """
1265
+
1266
+ extra_steps_after_cancel: int = 10
1267
+ """
1268
+ Under certain conditions when a run is canceled we train for a few extra steps after saving
1269
+ the final checkpoint so that when the run is restarted from the latest checkpoint we have some
1270
+ overlap in metrics.
1271
+ """
1272
+
1273
+ early_stopping_factor: Optional[float] = None
1274
+
1275
+ save_data_indices: bool = True
1276
+ """
1277
+ Save training data indices from each batch for each worker.
1278
+ """
1279
+
1280
+ python_profiling: bool = False
1281
+ """
1282
+ Whether to run the Python profiler on batches 6, 7, and 8.
1283
+ """
1284
+
1285
+ torch_profiling: bool = False
1286
+ """
1287
+ Whether to run the PyTorch profiler on batches 6, 7, and 8.
1288
+ """
1289
+
1290
+ stop_at: Optional[int] = None
1291
+ """
1292
+ Stop at a specific step.
1293
+ """
1294
+
1295
+ stop_after: Optional[int] = None
1296
+ """
1297
+ Stop after a specific number of steps.
1298
+ """
1299
+
1300
+ activation_checkpointing: Optional[ActivationCheckpointingStrategy] = None
1301
+ """
1302
+ The activation checkpointing strategy to use.
1303
+ """
1304
+
1305
+ fused_loss: Optional[bool] = None
1306
+ """
1307
+ Whether to use the fused CE loss function from `flash-attn`.
1308
+ """
1309
+
1310
+ hf_datasets_cache_dir: Optional[str] = None
1311
+ """
1312
+ Deprecated, HF datasets are now stored in `olmo_data.hf_datasets`.
1313
+
1314
+ Path to cache directory of HF datasets saved with `datasets.save_to_disk`.
1315
+ """
1316
+
1317
+ module_outputs_save_steps: Optional[List[int]] = None
1318
+ """
1319
+ Outputs of model submodules are saved during the provided steps. Submodule outputs
1320
+ can be compared using `scripts/compare_module_outputs.py`.
1321
+ """
1322
+
1323
+ @property
1324
+ def autocast_precision(self) -> torch.dtype:
1325
+ if self.precision == "amp_bf16":
1326
+ return torch.bfloat16
1327
+ elif self.precision == "amp_fp16":
1328
+ return torch.float16
1329
+ elif self.precision == "fp32":
1330
+ return torch.float32
1331
+ else:
1332
+ raise ValueError(f"Unexpected precision type '{self.precision}'")
1333
+
1334
+ @property
1335
+ def fsdp_precision(self) -> Optional[MixedPrecision]:
1336
+ if self.fsdp is not None:
1337
+ if self.fsdp.precision is None:
1338
+ return None
1339
+ elif self.fsdp.precision == FSDPPrecision.pure:
1340
+ return MixedPrecision(
1341
+ param_dtype=self.autocast_precision,
1342
+ reduce_dtype=self.autocast_precision,
1343
+ buffer_dtype=self.autocast_precision,
1344
+ )
1345
+ elif self.fsdp.precision == FSDPPrecision.mixed:
1346
+ return MixedPrecision(
1347
+ param_dtype=self.autocast_precision,
1348
+ reduce_dtype=torch.float32,
1349
+ buffer_dtype=self.autocast_precision,
1350
+ )
1351
+ else:
1352
+ raise NotImplementedError(f"{self.fsdp.precision}")
1353
+ else:
1354
+ raise ValueError("self.fsdp is None!")
1355
+
1356
+ @classmethod
1357
+ def update_legacy_settings(cls, config: D) -> D:
1358
+ new_config = config.copy()
1359
+ if om.is_dict(new_config):
1360
+ assert isinstance(new_config, DictConfig)
1361
+
1362
+ if hasattr(new_config, "activation_checkpointing"):
1363
+ if new_config.activation_checkpointing is False:
1364
+ new_config.activation_checkpointing = None
1365
+ if new_config.activation_checkpointing is True:
1366
+ new_config.activation_checkpointing = ActivationCheckpointingStrategy.whole_layer
1367
+
1368
+ if hasattr(new_config, "optimizer"):
1369
+ new_config.optimizer = OptimizerConfig.update_legacy_settings(new_config.optimizer)
1370
+
1371
+ return new_config
exceptions.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __all__ = [
2
+ "OLMoError",
3
+ "OLMoConfigurationError",
4
+ "OLMoCliError",
5
+ "OLMoEnvironmentError",
6
+ "OLMoNetworkError",
7
+ "OLMoCheckpointError",
8
+ ]
9
+
10
+
11
+ class OLMoError(Exception):
12
+ """
13
+ Base class for all custom OLMo exceptions.
14
+ """
15
+
16
+
17
+ class OLMoConfigurationError(OLMoError):
18
+ """
19
+ An error with a configuration file.
20
+ """
21
+
22
+
23
+ class OLMoCliError(OLMoError):
24
+ """
25
+ An error from incorrect CLI usage.
26
+ """
27
+
28
+
29
+ class OLMoEnvironmentError(OLMoError):
30
+ """
31
+ An error from incorrect environment variables.
32
+ """
33
+
34
+
35
+ class OLMoNetworkError(OLMoError):
36
+ """
37
+ An error with a network request.
38
+ """
39
+
40
+
41
+ class OLMoCheckpointError(OLMoError):
42
+ """
43
+ An error occurred reading or writing from a checkpoint.
44
+ """
45
+
46
+
47
+ class OLMoThreadError(Exception):
48
+ """
49
+ Raised when a thread fails.
50
+ """
initialization.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Union
2
+
3
+ import torch.nn as nn
4
+
5
+ __all__ = ["init_normal"]
6
+
7
+
8
+ def init_normal(
9
+ module: Union[nn.Linear, nn.Embedding],
10
+ std: float,
11
+ init_cutoff_factor: Optional[float] = None,
12
+ ):
13
+ # weights
14
+ if init_cutoff_factor is not None:
15
+ cutoff_value = init_cutoff_factor * std
16
+ nn.init.trunc_normal_(module.weight, mean=0.0, std=std, a=-cutoff_value, b=cutoff_value)
17
+ else:
18
+ nn.init.normal_(module.weight, mean=0.0, std=std)
19
+
20
+ # biases
21
+ if isinstance(module, nn.Linear) and module.bias is not None:
22
+ nn.init.zeros_(module.bias)
model.py ADDED
@@ -0,0 +1,1959 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Adapted from
3
+ [MosaiclML](https://github.com/mosaicml/examples.git) and
4
+ [minGPT](https://github.com/karpathy/minGPT.git)
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import logging
10
+ import math
11
+ import sys
12
+ from abc import abstractmethod
13
+ from collections import defaultdict
14
+ from functools import partial
15
+ from typing import (
16
+ Callable,
17
+ Dict,
18
+ Iterable,
19
+ List,
20
+ NamedTuple,
21
+ Optional,
22
+ Sequence,
23
+ Set,
24
+ Tuple,
25
+ cast,
26
+ )
27
+
28
+ import torch
29
+ import torch.backends.cuda
30
+ import torch.nn as nn
31
+ import torch.nn.functional as F
32
+ from torch import einsum
33
+
34
+ from .aliases import PathOrStr
35
+ from .beam_search import BeamSearch, Constraint, FinalSequenceScorer, Sampler
36
+ from .config import (
37
+ ActivationCheckpointingStrategy,
38
+ ActivationType,
39
+ BlockType,
40
+ CheckpointType,
41
+ FSDPWrapStrategy,
42
+ InitFnType,
43
+ LayerNormType,
44
+ ModelConfig,
45
+ ShardedCheckpointerType,
46
+ TrainConfig,
47
+ )
48
+ from .exceptions import OLMoConfigurationError
49
+ from .initialization import init_normal
50
+ from .torch_util import ensure_finite_, get_cumulative_document_lengths
51
+
52
+ if sys.version_info.minor > 8:
53
+ from collections.abc import MutableMapping
54
+ elif sys.version_info.minor == 8:
55
+ from typing import MutableMapping
56
+ else:
57
+ raise SystemExit("This script supports Python 3.8 or higher")
58
+
59
+ __all__ = [
60
+ "LayerNormBase",
61
+ "LayerNorm",
62
+ "RMSLayerNorm",
63
+ "RotaryEmbedding",
64
+ "Activation",
65
+ "GELU",
66
+ "ReLU",
67
+ "SwiGLU",
68
+ "OLMoBlock",
69
+ "OLMoSequentialBlock",
70
+ "OLMo",
71
+ "OLMoOutput",
72
+ "OLMoGenerateOutput",
73
+ ]
74
+
75
+ log = logging.getLogger(__name__)
76
+
77
+ class FANLayer(nn.Module):
78
+ """
79
+ FANLayer: The layer used in FAN (https://arxiv.org/abs/2410.02675).
80
+
81
+ Args:
82
+ input_dim (int): The number of input features.
83
+ output_dim (int): The number of output features.
84
+ p_ratio (float): The ratio of output dimensions used for cosine and sine parts (default: 0.25).
85
+ activation (str or callable): The activation function to apply to the g component. If a string is passed,
86
+ the corresponding activation from torch.nn.functional is used (default: 'gelu').
87
+ use_p_bias (bool): If True, include bias in the linear transformations of p component (default: True).
88
+ There is almost no difference between bias and non-bias in our experiments.
89
+ """
90
+
91
+ def __init__(self, input_dim, output_dim, p_ratio=0.25, activation='gelu', use_p_bias=True):
92
+ super(FANLayer, self).__init__()
93
+
94
+ # Ensure the p_ratio is within a valid range
95
+ assert 0 <= p_ratio <= 0.5, "p_ratio must be between 0 and 0.5"
96
+
97
+ self.p_ratio = p_ratio
98
+ p_output_dim = int(output_dim * self.p_ratio)
99
+ g_output_dim = output_dim - p_output_dim * 2 # Account for cosine and sine terms
100
+
101
+
102
+ self.input_linear = nn.Linear(input_dim, p_output_dim+g_output_dim, bias=use_p_bias)
103
+
104
+ self.fused_dims = (p_output_dim, g_output_dim)
105
+
106
+ # Set the activation function
107
+ if isinstance(activation, str):
108
+ self.activation = getattr(F, activation)
109
+ else:
110
+ self.activation = activation if activation else lambda x: x
111
+
112
+ def forward(self, src, norm_g=None):
113
+ """
114
+ Args:
115
+ src (Tensor): Input tensor of shape (batch_size, input_dim).
116
+
117
+ Returns:
118
+ Tensor: Output tensor of shape (batch_size, output_dim), after applying the FAN layer.
119
+ """
120
+ pg = self.input_linear(src)
121
+
122
+ p, g = pg.split(self.fused_dims, dim=-1)
123
+
124
+ # Concatenate cos(p), sin(p), and activated g along the last dimension
125
+ output = torch.cat((torch.cos(p), torch.sin(p), self.activation(g)), dim=-1)
126
+
127
+ return output
128
+
129
+ class FAN(nn.Module):
130
+ def __init__(self, input_dim, output_dim, config, activation='gelu'):
131
+ super(FAN, self).__init__()
132
+
133
+ self.fanlayer = FANLayer(input_dim, input_dim, config.p_ratio, activation)
134
+ self.linear = nn.Linear(input_dim, output_dim, bias=config.include_bias, device=config.init_device)
135
+
136
+ def forward(self, src):
137
+ return self.linear(self.fanlayer(src))
138
+
139
+
140
+ def activation_checkpoint_function(cfg: ModelConfig):
141
+ preserve_rng_state = not (
142
+ (cfg.attention_dropout == 0.0) and (cfg.embedding_dropout == 0.0) and (cfg.residual_dropout == 0.0)
143
+ )
144
+ from torch.utils.checkpoint import checkpoint
145
+
146
+ return partial(
147
+ checkpoint,
148
+ preserve_rng_state=preserve_rng_state,
149
+ use_reentrant=False,
150
+ )
151
+
152
+
153
+ def should_checkpoint_block(strategy: Optional[ActivationCheckpointingStrategy], block_idx: int) -> bool:
154
+ if strategy is None:
155
+ return False
156
+ elif (
157
+ (strategy == ActivationCheckpointingStrategy.whole_layer)
158
+ or (strategy == ActivationCheckpointingStrategy.one_in_two and block_idx % 2 == 0)
159
+ or (strategy == ActivationCheckpointingStrategy.one_in_three and block_idx % 3 == 0)
160
+ or (strategy == ActivationCheckpointingStrategy.one_in_four and block_idx % 4 == 0)
161
+ or (strategy == ActivationCheckpointingStrategy.one_in_eight and block_idx % 8 == 0)
162
+ or (strategy == ActivationCheckpointingStrategy.two_in_three and block_idx % 3 != 0)
163
+ or (strategy == ActivationCheckpointingStrategy.three_in_four and block_idx % 4 != 0)
164
+ ):
165
+ return True
166
+ else:
167
+ return False
168
+
169
+
170
+ class BufferCache(dict, MutableMapping[str, torch.Tensor]):
171
+ """
172
+ Cache for attention biases and other things that would normally be stored as buffers.
173
+ We avoid using buffers because we've run into various issues doing so with FSDP.
174
+ In general it appears the way FSDP handles buffers is not well-defined.
175
+ It doesn't shard them but apparently it does synchronize them across processes, which we want to avoid
176
+ since (A) it isn't necessary, and (B) we sometimes have `-inf` in these biases which might get turned into
177
+ NaNs when they're synchronized due to casting or some other issue.
178
+ """
179
+
180
+
181
+ def _non_meta_init_device(config: ModelConfig) -> torch.device:
182
+ if config.init_device is not None and config.init_device != "meta":
183
+ return torch.device(config.init_device)
184
+ else:
185
+ if torch.backends.mps.is_available():
186
+ return torch.device("mps")
187
+ elif torch.cuda.is_available():
188
+ return torch.device("cuda")
189
+ else:
190
+ return torch.device("cpu")
191
+
192
+
193
+ class Dropout(nn.Dropout):
194
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
195
+ if self.p == 0.0:
196
+ return input
197
+ else:
198
+ return F.dropout(input, self.p, self.training, self.inplace)
199
+
200
+
201
+ class LayerNormBase(nn.Module):
202
+ def __init__(
203
+ self,
204
+ config: ModelConfig,
205
+ *,
206
+ size: Optional[int] = None,
207
+ elementwise_affine: Optional[bool] = True,
208
+ ):
209
+ super().__init__()
210
+ self.config = config
211
+ self.eps = config.layer_norm_eps
212
+ self.normalized_shape = (size or config.d_model,)
213
+ if elementwise_affine or (elementwise_affine is None and self.config.layer_norm_with_affine):
214
+ self.weight = nn.Parameter(torch.ones(self.normalized_shape, device=config.init_device))
215
+ use_bias = self.config.bias_for_layer_norm
216
+ if use_bias is None:
217
+ use_bias = self.config.include_bias
218
+ if use_bias:
219
+ self.bias = nn.Parameter(torch.zeros(self.normalized_shape, device=config.init_device))
220
+ else:
221
+ self.register_parameter("bias", None)
222
+ else:
223
+ self.register_parameter("bias", None)
224
+ self.register_parameter("weight", None)
225
+
226
+ @abstractmethod
227
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
228
+ raise NotImplementedError
229
+
230
+ @classmethod
231
+ def build(cls, config: ModelConfig, size: Optional[int] = None, **kwargs) -> LayerNormBase:
232
+ if config.layer_norm_type == LayerNormType.default:
233
+ return LayerNorm(config, size=size, low_precision=False, **kwargs)
234
+ elif config.layer_norm_type == LayerNormType.low_precision:
235
+ return LayerNorm(config, size=size, low_precision=True, **kwargs)
236
+ elif config.layer_norm_type == LayerNormType.rms:
237
+ return RMSLayerNorm(config, size=size, **kwargs)
238
+ else:
239
+ raise NotImplementedError(f"Unknown LayerNorm type: '{config.layer_norm_type}'")
240
+
241
+ def _cast_if_autocast_enabled(self, tensor: torch.Tensor, dtype: Optional[torch.dtype] = None) -> torch.Tensor:
242
+ # NOTE: `is_autocast_enabled()` only checks for CUDA autocast, so we use the separate function
243
+ # `is_autocast_cpu_enabled()` for CPU autocast.
244
+ # See https://github.com/pytorch/pytorch/issues/110966.
245
+ if tensor.device.type == "cuda" and torch.is_autocast_enabled():
246
+ return tensor.to(dtype=dtype if dtype is not None else torch.get_autocast_gpu_dtype())
247
+ elif tensor.device.type == "cpu" and torch.is_autocast_cpu_enabled():
248
+ return tensor.to(dtype=dtype if dtype is not None else torch.get_autocast_cpu_dtype())
249
+ else:
250
+ return tensor
251
+
252
+ def reset_parameters(self):
253
+ if self.weight is not None:
254
+ torch.nn.init.ones_(self.weight) # type: ignore
255
+ if self.bias is not None:
256
+ torch.nn.init.zeros_(self.bias) # type: ignore
257
+
258
+
259
+ class LayerNorm(LayerNormBase):
260
+ """
261
+ The default :class:`LayerNorm` implementation which can optionally run in low precision.
262
+ """
263
+
264
+ def __init__(
265
+ self,
266
+ config: ModelConfig,
267
+ size: Optional[int] = None,
268
+ low_precision: bool = False,
269
+ elementwise_affine: Optional[bool] = None,
270
+ ):
271
+ super().__init__(config, size=size, elementwise_affine=elementwise_affine)
272
+ self.low_precision = low_precision
273
+
274
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
275
+ if self.low_precision:
276
+ module_device = x.device
277
+ downcast_x = self._cast_if_autocast_enabled(x)
278
+ downcast_weight = (
279
+ self._cast_if_autocast_enabled(self.weight) if self.weight is not None else self.weight
280
+ )
281
+ downcast_bias = self._cast_if_autocast_enabled(self.bias) if self.bias is not None else self.bias
282
+ with torch.autocast(enabled=False, device_type=module_device.type):
283
+ return F.layer_norm(
284
+ downcast_x, self.normalized_shape, weight=downcast_weight, bias=downcast_bias, eps=self.eps
285
+ )
286
+ else:
287
+ return F.layer_norm(x, self.normalized_shape, weight=self.weight, bias=self.bias, eps=self.eps)
288
+
289
+
290
+ class RMSLayerNorm(LayerNormBase):
291
+ """
292
+ RMS layer norm, a simplified :class:`LayerNorm` implementation
293
+ """
294
+
295
+ def __init__(
296
+ self,
297
+ config: ModelConfig,
298
+ size: Optional[int] = None,
299
+ elementwise_affine: Optional[bool] = None,
300
+ ):
301
+ super().__init__(config, size=size, elementwise_affine=elementwise_affine)
302
+
303
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
304
+ with torch.autocast(enabled=False, device_type=x.device.type):
305
+ og_dtype = x.dtype
306
+ x = x.to(torch.float32)
307
+ variance = x.pow(2).mean(-1, keepdim=True)
308
+ x = x * torch.rsqrt(variance + self.eps)
309
+ x = x.to(og_dtype)
310
+
311
+ if self.weight is not None:
312
+ if self.bias is not None:
313
+ return self.weight * x + self.bias
314
+ else:
315
+ return self.weight * x
316
+ else:
317
+ return x
318
+
319
+
320
+ class RotaryEmbedding(nn.Module):
321
+ """
322
+ [Rotary positional embeddings (RoPE)](https://arxiv.org/abs/2104.09864).
323
+ """
324
+
325
+ def __init__(self, config: ModelConfig, cache: BufferCache):
326
+ super().__init__()
327
+ self.config = config
328
+ self.__cache = cache
329
+ # Warm up cache.
330
+ self.get_rotary_embedding(config.max_sequence_length, _non_meta_init_device(config))
331
+
332
+ def get_rotary_embedding(self, seq_len: int, device: torch.device) -> Tuple[torch.Tensor, torch.Tensor]:
333
+ if (
334
+ (pos_sin := self.__cache.get("rope_pos_sin")) is not None
335
+ and (pos_cos := self.__cache.get("rope_pos_cos")) is not None
336
+ and pos_sin.shape[-2] >= seq_len
337
+ and pos_cos.shape[-2] >= seq_len
338
+ ):
339
+ if pos_sin.device != device:
340
+ pos_sin = pos_sin.to(device)
341
+ self.__cache["rope_pos_sin"] = pos_sin
342
+ if pos_cos.device != device:
343
+ pos_cos = pos_cos.to(device)
344
+ self.__cache["rope_pos_cos"] = pos_cos
345
+ return pos_sin[:, :, :seq_len, :], pos_cos[:, :, :seq_len, :]
346
+
347
+ with torch.autocast(device.type, enabled=False):
348
+ dim = self.config.d_model // self.config.n_heads
349
+ inv_freq = 1.0 / (
350
+ self.config.rope_theta ** (torch.arange(0, dim, 2, device=device, dtype=torch.float) / dim)
351
+ )
352
+ seq = torch.arange(seq_len, device=device, dtype=torch.float)
353
+ freqs = einsum("i , j -> i j", seq, inv_freq)
354
+ positions = torch.cat((freqs, freqs), dim=-1)
355
+ pos_sin, pos_cos = positions.sin()[None, None, :, :], positions.cos()[None, None, :, :]
356
+ self.__cache["rope_pos_sin"] = pos_sin
357
+ self.__cache["rope_pos_cos"] = pos_cos
358
+ return pos_sin, pos_cos
359
+
360
+ def rotate_half(self, x: torch.Tensor) -> torch.Tensor:
361
+ B, nh, T, hs = x.size()
362
+ x = x.view(B, nh, T, 2, hs // 2)
363
+ x1, x2 = x.unbind(dim=-2)
364
+ return torch.cat((-x2, x1), dim=-1)
365
+
366
+ def apply_rotary_pos_emb(self, pos_sin: torch.Tensor, pos_cos: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
367
+ return ((t * pos_cos) + (self.rotate_half(t) * pos_sin)).to(t.dtype)
368
+
369
+ def forward(self, q: torch.Tensor, k: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
370
+ if self.config.rope_full_precision:
371
+ q_, k_ = q.float(), k.float()
372
+ else:
373
+ q_, k_ = q, k
374
+
375
+ with torch.autocast(q.device.type, enabled=False):
376
+ query_len, key_len = q_.shape[-2], k_.shape[-2] # could be different if layer_past not None
377
+ pos_sin, pos_cos = self.get_rotary_embedding(key_len, q_.device)
378
+ pos_sin = pos_sin.type_as(q_)
379
+ pos_cos = pos_cos.type_as(q_)
380
+ q_ = self.apply_rotary_pos_emb(
381
+ pos_sin[:, :, key_len - query_len : key_len, :],
382
+ pos_cos[:, :, key_len - query_len : key_len, :],
383
+ q_,
384
+ )
385
+ k_ = self.apply_rotary_pos_emb(pos_sin, pos_cos, k_)
386
+ return q_.type_as(q), k_.type_as(k)
387
+
388
+
389
+ class Activation(nn.Module):
390
+ def __init__(self, config: ModelConfig):
391
+ super().__init__()
392
+ self.config = config
393
+
394
+ @abstractmethod
395
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
396
+ raise NotImplementedError
397
+
398
+ @property
399
+ @abstractmethod
400
+ def output_multiplier(self) -> float:
401
+ raise NotImplementedError
402
+
403
+ @classmethod
404
+ def build(cls, config: ModelConfig) -> Activation:
405
+ if config.activation_type == ActivationType.gelu:
406
+ return cast(Activation, GELU(approximate="none"))
407
+ elif config.activation_type == ActivationType.relu:
408
+ return cast(Activation, ReLU(inplace=False))
409
+ elif config.activation_type == ActivationType.swiglu:
410
+ return SwiGLU(config)
411
+ else:
412
+ raise NotImplementedError(f"Unknown activation: '{config.activation_type}'")
413
+
414
+
415
+ class GELU(nn.GELU):
416
+ @property
417
+ def output_multiplier(self) -> float:
418
+ return 1.0
419
+
420
+
421
+ class ReLU(nn.ReLU):
422
+ @property
423
+ def output_multiplier(self) -> float:
424
+ return 1.0
425
+
426
+
427
+ class SwiGLU(Activation):
428
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
429
+ x, gate = x.chunk(2, dim=-1)
430
+ return F.silu(gate) * x
431
+
432
+ @property
433
+ def output_multiplier(self) -> float:
434
+ return 0.5
435
+
436
+
437
+ def causal_attention_bias(seq_len: int, device: torch.device) -> torch.FloatTensor:
438
+ att_bias = torch.triu(
439
+ torch.ones(seq_len, seq_len, device=device, dtype=torch.float),
440
+ diagonal=1,
441
+ )
442
+ att_bias.masked_fill_(att_bias == 1, torch.finfo(att_bias.dtype).min)
443
+ return att_bias.view(1, 1, seq_len, seq_len) # type: ignore
444
+
445
+
446
+ def get_causal_attention_bias(cache: BufferCache, seq_len: int, device: torch.device) -> torch.Tensor:
447
+ if (causal_bias := cache.get("causal_attention_bias")) is not None and causal_bias.shape[-1] >= seq_len:
448
+ if causal_bias.device != device:
449
+ causal_bias = causal_bias.to(device)
450
+ cache["causal_attention_bias"] = causal_bias
451
+ return causal_bias
452
+ with torch.autocast(device.type, enabled=False):
453
+ causal_bias = causal_attention_bias(seq_len, device)
454
+ cache["causal_attention_bias"] = causal_bias
455
+ return causal_bias
456
+
457
+
458
+ def alibi_attention_bias(seq_len: int, config: ModelConfig, device: torch.device) -> torch.FloatTensor:
459
+ alibi_bias = torch.arange(1 - seq_len, 1, dtype=torch.float, device=device).view(1, 1, 1, seq_len)
460
+
461
+ # shape: (1, 1, seq_len, seq_len)
462
+ alibi_bias = alibi_bias - torch.arange(1 - seq_len, 1, dtype=torch.float, device=device).view(1, 1, seq_len, 1)
463
+ alibi_bias.abs_().mul_(-1)
464
+
465
+ # shape: (n_heads,)
466
+ m = torch.arange(1, config.n_heads + 1, dtype=torch.float, device=device)
467
+ m.mul_(config.alibi_bias_max / config.n_heads)
468
+
469
+ # shape: (1, n_heads, seq_len, seq_len)
470
+ return alibi_bias * (1.0 / (2 ** m.view(1, config.n_heads, 1, 1))) # type: ignore
471
+
472
+
473
+ class OLMoBlock(nn.Module):
474
+ """
475
+ A base class for transformer block implementations.
476
+ """
477
+
478
+ def __init__(self, layer_id: int, config: ModelConfig, cache: BufferCache):
479
+ super().__init__()
480
+ self.layer_id = layer_id
481
+ self.config = config
482
+ self.hidden_size = (
483
+ config.mlp_hidden_size if config.mlp_hidden_size is not None else config.mlp_ratio * config.d_model
484
+ )
485
+ self.__cache = cache
486
+ assert config.d_model % config.n_heads == 0
487
+
488
+ self._activation_checkpoint_fn: Optional[Callable] = None
489
+
490
+ # Dropout.
491
+ self.dropout = Dropout(config.residual_dropout)
492
+
493
+ # Layer norms.
494
+ self.k_norm: Optional[LayerNormBase] = None
495
+ self.q_norm: Optional[LayerNormBase] = None
496
+ if config.attention_layer_norm:
497
+ assert config.effective_n_kv_heads is not None
498
+ self.k_norm = LayerNormBase.build(
499
+ config,
500
+ size=(config.d_model // config.n_heads) * config.effective_n_kv_heads,
501
+ elementwise_affine=config.attention_layer_norm_with_affine,
502
+ )
503
+ self.q_norm = LayerNormBase.build(config, elementwise_affine=config.attention_layer_norm_with_affine)
504
+
505
+ # Make sure QKV clip coefficient is positive, otherwise it's not well-defined.
506
+ if config.clip_qkv is not None:
507
+ assert config.clip_qkv > 0
508
+
509
+ # Activation function.
510
+ self.act = Activation.build(config)
511
+ assert (self.act.output_multiplier * self.hidden_size) % 1 == 0
512
+
513
+ # Attention output projection.
514
+ self.attn_out = nn.Linear(
515
+ config.d_model, config.d_model, bias=config.include_bias, device=config.init_device
516
+ )
517
+
518
+ # Feed-forward output projection.
519
+ self.ff_out = nn.Linear(
520
+ int(self.act.output_multiplier * self.hidden_size),
521
+ config.d_model,
522
+ bias=config.include_bias,
523
+ device=config.init_device,
524
+ )
525
+ self.ff_out._is_residual = True # type: ignore
526
+
527
+ # Rotary embeddings.
528
+ if self.config.rope:
529
+ self.rotary_emb = RotaryEmbedding(config, self.__cache)
530
+
531
+ self.flash_attn_func = None
532
+ self.flash_attn_varlen_func = None
533
+ if config.flash_attention:
534
+ try:
535
+ from flash_attn import ( # type: ignore
536
+ flash_attn_func,
537
+ flash_attn_varlen_func,
538
+ )
539
+
540
+ self.flash_attn_func = flash_attn_func
541
+ self.flash_attn_varlen_func = flash_attn_varlen_func
542
+ except ModuleNotFoundError:
543
+ pass
544
+
545
+ def reset_parameters(self):
546
+ if self.k_norm is not None:
547
+ self.k_norm.reset_parameters()
548
+ if self.q_norm is not None:
549
+ self.q_norm.reset_parameters()
550
+
551
+ if self.config.init_fn == InitFnType.normal:
552
+ attn_out_std = ff_out_std = self.config.init_std
553
+ cutoff_factor = self.config.init_cutoff_factor
554
+
555
+ elif self.config.init_fn == InitFnType.mitchell:
556
+ attn_out_std = 1 / (math.sqrt(2 * self.config.d_model * (self.layer_id + 1)))
557
+ ff_out_std = 1 / (math.sqrt(2 * self.ff_out.in_features * (self.layer_id + 1)))
558
+ cutoff_factor = self.config.init_cutoff_factor or 3.0
559
+
560
+ elif self.config.init_fn == InitFnType.full_megatron:
561
+ attn_out_std = ff_out_std = self.config.init_std / math.sqrt(2.0 * self.config.n_layers)
562
+ cutoff_factor = self.config.init_cutoff_factor or 3.0
563
+
564
+ else:
565
+ raise NotImplementedError(self.config.init_fn)
566
+
567
+ init_normal(self.attn_out, std=attn_out_std, init_cutoff_factor=cutoff_factor)
568
+ init_normal(self.ff_out, std=ff_out_std, init_cutoff_factor=cutoff_factor)
569
+
570
+ def set_activation_checkpointing(
571
+ self, strategy: Optional[ActivationCheckpointingStrategy], checkpoint_func: Optional[Callable] = None
572
+ ):
573
+ if strategy == ActivationCheckpointingStrategy.fine_grained:
574
+ self._activation_checkpoint_fn = checkpoint_func or activation_checkpoint_function(self.config)
575
+ else:
576
+ self._activation_checkpoint_fn = None
577
+
578
+ @classmethod
579
+ def _cast_attn_bias(cls, bias: torch.Tensor, input_dtype: torch.dtype) -> torch.Tensor:
580
+ target_dtype = input_dtype
581
+ # NOTE: `is_autocast_enabled()` only checks for CUDA autocast, so we use the separate function
582
+ # `is_autocast_cpu_enabled()` for CPU autocast.
583
+ # See https://github.com/pytorch/pytorch/issues/110966.
584
+ if bias.device.type == "cuda" and torch.is_autocast_enabled():
585
+ target_dtype = torch.get_autocast_gpu_dtype()
586
+ elif bias.device.type == "cpu" and torch.is_autocast_cpu_enabled():
587
+ target_dtype = torch.get_autocast_cpu_dtype()
588
+ elif bias.device.type == "mps":
589
+ target_dtype = torch.get_autocast_dtype("mps")
590
+ if bias.dtype != target_dtype:
591
+ bias = bias.to(target_dtype)
592
+ ensure_finite_(bias, check_neg_inf=True, check_pos_inf=False)
593
+ return bias
594
+
595
+ def _scaled_dot_product_attention(
596
+ self,
597
+ q: torch.Tensor,
598
+ k: torch.Tensor,
599
+ v: torch.Tensor,
600
+ attn_mask: Optional[torch.Tensor] = None,
601
+ dropout_p: float = 0.0,
602
+ is_causal: bool = False,
603
+ max_doc_len: Optional[int] = None,
604
+ cu_doc_lens: Optional[torch.Tensor] = None,
605
+ ) -> torch.Tensor:
606
+ """
607
+ Computes scaled dot product attention on query, key and value tensors, using an optional
608
+ attention mask if passed, and applying dropout if a probability greater than 0.0 is specified.
609
+ """
610
+ if max_doc_len is not None and cu_doc_lens is not None:
611
+ assert self.flash_attn_varlen_func is not None, "flash-attn is required for document masking"
612
+ assert attn_mask is None, "attn-mask is currently not supported with document masking"
613
+ B, T, D = q.size(0), q.size(2), q.size(3)
614
+ r = self.flash_attn_varlen_func(
615
+ q.transpose(1, 2).view(B * T, -1, D),
616
+ k.transpose(1, 2).view(B * T, -1, D),
617
+ v.transpose(1, 2).view(B * T, -1, D),
618
+ cu_doc_lens,
619
+ cu_doc_lens,
620
+ max_doc_len,
621
+ max_doc_len,
622
+ dropout_p=dropout_p,
623
+ causal=is_causal,
624
+ )
625
+ return r.view(B, T, -1, D).transpose(1, 2)
626
+ elif self.flash_attn_func is not None and attn_mask is None:
627
+ r = self.flash_attn_func(
628
+ q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), dropout_p=dropout_p, causal=is_causal
629
+ )
630
+ return r.transpose(1, 2)
631
+ else:
632
+ # torch's sdpa doesn't support GQA, so we're doing this
633
+ assert k.size(1) == v.size(1)
634
+ num_kv_heads = k.size(1)
635
+ num_q_heads = q.size(1)
636
+ if num_q_heads != num_kv_heads:
637
+ assert num_q_heads % num_kv_heads == 0
638
+ k = k.repeat_interleave(num_q_heads // num_kv_heads, dim=1, output_size=num_q_heads)
639
+ v = v.repeat_interleave(num_q_heads // num_kv_heads, dim=1, output_size=num_q_heads)
640
+
641
+ return F.scaled_dot_product_attention(
642
+ q,
643
+ k,
644
+ v,
645
+ attn_mask=attn_mask,
646
+ dropout_p=dropout_p,
647
+ is_causal=is_causal,
648
+ )
649
+
650
+ def attention(
651
+ self,
652
+ q: torch.Tensor,
653
+ k: torch.Tensor,
654
+ v: torch.Tensor,
655
+ attention_bias: Optional[torch.Tensor] = None,
656
+ layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
657
+ use_cache: bool = False,
658
+ max_doc_len: Optional[int] = None,
659
+ cu_doc_lens: Optional[torch.Tensor] = None,
660
+ ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
661
+ B, T, C = q.size() # batch size, sequence length, d_model
662
+ dtype = k.dtype
663
+
664
+ # Optionally apply layer norm to keys and queries.
665
+ if self.q_norm is not None and self.k_norm is not None:
666
+ q = self.q_norm(q).to(dtype=dtype)
667
+ k = self.k_norm(k).to(dtype=dtype)
668
+
669
+ # Move head forward to be next to the batch dim.
670
+ # shape: (B, nh, T, hs)
671
+ q = q.view(B, T, self.config.n_heads, C // self.config.n_heads).transpose(1, 2)
672
+ # shape: (B, n_kv_h, T, hs)
673
+ k = k.view(B, T, self.config.effective_n_kv_heads, C // self.config.n_heads).transpose(1, 2)
674
+ # shape: (B, n_kv_h, T, hs)
675
+ v = v.view(B, T, self.config.effective_n_kv_heads, C // self.config.n_heads).transpose(1, 2)
676
+
677
+ if layer_past is not None:
678
+ past_key, past_value = layer_past
679
+ k = torch.cat((past_key, k), dim=-2)
680
+ v = torch.cat((past_value, v), dim=-2)
681
+
682
+ present = (k, v) if use_cache else None
683
+ query_len, key_len = q.shape[-2], k.shape[-2] # could be different if layer_past not None
684
+
685
+ if self.config.rope:
686
+ # Apply rotary embeddings.
687
+ q, k = self.rotary_emb(q, k)
688
+
689
+ if attention_bias is not None:
690
+ # Resize and cast attention bias.
691
+ # The current dtype of the attention bias might not match the dtype that the SDP attn function will
692
+ # run in if AMP is enabled, and this can be a problem if some tokens are masked out due to padding
693
+ # as down-casting the attention bias to the autocast precision will result in -infs, which will
694
+ # cause the SDP attn function to produce NaNs.
695
+ attention_bias = self._cast_attn_bias(
696
+ attention_bias[:, :, key_len - query_len : key_len, :key_len], dtype
697
+ )
698
+
699
+ # Get the attention scores.
700
+ # shape: (B, nh, T, hs)
701
+ att = self._scaled_dot_product_attention(
702
+ q,
703
+ k,
704
+ v,
705
+ attn_mask=attention_bias,
706
+ dropout_p=0.0 if not self.training else self.config.attention_dropout,
707
+ is_causal=attention_bias is None,
708
+ max_doc_len=max_doc_len,
709
+ cu_doc_lens=cu_doc_lens,
710
+ )
711
+
712
+ # Re-assemble all head outputs side-by-side.
713
+ att = att.transpose(1, 2).contiguous().view(B, T, C)
714
+
715
+ # Apply output projection.
716
+ return self.attn_out(att), present
717
+
718
+ @abstractmethod
719
+ def forward(
720
+ self,
721
+ x: torch.Tensor,
722
+ attention_bias: Optional[torch.FloatTensor] = None,
723
+ layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
724
+ use_cache: bool = False,
725
+ max_doc_len: Optional[int] = None,
726
+ cu_doc_lens: Optional[torch.Tensor] = None,
727
+ ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
728
+ raise NotImplementedError
729
+
730
+ @classmethod
731
+ def build(cls, layer_id: int, config: ModelConfig, cache: BufferCache) -> OLMoBlock:
732
+ if config.block_type == BlockType.sequential:
733
+ return OLMoSequentialBlock(layer_id, config, cache)
734
+ elif config.block_type == BlockType.llama:
735
+ return OLMoLlamaBlock(layer_id, config, cache)
736
+ else:
737
+ raise NotImplementedError(f"Unknown block type: '{config.block_type}'")
738
+
739
+
740
+ class OLMoSequentialBlock(OLMoBlock):
741
+ """
742
+ This is a typical transformer block where the output is computed as ``MLP(LN(x + Attention(LN(x))))``
743
+ (plus another skip connection). To compute it as ``LN(MLP(x + LN(Attention(x))))``,
744
+ use the flag `norm_after`.
745
+ """
746
+
747
+ def __init__(self, layer_id: int, config: ModelConfig, cache: BufferCache):
748
+ super().__init__(layer_id, config, cache)
749
+ # Attention input projection. Projects x -> (q, k, v)
750
+ self.use_ATF = config.use_ATF
751
+
752
+ head_dim = config.d_model // config.n_heads
753
+ self.fused_dims = (
754
+ config.d_model,
755
+ config.effective_n_kv_heads * head_dim,
756
+ config.effective_n_kv_heads * head_dim,
757
+ )
758
+
759
+
760
+ if self.use_ATF:
761
+ self.att_proj = FAN(config.d_model, sum(self.fused_dims), config, activation=config.attention_activation)
762
+ else:
763
+ self.att_proj = nn.Linear(
764
+ config.d_model, sum(self.fused_dims), bias=config.include_bias, device=config.init_device
765
+ )
766
+
767
+ # Feed-forward input projection.
768
+ self.ff_proj = nn.Linear(
769
+ config.d_model, self.hidden_size, bias=config.include_bias, device=config.init_device
770
+ )
771
+
772
+ # Layer norms.
773
+ self.attn_norm = LayerNorm.build(config, size=config.d_model)
774
+ self.ff_norm = LayerNorm.build(config, size=config.d_model)
775
+
776
+ def reset_parameters(self):
777
+ super().reset_parameters()
778
+ self.attn_norm.reset_parameters()
779
+ self.ff_norm.reset_parameters()
780
+ # NOTE: the standard deviation for these weights does not depend on the layer.
781
+
782
+ if self.config.init_fn == InitFnType.normal:
783
+ std = self.config.init_std
784
+ cutoff_factor = self.config.init_cutoff_factor
785
+ elif self.config.init_fn == InitFnType.mitchell:
786
+ std = 1 / math.sqrt(self.config.d_model)
787
+ cutoff_factor = self.config.init_cutoff_factor or 3.0
788
+ elif self.config.init_fn == InitFnType.full_megatron:
789
+ std = self.config.init_std
790
+ cutoff_factor = self.config.init_cutoff_factor or 3.0
791
+ else:
792
+ raise NotImplementedError(self.config.init_fn)
793
+
794
+ if self.use_ATF:
795
+ init_normal(self.att_proj.fanlayer.input_linear, std, cutoff_factor)
796
+ init_normal(self.att_proj.linear, std, cutoff_factor)
797
+ else:
798
+ init_normal(self.att_proj, std, cutoff_factor)
799
+
800
+ init_normal(self.ff_proj, std, cutoff_factor)
801
+
802
+ def forward(
803
+ self,
804
+ x: torch.Tensor,
805
+ attention_bias: Optional[torch.Tensor] = None,
806
+ layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
807
+ use_cache: bool = False,
808
+ max_doc_len: Optional[int] = None,
809
+ cu_doc_lens: Optional[torch.Tensor] = None,
810
+ ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
811
+ # Get query, key, value projections.
812
+ # shape:
813
+ # - for regular attn q, k, v: (batch_size, seq_len, d_model)
814
+ # - for multi-query attn q: (batch_size, seq_len, d_model)
815
+ # k, v: (batch_size, seq_len, d_model // n_heads)
816
+ # - for group query attn q: (batch_size, seq_len, d_model)
817
+ # k, v: (batch_size, seq_len, d_model // n_kv_heads)
818
+
819
+ # apply norm before
820
+ if not self.config.norm_after:
821
+ if self._activation_checkpoint_fn is not None:
822
+ h = self._activation_checkpoint_fn(self.attn_norm, x)
823
+ else:
824
+ h = self.attn_norm(x)
825
+ else:
826
+ h = x
827
+
828
+ qkv = self.att_proj(h)
829
+
830
+ if self.config.clip_qkv is not None:
831
+ qkv.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
832
+
833
+ q, k, v = qkv.split(self.fused_dims, dim=-1)
834
+
835
+ # Get attention scores.
836
+ if self._activation_checkpoint_fn is not None:
837
+ att, cache = self._activation_checkpoint_fn( # type: ignore
838
+ self.attention,
839
+ q,
840
+ k,
841
+ v,
842
+ attention_bias,
843
+ layer_past=layer_past,
844
+ use_cache=use_cache,
845
+ max_doc_len=max_doc_len,
846
+ cu_doc_lens=cu_doc_lens,
847
+ )
848
+ else:
849
+ att, cache = self.attention(
850
+ q,
851
+ k,
852
+ v,
853
+ attention_bias,
854
+ layer_past=layer_past,
855
+ use_cache=use_cache,
856
+ max_doc_len=max_doc_len,
857
+ cu_doc_lens=cu_doc_lens,
858
+ )
859
+
860
+ if self.config.norm_after:
861
+ if self._activation_checkpoint_fn is not None:
862
+ att = self._activation_checkpoint_fn(self.attn_norm, att)
863
+ else:
864
+ att = self.attn_norm(att)
865
+
866
+ # Add attention scores.
867
+ # shape: (B, T, C)
868
+ x = x + self.dropout(att)
869
+
870
+ # Add feed-forward projection.
871
+ # shape: (batch_size, seq_len, d_model)
872
+ og_x = x
873
+
874
+ if not self.config.norm_after:
875
+ if self._activation_checkpoint_fn is not None:
876
+ x = self._activation_checkpoint_fn(self.ff_norm, x) # type: ignore
877
+ else:
878
+ x = self.ff_norm(x)
879
+
880
+ x = self.ff_proj(x)
881
+
882
+ if self._activation_checkpoint_fn is not None:
883
+ x = self._activation_checkpoint_fn(self.act, x) # type: ignore
884
+ else:
885
+ x = self.act(x)
886
+ x = self.ff_out(x)
887
+
888
+ if self.config.norm_after:
889
+ if self._activation_checkpoint_fn is not None:
890
+ x = self._activation_checkpoint_fn(self.ff_norm, x) # type: ignore
891
+ else:
892
+ x = self.ff_norm(x)
893
+
894
+ x = self.dropout(x)
895
+ x = og_x + x
896
+
897
+ return x, cache
898
+
899
+
900
+ class OLMoLlamaBlock(OLMoBlock):
901
+ """
902
+ This is a transformer block where the output is computed as ``MLP(LN(x + Attention(LN(x))))``
903
+ (plus another skip connection). This block is similar to `OLMoSequentialBlock`
904
+ but some operations have slightly different implementations to imitate the
905
+ behavior of Llama.
906
+ """
907
+
908
+ def __init__(self, layer_id: int, config: ModelConfig, cache: BufferCache):
909
+ super().__init__(layer_id, config, cache)
910
+ # Layer norms.
911
+ self.use_ATF = config.use_ATF
912
+ self.attn_norm = LayerNorm.build(config)
913
+ self.ff_norm = LayerNorm.build(config)
914
+ self.__cache = cache
915
+
916
+ # Attention input projection. Projects x -> (q, k, v)
917
+ if config.multi_query_attention:
918
+ q_proj_out_dim = config.d_model
919
+ k_proj_out_dim = config.d_model // config.n_heads
920
+ v_proj_out_dim = config.d_model // config.n_heads
921
+ else:
922
+ q_proj_out_dim = config.d_model
923
+ k_proj_out_dim = config.d_model
924
+ v_proj_out_dim = config.d_model
925
+
926
+ if self.use_ATF:
927
+ self.q_proj = FAN(config.d_model, q_proj_out_dim, config)
928
+ self.k_proj = FAN(config.d_model, k_proj_out_dim, config)
929
+ self.v_proj = FAN(config.d_model, v_proj_out_dim, config)
930
+ else:
931
+ self.q_proj = nn.Linear(
932
+ config.d_model, q_proj_out_dim, bias=config.include_bias, device=config.init_device
933
+ )
934
+ self.k_proj = nn.Linear(
935
+ config.d_model, k_proj_out_dim, bias=config.include_bias, device=config.init_device
936
+ )
937
+ self.v_proj = nn.Linear(
938
+ config.d_model, v_proj_out_dim, bias=config.include_bias, device=config.init_device
939
+ )
940
+
941
+ # Feed-forward input projection.
942
+ self.ff_proj = nn.Linear(
943
+ config.d_model, self.hidden_size, bias=config.include_bias, device=config.init_device
944
+ )
945
+
946
+ def reset_parameters(self):
947
+ super().reset_parameters()
948
+ self.attn_norm.reset_parameters()
949
+ self.ff_norm.reset_parameters()
950
+ # NOTE: the standard deviation for these weights does not depend on the layer.
951
+
952
+ if self.config.init_fn == InitFnType.normal:
953
+ std = self.config.init_std
954
+ cutoff_factor = self.config.init_cutoff_factor
955
+ elif self.config.init_fn == InitFnType.mitchell:
956
+ std = 1 / math.sqrt(self.config.d_model)
957
+ cutoff_factor = self.config.init_cutoff_factor or 3.0
958
+ elif self.config.init_fn == InitFnType.full_megatron:
959
+ std = self.config.init_std
960
+ cutoff_factor = self.config.init_cutoff_factor or 3.0
961
+ else:
962
+ raise NotImplementedError(self.config.init_fn)
963
+
964
+ init_normal(self.q_proj, std, cutoff_factor)
965
+ init_normal(self.k_proj, std, cutoff_factor)
966
+ init_normal(self.v_proj, std, cutoff_factor)
967
+ init_normal(self.ff_proj, std, cutoff_factor)
968
+
969
+ def _scaled_dot_product_attention(
970
+ self,
971
+ q: torch.Tensor,
972
+ k: torch.Tensor,
973
+ v: torch.Tensor,
974
+ attn_mask: Optional[torch.Tensor] = None,
975
+ dropout_p: float = 0.0,
976
+ is_causal: bool = False,
977
+ max_doc_len: Optional[int] = None,
978
+ cu_doc_lens: Optional[torch.Tensor] = None,
979
+ ) -> torch.Tensor:
980
+ if max_doc_len is not None or cu_doc_lens is not None:
981
+ raise NotImplementedError(
982
+ f"attention document masking is not implemented for {self.__class__.__name__}"
983
+ )
984
+
985
+ attn_weights = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(q.size(-1))
986
+
987
+ if is_causal:
988
+ assert attn_mask is None
989
+
990
+ query_len, key_len = q.shape[-2], k.shape[-2] # could be different if layer_past not None
991
+ attn_bias = get_causal_attention_bias(self.__cache, key_len, q.device)[:, :, :query_len, :key_len]
992
+ elif attn_mask is not None:
993
+ attn_bias = attn_mask.to(q.dtype)
994
+ else:
995
+ attn_bias = torch.zeros_like(attn_weights)
996
+
997
+ attn_weights += attn_bias
998
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1).to(q.dtype)
999
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout_p)
1000
+ return torch.matmul(attn_weights, v)
1001
+
1002
+ def forward(
1003
+ self,
1004
+ x: torch.Tensor,
1005
+ attention_bias: Optional[torch.Tensor] = None,
1006
+ layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
1007
+ use_cache: bool = False,
1008
+ max_doc_len: Optional[int] = None,
1009
+ cu_doc_lens: Optional[torch.Tensor] = None,
1010
+ ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
1011
+ # Get query, key, value projections.
1012
+ # shape:
1013
+ # - for regular attn q, k, v: (batch_size, seq_len, d_model)
1014
+ # - for multi-query attn q: (batch_size, seq_len, d_model)
1015
+ # k, v: (batch_size, seq_len, d_model // n_heads)
1016
+ x_normed = self.attn_norm(x)
1017
+ q = self.q_proj(x_normed)
1018
+ k = self.k_proj(x_normed)
1019
+ v = self.v_proj(x_normed)
1020
+
1021
+ if self.config.clip_qkv is not None:
1022
+ q.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
1023
+ k.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
1024
+ v.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
1025
+
1026
+ # Get attention scores.
1027
+ att, cache = self.attention(
1028
+ q,
1029
+ k,
1030
+ v,
1031
+ attention_bias,
1032
+ layer_past=layer_past,
1033
+ use_cache=use_cache,
1034
+ max_doc_len=max_doc_len,
1035
+ cu_doc_lens=cu_doc_lens,
1036
+ )
1037
+
1038
+ # Add attention scores.
1039
+ # shape: (B, T, C)
1040
+ x = x + self.dropout(att)
1041
+
1042
+ # Add feed-forward projection.
1043
+ # shape: (batch_size, seq_len, d_model)
1044
+ og_x = x
1045
+ if self._activation_checkpoint_fn is not None:
1046
+ x = self._activation_checkpoint_fn(self.ff_norm, x) # type: ignore
1047
+ else:
1048
+ x = self.ff_norm(x)
1049
+ x = self.ff_proj(x)
1050
+ if self._activation_checkpoint_fn is not None:
1051
+ x = self._activation_checkpoint_fn(self.act, x) # type: ignore
1052
+ else:
1053
+ x = self.act(x)
1054
+ x = self.ff_out(x)
1055
+ x = self.dropout(x)
1056
+ x = og_x + x
1057
+
1058
+ return x, cache
1059
+
1060
+
1061
+ class OLMoOutput(NamedTuple):
1062
+ logits: torch.FloatTensor
1063
+ """
1064
+ A tensor of shape `(batch_size, seq_len, vocab_size)` representing the log probabilities
1065
+ for the next token *before* normalization via (log) softmax.
1066
+ """
1067
+
1068
+ attn_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]]
1069
+ """
1070
+ Attention keys and values from each block.
1071
+ """
1072
+
1073
+ hidden_states: Optional[Tuple[torch.Tensor, ...]]
1074
+ """
1075
+ Hidden states from each block.
1076
+ """
1077
+
1078
+
1079
+ class OLMoGenerateOutput(NamedTuple):
1080
+ token_ids: torch.LongTensor
1081
+ """
1082
+ The generated token IDs, a tensor of shape `(batch_size, beam_size, max_steps)`.
1083
+ These do *not* include the original input IDs.
1084
+ """
1085
+
1086
+ scores: torch.FloatTensor
1087
+ """
1088
+ The scores of the generated sequences, a tensor of shape `(batch_size, beam_size)`.
1089
+ """
1090
+
1091
+
1092
+ class OLMoBlockGroup(nn.ModuleList):
1093
+ def __init__(self, config: ModelConfig, layer_offset: int, modules: Optional[Iterable[nn.Module]] = None):
1094
+ super().__init__(modules)
1095
+ self.config = config
1096
+ self.layer_offset = layer_offset
1097
+ self.activation_checkpointing_strategy: Optional[ActivationCheckpointingStrategy] = None
1098
+ self._activation_checkpoint_fn = activation_checkpoint_function(self.config)
1099
+
1100
+ def forward(
1101
+ self,
1102
+ x: torch.Tensor,
1103
+ attention_bias: Optional[torch.FloatTensor] = None,
1104
+ layers_past: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
1105
+ use_cache: bool = False,
1106
+ max_doc_len: Optional[int] = None,
1107
+ cu_doc_lens: Optional[torch.Tensor] = None,
1108
+ ) -> Tuple[torch.Tensor, Optional[List[Tuple[torch.Tensor, torch.Tensor]]]]:
1109
+ attn_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = [] if use_cache else None
1110
+ for block_idx, block in enumerate(self):
1111
+ layer_past = None if layers_past is None else layers_past[block_idx]
1112
+ block_idx += self.layer_offset
1113
+ if should_checkpoint_block(self.activation_checkpointing_strategy, block_idx):
1114
+ # shape: (batch_size, seq_len, d_model)
1115
+ x, cache = self._activation_checkpoint_fn( # type: ignore
1116
+ block,
1117
+ x,
1118
+ attention_bias=attention_bias,
1119
+ layer_past=layer_past,
1120
+ use_cache=use_cache,
1121
+ max_doc_len=max_doc_len,
1122
+ cu_doc_lens=cu_doc_lens,
1123
+ )
1124
+ else:
1125
+ # shape: (batch_size, seq_len, d_model)
1126
+ x, cache = block(
1127
+ x,
1128
+ attention_bias=attention_bias,
1129
+ layer_past=layer_past,
1130
+ use_cache=use_cache,
1131
+ max_doc_len=max_doc_len,
1132
+ cu_doc_lens=cu_doc_lens,
1133
+ )
1134
+ if attn_key_values is not None:
1135
+ assert cache is not None
1136
+ attn_key_values.append(cache)
1137
+ return x, attn_key_values
1138
+
1139
+ def reset_parameters(self):
1140
+ for block in self:
1141
+ block.reset_parameters()
1142
+
1143
+ def set_activation_checkpointing(
1144
+ self, strategy: Optional[ActivationCheckpointingStrategy], checkpoint_func: Optional[Callable] = None
1145
+ ):
1146
+ self.activation_checkpointing_strategy = strategy
1147
+ for block in self:
1148
+ block.set_activation_checkpointing(strategy, checkpoint_func=checkpoint_func)
1149
+
1150
+
1151
+ class OLMo(nn.Module):
1152
+ def __init__(self, config: ModelConfig, init_params: bool = True):
1153
+ super().__init__()
1154
+ self.config = config
1155
+ self.__cache = BufferCache()
1156
+
1157
+ # Validate config.
1158
+ if self.config.alibi and self.config.flash_attention:
1159
+ raise OLMoConfigurationError("ALiBi is currently not supported with FlashAttention")
1160
+
1161
+ if self.config.alibi and self.config.rope:
1162
+ raise OLMoConfigurationError("ALiBi and RoPE are mutually exclusive")
1163
+
1164
+ if self.config.embedding_size is not None and self.config.embedding_size != self.config.vocab_size:
1165
+ if self.config.embedding_size < self.config.vocab_size:
1166
+ raise OLMoConfigurationError("embedding size should be at least as big as vocab size")
1167
+ elif self.config.embedding_size % 128 != 0:
1168
+ import warnings
1169
+
1170
+ warnings.warn(
1171
+ "Embedding size is not a multiple of 128! This could hurt throughput performance.", UserWarning
1172
+ )
1173
+
1174
+ self.activation_checkpointing_strategy: Optional[ActivationCheckpointingStrategy] = None
1175
+ self._activation_checkpoint_fn: Callable = activation_checkpoint_function(self.config)
1176
+
1177
+ if not (
1178
+ 0 < self.config.block_group_size <= self.config.n_layers
1179
+ and self.config.n_layers % self.config.block_group_size == 0
1180
+ ):
1181
+ raise OLMoConfigurationError("n layers must be divisible by block group size")
1182
+
1183
+ torch.backends.cuda.enable_flash_sdp(True)
1184
+ torch.backends.cuda.enable_mem_efficient_sdp(False) # this is super slow so make sure torch won't use it
1185
+
1186
+ self.transformer = nn.ModuleDict(
1187
+ dict(
1188
+ wte=nn.Embedding(
1189
+ config.embedding_size or config.vocab_size, config.d_model, device=config.init_device
1190
+ ),
1191
+ emb_drop=Dropout(config.embedding_dropout),
1192
+ ln_f=LayerNorm.build(config),
1193
+ )
1194
+ )
1195
+
1196
+ blocks = [OLMoBlock.build(i, config, self.__cache) for i in range(config.n_layers)]
1197
+ if self.config.block_group_size > 1:
1198
+ block_groups = [
1199
+ OLMoBlockGroup(config, i, blocks[i : i + config.block_group_size])
1200
+ for i in range(0, config.n_layers, config.block_group_size)
1201
+ ]
1202
+ self.transformer.update({"block_groups": nn.ModuleList(block_groups)})
1203
+ else:
1204
+ self.transformer.update({"blocks": nn.ModuleList(blocks)})
1205
+
1206
+ if not (self.config.alibi or self.config.rope):
1207
+ self.transformer.update(
1208
+ {"wpe": nn.Embedding(config.max_sequence_length, config.d_model, device=config.init_device)}
1209
+ )
1210
+ if not config.weight_tying:
1211
+ self.transformer.update(
1212
+ {
1213
+ "ff_out": nn.Linear(
1214
+ config.d_model,
1215
+ config.embedding_size or config.vocab_size,
1216
+ bias=config.include_bias,
1217
+ device=config.init_device,
1218
+ )
1219
+ }
1220
+ )
1221
+ if config.embedding_layer_norm:
1222
+ self.transformer.update({"emb_norm": LayerNorm.build(config)})
1223
+
1224
+ # When `init_device="meta"` FSDP will call `reset_parameters()` to initialize weights.
1225
+ if init_params and self.config.init_device != "meta":
1226
+ self.reset_parameters()
1227
+ self.__num_fwd_flops: Optional[int] = None
1228
+ self.__num_bck_flops: Optional[int] = None
1229
+
1230
+ # Warm up cache.
1231
+ if self.config.alibi:
1232
+ get_causal_attention_bias(self.__cache, config.max_sequence_length, _non_meta_init_device(config))
1233
+ self.get_alibi_attention_bias(config.max_sequence_length, _non_meta_init_device(config))
1234
+
1235
+ def set_activation_checkpointing(
1236
+ self, strategy: Optional[ActivationCheckpointingStrategy], checkpoint_func: Optional[Callable] = None
1237
+ ):
1238
+ self.activation_checkpointing_strategy = strategy
1239
+ if self.config.block_group_size != 1:
1240
+ for block_group in self.transformer.block_groups:
1241
+ block_group.set_activation_checkpointing(strategy, checkpoint_func=checkpoint_func)
1242
+ else:
1243
+ for block in self.transformer.blocks:
1244
+ block.set_activation_checkpointing(strategy, checkpoint_func=checkpoint_func)
1245
+
1246
+ @property
1247
+ def device(self) -> torch.device:
1248
+ device: torch.device = self.transformer.wte.weight.device # type: ignore
1249
+ if device.type == "meta":
1250
+ return _non_meta_init_device(self.config)
1251
+ else:
1252
+ return device
1253
+
1254
+ def reset_parameters(self):
1255
+ log.info("Initializing model parameters...")
1256
+ # Top-level embeddings / linear layers.
1257
+
1258
+ if self.config.init_fn == InitFnType.normal:
1259
+ # Note: We may potentially want to multiply the std by a factor of sqrt(d) in case of `scale_logits`
1260
+ # and `weight_tying`. However, we are currently not using either, and may need to rethink the init logic
1261
+ # if/when we do want it.
1262
+ wte_std = self.config.emb_init_std or self.config.init_std
1263
+ wte_cutoff_factor = self.config.init_cutoff_factor
1264
+ elif self.config.init_fn == InitFnType.mitchell:
1265
+ wte_std = self.config.emb_init_std or 1.0 / math.sqrt(self.config.d_model)
1266
+ wte_cutoff_factor = self.config.init_cutoff_factor or 3.0
1267
+ elif self.config.init_fn == InitFnType.full_megatron:
1268
+ wte_std = self.config.init_std
1269
+ if self.config.emb_init_std is not None:
1270
+ wte_std = self.config.emb_init_std
1271
+ elif self.config.scale_emb_init:
1272
+ wte_std *= math.sqrt(self.config.d_model)
1273
+ wte_cutoff_factor = self.config.init_cutoff_factor or 3.0
1274
+ else:
1275
+ raise NotImplementedError(self.config.init_fn)
1276
+
1277
+ init_normal(self.transformer.wte, std=wte_std, init_cutoff_factor=wte_cutoff_factor)
1278
+
1279
+ if hasattr(self.transformer, "wpe"):
1280
+ if self.config.init_fn == InitFnType.normal:
1281
+ wpe_std = self.config.init_std
1282
+ wpe_cutoff_factor = self.config.init_cutoff_factor
1283
+ elif self.config.init_fn == InitFnType.mitchell:
1284
+ wpe_std = 1 / math.sqrt(self.config.d_model)
1285
+ wpe_cutoff_factor = self.config.init_cutoff_factor or 3.0
1286
+ elif self.config.init_fn == InitFnType.full_megatron:
1287
+ wpe_std = self.config.init_std
1288
+ wpe_cutoff_factor = self.config.init_cutoff_factor or 3.0
1289
+ else:
1290
+ raise NotImplementedError(self.config.init_fn)
1291
+
1292
+ init_normal(self.transformer.wpe, std=wpe_std, init_cutoff_factor=wpe_cutoff_factor)
1293
+
1294
+ # Top-level layer norm.
1295
+ self.transformer.ln_f.reset_parameters() # type: ignore
1296
+
1297
+ # Output weights.
1298
+ if hasattr(self.transformer, "ff_out"):
1299
+ if self.config.init_fn == InitFnType.normal:
1300
+ ff_out_std = self.config.init_std
1301
+ ff_out_cutoff_factor = self.config.init_cutoff_factor
1302
+ elif self.config.init_fn == InitFnType.mitchell:
1303
+ ff_out_std = 1 / math.sqrt(self.config.d_model)
1304
+ ff_out_cutoff_factor = self.config.init_cutoff_factor or 3.0
1305
+ elif self.config.init_fn == InitFnType.full_megatron:
1306
+ ff_out_std = 1 / math.sqrt(self.config.d_model)
1307
+ ff_out_cutoff_factor = self.config.init_cutoff_factor or 3.0
1308
+ else:
1309
+ raise NotImplementedError(self.config.init_fn)
1310
+
1311
+ init_normal(self.transformer.ff_out, ff_out_std, ff_out_cutoff_factor)
1312
+
1313
+ # Let the blocks handle themselves.
1314
+ if self.config.block_group_size == 1:
1315
+ for block in self.transformer.blocks:
1316
+ block.reset_parameters()
1317
+ else:
1318
+ for block_group in self.transformer.block_groups:
1319
+ block_group.reset_parameters()
1320
+
1321
+ def get_alibi_attention_bias(self, seq_len: int, device: torch.device) -> torch.Tensor:
1322
+ if (alibi_bias := self.__cache.get("alibi_attention_bias")) is not None and alibi_bias.shape[
1323
+ -1
1324
+ ] >= seq_len:
1325
+ if alibi_bias.device != device:
1326
+ alibi_bias = alibi_bias.to(device)
1327
+ self.__cache["alibi_attention_bias"] = alibi_bias
1328
+ return alibi_bias
1329
+ with torch.autocast(device.type, enabled=False):
1330
+ alibi_bias = alibi_attention_bias(seq_len, self.config, device)
1331
+ self.__cache["alibi_attention_bias"] = alibi_bias
1332
+ return alibi_bias
1333
+
1334
+ def forward(
1335
+ self,
1336
+ input_ids: torch.LongTensor,
1337
+ input_embeddings: Optional[torch.FloatTensor] = None,
1338
+ attention_mask: Optional[torch.Tensor] = None,
1339
+ attention_bias: Optional[torch.Tensor] = None,
1340
+ past_key_values: Optional[Sequence[Tuple[torch.Tensor, torch.Tensor]]] = None,
1341
+ use_cache: bool = False,
1342
+ last_logits_only: bool = False,
1343
+ output_hidden_states: Optional[bool] = None,
1344
+ doc_lens: Optional[torch.Tensor] = None,
1345
+ max_doc_lens: Optional[Sequence[int]] = None,
1346
+ ) -> OLMoOutput:
1347
+ """
1348
+ :param input_ids: A tensor of shape `(batch_size, seq_len)`.
1349
+ :param input_embeddings: A tensor of shape `(batch_size, seq_len, d_model)` with input
1350
+ embeddings. When provided, it is treated as the output of the input embedding layer.
1351
+ :param attention_mask: A tensor of shape `(batch_size, seq_len)` that indicates
1352
+ which input IDs are masked. A `1` value in the mask means that
1353
+ the corresponding input ID should *not* be ignored. A `0` means
1354
+ that the corresponding input ID is masked.
1355
+
1356
+ This has the same meaning as the `attention_mask` in HuggingFace's `transformers`
1357
+ library.
1358
+ :param attention_bias: A tensor of shape `(batch_size, 1, seq_len, seq_len)`,
1359
+ `(1, 1, seq_len, seq_len)`, or `(seq_len, seq_len)`. This is used
1360
+ to introduce causal or other biases.
1361
+
1362
+ If the tensor is a bool or byte tensor, a `True` or `1` at `attention_bias[:, :, i, j]`
1363
+ indicates that the i-th element in the sequence is allowed to attend to the j-th
1364
+ element in the sequence.
1365
+
1366
+ If the tensor is a float tensor, it will just be added to the attention
1367
+ scores before the softmax.
1368
+
1369
+ The default is causal, which corresponds to a lower-diagonal byte matrix of ones.
1370
+ :param past_key_values: Pre-computed keys and values for each attention block.
1371
+ Can be used to speed up sequential decoding. The `input_ids` which have
1372
+ their past given to this model should not be passed as `input_ids` as they have already been computed.
1373
+ :param use_cache: If `True`, return key and value tensors for each block.
1374
+ :param last_logits_only: If `True`, only compute the logits for the last token of each sequence.
1375
+ This can speed up decoding when you only care about the next token.
1376
+ :param doc_lens: Document lengths to use in attention for intra-document masking.
1377
+ Shape `(batch_size, max_docs)`.
1378
+ :param max_doc_lens: Maximum document length for each instance in the batch.
1379
+ """
1380
+ output_hidden_states = output_hidden_states if output_hidden_states is not None else False
1381
+
1382
+ if past_key_values:
1383
+ assert len(past_key_values) == self.config.n_layers
1384
+
1385
+ batch_size, seq_len = input_ids.size() if input_embeddings is None else input_embeddings.size()[:2]
1386
+ if past_key_values is None:
1387
+ past_length = 0
1388
+ else:
1389
+ past_length = past_key_values[0][0].size(-2)
1390
+
1391
+ max_doc_len: Optional[int] = None
1392
+ cu_doc_lens: Optional[torch.Tensor] = None
1393
+ if doc_lens is not None and max_doc_lens is not None:
1394
+ max_doc_len = max(max_doc_lens)
1395
+ cu_doc_lens = get_cumulative_document_lengths(doc_lens)
1396
+
1397
+ # Get embeddings of input.
1398
+ # shape: (batch_size, seq_len, d_model)
1399
+ x = self.transformer.wte(input_ids) if input_embeddings is None else input_embeddings # type: ignore
1400
+
1401
+ # Apply embedding layer norm.
1402
+ if self.config.embedding_layer_norm:
1403
+ x = self.transformer.emb_norm(x)
1404
+
1405
+ if not (self.config.alibi or self.config.rope):
1406
+ # Get positional embeddings.
1407
+ # shape: (1, seq_len)
1408
+ pos = torch.arange(past_length, past_length + seq_len, dtype=torch.long, device=x.device).unsqueeze(0)
1409
+ # shape: (1, seq_len, d_model)
1410
+ pos_emb = self.transformer.wpe(pos) # type: ignore
1411
+ x = pos_emb + x
1412
+
1413
+ # Apply dropout.
1414
+ # shape: (batch_size, seq_len, d_model)
1415
+ x = self.transformer.emb_drop(x) # type: ignore
1416
+
1417
+ # Transform the attention mask into what the blocks expect.
1418
+ if attention_mask is not None:
1419
+ # shape: (batch_size, 1, 1, seq_len)
1420
+ attention_mask = attention_mask.to(dtype=torch.float).view(batch_size, -1)[:, None, None, :]
1421
+ attention_mask = (1.0 - attention_mask) * torch.finfo(attention_mask.dtype).min
1422
+
1423
+ # Merge attention mask with attention bias.
1424
+ if (
1425
+ attention_bias is not None
1426
+ or attention_mask is not None
1427
+ or self.config.alibi
1428
+ # NOTE (epwalsh): we need to initialize the attn bias in order for attn to work properly
1429
+ # with key+value cache. Otherwise `F.scaled_dot_product_attention()` doesn't seem to compute
1430
+ # scores correctly.
1431
+ or past_key_values is not None
1432
+ ):
1433
+ if attention_bias is None and self.config.alibi:
1434
+ attention_bias = get_causal_attention_bias(
1435
+ self.__cache, past_length + seq_len, x.device
1436
+ ) + self.get_alibi_attention_bias(past_length + seq_len, x.device)
1437
+ elif attention_bias is None:
1438
+ attention_bias = get_causal_attention_bias(self.__cache, past_length + seq_len, x.device)
1439
+ elif attention_bias.dtype in (torch.int8, torch.bool):
1440
+ attention_bias = attention_bias.to(dtype=torch.float)
1441
+ attention_bias.masked_fill_(attention_bias == 0.0, torch.finfo(attention_bias.dtype).min)
1442
+
1443
+ # Transform to the right shape and data type.
1444
+ mask_len = seq_len
1445
+ if attention_mask is not None:
1446
+ mask_len = attention_mask.shape[-1]
1447
+ elif past_key_values is not None:
1448
+ mask_len = past_key_values[0][0].shape[-2] + seq_len
1449
+ attention_bias = attention_bias[:, :, :mask_len, :mask_len].to(dtype=torch.float)
1450
+
1451
+ # Add in the masking bias.
1452
+ if attention_mask is not None:
1453
+ attention_bias = attention_bias + attention_mask
1454
+ # Might get -infs after adding attention mask, since dtype.min + dtype.min = -inf.
1455
+ # `F.scaled_dot_product_attention()` doesn't handle -inf like you'd expect, instead
1456
+ # it can produce NaNs.
1457
+ ensure_finite_(attention_bias, check_neg_inf=True, check_pos_inf=False)
1458
+
1459
+ attn_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = [] if use_cache else None
1460
+
1461
+ # decoder layers
1462
+ all_hidden_states = []
1463
+
1464
+ # Apply blocks one-by-one.
1465
+ if self.config.block_group_size == 1:
1466
+ for block_idx, block in enumerate(self.transformer.blocks):
1467
+ if output_hidden_states:
1468
+ # add hidden states
1469
+ all_hidden_states.append(x)
1470
+
1471
+ layer_past = None if past_key_values is None else past_key_values[block_idx]
1472
+ if should_checkpoint_block(self.activation_checkpointing_strategy, block_idx):
1473
+ # shape: (batch_size, seq_len, d_model)
1474
+ x, cache = self._activation_checkpoint_fn(
1475
+ block,
1476
+ x,
1477
+ attention_bias=attention_bias,
1478
+ layer_past=layer_past,
1479
+ use_cache=use_cache,
1480
+ max_doc_len=max_doc_len,
1481
+ cu_doc_lens=cu_doc_lens,
1482
+ )
1483
+ else:
1484
+ # shape: (batch_size, seq_len, d_model)
1485
+ x, cache = block(
1486
+ x,
1487
+ attention_bias=attention_bias,
1488
+ layer_past=layer_past,
1489
+ use_cache=use_cache,
1490
+ max_doc_len=max_doc_len,
1491
+ cu_doc_lens=cu_doc_lens,
1492
+ )
1493
+
1494
+ if attn_key_values is not None:
1495
+ assert cache is not None
1496
+ attn_key_values.append(cache)
1497
+ else:
1498
+ for group_idx, block_group in enumerate(self.transformer.block_groups):
1499
+ if output_hidden_states:
1500
+ # add hidden states
1501
+ all_hidden_states.append(x)
1502
+
1503
+ layers_past = (
1504
+ None
1505
+ if past_key_values is None
1506
+ else past_key_values[
1507
+ group_idx * self.config.block_group_size : (group_idx + 1) * self.config.block_group_size
1508
+ ]
1509
+ )
1510
+ x, cache = block_group(
1511
+ x,
1512
+ attention_bias=attention_bias,
1513
+ layers_past=layers_past,
1514
+ use_cache=use_cache,
1515
+ max_doc_len=max_doc_len,
1516
+ cu_doc_lens=cu_doc_lens,
1517
+ )
1518
+ if attn_key_values is not None:
1519
+ assert cache is not None
1520
+ attn_key_values.extend(cache)
1521
+
1522
+ if last_logits_only:
1523
+ # shape: (batch_size, 1, d_model)
1524
+ x = x[:, -1, :].unsqueeze(1)
1525
+
1526
+ # Apply final layer norm.
1527
+ # shape: (batch_size, seq_len or 1, d_model)
1528
+ x = self.transformer.ln_f(x) # type: ignore
1529
+ if output_hidden_states:
1530
+ # add final hidden state post-final-layernorm, following HuggingFace's convention
1531
+ all_hidden_states.append(x)
1532
+
1533
+ # Get logits.
1534
+ # shape: (batch_size, seq_len or 1, vocab_size)
1535
+ if self.config.weight_tying:
1536
+ logits = F.linear(x, self.transformer.wte.weight, None) # type: ignore
1537
+ else:
1538
+ logits = self.transformer.ff_out(x) # type: ignore
1539
+ if self.config.scale_logits:
1540
+ logits.mul_(1 / math.sqrt(self.config.d_model))
1541
+
1542
+ return OLMoOutput(
1543
+ logits=logits,
1544
+ attn_key_values=attn_key_values,
1545
+ hidden_states=tuple(all_hidden_states) if output_hidden_states else None,
1546
+ )
1547
+
1548
+ def get_fsdp_wrap_policy(self, wrap_strategy: Optional[FSDPWrapStrategy] = None):
1549
+ if wrap_strategy is None:
1550
+ return None
1551
+
1552
+ # The 'recurse' mode for the wrap function does not behave like you'd expect.
1553
+ # Even if we return False, it may still recurse because PyTorch does what it wants,
1554
+ # not what you want. This causes issues when, for example, we want to wrap 'ff_out' (a linear layer)
1555
+ # but not other linear layers within a block.
1556
+ # So we have to explicitly tell PyTorch which linear layers to wrap, and we also just
1557
+ # return True in 'recurse' mode for simplicity.
1558
+ size_based_module_to_wrap = {self.transformer.wte}
1559
+ if hasattr(self.transformer, "ff_out"):
1560
+ size_based_module_to_wrap.add(self.transformer.ff_out)
1561
+
1562
+ if wrap_strategy == FSDPWrapStrategy.by_block:
1563
+
1564
+ def fsdp_wrap_fn(module, recurse: bool = True, nonwrapped_numel: int = 0):
1565
+ del nonwrapped_numel
1566
+ wrap = isinstance(module, OLMoBlock)
1567
+ if recurse:
1568
+ return True
1569
+ else:
1570
+ return wrap
1571
+
1572
+ return fsdp_wrap_fn
1573
+ elif wrap_strategy == FSDPWrapStrategy.by_block_and_size:
1574
+
1575
+ def fsdp_wrap_fn(module, recurse: bool = True, nonwrapped_numel: int = 0):
1576
+ del nonwrapped_numel
1577
+ wrap = isinstance(module, (OLMoBlock,)) or module in size_based_module_to_wrap
1578
+ if recurse:
1579
+ return True
1580
+ else:
1581
+ return wrap
1582
+
1583
+ return fsdp_wrap_fn
1584
+ elif wrap_strategy == FSDPWrapStrategy.by_block_group:
1585
+ if self.config.block_group_size <= 1:
1586
+ raise OLMoConfigurationError(
1587
+ "'by_block_group' FSDP wrapping strategy requires block group size greater than 1"
1588
+ )
1589
+
1590
+ def fsdp_wrap_fn(module, recurse: bool = True, nonwrapped_numel: int = 0):
1591
+ del nonwrapped_numel
1592
+ wrap = isinstance(module, OLMoBlockGroup)
1593
+ if recurse:
1594
+ return True
1595
+ else:
1596
+ return wrap
1597
+
1598
+ return fsdp_wrap_fn
1599
+ elif wrap_strategy == FSDPWrapStrategy.by_block_group_and_size:
1600
+ if self.config.block_group_size <= 1:
1601
+ raise OLMoConfigurationError(
1602
+ "'by_block_group_and_size' FSDP wrapping strategy requires block group size greater than 1"
1603
+ )
1604
+
1605
+ def fsdp_wrap_fn(module, recurse: bool = True, nonwrapped_numel: int = 0):
1606
+ del nonwrapped_numel
1607
+ wrap = isinstance(module, (OLMoBlockGroup,)) or module in size_based_module_to_wrap
1608
+ if recurse:
1609
+ return True
1610
+ else:
1611
+ return wrap
1612
+
1613
+ return fsdp_wrap_fn
1614
+ elif wrap_strategy == FSDPWrapStrategy.size_based:
1615
+ from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy
1616
+
1617
+ return size_based_auto_wrap_policy
1618
+ elif wrap_strategy in {
1619
+ FSDPWrapStrategy.one_in_two,
1620
+ FSDPWrapStrategy.one_in_three,
1621
+ FSDPWrapStrategy.one_in_four,
1622
+ FSDPWrapStrategy.one_in_five,
1623
+ }:
1624
+ c = {
1625
+ FSDPWrapStrategy.one_in_two: 2,
1626
+ FSDPWrapStrategy.one_in_three: 3,
1627
+ FSDPWrapStrategy.one_in_four: 4,
1628
+ FSDPWrapStrategy.one_in_five: 5,
1629
+ }[wrap_strategy]
1630
+
1631
+ def fsdp_wrap_fn(module, recurse: bool = True, nonwrapped_numel: int = 0):
1632
+ del nonwrapped_numel
1633
+ wrap = isinstance(module, OLMoBlock) and module.layer_id % c == 0
1634
+ if recurse:
1635
+ return True
1636
+ else:
1637
+ return wrap
1638
+
1639
+ return fsdp_wrap_fn
1640
+ else:
1641
+ raise NotImplementedError(wrap_strategy)
1642
+
1643
+ def num_params(self, include_embedding: bool = True) -> int:
1644
+ """
1645
+ Get the total number of parameters.
1646
+ """
1647
+ params = (np for np in self.named_parameters())
1648
+ if not include_embedding:
1649
+ params = filter( # type: ignore
1650
+ lambda np: ".wte." not in np[0] and ".wpe." not in np[0],
1651
+ params,
1652
+ )
1653
+ return sum(p.numel() for _, p in params)
1654
+
1655
+ @property
1656
+ def num_fwd_flops(self):
1657
+ if self.__num_fwd_flops:
1658
+ return self.__num_fwd_flops
1659
+
1660
+ # embedding table is just a lookup in the forward pass
1661
+ n_params = self.num_params(include_embedding=False)
1662
+ # the number of parameters is approximately the number of multiply-accumulates (MAC) in the network
1663
+ # each MAC has 2 FLOPs - we multiply by 2 ie 2 * n_param
1664
+ # this gets us FLOPs / token
1665
+ params_flops_per_token = 2 * n_params
1666
+ # there are 2 FLOPS per mac; there is A=Q*K^T and out=A*V ops (ie mult by 2)
1667
+ attn_flops_per_token = (
1668
+ self.config.n_layers * 2 * 2 * (self.config.d_model * self.config.max_sequence_length)
1669
+ )
1670
+ self.__num_fwd_flops = params_flops_per_token + attn_flops_per_token
1671
+ return self.__num_fwd_flops
1672
+
1673
+ @property
1674
+ def num_bck_flops(self):
1675
+ if self.__num_bck_flops:
1676
+ return self.__num_bck_flops
1677
+
1678
+ n_params = self.num_params()
1679
+ params_flops_per_token = 4 * n_params
1680
+ attn_flops_per_token = self.config.n_layers * 8 * (self.config.d_model * self.config.max_sequence_length)
1681
+ self.__num_bck_flops = params_flops_per_token + attn_flops_per_token
1682
+ return self.__num_bck_flops
1683
+
1684
+ def generate(
1685
+ self,
1686
+ input_ids: torch.LongTensor,
1687
+ attention_mask: Optional[torch.Tensor] = None,
1688
+ attention_bias: Optional[torch.Tensor] = None,
1689
+ max_steps: int = 10,
1690
+ beam_size: int = 1,
1691
+ per_node_beam_size: Optional[int] = None,
1692
+ sampler: Optional[Sampler] = None,
1693
+ min_steps: Optional[int] = None,
1694
+ final_sequence_scorer: Optional[FinalSequenceScorer] = None,
1695
+ constraints: Optional[List[Constraint]] = None,
1696
+ ) -> OLMoGenerateOutput:
1697
+ """
1698
+ Generate token IDs using beam search.
1699
+
1700
+ Note that by default ``beam_size`` is set to 1, which is greedy decoding.
1701
+
1702
+ :param input_ids: A tensor of shape `(batch_size, seq_len)`.
1703
+ :param attention_mask: A optional tensor of shape `(batch_size, seq_len)`, the same
1704
+ as for the forward method.
1705
+ :param attention_bias: A tensor of shape
1706
+ `(batch_size, 1, seq_len + tokens_to_generate, seq_len + tokens_to_generate)`,
1707
+ the same as for the forward method except only one shape is excepted here.
1708
+
1709
+ For an explanation of the other arguments, see :class:`BeamSearch`.
1710
+ """
1711
+ beam_search = BeamSearch(
1712
+ self.config.eos_token_id,
1713
+ max_steps=max_steps,
1714
+ beam_size=beam_size,
1715
+ per_node_beam_size=per_node_beam_size,
1716
+ sampler=sampler,
1717
+ min_steps=min_steps,
1718
+ final_sequence_scorer=final_sequence_scorer,
1719
+ constraints=constraints,
1720
+ )
1721
+
1722
+ # Validate inputs.
1723
+ batch_size, seq_len = input_ids.shape
1724
+ if attention_mask is not None:
1725
+ assert attention_mask.shape == (batch_size, seq_len)
1726
+ if attention_bias is not None:
1727
+ assert len(attention_bias.shape) == 4
1728
+ assert attention_bias.shape[:2] == (batch_size, 1)
1729
+ assert (
1730
+ seq_len + beam_search.max_steps
1731
+ <= attention_bias.shape[2]
1732
+ == attention_bias.shape[3]
1733
+ <= self.config.max_sequence_length
1734
+ )
1735
+
1736
+ tokens_generated = 0
1737
+
1738
+ def flatten_past_key_values(
1739
+ past_key_values: List[Tuple[torch.Tensor, torch.Tensor]],
1740
+ ) -> Dict[str, torch.Tensor]:
1741
+ out = {}
1742
+ for i, (key, value) in enumerate(past_key_values):
1743
+ out[f"past_key_{i}"] = key
1744
+ out[f"past_value_{i}"] = value
1745
+ return out
1746
+
1747
+ def unflatten_past_key_values(
1748
+ past_key_values: Dict[str, torch.Tensor],
1749
+ ) -> List[Tuple[torch.Tensor, torch.Tensor]]:
1750
+ out = []
1751
+ for i in range(self.config.n_layers):
1752
+ past_key = past_key_values[f"past_key_{i}"]
1753
+ past_value = past_key_values[f"past_value_{i}"]
1754
+ out.append((past_key, past_value))
1755
+ return out
1756
+
1757
+ def step(
1758
+ last_predictions: torch.Tensor, state: dict[str, torch.Tensor]
1759
+ ) -> tuple[torch.Tensor, dict[str, torch.Tensor]]:
1760
+ nonlocal tokens_generated
1761
+
1762
+ attention_mask = state.get("attention_mask")
1763
+ attention_bias = state.get("attention_bias")
1764
+
1765
+ if tokens_generated > 0:
1766
+ past_key_values = unflatten_past_key_values(state)
1767
+ input_ids = last_predictions.unsqueeze(1)
1768
+ if attention_mask is not None:
1769
+ group_size = input_ids.shape[0]
1770
+ attention_mask = torch.cat((attention_mask, attention_mask.new_ones((group_size, 1))), dim=-1)
1771
+ else:
1772
+ past_key_values = None
1773
+ input_ids = state["input_ids"]
1774
+
1775
+ tokens_generated += 1
1776
+
1777
+ # Run forward pass of model to get logits, then normalize to get log probs.
1778
+ output = self(
1779
+ input_ids,
1780
+ attention_mask=attention_mask,
1781
+ attention_bias=attention_bias,
1782
+ past_key_values=past_key_values,
1783
+ use_cache=True,
1784
+ last_logits_only=True,
1785
+ )
1786
+ log_probs = F.log_softmax(output.logits[:, -1, :], dim=-1)
1787
+
1788
+ # Create new state.
1789
+ state = flatten_past_key_values(output.attn_key_values)
1790
+ if attention_mask is not None:
1791
+ state["attention_mask"] = attention_mask
1792
+ if attention_bias is not None:
1793
+ state["attention_bias"] = attention_bias
1794
+
1795
+ return log_probs, state
1796
+
1797
+ initial_preds = input_ids.new_zeros((batch_size,)) # This is arbitrary, we won't use this.
1798
+ state: dict[str, torch.Tensor] = {"input_ids": input_ids}
1799
+ if attention_mask is not None:
1800
+ state["attention_mask"] = attention_mask
1801
+ if attention_bias is not None:
1802
+ state["attention_bias"] = attention_bias
1803
+ with torch.no_grad():
1804
+ token_ids, scores = beam_search.search(initial_preds, state, step)
1805
+
1806
+ return OLMoGenerateOutput(
1807
+ token_ids=token_ids, # type: ignore[arg-type]
1808
+ scores=scores, # type: ignore[arg-type]
1809
+ )
1810
+
1811
+ @classmethod
1812
+ def from_checkpoint(
1813
+ cls, checkpoint_dir: PathOrStr, device: str = "cpu", checkpoint_type: Optional[CheckpointType] = None
1814
+ ) -> OLMo:
1815
+ """
1816
+ Load an OLMo model from a checkpoint.
1817
+ """
1818
+ from .util import resource_path
1819
+
1820
+ # Guess checkpoint type.
1821
+ if checkpoint_type is None:
1822
+ try:
1823
+ if resource_path(checkpoint_dir, "model.pt").is_file():
1824
+ checkpoint_type = CheckpointType.unsharded
1825
+ else:
1826
+ checkpoint_type = CheckpointType.sharded
1827
+ except FileNotFoundError:
1828
+ checkpoint_type = CheckpointType.sharded
1829
+
1830
+ # Load config.
1831
+ config_path = resource_path(checkpoint_dir, "config.yaml")
1832
+ model_config = ModelConfig.load(config_path, key="model", validate_paths=False)
1833
+
1834
+ if checkpoint_type == CheckpointType.unsharded:
1835
+ # Initialize model (always on CPU to start with so we don't run out of GPU memory).
1836
+ model_config.init_device = "cpu"
1837
+ model = OLMo(model_config)
1838
+
1839
+ # Load state dict directly to target device.
1840
+ state_dict_path = resource_path(checkpoint_dir, "model.pt")
1841
+ state_dict = torch.load(state_dict_path, map_location="cpu")
1842
+ model.load_state_dict(model._make_state_dict_compatible(state_dict)[0])
1843
+ model = model.to(torch.device(device))
1844
+ else:
1845
+ train_config = TrainConfig.load(config_path)
1846
+ if train_config.sharded_checkpointer == ShardedCheckpointerType.olmo_core:
1847
+ from olmo_core.distributed.checkpoint import ( # type: ignore
1848
+ load_model_and_optim_state,
1849
+ )
1850
+
1851
+ model_config.init_device = device
1852
+ model = OLMo(model_config)
1853
+ load_model_and_optim_state(checkpoint_dir, model)
1854
+ else:
1855
+ # train_config.sharded_checkpointer == ShardedCheckpointerType.torch_new
1856
+ from .checkpoint import load_model_state
1857
+
1858
+ # Initialize model on target device. In this case the state dict is loaded in-place
1859
+ # so it's not necessary to start on CPU if the target device is a GPU.
1860
+ model_config.init_device = device
1861
+ model = OLMo(model_config)
1862
+
1863
+ # Load state dict in place.
1864
+ load_model_state(checkpoint_dir, model)
1865
+
1866
+ return model.eval()
1867
+
1868
+ def _make_state_dict_compatible(
1869
+ self, state_dict: Dict[str, torch.Tensor]
1870
+ ) -> Tuple[Dict[str, torch.Tensor], Dict[str, Set[str]]]:
1871
+ """
1872
+ Handles some cases where the state dict is valid yet may need to be transformed in order to
1873
+ be loaded.
1874
+
1875
+ This modifies the state dict in-place and also returns it, along with a mapping of original key
1876
+ names to new key names in cases where the keys were simply renamed. That mapping can be used
1877
+ to make a corresponding optimizer state dict compatible as well.
1878
+ """
1879
+ import re
1880
+ from fnmatch import fnmatch
1881
+
1882
+ new_keys_to_og_keys: Dict[str, str] = {}
1883
+
1884
+ # Remove "_fsdp_wrapped_module." prefix from all keys. We don't want this prefix when the model is
1885
+ # not wrapped in FSDP. And when the model is wrapped in FSDP, loading this state dict will still work
1886
+ # fine without the prefixes. This also simplifies the other steps below.
1887
+ for key in list(state_dict.keys()):
1888
+ state_dict[(new_key := key.replace("_fsdp_wrapped_module.", ""))] = state_dict.pop(key)
1889
+ new_keys_to_og_keys[new_key] = key
1890
+
1891
+ # For backwards compatibility prior to fixing https://github.com/allenai/LLM/issues/222
1892
+ if self.config.block_type == BlockType.sequential:
1893
+ for key in list(state_dict.keys()):
1894
+ if fnmatch(key, "transformer.*.norm.weight"):
1895
+ tensor = state_dict.pop(key)
1896
+ state_dict[(new_key := key.replace("norm.weight", "attn_norm.weight"))] = tensor
1897
+ new_keys_to_og_keys[new_key] = new_keys_to_og_keys[key]
1898
+ state_dict[(new_key := key.replace("norm.weight", "ff_norm.weight"))] = tensor.clone()
1899
+ new_keys_to_og_keys[new_key] = new_keys_to_og_keys[key]
1900
+ del new_keys_to_og_keys[key]
1901
+ elif fnmatch(key, "transformer.*.norm.bias"):
1902
+ tensor = state_dict.pop(key)
1903
+ state_dict[(new_key := key.replace("norm.bias", "attn_norm.bias"))] = tensor
1904
+ new_keys_to_og_keys[new_key] = new_keys_to_og_keys[key]
1905
+ state_dict[(new_key := key.replace("norm.bias", "ff_norm.bias"))] = tensor.clone()
1906
+ new_keys_to_og_keys[new_key] = new_keys_to_og_keys[key]
1907
+ del new_keys_to_og_keys[key]
1908
+
1909
+ # For loading a state dict that was saved with a different `block_group_size`.
1910
+ if "transformer.block_groups.0.0.attn_out.weight" in state_dict.keys():
1911
+ state_dict_block_group_size = len(
1912
+ [k for k in state_dict.keys() if fnmatch(k, "transformer.block_groups.0.*.attn_out.weight")]
1913
+ )
1914
+ else:
1915
+ state_dict_block_group_size = 1
1916
+ if self.config.block_group_size != state_dict_block_group_size:
1917
+ log.info(
1918
+ f"Regrouping state dict blocks from group size {state_dict_block_group_size} to "
1919
+ f"group size {self.config.block_group_size}"
1920
+ )
1921
+ # For simplicity we're first going to flatten out the block groups in the state dict (if necessary)
1922
+ # and then (re-)group them into the right block sizes.
1923
+ if state_dict_block_group_size > 1:
1924
+ for key in list(state_dict.keys()):
1925
+ if (m := re.match(r"transformer.block_groups\.(\d+)\.(\d+)\..*", key)) is not None:
1926
+ group_idx, group_block_idx = int(m.group(1)), int(m.group(2))
1927
+ block_idx = (group_idx * state_dict_block_group_size) + group_block_idx
1928
+ state_dict[
1929
+ (
1930
+ new_key := key.replace(
1931
+ f"block_groups.{group_idx}.{group_block_idx}.", f"blocks.{block_idx}."
1932
+ )
1933
+ )
1934
+ ] = state_dict.pop(key)
1935
+ new_keys_to_og_keys[new_key] = new_keys_to_og_keys.pop(key)
1936
+
1937
+ if self.config.block_group_size > 1:
1938
+ # Group the state dict blocks into the right block size.
1939
+ for key in list(state_dict.keys()):
1940
+ if (m := re.match(r"transformer.blocks\.(\d+)\..*", key)) is not None:
1941
+ block_idx = int(m.group(1))
1942
+ group_idx, group_block_idx = (
1943
+ block_idx // self.config.block_group_size,
1944
+ block_idx % self.config.block_group_size,
1945
+ )
1946
+ state_dict[
1947
+ (
1948
+ new_key := key.replace(
1949
+ f"blocks.{block_idx}.", f"block_groups.{group_idx}.{group_block_idx}."
1950
+ )
1951
+ )
1952
+ ] = state_dict.pop(key)
1953
+ new_keys_to_og_keys[new_key] = new_keys_to_og_keys.pop(key)
1954
+
1955
+ og_keys_to_new: Dict[str, Set[str]] = defaultdict(set)
1956
+ for new_key, og_key in new_keys_to_og_keys.items():
1957
+ og_keys_to_new[og_key].add(new_key)
1958
+
1959
+ return state_dict, og_keys_to_new
modeling_fan.py ADDED
@@ -0,0 +1,271 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from dataclasses import fields
3
+ from typing import Callable, List, Optional, Tuple, Union
4
+
5
+ import torch
6
+ from transformers import PreTrainedModel
7
+ from transformers.cache_utils import Cache
8
+ from transformers.modeling_outputs import CausalLMOutputWithPast
9
+ from transformers.models.auto import AutoModelForCausalLM
10
+
11
+ from .config import ActivationCheckpointingStrategy, ModelConfig
12
+ from .model import OLMo
13
+
14
+ from .configuration_olmo import OLMoConfig
15
+ from typing import (
16
+ Callable,
17
+ Dict,
18
+ Iterable,
19
+ List,
20
+ NamedTuple,
21
+ Optional,
22
+ Sequence,
23
+ Set,
24
+ Tuple,
25
+ cast,
26
+ )
27
+ log = logging.getLogger(__name__)
28
+
29
+
30
+ def create_model_config_from_pretrained_config(config: OLMoConfig):
31
+ """
32
+ Utility function
33
+ """
34
+
35
+ kwargs = {}
36
+ for field in fields(ModelConfig):
37
+ kwargs[field.name] = getattr(config, field.name)
38
+
39
+ model_config = ModelConfig(**kwargs)
40
+
41
+ # Handle flash attention settings
42
+ if config._attn_implementation == "flash_attention_2":
43
+ model_config.flash_attention = True
44
+ elif config._attn_implementation in ("eager", "sdpa"):
45
+ model_config.flash_attention = False
46
+ else:
47
+ raise ValueError(f"Unexpected _attn_implementation {config._attn_implementation}")
48
+
49
+ return model_config
50
+
51
+
52
+ class OLMoForCausalLM(PreTrainedModel):
53
+ """
54
+ Extremely barebones HF model wrapper.
55
+ """
56
+
57
+ config_class = OLMoConfig
58
+ base_model_prefix = "model"
59
+ _no_split_modules = ["OLMoBlock"]
60
+ _supports_flash_attn_2 = True
61
+ _supports_sdpa = True
62
+ supports_gradient_checkpointing = True
63
+
64
+ def __init__(self, config: OLMoConfig, model: Optional[OLMo] = None, init_params: bool = False):
65
+ super().__init__(config)
66
+ self._gradient_checkpointing_func: Optional[Callable] = None
67
+ self._gradient_checkpointing = False
68
+
69
+ if not model:
70
+ model_config = create_model_config_from_pretrained_config(config)
71
+ # Initialize model (always on CPU to start with so we don't run out of GPU memory).
72
+ model_config.init_device = "cpu"
73
+ self.model = OLMo(model_config, init_params=init_params)
74
+ else:
75
+ self.model = model
76
+
77
+ @property
78
+ def gradient_checkpointing(self) -> bool:
79
+ return self._gradient_checkpointing
80
+
81
+ @gradient_checkpointing.setter
82
+ def gradient_checkpointing(self, enabled: bool):
83
+ if self._gradient_checkpointing == enabled:
84
+ return
85
+
86
+ # HF does not specify a way to pass checkpointing strategies, so we pick
87
+ # whole layer as our strategy. We can make this configurable later if needed.
88
+ checkpointing_strategy = ActivationCheckpointingStrategy.whole_layer if enabled else None
89
+ self.model.set_activation_checkpointing(
90
+ checkpointing_strategy, checkpoint_func=self._gradient_checkpointing_func
91
+ )
92
+ self._gradient_checkpointing = enabled
93
+
94
+ def forward(
95
+ self,
96
+ input_ids: torch.LongTensor,
97
+ inputs_embeds: Optional[torch.FloatTensor] = None,
98
+ attention_mask: Optional[torch.Tensor] = None,
99
+ attention_bias: Optional[torch.Tensor] = None,
100
+ # past_key_values: Optional[List[torch.FloatTensor]] = None,
101
+ past_key_values: Optional[Sequence[Tuple[torch.Tensor, torch.Tensor]]] = None,
102
+ labels: Optional[torch.LongTensor] = None,
103
+ use_cache: Optional[bool] = None,
104
+ output_attentions: Optional[bool] = None,
105
+ output_hidden_states: Optional[bool] = None,
106
+ return_dict: Optional[bool] = None,
107
+ cache_position: Optional[
108
+ Cache
109
+ ] = None, # This is a hack mitigation of an issue in transformers `4.39.x` https://github.com/huggingface/transformers/issues/29426
110
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
111
+ if use_cache is None:
112
+ use_cache = self.config.use_cache
113
+
114
+ if output_attentions:
115
+ raise ValueError("output_attentions is not yet supported in OLMo")
116
+
117
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
118
+
119
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
120
+ outputs = self.model.forward(
121
+ input_ids=input_ids,
122
+ input_embeddings=inputs_embeds,
123
+ attention_mask=attention_mask,
124
+ attention_bias=attention_bias,
125
+ past_key_values=past_key_values,
126
+ use_cache=use_cache,
127
+ output_hidden_states=output_hidden_states,
128
+ )
129
+
130
+ logits = outputs.logits
131
+ hidden_states = outputs.hidden_states
132
+
133
+ loss = None
134
+ if labels is not None:
135
+ # Shift so that tokens < n predict n
136
+ shift_logits = logits[..., :-1, :].contiguous()
137
+ shift_labels = labels[..., 1:].contiguous()
138
+ # Flatten the tokens
139
+ loss_fct = torch.nn.CrossEntropyLoss()
140
+ shift_logits = shift_logits.view(-1, self.config.embedding_size)
141
+ shift_labels = shift_labels.view(-1)
142
+ # Enable model parallelism
143
+ shift_labels = shift_labels.to(shift_logits.device)
144
+ loss = loss_fct(shift_logits, shift_labels)
145
+
146
+ if not return_dict:
147
+ output = (logits,) + outputs[1:]
148
+ return (loss,) + output if loss is not None else output
149
+
150
+ return CausalLMOutputWithPast(
151
+ loss=loss,
152
+ logits=logits,
153
+ past_key_values=outputs.attn_key_values,
154
+ hidden_states=hidden_states,
155
+ )
156
+
157
+ def can_generate(self) -> bool:
158
+ return True
159
+
160
+ def prepare_inputs_for_generation(
161
+ self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple]] = None, **kwargs
162
+ ):
163
+ if past_key_values:
164
+ # This is because we want the model to only process the last generated token.
165
+ input_ids = input_ids[:, -1:]
166
+ model_inputs = {"input_ids": input_ids, "past_key_values": past_key_values}
167
+
168
+ model_inputs.update(kwargs)
169
+ model_inputs["use_cache"] = kwargs.pop("use_cache", self.config.use_cache)
170
+ return model_inputs
171
+
172
+ # TODO: these are required to make the implementation complete.
173
+ # def resize_position_embeddings(self, new_num_position_embeddings: int):
174
+ # pass
175
+ #
176
+ # def get_position_embeddings(self) -> Union[nn.Embedding, Tuple[nn.Embedding]]:
177
+ # pass
178
+ #
179
+ # def _reorder_cache(self, past_key_values, beam_idx):
180
+ # pass
181
+
182
+ def get_input_embeddings(self) -> torch.nn.Module:
183
+ return self.model.transformer.wte
184
+
185
+ def set_input_embeddings(self, value: torch.nn.Module):
186
+ self.model.transformer.wte = value
187
+
188
+ def get_output_embeddings(self):
189
+ if self.config.weight_tying:
190
+ return self.model.transformer.wte
191
+ else:
192
+ return self.model.transformer.ff_out
193
+
194
+ def set_output_embeddings(self, value: torch.nn.Module):
195
+ if self.config.weight_tying:
196
+ self.model.transformer.wte = value
197
+ else:
198
+ self.model.transformer.ff_out = value
199
+
200
+ def tie_weights(self):
201
+ """
202
+ This function is intentionally left as a no-op.
203
+
204
+ Weight tying is handled as follows:
205
+ - When the model is initialized, the `ff_out` layer is conditionally defined based on the `weight_tying` configuration.
206
+ See: `if not config.weight_tying: self.transformer.update(...)` in `olmo/model.py`.
207
+ - When computing logits, the `wte` weights are used directly if `weight_tying` is enabled.
208
+ See: `if self.config.weight_tying: logits = F.linear(x, self.transformer.wte.weight, None)` in the `forward` method.
209
+
210
+ Therefore, there is no need to explicitly tie the weights in this function.
211
+ """
212
+ pass
213
+
214
+ def resize_token_embeddings(
215
+ self, new_num_tokens: Optional[int] = None, pad_to_multiple_of: Optional[int] = None
216
+ ) -> torch.nn.Embedding:
217
+ """
218
+ Resizes input token embeddings matrix of the model if `new_num_tokens != config.embedding_size`.
219
+
220
+ Takes care of tying weights embeddings afterwards if the model class has a `tie_weights()` method.
221
+
222
+ Arguments:
223
+ new_num_tokens (`int`, *optional*):
224
+ The new number of tokens in the embedding matrix. Increasing the size will add newly initialized
225
+ vectors at the end. Reducing the size will remove vectors from the end. If not provided or `None`, just
226
+ returns a pointer to the input tokens `torch.nn.Embedding` module of the model without doing anything.
227
+ pad_to_multiple_of (`int`, *optional*):
228
+ If set will pad the embedding matrix to a multiple of the provided value. If `new_num_tokens` is set to
229
+ `None` will just pad the embedding to a multiple of `pad_to_multiple_of`.
230
+
231
+ This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability
232
+ `>= 7.5` (Volta), or on TPUs which benefit from having sequence lengths be a multiple of 128. For more
233
+ details about this, or help on choosing the correct value for resizing, refer to this guide:
234
+ https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html#requirements-tc
235
+
236
+ Return:
237
+ `torch.nn.Embedding`: Pointer to the input tokens Embeddings Module of the model.
238
+
239
+ Note:
240
+ This method differs from the base class implementation by resizing the `embedding_size` attribute of the
241
+ model configuration instead of the `vocab_size`. It also includes a warning if the resized `embedding_size`
242
+ is less than the `vocab_size`. In OLMo, `embedding_size` refers to the dimensionality of the model's token
243
+ embeddings, while `vocab_size` refers to the number of unique tokens in the vocabulary.
244
+ """
245
+ model_embeds = self._resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
246
+ if new_num_tokens is None and pad_to_multiple_of is None:
247
+ return model_embeds
248
+
249
+ # Update base model and current model config
250
+ self.config.embedding_size = model_embeds.weight.shape[0]
251
+ self.model.config.embedding_size = model_embeds.weight.shape[0]
252
+
253
+ # Check if the embedding size is less than the vocab size
254
+ if self.config.embedding_size < self.config.vocab_size:
255
+ warning_message = (
256
+ f"Resizing token embeddings to size {self.config.embedding_size}, which is less than the vocab size "
257
+ f"{self.config.vocab_size} defined in the model configuration. Make sure your tokenizer's vocabulary "
258
+ "size is less than or equal to the new token embedding size."
259
+ )
260
+ log.warning(warning_message)
261
+
262
+ # Tie weights again if needed
263
+ self.tie_weights()
264
+
265
+ return model_embeds
266
+
267
+
268
+ # Register the model so that it is available for transformer pipelines, auto-loading, etc.
269
+ # OLMo is integrated directly in transformers from v4.40.0 onwards, but the version in transformers
270
+ # may not support the newest architectures we create.
271
+ AutoModelForCausalLM.register(OLMoConfig, OLMoForCausalLM)
optim.py ADDED
@@ -0,0 +1,1040 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from abc import ABCMeta, abstractmethod
3
+ from dataclasses import dataclass, replace
4
+ from math import cos, pi, sqrt
5
+ from typing import Any, Dict, List, Optional, Tuple, Union
6
+
7
+ import torch
8
+ import torch.distributed as dist
9
+ import torch.nn as nn
10
+ from torch.distributed.fsdp import FullyShardedDataParallel
11
+ from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
12
+ from torch.optim.optimizer import Optimizer as OptimizerBase
13
+
14
+ from . import LayerNormBase
15
+ from .config import OptimizerType, SchedulerConfig, SchedulerType, TrainConfig
16
+ from .torch_util import get_default_device, is_distributed
17
+
18
+ __all__ = [
19
+ "Optimizer",
20
+ "LionW",
21
+ "AdamW",
22
+ "Scheduler",
23
+ "CosWithWarmup",
24
+ "LinearWithWarmup",
25
+ "InvSqrtWithWarmup",
26
+ "MaxScheduler",
27
+ "ConstantScheduler",
28
+ "CosLinearEnvelope",
29
+ "BoltOnWarmupScheduler",
30
+ "build_optimizer",
31
+ "build_scheduler",
32
+ ]
33
+
34
+
35
+ log = logging.getLogger(__name__)
36
+
37
+
38
+ class Optimizer(OptimizerBase):
39
+ def __init__(self, *args, record_update_metrics: bool = False, selective_updates: bool = False, **kwargs):
40
+ super().__init__(*args, **kwargs)
41
+ self._record_update_metrics = record_update_metrics
42
+ self._collecting_metrics = False
43
+ self._selective_updates = selective_updates
44
+
45
+ def _clean_param_name(self, name: str) -> str:
46
+ return name.replace("_fsdp_wrapped_module.", "")
47
+
48
+ @torch.no_grad()
49
+ def clip_grads_and_collect_metrics(
50
+ self,
51
+ global_step: int,
52
+ collect_param_metrics: bool = True,
53
+ process_group: Optional[dist.ProcessGroup] = None,
54
+ device: Optional[torch.device] = None,
55
+ ) -> Dict[str, torch.Tensor]:
56
+ """
57
+ Clips gradients for every group that has the field `max_grad_norm`.
58
+ At the same time collect metrics for each parameter and its gradient.
59
+ """
60
+ self._collecting_metrics = collect_param_metrics
61
+ device = get_default_device() if device is None else device
62
+
63
+ # NOTE (epwalsh): during distributed training we're making an assumption that the order of
64
+ # the param groups and the params within each group are the same across all ranks.
65
+ # This is justified since we initialize the parameter groups in every rank by iterating over
66
+ # `module.parameters()` or `module.named_modules()` / `module.named_parameters()`, each of which
67
+ # provides a consistent order.
68
+ # For each parameter (with a gradient) we'll collect:
69
+ # - min, max, avg, norm of the param itself
70
+ # - min, max, avg, norm of the param's gradient
71
+ # - min, max, avg, norm of any additional per-parameter optimizer state metrics returned from
72
+ # `self.get_state_for_param()`.
73
+ # Afterwards we'll reduce these all over all ranks.
74
+ per_param_min_metrics: List[torch.Tensor] = []
75
+ per_param_max_metrics: List[torch.Tensor] = []
76
+ per_param_sum_metrics: List[torch.Tensor] = []
77
+ per_param_norm_metrics: List[torch.Tensor] = []
78
+ per_param_numel_metrics: List[torch.Tensor] = []
79
+
80
+ per_param_min_metric_names: List[str] = []
81
+ per_param_max_metric_names: List[str] = []
82
+ per_param_avg_metric_names: List[str] = []
83
+ per_param_norm_metric_names: List[str] = []
84
+
85
+ dst_rank = 0
86
+ if process_group is not None:
87
+ dst_rank = dist.get_global_rank(process_group, 0)
88
+
89
+ #######################################################################
90
+ # part 1: collect metrics locally
91
+ #######################################################################
92
+ for group in self.param_groups:
93
+ for name, p in zip(group["param_names"], group["params"]):
94
+ name = self._clean_param_name(name)
95
+ # Always need to collect the norm of gradients for clipping, even if we're not collecting
96
+ # other metrics.
97
+ tensors: List[Optional[torch.Tensor]] = [p.grad]
98
+ prefixes: List[str] = [f"grad/{name}"]
99
+ if collect_param_metrics:
100
+ state = self.get_state_for_param(p)
101
+ sorted_state_keys = sorted([k for k in state.keys()])
102
+ tensors.extend([p] + [state[key] for key in sorted_state_keys])
103
+ prefixes.extend([f"param/{name}"] + [f"{key}/{name}" for key in sorted_state_keys])
104
+ assert len(tensors) == len(prefixes)
105
+
106
+ # Get min, max, avg, and norm for all `tensors` associated with the parameter.
107
+ for x, prefix in zip(tensors, prefixes):
108
+ # grad or state tensors could be none for params that have their shards completely on
109
+ # other ranks.
110
+ if x is not None and x.numel() > 0:
111
+ if collect_param_metrics:
112
+ x_abs = x.abs()
113
+ per_param_min_metrics.append(x_abs.min().unsqueeze(0).to(dtype=torch.float32))
114
+ per_param_max_metrics.append(x_abs.max().unsqueeze(0).to(dtype=torch.float32))
115
+ per_param_sum_metrics.append(x.sum().unsqueeze(0).to(dtype=torch.float32))
116
+ per_param_numel_metrics.append(
117
+ torch.tensor([x.numel()], device=device, dtype=torch.float32)
118
+ )
119
+ per_param_norm_metrics.append(
120
+ torch.linalg.vector_norm(x, 2.0, dtype=torch.float32).unsqueeze(0)
121
+ )
122
+ else:
123
+ if collect_param_metrics:
124
+ per_param_min_metrics.append(
125
+ torch.tensor([float("inf")], device=device, dtype=torch.float32)
126
+ )
127
+ per_param_max_metrics.append(torch.tensor([0.0], device=device, dtype=torch.float32))
128
+ per_param_sum_metrics.append(torch.tensor([0.0], device=device, dtype=torch.float32))
129
+ per_param_numel_metrics.append(torch.tensor([0.0], device=device, dtype=torch.float32))
130
+ per_param_norm_metrics.append(torch.tensor([0.0], device=device, dtype=torch.float32))
131
+ if collect_param_metrics:
132
+ per_param_min_metric_names.append(f"{prefix}.min")
133
+ per_param_max_metric_names.append(f"{prefix}.max")
134
+ per_param_avg_metric_names.append(f"{prefix}.avg")
135
+ per_param_norm_metric_names.append(f"{prefix}.norm")
136
+
137
+ assert (
138
+ len(per_param_min_metrics)
139
+ == len(per_param_min_metric_names)
140
+ == len(per_param_max_metrics)
141
+ == len(per_param_max_metric_names)
142
+ == len(per_param_sum_metrics)
143
+ == len(per_param_numel_metrics)
144
+ == len(per_param_avg_metric_names)
145
+ )
146
+ assert len(per_param_norm_metrics) == len(per_param_norm_metric_names)
147
+
148
+ def is_grad_norm_metric(metric_name: str) -> bool:
149
+ return metric_name.startswith("grad/") and metric_name.endswith(".norm")
150
+
151
+ #######################################################################
152
+ # part 2: reduce metrics over ranks
153
+ #######################################################################
154
+ param_group_sharded = False
155
+ for group in self.param_groups:
156
+ param_group_sharded = param_group_sharded or group.get("sharded", False)
157
+
158
+ total_grad_norm: torch.Tensor
159
+ per_param_avg_metrics: List[torch.Tensor] = []
160
+ if is_distributed() and param_group_sharded:
161
+ # Reduce metrics across all ranks. Note that we can use a `reduce` for most cases
162
+ # instead of an `all_reduce`, but we need `all_reduce` for norms so that all ranks
163
+ # get the right value for gradient norms so they can clip correctly.
164
+ # Reduce mins.
165
+ if per_param_min_metrics:
166
+ all_mins = torch.cat(per_param_min_metrics).to(device)
167
+ dist.reduce(all_mins, dst_rank, op=dist.ReduceOp.MIN, group=process_group)
168
+ per_param_min_metrics = all_mins.split(1)
169
+ # Reduce maxs.
170
+ if per_param_max_metrics:
171
+ all_maxs = torch.cat(per_param_max_metrics).to(device)
172
+ dist.reduce(all_maxs, dst_rank, op=dist.ReduceOp.MAX, group=process_group)
173
+ per_param_max_metrics = all_maxs.split(1)
174
+ # Reduce sums or just norms.
175
+ all_norms = torch.cat(per_param_norm_metrics).to(device) ** 2.0
176
+ if per_param_sum_metrics and per_param_numel_metrics:
177
+ all_sums = torch.cat(per_param_sum_metrics).to(device)
178
+ all_numels = torch.cat(per_param_numel_metrics).to(device)
179
+ all_sums_norms_numels = torch.cat(
180
+ [all_sums.unsqueeze(0), all_norms.unsqueeze(0), all_numels.unsqueeze(0)], dim=0
181
+ )
182
+ dist.all_reduce(all_sums_norms_numels, op=dist.ReduceOp.SUM, group=process_group)
183
+ all_sums, all_norms, all_numels = all_sums_norms_numels.split(1)
184
+ # Get averages.
185
+ # NOTE: could get infs for non-rank0 processes but that's okay.
186
+ per_param_avg_metrics = (all_sums / all_numels).squeeze(0).split(1)
187
+ else:
188
+ dist.all_reduce(all_norms, op=dist.ReduceOp.SUM, group=process_group)
189
+ grad_norm_metric_mask = torch.tensor(
190
+ [float(is_grad_norm_metric(n)) for n in per_param_norm_metric_names], device=all_norms.device
191
+ )
192
+ total_grad_norm = (all_norms * grad_norm_metric_mask).sum() ** 0.5
193
+ per_param_norm_metrics = (all_norms ** (0.5)).squeeze(0).split(1)
194
+ else:
195
+ total_grad_norm = (
196
+ torch.cat(
197
+ [
198
+ m
199
+ for m, n in zip(per_param_norm_metrics, per_param_norm_metric_names)
200
+ if is_grad_norm_metric(n)
201
+ ]
202
+ )
203
+ ** 2.0
204
+ ).sum() ** 0.5
205
+ per_param_avg_metrics = [x / n for x, n in zip(per_param_sum_metrics, per_param_numel_metrics)]
206
+
207
+ assert len(per_param_avg_metrics) == len(per_param_avg_metric_names)
208
+
209
+ # Collect all metrics into a single dict.
210
+ all_metrics: Dict[str, torch.Tensor] = {}
211
+ if collect_param_metrics:
212
+ for metric_name, metric in zip(per_param_min_metric_names, per_param_min_metrics):
213
+ all_metrics[metric_name] = metric.squeeze(0)
214
+ for metric_name, metric in zip(per_param_max_metric_names, per_param_max_metrics):
215
+ all_metrics[metric_name] = metric.squeeze(0)
216
+ for metric_name, metric in zip(per_param_avg_metric_names, per_param_avg_metrics):
217
+ all_metrics[metric_name] = metric.squeeze(0)
218
+
219
+ for metric_name, metric in zip(per_param_norm_metric_names, per_param_norm_metrics):
220
+ all_metrics[metric_name] = metric.squeeze(0)
221
+ all_metrics["total_grad_norm"] = total_grad_norm
222
+
223
+ #######################################################################
224
+ # part 3: clip grads
225
+ #######################################################################
226
+ num_grads_clipped = 0
227
+ num_eligible_grads = 0
228
+ for group in self.param_groups:
229
+ if (max_norm_ratio := group.get("max_grad_norm_ratio")) is not None:
230
+ num_clipped = self._do_adaptive_clipping(
231
+ group, max_norm_ratio, global_step, all_metrics, collect_param_metrics=collect_param_metrics
232
+ )
233
+ elif (max_norm := group.get("max_grad_norm")) is not None:
234
+ num_clipped = self._do_global_fixed_clipping(
235
+ group, max_norm, all_metrics, collect_param_metrics=collect_param_metrics
236
+ )
237
+ else:
238
+ # No clipping needed.
239
+ continue
240
+ num_eligible_grads += len(group["params"])
241
+ if num_clipped is not None:
242
+ num_grads_clipped += num_clipped
243
+
244
+ if collect_param_metrics:
245
+ if num_eligible_grads > 0:
246
+ clipping_rate = torch.tensor(num_grads_clipped / num_eligible_grads, device="cpu")
247
+ else:
248
+ clipping_rate = torch.tensor(0.0, device="cpu")
249
+ all_metrics["clipping_rate"] = clipping_rate
250
+
251
+ # total_grad_norm is computed at all steps, even when collect_param_metrics is set to False
252
+ return all_metrics
253
+
254
+ @torch.no_grad()
255
+ def _do_adaptive_clipping(
256
+ self,
257
+ group: Dict[str, Any],
258
+ max_norm_ratio: float,
259
+ global_step: int,
260
+ all_metrics: Dict[str, torch.Tensor],
261
+ collect_param_metrics: bool = True,
262
+ device: Optional[torch.device] = None,
263
+ ) -> Optional[int]:
264
+ """
265
+ Do adaptive gradient clipping on a param group.
266
+
267
+ If ``collect_param_metrics`` is ``True`` this will return the total number of gradients clipped.
268
+ """
269
+ device = get_default_device() if device is None else device
270
+ num_grads_clipped = 0
271
+ # We'll use the bigger of beta1 and beta2 to update the exponential average of the norm of
272
+ # the gradient (a scalar), not to be confused with the exponential average of the gradient.
273
+ # TODO (epwalsh): handle optimizers that don't have betas.
274
+ beta1, beta2 = group["betas"]
275
+ beta = max(beta1, beta2)
276
+ for name, p in zip(group["param_names"], group["params"]):
277
+ name = self._clean_param_name(name)
278
+ grad_norm = all_metrics.get(f"grad/{name}.norm")
279
+ if grad_norm is None:
280
+ continue
281
+
282
+ # Get or initialize the exponential average of grad norm.
283
+ # TODO: The way we have it right now, every rank tracks the `grad_norm_exp_avg` of every parameter,
284
+ # even parameters for which the corresponding local shard is empty. This has the potential to
285
+ # cause some issues with the optimizer, as we ran into with https://github.com/allenai/LLM/pull/372.
286
+ # So we should consider changing how we do this at some point so that we don't add any state
287
+ # to parameters for which the local shard is empty. That would probably add extra distributed
288
+ # communication, at least on steps where we have to log (i.e. when `collect_param_metrics=True`).
289
+ state = self.state[p]
290
+ grad_norm_exp_avg = state.get("grad_norm_exp_avg")
291
+ if grad_norm_exp_avg is None:
292
+ grad_norm_exp_avg = grad_norm.clone().to(device)
293
+ # We don't want to add anything to `state` until `state` has been initialized, otherwise
294
+ # this will crash some optimizers which rely on checking `len(state)`. The downside here
295
+ # is that we won't start tracking `grad_norm_exp_avg` until the 2nd training step.
296
+ if global_step > 1:
297
+ state["grad_norm_exp_avg"] = grad_norm_exp_avg
298
+
299
+ max_allowed_norm = max_norm_ratio * grad_norm_exp_avg
300
+ clip_coef = max_allowed_norm / (grad_norm + 1e-6)
301
+
302
+ # Clip the gradients and update the exponential average.
303
+ # Note that multiplying by the clamped coefficient is meaningless when it is
304
+ # equal to 1, but it avoids the host-device sync that would result from `if clip_coef_clamped < 1`.
305
+ clip_coef_clamped = torch.clamp(clip_coef, max=1.0)
306
+ if p.grad is not None:
307
+ # p.grad could be none for some ranks when using FSDP.
308
+ p.grad.detach().mul_(clip_coef_clamped.to(p.grad.device, p.grad.dtype))
309
+
310
+ # Update the exponential average of the norm of the gradient with the clipped norm of the gradient.
311
+ grad_norm_exp_avg.lerp_((grad_norm * clip_coef_clamped).to(grad_norm_exp_avg.device), 1 - beta)
312
+ # Alternative: update with the *unclipped* norm of the gradient.
313
+ # grad_norm_exp_avg.lerp_(grad_norm.to(grad_norm_exp_avg.device), 1 - beta)
314
+
315
+ if collect_param_metrics:
316
+ # Can't avoid host-device sync here.
317
+ if clip_coef_clamped < 1.0:
318
+ num_grads_clipped += 1
319
+ all_metrics[f"grad_norm_exp_avg/{name}"] = grad_norm_exp_avg
320
+ return num_grads_clipped if collect_param_metrics else None
321
+
322
+ @torch.no_grad()
323
+ def _do_global_fixed_clipping(
324
+ self,
325
+ group: Dict[str, Any],
326
+ max_norm: float,
327
+ all_metrics: Dict[str, torch.Tensor],
328
+ collect_param_metrics: bool = True,
329
+ device: Optional[torch.device] = None,
330
+ ) -> Optional[int]:
331
+ """
332
+ Do global fixed gradient clipping on a param group.
333
+
334
+ If ``collect_param_metrics`` is ``True`` this will return the total number of gradients clipped.
335
+ """
336
+ device = get_default_device() if device is None else device
337
+ total_grad_norm = all_metrics["total_grad_norm"]
338
+ clip_coef = max_norm / (total_grad_norm.to(device) + 1e-6)
339
+ clip_coef_clamped = torch.clamp(clip_coef, max=1.0)
340
+ num_grads_clipped: Optional[int] = None
341
+ if collect_param_metrics:
342
+ # Can't avoid host-device sync here.
343
+ if clip_coef_clamped < 1.0:
344
+ num_grads_clipped = len(group["params"])
345
+ for p in group["params"]:
346
+ # Clip the gradients.
347
+ # Note that multiplying by the clamped coefficient is meaningless when it is
348
+ # equal to 1, but it avoids the host-device sync that would result from `if clip_coef_clamped < 1`.
349
+ if p.grad is not None:
350
+ # p.grad could be none for some ranks when using FSDP.
351
+ p.grad.detach().mul_(clip_coef_clamped.to(p.grad.device, p.grad.dtype))
352
+ return num_grads_clipped
353
+
354
+ def get_post_step_metrics(
355
+ self, module: nn.Module, process_group: Optional[dist.ProcessGroup] = None
356
+ ) -> Dict[str, torch.Tensor]:
357
+ del module, process_group
358
+ return {}
359
+
360
+ def get_state_for_param(self, param: nn.Parameter) -> Dict[str, Optional[torch.Tensor]]:
361
+ del param
362
+ return {}
363
+
364
+
365
+ class LionW(Optimizer):
366
+ """
367
+ Adapted from https://github.com/google/automl/blob/master/lion/lion_pytorch.py
368
+ """
369
+
370
+ def __init__(
371
+ self,
372
+ params,
373
+ lr: float = 1e-4,
374
+ betas: Tuple[float, float] = (0.9, 0.99),
375
+ weight_decay: float = 0.0,
376
+ record_update_metrics: bool = False,
377
+ selective_updates: bool = False,
378
+ device: Optional[torch.device] = None,
379
+ ):
380
+ assert lr > 0.0
381
+ assert all([0.0 <= beta <= 1.0 for beta in betas])
382
+ defaults = dict(lr=lr, betas=betas, weight_decay=weight_decay)
383
+ super().__init__(
384
+ params, defaults, record_update_metrics=record_update_metrics, selective_updates=selective_updates
385
+ )
386
+ for group in self.param_groups:
387
+ group["initial_lr"] = group["lr"]
388
+ self._update_total_dot_prod: Optional[torch.Tensor] = None
389
+ self._update_total_norm: Optional[torch.Tensor] = None
390
+ self._signed_update_total_norm: Optional[torch.Tensor] = None
391
+ self._device: Optional[torch.device] = device
392
+
393
+ def get_post_step_metrics(
394
+ self, module: nn.Module, process_group: Optional[dist.ProcessGroup] = None
395
+ ) -> Dict[str, torch.Tensor]:
396
+ assert isinstance(
397
+ module, FSDP
398
+ ), "`get_post_step_metrics` expects module to be FSDP and will not work with other `distributed_strategy`."
399
+
400
+ update_total_dot_prod = self._update_total_dot_prod
401
+ update_total_norm = self._update_total_norm
402
+ signed_update_total_norm = self._signed_update_total_norm
403
+ if update_total_dot_prod is None or update_total_norm is None or signed_update_total_norm is None:
404
+ return {}
405
+
406
+ self._update_total_dot_prod = None
407
+ self._update_total_norm = None
408
+ self._signed_update_total_norm = None
409
+
410
+ if is_distributed() and isinstance(module, FullyShardedDataParallel):
411
+ # Reduce total dot prod and norms across all ranks.
412
+ update_total_norm = update_total_norm**2.0
413
+ signed_update_total_norm = signed_update_total_norm**2.0
414
+ # Reduce all together to avoid multiple communication calls.
415
+ all_together = torch.stack([update_total_dot_prod, update_total_norm, signed_update_total_norm])
416
+ # Only need the final result on rank0, since that's where we log from.
417
+ dist.reduce(
418
+ all_together,
419
+ 0 if process_group is None else dist.get_global_rank(process_group, 0),
420
+ group=process_group,
421
+ )
422
+ update_total_dot_prod, update_total_norm, signed_update_total_norm = all_together
423
+ update_total_norm = update_total_norm**0.5
424
+ signed_update_total_norm = signed_update_total_norm**0.5
425
+
426
+ update_cos_sim = update_total_dot_prod / torch.max(
427
+ update_total_norm * signed_update_total_norm,
428
+ torch.tensor(1e-8, device=get_default_device() if self._device is None else self._device),
429
+ )
430
+ return {"update_cos_sim": update_cos_sim}
431
+
432
+ @torch.no_grad()
433
+ def step(self, closure=None) -> None:
434
+ if closure is not None:
435
+ with torch.enable_grad():
436
+ closure()
437
+
438
+ update_total_dot_prod: Optional[torch.Tensor] = None
439
+ update_norms: Optional[List[torch.Tensor]] = None
440
+ signed_update_norms: Optional[List[torch.Tensor]] = None
441
+ if self._collecting_metrics and self._record_update_metrics:
442
+ update_total_dot_prod = torch.tensor(0.0, dtype=torch.float32)
443
+ update_norms = []
444
+ signed_update_norms = []
445
+
446
+ for group in self.param_groups:
447
+ for p in group["params"]:
448
+ grad = p.grad
449
+ if grad is None:
450
+ continue
451
+
452
+ state = self.state[p]
453
+
454
+ # Perform step weight decay
455
+ mask: Union[torch.Tensor, int] = grad != 0 if self._selective_updates else 1
456
+ p.data.mul_(1 - mask * (group["lr"] * group["weight_decay"]))
457
+
458
+ # State initialization
459
+ if len(state) == 0:
460
+ # Exponential moving average of gradient values
461
+ state["exp_avg"] = torch.zeros_like(p)
462
+
463
+ exp_avg = state["exp_avg"]
464
+ beta1, beta2 = group["betas"]
465
+
466
+ # Weight update
467
+ update = exp_avg * beta1 + grad * (1 - beta1)
468
+ if isinstance(mask, torch.Tensor):
469
+ # When mask isn't a tensor it's just a literal `1` (python int), so there's
470
+ # no point in calling this op.
471
+ update.mul_(mask)
472
+ signed_update = torch.sign(update)
473
+ p.add_(signed_update, alpha=-group["lr"])
474
+
475
+ # Decay the momentum running average coefficient
476
+ exp_avg.mul_(1 - mask * (1 - beta2)).add_(grad, alpha=1 - beta2)
477
+
478
+ # Track dot product and norms of update vs signed update in order to calculate
479
+ # their cosine similarity.
480
+ if (
481
+ update_total_dot_prod is not None
482
+ and update_norms is not None
483
+ and signed_update_norms is not None
484
+ ):
485
+ update_total_dot_prod = update_total_dot_prod.to(update.device)
486
+ update_total_dot_prod += torch.tensordot(update, signed_update, dims=len(update.shape))
487
+ update_norms.append(torch.linalg.vector_norm(update, 2.0, dtype=torch.float32))
488
+ signed_update_norms.append(torch.linalg.vector_norm(signed_update, 2.0, dtype=torch.float32))
489
+
490
+ # Compute cosine similarity between update and signed update.
491
+ if update_total_dot_prod is not None and update_norms is not None and signed_update_norms is not None:
492
+ device = get_default_device() if self._device is None else self._device
493
+ self._update_total_dot_prod = update_total_dot_prod.to(device)
494
+ self._update_total_norm = torch.linalg.vector_norm(
495
+ torch.stack(update_norms),
496
+ 2.0,
497
+ dtype=torch.float32,
498
+ ).to(device)
499
+ self._signed_update_total_norm = torch.linalg.vector_norm(
500
+ torch.stack(signed_update_norms),
501
+ 2.0,
502
+ dtype=torch.float32,
503
+ ).to(device)
504
+
505
+
506
+ class AdamW(torch.optim.AdamW, Optimizer):
507
+ def __init__(self, *args, record_update_metrics: bool = False, selective_updates: bool = False, **kwargs):
508
+ super().__init__(*args, **kwargs)
509
+
510
+ # Need to set these here just like in our base `Optimizer` class since our `Optimizer.__init__`
511
+ # won't be called.
512
+ self._record_update_metrics = record_update_metrics
513
+ self._collecting_metrics = False
514
+ self._selective_updates = selective_updates
515
+
516
+ self._step_size_param_names: Optional[List[str]] = None
517
+ self._step_size_norms: Optional[List[torch.Tensor]] = None
518
+ self._step_size_maxs: Optional[List[torch.Tensor]] = None
519
+
520
+ @torch.no_grad()
521
+ def step(self, closure=None) -> None:
522
+ if not (self._record_update_metrics and self._collecting_metrics) and not self._selective_updates:
523
+ return super().step(closure=closure)
524
+
525
+ device = get_default_device()
526
+ param_names = []
527
+ step_size_norms = []
528
+ step_size_maxs = []
529
+ for group in self.param_groups:
530
+ beta1, beta2 = group["betas"]
531
+ lr = group["lr"]
532
+ weight_decay = group["weight_decay"]
533
+ eps = group["eps"]
534
+ amsgrad = group["amsgrad"]
535
+ for name, param in zip(group["param_names"], group["params"]):
536
+ name = self._clean_param_name(name)
537
+ param_names.append(name)
538
+ grad = param.grad
539
+ if grad is None:
540
+ step_size_norms.append(torch.tensor([0.0], device=device))
541
+ step_size_maxs.append(torch.tensor([0.0], device=device))
542
+ continue
543
+
544
+ state = self.state[param]
545
+ # init state if needed
546
+ if len(state) == 0:
547
+ state["step"] = (
548
+ torch.zeros((), dtype=torch.float32, device=param.device)
549
+ if group["capturable"] or group["fused"]
550
+ else torch.tensor(0.0, dtype=torch.float32)
551
+ )
552
+ # Exponential moving average of gradient values
553
+ state["exp_avg"] = torch.zeros_like(param, memory_format=torch.preserve_format)
554
+ # Exponential moving average of squared gradient values
555
+ state["exp_avg_sq"] = torch.zeros_like(param, memory_format=torch.preserve_format)
556
+ if amsgrad:
557
+ # Maintains max of all exp. moving avg. of sq. grad. values
558
+ state["max_exp_avg_sq"] = torch.zeros_like(param, memory_format=torch.preserve_format)
559
+
560
+ exp_avg = state["exp_avg"]
561
+ exp_avg_sq = state["exp_avg_sq"]
562
+ step_t = state["step"]
563
+
564
+ # Update step.
565
+ step_t += 1
566
+
567
+ # Perform step weight decay.
568
+ mask: Union[torch.Tensor, int] = grad != 0 if self._selective_updates else 1
569
+ param.mul_(1 - mask * (lr * weight_decay))
570
+
571
+ # Decay the first and second moment running average coefficient.
572
+ exp_avg.lerp_(grad, mask * (1 - beta1))
573
+ exp_avg_sq.mul_(1 - mask * (1 - beta2)).addcmul_(grad, grad, value=1 - beta2)
574
+
575
+ step = step_t.item()
576
+
577
+ bias_correction1 = 1 - beta1**step
578
+ bias_correction2 = 1 - beta2**step
579
+
580
+ step_size = lr / bias_correction1
581
+
582
+ bias_correction2_sqrt = sqrt(bias_correction2)
583
+
584
+ if amsgrad:
585
+ max_exp_avg_sq = state["max_exp_avg_sq"]
586
+ # Maintains the maximum of all 2nd moment running avg. till now
587
+ torch.maximum(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
588
+
589
+ # Use the max. for normalizing running avg. of gradient
590
+ denom = (max_exp_avg_sq.sqrt() / bias_correction2_sqrt).add_(eps)
591
+ else:
592
+ denom = (exp_avg_sq.sqrt() / bias_correction2_sqrt).add_(eps)
593
+
594
+ update = -step_size * torch.div(exp_avg, denom)
595
+ if isinstance(mask, torch.Tensor):
596
+ # When mask isn't a tensor it's just a literal `1` (python int), so there's
597
+ # no point in calling this op.
598
+ update.mul_(mask)
599
+ param.add_(update)
600
+ step_size_norms.append(torch.linalg.vector_norm(update, 2.0, dtype=torch.float32).unsqueeze(0))
601
+ step_size_maxs.append(update.abs().max().unsqueeze(0))
602
+
603
+ self._step_size_param_names = param_names
604
+ self._step_size_norms = step_size_norms
605
+ self._step_size_maxs = step_size_maxs
606
+
607
+ def get_state_for_param(self, param: nn.Parameter) -> Dict[str, Optional[torch.Tensor]]:
608
+ return {key: self.state[param].get(key) for key in ("exp_avg", "exp_avg_sq")} # type: ignore
609
+
610
+ def get_post_step_metrics(
611
+ self, module: nn.Module, process_group: Optional[dist.ProcessGroup] = None
612
+ ) -> Dict[str, torch.Tensor]:
613
+ if not (self._record_update_metrics and self._collecting_metrics):
614
+ return {}
615
+ else:
616
+ device = get_default_device()
617
+ dst_rank = 0
618
+ if process_group is not None:
619
+ dst_rank = dist.get_global_rank(process_group, 0)
620
+ param_names = self._step_size_param_names
621
+ step_size_norms = self._step_size_norms
622
+ step_size_maxs = self._step_size_maxs
623
+ assert param_names is not None
624
+ assert step_size_norms is not None
625
+ assert step_size_maxs is not None
626
+
627
+ # Reduce metrics if needed.
628
+ if is_distributed() and isinstance(module, FullyShardedDataParallel):
629
+ # Reduce norms.
630
+ all_norms = torch.cat(step_size_norms).to(device) ** 2.0
631
+ dist.reduce(all_norms, dst_rank, op=dist.ReduceOp.SUM, group=process_group)
632
+ step_size_norms = (all_norms ** (0.5)).squeeze(0).split(1)
633
+
634
+ # Reduce maxs.
635
+ all_maxs = torch.cat(step_size_maxs).to(device)
636
+ dist.reduce(all_maxs, dst_rank, op=dist.ReduceOp.MAX, group=process_group)
637
+ step_size_maxs = all_maxs.split(1)
638
+
639
+ metrics = {}
640
+ for param_name, step_size_norm, step_size_max in zip(param_names, step_size_norms, step_size_maxs): # type: ignore[arg-type]
641
+ metrics[f"step/{param_name}.norm"] = step_size_norm.squeeze(0)
642
+ metrics[f"step/{param_name}.max"] = step_size_max.squeeze(0)
643
+
644
+ self._step_size_param_names = None
645
+ self._step_size_norms = None
646
+ self._step_size_maxs = None
647
+ return metrics
648
+
649
+
650
+ @dataclass
651
+ class Scheduler(metaclass=ABCMeta):
652
+ # NOTE: these fields are not given default values because otherwise dataclasses complains
653
+ # about how the scheduler subclasses are defined.
654
+ grad_clip_warmup_steps: Optional[int]
655
+ grad_clip_warmup_factor: Optional[float]
656
+ warmup_min_lr: Optional[float]
657
+
658
+ @abstractmethod
659
+ def get_lr(self, initial_lr: float, step: int, max_steps: int) -> float:
660
+ raise NotImplementedError
661
+
662
+ def _get_max_grad_norm_coeff(
663
+ self, initial_value: Optional[float], step: int, max_steps: int
664
+ ) -> Optional[float]:
665
+ del max_steps # might need this in the future, but for now I just wanted to match the API of `get_lr()`.
666
+ if initial_value is None:
667
+ return None
668
+ elif (
669
+ self.grad_clip_warmup_steps is None
670
+ or self.grad_clip_warmup_factor is None
671
+ or step > self.grad_clip_warmup_steps
672
+ ):
673
+ return initial_value
674
+ else:
675
+ return self.grad_clip_warmup_factor * initial_value
676
+
677
+ def get_max_grad_norm(
678
+ self, initial_max_grad_norm: Optional[float], step: int, max_steps: int
679
+ ) -> Optional[float]:
680
+ return self._get_max_grad_norm_coeff(initial_max_grad_norm, step, max_steps)
681
+
682
+ def get_max_grad_norm_ratio(
683
+ self, initial_max_grad_norm_ratio: Optional[float], step: int, max_steps: int
684
+ ) -> Optional[float]:
685
+ return self._get_max_grad_norm_coeff(initial_max_grad_norm_ratio, step, max_steps)
686
+
687
+ def _linear_warmup(self, initial_lr: float, step: int, warmup_steps: int = 2000) -> float:
688
+ warmup_min_lr = self.warmup_min_lr if self.warmup_min_lr is not None else initial_lr * 0.10
689
+ assert 0 <= warmup_min_lr < initial_lr
690
+ return warmup_min_lr + (initial_lr - warmup_min_lr) * min(step, warmup_steps) / warmup_steps
691
+
692
+
693
+ @dataclass
694
+ class CosWithWarmup(Scheduler):
695
+ warmup_steps: int
696
+ alpha_f: float = 0.1
697
+ t_max: Optional[int] = None
698
+
699
+ def get_lr(self, initial_lr: float, step: int, max_steps: int) -> float:
700
+ max_steps = max_steps if self.t_max is None else self.t_max
701
+ eta_min = initial_lr * self.alpha_f
702
+ if step < self.warmup_steps:
703
+ return self._linear_warmup(initial_lr, step, self.warmup_steps)
704
+ elif step >= max_steps:
705
+ return eta_min
706
+ else:
707
+ step = step - self.warmup_steps
708
+ max_steps = max_steps - self.warmup_steps
709
+ return eta_min + (initial_lr - eta_min) * (1 + cos(pi * step / max_steps)) / 2
710
+
711
+
712
+ @dataclass
713
+ class LinearWithWarmup(Scheduler):
714
+ warmup_steps: int
715
+ alpha_f: float = 0.1
716
+ t_max: Optional[int] = None
717
+
718
+ def get_lr(self, initial_lr: float, step: int, max_steps: int) -> float:
719
+ max_steps = max_steps if self.t_max is None else self.t_max
720
+ eta_min = initial_lr * self.alpha_f
721
+ if step < self.warmup_steps:
722
+ return self._linear_warmup(initial_lr, step, self.warmup_steps)
723
+ elif step >= max_steps:
724
+ return eta_min
725
+ else:
726
+ step = step - self.warmup_steps
727
+ max_steps = max_steps - self.warmup_steps
728
+ return initial_lr - (initial_lr - eta_min) * (step / max_steps)
729
+
730
+
731
+ @dataclass
732
+ class InvSqrtWithWarmup(Scheduler):
733
+ warmup_steps: int
734
+
735
+ def get_lr(self, initial_lr: float, step: int, max_steps: int) -> float:
736
+ if step < self.warmup_steps:
737
+ return self._linear_warmup(initial_lr, step, self.warmup_steps)
738
+ del max_steps
739
+ return initial_lr * sqrt(self.warmup_steps / max(self.warmup_steps, step))
740
+
741
+
742
+ @dataclass
743
+ class MaxScheduler(Scheduler):
744
+ sched1: Scheduler
745
+ sched2: Scheduler
746
+
747
+ def get_lr(self, initial_lr: float, step: int, max_steps: int) -> float:
748
+ return max(
749
+ self.sched1.get_lr(initial_lr, step, max_steps), self.sched2.get_lr(initial_lr, step, max_steps)
750
+ )
751
+
752
+
753
+ @dataclass
754
+ class BoltOnWarmupScheduler(Scheduler):
755
+ inner: Scheduler
756
+ warmup_start: int
757
+ warmup_end: int
758
+
759
+ @classmethod
760
+ def wrap(cls, scheduler: Scheduler, warmup_start: int, warmup_end: int) -> "BoltOnWarmupScheduler":
761
+ return cls(
762
+ grad_clip_warmup_steps=None,
763
+ grad_clip_warmup_factor=None,
764
+ inner=scheduler,
765
+ warmup_start=warmup_start,
766
+ warmup_end=warmup_end,
767
+ warmup_min_lr=None,
768
+ )
769
+
770
+ def get_lr(self, initial_lr: float, step: int, max_steps: int) -> float:
771
+ if step < self.warmup_start:
772
+ return 0.0
773
+ if step < self.warmup_end:
774
+ lr_at_intercept = self.inner.get_lr(initial_lr, self.warmup_end, max_steps)
775
+ return lr_at_intercept * (step - self.warmup_start) / (self.warmup_end - self.warmup_start)
776
+ else:
777
+ return self.inner.get_lr(initial_lr, step, max_steps)
778
+
779
+ def _get_max_grad_norm_coeff(
780
+ self, initial_value: Optional[float], step: int, max_steps: int
781
+ ) -> Optional[float]:
782
+ return self.inner._get_max_grad_norm_coeff(initial_value, step, max_steps)
783
+
784
+
785
+ @dataclass
786
+ class ConstantScheduler(Scheduler):
787
+ def get_lr(self, initial_lr: float, step: int, max_steps: int) -> float:
788
+ del step, max_steps
789
+ return initial_lr
790
+
791
+
792
+ @dataclass
793
+ class CosLinearEnvelope(Scheduler):
794
+ "Pointwise product of cosine schedule and linear decay; useful during annealing."
795
+ warmup_steps: int
796
+ alpha_f: float = 0.1
797
+ t_max: Optional[int] = None
798
+
799
+ def get_lr(self, initial_lr: float, step: int, max_steps: int) -> float:
800
+ max_steps = max_steps if self.t_max is None else self.t_max
801
+ eta_min = initial_lr * self.alpha_f
802
+
803
+ if step < self.warmup_steps:
804
+ return self._linear_warmup(initial_lr, step, self.warmup_steps)
805
+ if step >= max_steps:
806
+ return eta_min
807
+ else:
808
+ step = step - self.warmup_steps
809
+ max_steps = max_steps - self.warmup_steps
810
+ linear_envelope = 1 - (step / max_steps)
811
+ cosine_schedule = (initial_lr - eta_min) * (1 + cos(pi * step / max_steps)) / 2
812
+ return eta_min + linear_envelope * cosine_schedule
813
+
814
+
815
+ @dataclass
816
+ class ConstantWithWarmupScheduler(Scheduler):
817
+ warmup_steps: int
818
+
819
+ def get_lr(self, initial_lr: float, step: int, max_steps: int) -> float:
820
+ if step < self.warmup_steps:
821
+ return self._linear_warmup(initial_lr, step, self.warmup_steps)
822
+ del max_steps
823
+ return initial_lr
824
+
825
+
826
+ PARAM_GROUP_FIELDS = ("sharded", "max_grad_norm", "max_grad_norm_ratio", "param_names")
827
+
828
+
829
+ def get_param_groups(cfg: TrainConfig, model: nn.Module) -> List[Dict[str, Any]]:
830
+ """
831
+ Separate parameters into weight decay and non weight decay groups.
832
+ """
833
+ param_groups: List[Dict[str, Any]]
834
+ param_group_defaults = {
835
+ "sharded": isinstance(model, FullyShardedDataParallel),
836
+ "max_grad_norm": cfg.max_grad_norm,
837
+ "max_grad_norm_ratio": cfg.max_grad_norm_ratio,
838
+ }
839
+
840
+ # Separate out parameters that we don't want to apply weight decay to, like norms and biases.
841
+ decay = set()
842
+ no_decay = set()
843
+ all_params = {}
844
+ for mn, m in model.named_modules():
845
+ for pn, p in m.named_parameters():
846
+ # NOTE: because named_modules and named_parameters are recursive
847
+ # we will see the same tensors p many many times, but doing it this way
848
+ # allows us to know which parent module any tensor p belongs to...
849
+ if not p.requires_grad:
850
+ continue
851
+
852
+ fpn = f"{mn}.{pn}" if mn else pn
853
+ all_params[fpn] = p
854
+
855
+ if pn.endswith("bias"):
856
+ if cfg.optimizer.decay_norm_and_bias:
857
+ decay.add(fpn)
858
+ else:
859
+ no_decay.add(fpn)
860
+ elif pn.endswith("weight") and isinstance(m, nn.Linear):
861
+ decay.add(fpn)
862
+ elif pn.endswith("weight") and isinstance(m, (LayerNormBase, nn.LayerNorm)):
863
+ if cfg.optimizer.decay_norm_and_bias:
864
+ decay.add(fpn)
865
+ else:
866
+ no_decay.add(fpn)
867
+ elif pn.endswith("weight") and isinstance(m, nn.Embedding):
868
+ if cfg.optimizer.decay_embeddings:
869
+ decay.add(fpn)
870
+ else:
871
+ no_decay.add(fpn)
872
+
873
+ # Validate that we've considered every parameter
874
+ inter_params = decay & no_decay
875
+ union_params = decay | no_decay
876
+ assert len(inter_params) == 0, f"parameters {inter_params} made it into both decay/no_decay sets!"
877
+ assert (
878
+ len(all_params.keys() - union_params) == 0
879
+ ), f"parameters {all_params.keys() - union_params} were not separated into either decay/no_decay set!"
880
+
881
+ # Create the pytorch optimizer groups.
882
+ decay_sorted = sorted(list(decay))
883
+ no_decay_sorted = sorted(list(no_decay))
884
+ param_groups = []
885
+ if len(decay_sorted) > 0:
886
+ param_groups.append(
887
+ {
888
+ "params": [all_params[pn] for pn in decay_sorted],
889
+ "param_names": decay_sorted,
890
+ **param_group_defaults,
891
+ }
892
+ )
893
+ if len(no_decay_sorted) > 0:
894
+ param_groups.append(
895
+ {
896
+ "params": [all_params[pn] for pn in no_decay_sorted],
897
+ "param_names": no_decay_sorted,
898
+ "weight_decay": 0.0,
899
+ **param_group_defaults,
900
+ }
901
+ )
902
+
903
+ # Validate fields.
904
+ for group in param_groups:
905
+ for key in PARAM_GROUP_FIELDS:
906
+ assert key in group
907
+
908
+ return param_groups
909
+
910
+
911
+ def fix_optim_state_dict(optimizer: Optimizer, state_dict: Dict[str, Any]) -> Dict[str, Any]:
912
+ """
913
+ Make sure old optim state dicts are compatible with new versions.
914
+ """
915
+ if len(state_dict["param_groups"]) == 1 and len(optimizer.param_groups) == 2:
916
+ assert optimizer.param_groups[1]["weight_decay"] == 0.0
917
+
918
+ # Decay
919
+ decay_param_group = {k: v for k, v in state_dict["param_groups"][0].items() if k != "params"}
920
+ decay_param_group["params"] = optimizer.state_dict()["param_groups"][0]["params"]
921
+
922
+ # No decay.
923
+ no_decay_param_group = {k: v for k, v in state_dict["param_groups"][0].items() if k != "params"}
924
+ no_decay_param_group["weight_decay"] = 0.0
925
+ no_decay_param_group["params"] = optimizer.state_dict()["param_groups"][1]["params"]
926
+
927
+ state_dict["param_groups"] = [decay_param_group, no_decay_param_group]
928
+
929
+ assert len(optimizer.param_groups) == len(state_dict["param_groups"])
930
+
931
+ # Make sure:
932
+ # - All required fields are included in the state dict,
933
+ # - And that the values of those fields doesn't change from what's currently set in the optimizer,
934
+ # since we might have changed those fields on purpose after a restart.
935
+ for group, sd_group in zip(optimizer.param_groups, state_dict["param_groups"]):
936
+ for key in PARAM_GROUP_FIELDS:
937
+ sd_group[key] = group[key]
938
+
939
+ return state_dict
940
+
941
+
942
+ def build_optimizer(cfg: TrainConfig, model: nn.Module) -> Optimizer:
943
+ param_groups = get_param_groups(cfg, model)
944
+ log.info(f"Constructing optimizer with {len(param_groups)} param groups")
945
+ if cfg.optimizer.name == OptimizerType.lionw:
946
+ return LionW(
947
+ param_groups,
948
+ lr=cfg.optimizer.learning_rate,
949
+ betas=cfg.optimizer.betas,
950
+ weight_decay=cfg.optimizer.weight_decay,
951
+ record_update_metrics=cfg.optimizer.record_update_metrics,
952
+ selective_updates=cfg.optimizer.selective_updates,
953
+ )
954
+ elif cfg.optimizer.name == OptimizerType.adamw:
955
+ return AdamW(
956
+ param_groups,
957
+ lr=cfg.optimizer.learning_rate,
958
+ betas=cfg.optimizer.betas,
959
+ weight_decay=cfg.optimizer.weight_decay,
960
+ record_update_metrics=cfg.optimizer.record_update_metrics,
961
+ selective_updates=cfg.optimizer.selective_updates,
962
+ eps=cfg.optimizer.eps,
963
+ )
964
+ else:
965
+ raise NotImplementedError
966
+
967
+
968
+ def build_scheduler(cfg: TrainConfig, sched_cfg: Optional[SchedulerConfig] = None) -> Scheduler:
969
+ sched_cfg = sched_cfg if sched_cfg is not None else cfg.scheduler
970
+ if sched_cfg.name == SchedulerType.cosine_with_warmup:
971
+ return CosWithWarmup(
972
+ grad_clip_warmup_steps=(
973
+ None if sched_cfg.grad_clip_warmup_steps is None else int(sched_cfg.grad_clip_warmup_steps)
974
+ ),
975
+ grad_clip_warmup_factor=sched_cfg.grad_clip_warmup_factor,
976
+ warmup_steps=int(sched_cfg.t_warmup),
977
+ alpha_f=sched_cfg.alpha_f,
978
+ t_max=None if sched_cfg.t_max is None else int(sched_cfg.t_max),
979
+ warmup_min_lr=sched_cfg.warmup_min_lr,
980
+ )
981
+ elif sched_cfg.name == SchedulerType.linear_with_warmup:
982
+ return LinearWithWarmup(
983
+ grad_clip_warmup_steps=(
984
+ None if sched_cfg.grad_clip_warmup_steps is None else int(sched_cfg.grad_clip_warmup_steps)
985
+ ),
986
+ grad_clip_warmup_factor=sched_cfg.grad_clip_warmup_factor,
987
+ warmup_steps=int(sched_cfg.t_warmup),
988
+ alpha_f=sched_cfg.alpha_f,
989
+ t_max=None if sched_cfg.t_max is None else int(sched_cfg.t_max),
990
+ warmup_min_lr=sched_cfg.warmup_min_lr,
991
+ )
992
+ elif sched_cfg.name == SchedulerType.inverse_sqrt_with_warmup:
993
+ return InvSqrtWithWarmup(
994
+ grad_clip_warmup_steps=(
995
+ None if sched_cfg.grad_clip_warmup_steps is None else int(sched_cfg.grad_clip_warmup_steps)
996
+ ),
997
+ grad_clip_warmup_factor=sched_cfg.grad_clip_warmup_factor,
998
+ warmup_steps=int(sched_cfg.t_warmup),
999
+ warmup_min_lr=sched_cfg.warmup_min_lr,
1000
+ )
1001
+ elif sched_cfg.name == SchedulerType.max_scheduler:
1002
+ return MaxScheduler(
1003
+ grad_clip_warmup_steps=(
1004
+ None if sched_cfg.grad_clip_warmup_steps is None else int(sched_cfg.grad_clip_warmup_steps)
1005
+ ),
1006
+ grad_clip_warmup_factor=sched_cfg.grad_clip_warmup_factor,
1007
+ sched1=build_scheduler(cfg, replace(sched_cfg, name=SchedulerType.cosine_with_warmup)),
1008
+ sched2=build_scheduler(cfg, replace(sched_cfg, name=SchedulerType.inverse_sqrt_with_warmup)),
1009
+ warmup_min_lr=sched_cfg.warmup_min_lr,
1010
+ )
1011
+ elif sched_cfg.name == SchedulerType.constant:
1012
+ return ConstantScheduler(
1013
+ grad_clip_warmup_steps=(
1014
+ None if sched_cfg.grad_clip_warmup_steps is None else int(sched_cfg.grad_clip_warmup_steps)
1015
+ ),
1016
+ grad_clip_warmup_factor=sched_cfg.grad_clip_warmup_factor,
1017
+ warmup_min_lr=sched_cfg.warmup_min_lr,
1018
+ )
1019
+ elif sched_cfg.name == SchedulerType.cosine_linear_envelope:
1020
+ return CosLinearEnvelope(
1021
+ grad_clip_warmup_steps=(
1022
+ None if sched_cfg.grad_clip_warmup_steps is None else int(sched_cfg.grad_clip_warmup_steps)
1023
+ ),
1024
+ grad_clip_warmup_factor=sched_cfg.grad_clip_warmup_factor,
1025
+ warmup_steps=int(sched_cfg.t_warmup),
1026
+ alpha_f=sched_cfg.alpha_f,
1027
+ t_max=None if sched_cfg.t_max is None else int(sched_cfg.t_max),
1028
+ warmup_min_lr=sched_cfg.warmup_min_lr,
1029
+ )
1030
+ elif sched_cfg.name == SchedulerType.constant_with_warmup:
1031
+ return ConstantWithWarmupScheduler(
1032
+ grad_clip_warmup_steps=(
1033
+ None if sched_cfg.grad_clip_warmup_steps is None else int(sched_cfg.grad_clip_warmup_steps)
1034
+ ),
1035
+ grad_clip_warmup_factor=sched_cfg.grad_clip_warmup_factor,
1036
+ warmup_min_lr=sched_cfg.warmup_min_lr,
1037
+ warmup_steps=int(sched_cfg.t_warmup),
1038
+ )
1039
+ else:
1040
+ raise NotImplementedError
safetensors_util.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import pickle
3
+ from dataclasses import dataclass
4
+ from typing import Dict, Optional, Tuple
5
+
6
+ import safetensors.torch
7
+ import torch
8
+
9
+ from olmo.aliases import PathOrStr
10
+
11
+ __all__ = [
12
+ "state_dict_to_safetensors_file",
13
+ "safetensors_file_to_state_dict",
14
+ ]
15
+
16
+
17
+ @dataclass(eq=True, frozen=True)
18
+ class STKey:
19
+ keys: Tuple
20
+ value_is_pickled: bool
21
+
22
+
23
+ def encode_key(key: STKey) -> str:
24
+ b = pickle.dumps((key.keys, key.value_is_pickled))
25
+ b = base64.urlsafe_b64encode(b)
26
+ return str(b, "ASCII")
27
+
28
+
29
+ def decode_key(key: str) -> STKey:
30
+ b = base64.urlsafe_b64decode(key)
31
+ keys, value_is_pickled = pickle.loads(b)
32
+ return STKey(keys, value_is_pickled)
33
+
34
+
35
+ def flatten_dict(d: Dict) -> Dict[STKey, torch.Tensor]:
36
+ result = {}
37
+ for key, value in d.items():
38
+ if isinstance(value, torch.Tensor):
39
+ result[STKey((key,), False)] = value
40
+ elif isinstance(value, dict):
41
+ value = flatten_dict(value)
42
+ for inner_key, inner_value in value.items():
43
+ result[STKey((key,) + inner_key.keys, inner_key.value_is_pickled)] = inner_value
44
+ else:
45
+ pickled = bytearray(pickle.dumps(value))
46
+ pickled_tensor = torch.frombuffer(pickled, dtype=torch.uint8)
47
+ result[STKey((key,), True)] = pickled_tensor
48
+ return result
49
+
50
+
51
+ def unflatten_dict(d: Dict[STKey, torch.Tensor]) -> Dict:
52
+ result: Dict = {}
53
+
54
+ for key, value in d.items():
55
+ if key.value_is_pickled:
56
+ value = pickle.loads(value.numpy().data)
57
+
58
+ target_dict = result
59
+ for k in key.keys[:-1]:
60
+ new_target_dict = target_dict.get(k)
61
+ if new_target_dict is None:
62
+ new_target_dict = {}
63
+ target_dict[k] = new_target_dict
64
+ target_dict = new_target_dict
65
+ target_dict[key.keys[-1]] = value
66
+
67
+ return result
68
+
69
+
70
+ def state_dict_to_safetensors_file(state_dict: Dict, filename: PathOrStr):
71
+ state_dict = flatten_dict(state_dict)
72
+ state_dict = {encode_key(k): v for k, v in state_dict.items()}
73
+ safetensors.torch.save_file(state_dict, filename)
74
+
75
+
76
+ def safetensors_file_to_state_dict(filename: PathOrStr, map_location: Optional[str] = None) -> Dict:
77
+ if map_location is None:
78
+ map_location = "cpu"
79
+ state_dict = safetensors.torch.load_file(filename, device=map_location)
80
+ state_dict = {decode_key(k): v for k, v in state_dict.items()}
81
+ return unflatten_dict(state_dict)
torch_util.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+ import os
3
+ from typing import Optional, TypeVar
4
+
5
+ import torch
6
+ import torch.distributed as dist
7
+
8
+ T = TypeVar("T")
9
+
10
+
11
+ def seed_all(seed: int):
12
+ """Seed all rng objects."""
13
+ import random
14
+
15
+ import numpy as np
16
+
17
+ if seed < 0 or seed > 2**32 - 1:
18
+ raise ValueError(f"Seed {seed} is invalid. It must be on [0; 2^32 - 1]")
19
+ random.seed(seed)
20
+ np.random.seed(seed)
21
+ torch.manual_seed(seed)
22
+ # torch.manual_seed may call manual_seed_all but calling it again here
23
+ # to make sure it gets called at least once
24
+ torch.cuda.manual_seed_all(seed)
25
+
26
+
27
+ def is_distributed() -> bool:
28
+ return dist.is_available() and dist.is_initialized()
29
+
30
+
31
+ def get_node_rank() -> int:
32
+ return int(os.environ.get("NODE_RANK") or (get_global_rank() - get_local_rank()) // get_local_world_size())
33
+
34
+
35
+ def get_world_size() -> int:
36
+ if is_distributed():
37
+ return dist.get_world_size()
38
+ else:
39
+ return 1
40
+
41
+
42
+ def get_local_world_size() -> int:
43
+ return int(os.environ.get("LOCAL_WORLD_SIZE") or 1)
44
+
45
+
46
+ def get_global_rank() -> int:
47
+ if is_distributed():
48
+ return int(os.environ.get("RANK") or dist.get_rank())
49
+ else:
50
+ return 0
51
+
52
+
53
+ def get_local_rank() -> int:
54
+ return int(os.environ.get("LOCAL_RANK") or 0)
55
+
56
+
57
+ def get_fs_local_rank() -> int:
58
+ """Get the local rank per filesystem, meaning that, regardless of the number of nodes,
59
+ if all ranks share the same filesystem then `get_fs_local_rank()` will be equivalent to `get_global_rank()`,
60
+ but if nodes do not share the same filesystem then `get_fs_local_rank()` will be equivalent to `get_local_rank()`.
61
+ """
62
+ if os.environ.get("OLMO_SHARED_FS"):
63
+ return int(os.environ.get("FS_LOCAL_RANK") or get_global_rank())
64
+ else:
65
+ return int(os.environ.get("FS_LOCAL_RANK") or get_local_rank())
66
+
67
+
68
+ def move_to_device(o: T, device: torch.device) -> T:
69
+ if isinstance(o, torch.Tensor):
70
+ return o.to(device) # type: ignore[return-value]
71
+ elif isinstance(o, dict):
72
+ return {k: move_to_device(v, device) for k, v in o.items()} # type: ignore[return-value]
73
+ elif isinstance(o, list):
74
+ return [move_to_device(x, device) for x in o] # type: ignore[return-value]
75
+ elif isinstance(o, tuple):
76
+ return tuple((move_to_device(x, device) for x in o)) # type: ignore[return-value]
77
+ else:
78
+ return o
79
+
80
+
81
+ def ensure_finite_(x: torch.Tensor, check_neg_inf: bool = True, check_pos_inf: bool = False):
82
+ """
83
+ Modify ``x`` in place to replace ``float("-inf")`` with the minimum value of the dtype when ``check_neg_inf``
84
+ is ``True`` and to replace ``float("inf")`` with the maximum value of the dtype when ``check_pos_inf`` is ``True``.
85
+ """
86
+ if check_neg_inf:
87
+ x.masked_fill_(x == float("-inf"), torch.finfo(x.dtype).min)
88
+ if check_pos_inf:
89
+ x.masked_fill_(x == float("inf"), torch.finfo(x.dtype).max)
90
+
91
+
92
+ def get_default_device() -> torch.device:
93
+ if torch.cuda.is_available() and torch.cuda.is_initialized():
94
+ return torch.device("cuda")
95
+ else:
96
+ return torch.device("cpu")
97
+
98
+
99
+ def barrier() -> None:
100
+ if is_distributed():
101
+ dist.barrier()
102
+
103
+
104
+ def peak_gpu_memory(reset: bool = False) -> Optional[float]:
105
+ """
106
+ Get the peak GPU memory usage in MB across all ranks.
107
+ Only rank 0 will get the final result.
108
+ """
109
+ if not torch.cuda.is_available():
110
+ return None
111
+
112
+ device = torch.device("cuda")
113
+ peak_mb = torch.cuda.max_memory_allocated(device) / 1000000
114
+ if is_distributed():
115
+ peak_mb_tensor = torch.tensor(peak_mb, device=device)
116
+ dist.reduce(peak_mb_tensor, 0, dist.ReduceOp.MAX)
117
+ peak_mb = peak_mb_tensor.item()
118
+
119
+ if reset:
120
+ # Reset peak stats.
121
+ torch.cuda.reset_max_memory_allocated(device)
122
+
123
+ return peak_mb
124
+
125
+
126
+ V = TypeVar("V", bool, int, float)
127
+
128
+
129
+ def synchronize_value(value: V, device: torch.device) -> V:
130
+ if dist.is_available() and dist.is_initialized():
131
+ value_tensor = torch.tensor(value, device=device)
132
+ dist.broadcast(value_tensor, 0)
133
+ return value_tensor.item() # type: ignore
134
+ else:
135
+ return value
136
+
137
+
138
+ def synchronize_flag(flag: bool, device: torch.device) -> bool:
139
+ return synchronize_value(flag, device)
140
+
141
+
142
+ def gc_cuda():
143
+ gc.collect()
144
+ if torch.cuda.is_available():
145
+ torch.cuda.empty_cache()
146
+
147
+
148
+ def get_cumulative_document_lengths(doc_lens: torch.Tensor) -> torch.Tensor:
149
+ """
150
+ Transform a batched tensor of document lengths into a 1D tensor of cumulative document
151
+ lengths for the whole batch.
152
+ """
153
+ return torch.cat(
154
+ [
155
+ torch.tensor([0], dtype=torch.int32, device=doc_lens.device),
156
+ torch.cumsum(doc_lens.masked_select(doc_lens != 0), 0, dtype=torch.int32),
157
+ ]
158
+ )
train.py ADDED
@@ -0,0 +1,1384 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import cProfile
4
+ import functools
5
+ import gc
6
+ import logging
7
+ import math
8
+ import os
9
+ import random
10
+ import shutil
11
+ import time
12
+ from collections import deque
13
+ from contextlib import nullcontext
14
+ from dataclasses import dataclass, field
15
+ from itertools import islice
16
+ from pathlib import Path
17
+ from pstats import SortKey
18
+ from typing import Any, Callable, Deque, Dict, List, Optional, TextIO, Tuple, Union
19
+
20
+ import numpy as np
21
+ import torch
22
+ import torch.distributed as dist
23
+ import torch.nn.functional as F
24
+ import torch.utils
25
+ import torch.utils.hooks
26
+ import wandb
27
+ from packaging import version
28
+ from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
29
+ from torch.nn.parallel import DistributedDataParallel as DDP
30
+ from torch.utils.data import DataLoader
31
+
32
+ from .aliases import PathOrStr
33
+ from .checkpoint import Checkpointer, FullCheckpointer, build_sharded_checkpointer
34
+ from .config import (
35
+ CheckpointType,
36
+ DDPGradSyncMode,
37
+ DistributedStrategy,
38
+ SchedulerUnits,
39
+ ShardedCheckpointerType,
40
+ SpeedMonitorConfig,
41
+ TrainConfig,
42
+ )
43
+ from .data import IterableDataset
44
+ from .eval import Evaluator
45
+ from .exceptions import OLMoConfigurationError
46
+ from .model import OLMo
47
+ from .optim import Optimizer, Scheduler
48
+ from .torch_util import (
49
+ barrier,
50
+ gc_cuda,
51
+ get_fs_local_rank,
52
+ get_global_rank,
53
+ get_world_size,
54
+ move_to_device,
55
+ peak_gpu_memory,
56
+ synchronize_flag,
57
+ synchronize_value,
58
+ )
59
+ from .util import upload
60
+
61
+ __all__ = ["SpeedMonitor", "LRMonitor", "Trainer"]
62
+
63
+ log = logging.getLogger(__name__)
64
+
65
+
66
+ @dataclass
67
+ class SpeedMonitor:
68
+ cfg: SpeedMonitorConfig
69
+ start_times: Deque[float] = field(default_factory=lambda: deque([]))
70
+ global_total_tokens: int = 0
71
+ total_training_Gflops: float = 0
72
+ device_interval_tokens: Deque[int] = field(default_factory=lambda: deque([]))
73
+
74
+ def batch_start(
75
+ self,
76
+ global_total_tokens: int,
77
+ device_batch_num_tokens: int,
78
+ num_fwd_flops: int,
79
+ num_bck_flops: int,
80
+ record: bool = True,
81
+ ) -> None:
82
+ self.global_total_tokens = global_total_tokens
83
+ # num_fwd_flops and num_bck_flops from the OLMo model computes flops per token
84
+ # converting to GFLOPs here prevents numerical issues while logging
85
+ self.total_training_Gflops = (num_fwd_flops + num_bck_flops) * global_total_tokens / 1e9
86
+
87
+ if record:
88
+ if len(self.start_times) >= self.cfg.window_size:
89
+ self.start_times.popleft()
90
+ self.device_interval_tokens.popleft()
91
+ self.start_times.append(time.monotonic())
92
+ self.device_interval_tokens.append(device_batch_num_tokens)
93
+
94
+ def reset(self) -> None:
95
+ self.start_times.clear()
96
+ self.device_interval_tokens.clear()
97
+
98
+ def check(self) -> Dict[str, float]:
99
+ metrics: Dict[str, float] = {"throughput/total_tokens": self.global_total_tokens}
100
+
101
+ # plot flops related metrics
102
+ metrics["throughput/total_training_Gflops"] = self.total_training_Gflops
103
+ metrics["throughput/total_training_log_Gflops"] = math.log(self.total_training_Gflops)
104
+
105
+ if self.start_times:
106
+ interval_seconds = time.monotonic() - self.start_times[0]
107
+ interval_batches = len(self.start_times)
108
+ interval_tokens = sum(self.device_interval_tokens)
109
+ metrics["throughput/device/tokens_per_second"] = interval_tokens / interval_seconds
110
+ metrics["throughput/device/batches_per_second"] = interval_batches / interval_seconds
111
+ return metrics
112
+
113
+
114
+ @dataclass
115
+ class LRMonitor:
116
+ optim: torch.optim.Optimizer
117
+
118
+ def check(self) -> Dict[str, float]:
119
+ lrs = [group["lr"] for group in self.optim.param_groups]
120
+ return {f"optim/learning_rate_group{idx}": lr for idx, lr in enumerate(lrs)}
121
+
122
+
123
+ def cross_entropy_loss(
124
+ logits,
125
+ labels,
126
+ ignore_index: int = -100,
127
+ reduction: str = "mean",
128
+ compute_z_loss: bool = False,
129
+ z_loss_multiplier: float = 1e-4,
130
+ ):
131
+ loss = F.cross_entropy(logits, labels, ignore_index=ignore_index, reduction=reduction)
132
+
133
+ if not compute_z_loss:
134
+ return loss, None
135
+
136
+ z_squared = logits.logsumexp(-1).pow(2)
137
+ if reduction == "mean":
138
+ z_squared = (z_squared * (labels != ignore_index)).mean()
139
+ elif reduction == "sum":
140
+ z_squared = (z_squared * (labels != ignore_index)).sum()
141
+
142
+ z_loss = z_loss_multiplier * z_squared
143
+
144
+ return loss, z_loss
145
+
146
+
147
+ fused_loss_fn: Optional[Callable]
148
+
149
+ try:
150
+ import flash_attn
151
+ from flash_attn.ops.triton.cross_entropy import (
152
+ cross_entropy_loss as flash_cross_entropy_loss, # type: ignore
153
+ )
154
+
155
+ def fused_loss_fn(
156
+ logits,
157
+ labels,
158
+ ignore_index: int = -100,
159
+ reduction: str = "mean",
160
+ compute_z_loss: bool = False,
161
+ z_loss_multiplier: float = 1e-4,
162
+ ):
163
+ # The `ignored_index` parameter of `cross_entropy_loss` was changed to `ignore_index` in v2.5.8 with commit https://github.com/Dao-AILab/flash-attention/commit/ec6d22143b5d375e253b2ebfc563b26a43f43684
164
+ ce_loss_use_ignore_index_param = version.parse(flash_attn.__version__) >= version.parse("2.5.8")
165
+
166
+ if ce_loss_use_ignore_index_param:
167
+ ignore_index_kwarg = {"ignore_index": ignore_index}
168
+ else:
169
+ ignore_index_kwarg = {"ignored_index": ignore_index}
170
+
171
+ loss, z_loss = flash_cross_entropy_loss(
172
+ logits,
173
+ labels,
174
+ label_smoothing=0.0,
175
+ logit_scale=1.0,
176
+ lse_square_scale=z_loss_multiplier,
177
+ inplace_backward=False,
178
+ process_group=None,
179
+ **ignore_index_kwarg,
180
+ )
181
+
182
+ mask = labels != ignore_index
183
+
184
+ if reduction == "mean":
185
+ loss = loss.sum() / mask.sum()
186
+ elif reduction == "sum":
187
+ loss = loss.sum()
188
+ else:
189
+ loss = loss
190
+
191
+ if not compute_z_loss:
192
+ return loss, None
193
+
194
+ if reduction == "mean":
195
+ z_loss = z_loss.sum() / mask.sum()
196
+ elif reduction == "sum":
197
+ z_loss = z_loss.sum()
198
+ else:
199
+ z_loss = z_loss
200
+
201
+ return loss, z_loss
202
+
203
+ except ImportError:
204
+ fused_loss_fn = None
205
+
206
+
207
+ @dataclass
208
+ class Trainer:
209
+ cfg: TrainConfig
210
+ model: OLMo
211
+ dist_model: Union[DDP, FSDP]
212
+ optim: Optimizer
213
+ scheduler: Scheduler
214
+ train_loader: DataLoader
215
+ device: torch.device
216
+ evaluators: List[Evaluator]
217
+ epoch: Optional[int] = None
218
+ global_step: int = 0
219
+ global_train_examples_seen_this_epoch: int = 0
220
+ """Tracks the global number of training examples seen in the current epoch for the purpose of restoring
221
+ the data loader position on restarts."""
222
+ global_train_tokens_seen: int = 0
223
+ """Tracks the global total number of tokens trained on."""
224
+ checkpoints: List[Path] = field(default_factory=list)
225
+ unsharded_checkpoints: List[Path] = field(default_factory=list)
226
+ ephemeral_checkpoints: List[Path] = field(default_factory=list)
227
+ min_train_loss: float = float("inf")
228
+ cur_train_loss: float = float("inf")
229
+ indices_file: Optional[TextIO] = None
230
+ _start_time: float = 0.0
231
+ _gc_init_state: bool = True
232
+ loss_fn: Callable[..., torch.Tensor] = field(default_factory=lambda: cross_entropy_loss) # type: ignore
233
+ last_sharded_checkpoint_step: Optional[int] = None
234
+ last_unsharded_checkpoint_step: Optional[int] = None
235
+
236
+ def __post_init__(self):
237
+ if self.cfg.fused_loss:
238
+ if fused_loss_fn is not None:
239
+ self.loss_fn = fused_loss_fn
240
+ else:
241
+ raise NameError("`fused_loss_fn` is not defined. Please ensure that `flash_attn` is installed.")
242
+
243
+ @property
244
+ def dataset(self) -> IterableDataset:
245
+ assert isinstance(self.train_loader.dataset, IterableDataset)
246
+ return self.train_loader.dataset
247
+
248
+ @property
249
+ def tokens_per_batch(self) -> int:
250
+ return self.cfg.global_train_batch_size * self.cfg.model.max_sequence_length
251
+
252
+ @property
253
+ def batches_per_epoch(self) -> int:
254
+ return self.dataset.total_size // self.cfg.global_train_batch_size
255
+
256
+ @property
257
+ def max_epochs(self) -> int:
258
+ return math.ceil(self.max_steps / self.batches_per_epoch)
259
+
260
+ @property
261
+ def max_steps(self) -> int:
262
+ if isinstance(self.cfg.max_duration, int):
263
+ return self.cfg.max_duration
264
+ elif isinstance(self.cfg.max_duration, str):
265
+ if self.cfg.max_duration.endswith("T"):
266
+ # convert to float *first* to handle scientific notation
267
+ max_tokens = int(float(self.cfg.max_duration[:-1].strip()))
268
+ tokens_remaining = max(max_tokens - self.global_train_tokens_seen, 0)
269
+ steps_remaining = math.ceil(tokens_remaining / self.tokens_per_batch)
270
+ return self.global_step + steps_remaining
271
+ elif self.cfg.max_duration.endswith("ep"):
272
+ max_epochs = int(self.cfg.max_duration[:-2].strip())
273
+ return max_epochs * self.batches_per_epoch
274
+ else:
275
+ # convert to float *first* to handle scientific notation
276
+ return int(float(self.cfg.max_duration))
277
+ else:
278
+ raise TypeError(f"expected int or str for 'max_duration', found {type(self.cfg.max_duration)}")
279
+
280
+ @property
281
+ def max_tokens(self) -> int:
282
+ if isinstance(self.cfg.max_duration, int):
283
+ return (
284
+ self.global_train_tokens_seen
285
+ + max(self.cfg.max_duration - self.global_step, 0) * self.tokens_per_batch
286
+ )
287
+ elif isinstance(self.cfg.max_duration, str):
288
+ if self.cfg.max_duration.endswith("T"):
289
+ # convert to float *first* to handle scientific notation
290
+ return int(float(self.cfg.max_duration[:-1].strip()))
291
+ elif self.cfg.max_duration.endswith("ep"):
292
+ max_epochs = int(self.cfg.max_duration[:-2].strip())
293
+ return max_epochs * self.batches_per_epoch * self.tokens_per_batch
294
+ else:
295
+ # convert to float *first* to handle scientific notation
296
+ return (
297
+ self.global_train_tokens_seen
298
+ + max(int(float(self.cfg.max_duration)) - self.global_step, 0) * self.tokens_per_batch
299
+ )
300
+ else:
301
+ raise TypeError(f"expected int or str for 'max_duration', found {type(self.cfg.max_duration)}")
302
+
303
+ @property
304
+ def scheduler_current(self) -> int:
305
+ if self.cfg.scheduler.units == SchedulerUnits.steps:
306
+ return self.global_step
307
+ elif self.cfg.scheduler.units == SchedulerUnits.tokens:
308
+ return self.global_train_tokens_seen
309
+ else:
310
+ raise NotImplementedError(self.cfg.scheduler.units)
311
+
312
+ @property
313
+ def scheduler_max(self) -> int:
314
+ if self.cfg.scheduler.units == SchedulerUnits.steps:
315
+ return self.max_steps
316
+ elif self.cfg.scheduler.units == SchedulerUnits.tokens:
317
+ return self.max_tokens
318
+ else:
319
+ raise NotImplementedError(self.cfg.scheduler.units)
320
+
321
+ def trainer_state_dict(self) -> Dict[str, Any]:
322
+ return {
323
+ "epoch": self.epoch or 0,
324
+ "global_step": self.global_step,
325
+ "global_train_examples_seen_this_epoch": self.global_train_examples_seen_this_epoch,
326
+ "global_train_tokens_seen": self.global_train_tokens_seen,
327
+ "world_size": get_world_size(),
328
+ "checkpoints": self.checkpoints,
329
+ "unsharded_checkpoints": self.unsharded_checkpoints,
330
+ "ephemeral_checkpoints": self.ephemeral_checkpoints,
331
+ "rng": {
332
+ "python": random.getstate(),
333
+ "numpy": np.random.get_state(),
334
+ "torch": torch.random.get_rng_state(),
335
+ "cuda": torch.cuda.get_rng_state(),
336
+ },
337
+ }
338
+
339
+ def load_trainer_state_dict(self, state_dict: Dict[str, Any]) -> None:
340
+ # Checkpoint paths.
341
+ self.checkpoints = [
342
+ path
343
+ for path in state_dict["checkpoints"]
344
+ if path.is_dir() and path.resolve().parent == Path(self.cfg.save_folder).resolve()
345
+ ]
346
+ self.unsharded_checkpoints = [
347
+ path
348
+ for path in state_dict["unsharded_checkpoints"]
349
+ if path.is_dir() and path.resolve().parent == Path(self.cfg.save_folder).resolve()
350
+ ]
351
+ self.ephemeral_checkpoints = [
352
+ path
353
+ for path in state_dict.get("ephemeral_checkpoints", [])
354
+ if path.is_dir() and path.resolve().parent == Path(self.cfg.save_folder).resolve()
355
+ ]
356
+
357
+ # Dataset / dataloader position.
358
+ checkpoint_epoch = state_dict.get("epoch") or 0
359
+ self.global_step = state_dict["global_step"]
360
+ self.global_train_examples_seen_this_epoch = state_dict.get(
361
+ "global_train_examples_seen_this_epoch",
362
+ state_dict.get( # for backwards compatibility
363
+ "global_train_examples_seen",
364
+ state_dict.get("global_data_step", self.global_step) * self.cfg.global_train_batch_size,
365
+ ),
366
+ )
367
+ self.global_train_tokens_seen = state_dict.get(
368
+ "global_train_tokens_seen",
369
+ state_dict.get("global_data_step", self.global_step) # for backwards compatibility
370
+ * self.cfg.global_train_batch_size
371
+ * self.cfg.model.max_sequence_length,
372
+ )
373
+
374
+ if not self.cfg.restore_dataloader:
375
+ self.epoch = 0
376
+ self.global_step = 0
377
+ self.global_train_tokens_seen = 0
378
+ self.global_train_examples_seen_this_epoch = 0
379
+ elif self.epoch is None:
380
+ self.epoch = checkpoint_epoch
381
+ elif checkpoint_epoch != self.epoch:
382
+ log.info(f"Starting new epoch (epoch = {self.epoch})")
383
+ self.global_train_examples_seen_this_epoch = 0
384
+
385
+ assert self.epoch is not None
386
+ # Reshuffle dataset if needed.
387
+ if self.dataset.epoch != self.epoch:
388
+ log.info(f"Reshuffling data loader for epoch {self.epoch}...")
389
+ self.dataset.reshuffle(self.epoch)
390
+
391
+ if self.cfg.fast_forward_batches:
392
+ log.info(f"Fast-forwarding data loader by {self.cfg.fast_forward_batches:,d} steps")
393
+ # Technically we don't "see" these batches that we fast-forward through, but we use
394
+ # this variable to update the position of the dataset so we need to include them here.
395
+ self.global_train_examples_seen_this_epoch += (
396
+ self.cfg.fast_forward_batches * self.cfg.global_train_batch_size
397
+ )
398
+ # NOTE: on the other hand we don't add anything to 'self.global_train_tokens_seen' here because
399
+ # that variable is meant to track the actual number of tokens trained on.
400
+
401
+ if self.global_train_examples_seen_this_epoch > 0:
402
+ assert isinstance(self.dataset, IterableDataset)
403
+ log.info(f"Data loader will start at instance index {self.global_train_examples_seen_this_epoch:,d}")
404
+ self.dataset.start_index = self.global_train_examples_seen_this_epoch
405
+
406
+ # Reset learning rate and weight decay to the values from the config, not the checkpoint.
407
+ log.info("Resetting learning rate...")
408
+ new_learning_rate = self.scheduler.get_lr(
409
+ self.cfg.optimizer.learning_rate, self.scheduler_current, self.scheduler_max
410
+ )
411
+ for group in self.optim.param_groups:
412
+ group["lr"] = new_learning_rate
413
+ group["initial_lr"] = self.cfg.optimizer.learning_rate
414
+ if "weight_decay" in group and group["weight_decay"] > 0.0:
415
+ group["weight_decay"] = self.cfg.optimizer.weight_decay
416
+
417
+ # RNG states.
418
+ if "rng" in state_dict and state_dict.get("world_size", get_world_size()) == get_world_size():
419
+ log.info("Restoring RNG states...")
420
+ rng_state = state_dict["rng"]
421
+ self.restore_rng_state(rng_state)
422
+ else:
423
+ log.warning(
424
+ "Trainer will not restore RNG states since the RNG states in the checkpoint are missing or invalid. "
425
+ "This typically happens when restoring from an unsharded checkpoint or a checkpoint that was saved "
426
+ "with a different world size. If that's the case you can safely ignore this warning."
427
+ )
428
+
429
+ def restore_rng_state(self, rng_state: Dict[str, Any]) -> None:
430
+ random.setstate(rng_state["python"])
431
+ np.random.set_state(rng_state["numpy"])
432
+ torch.set_rng_state(rng_state["torch"])
433
+ torch.cuda.set_rng_state(rng_state["cuda"])
434
+
435
+ def _save_checkpoint(
436
+ self, checkpointer: Checkpointer, checkpoint_type: CheckpointType
437
+ ) -> Tuple[PathOrStr, Optional[PathOrStr]]:
438
+ if checkpoint_type == CheckpointType.sharded:
439
+ suffix = ""
440
+ current_checkpoints = self.checkpoints
441
+ link_latest = get_fs_local_rank() == 0
442
+ num_checkpoints_to_keep = self.cfg.save_num_checkpoints_to_keep
443
+ elif checkpoint_type == CheckpointType.unsharded:
444
+ suffix = "-unsharded"
445
+ current_checkpoints = self.unsharded_checkpoints
446
+ link_latest = get_global_rank() == 0
447
+ num_checkpoints_to_keep = self.cfg.save_num_unsharded_checkpoints_to_keep
448
+ elif checkpoint_type == CheckpointType.sharded_ephemeral:
449
+ suffix = ""
450
+ current_checkpoints = self.ephemeral_checkpoints
451
+ link_latest = get_fs_local_rank() == 0
452
+ num_checkpoints_to_keep = 1
453
+ else:
454
+ raise NotImplementedError(checkpoint_type)
455
+
456
+ # Zero-gradients to avoid gathering them.
457
+ self.optim.zero_grad(set_to_none=True)
458
+
459
+ # Flush data indices file.
460
+ # TODO: upload the indices files?
461
+ if self.indices_file is not None:
462
+ self.indices_file.flush()
463
+
464
+ checkpoint_dir = Path(self.cfg.save_folder) / f"step{self.global_step}{suffix}"
465
+ remote_checkpoint_dir: Optional[str] = None
466
+ if self.cfg.remote_save_folder is not None:
467
+ remote_checkpoint_dir = f"{self.cfg.remote_save_folder.rstrip('/')}/{checkpoint_dir.name}"
468
+ current_checkpoints.append(checkpoint_dir)
469
+
470
+ # Save the checkpoint.
471
+ try:
472
+ checkpointer.save_checkpoint(
473
+ checkpoint_dir,
474
+ self.dist_model,
475
+ self.optim,
476
+ self.trainer_state_dict(),
477
+ upload_to=remote_checkpoint_dir,
478
+ )
479
+ except FileExistsError:
480
+ raise OLMoConfigurationError(
481
+ f"Checkpoint for step {self.global_step} already exists, use --save_overwrite to overwrite it"
482
+ )
483
+
484
+ if link_latest:
485
+ # Link to 'latest'.
486
+ latest_path = Path(self.cfg.save_folder) / f"latest{suffix}"
487
+ latest_path.unlink(missing_ok=True)
488
+ try:
489
+ latest_path.symlink_to(checkpoint_dir.name, target_is_directory=True)
490
+ except FileExistsError:
491
+ # Same as above, caught when another (file-system) local rank 0 has already made the 'latest' symlink.
492
+ # This can happen when nodes are saving to a common NFS drive but otherwise have distinct
493
+ # file-systems.
494
+ if latest_path.resolve().name != checkpoint_dir.name:
495
+ raise
496
+
497
+ # Remove old checkpoints.
498
+ # For DDP, checkpoint_type being passed to remove_checkpoint is always `unsharded`.
499
+ if num_checkpoints_to_keep > 0:
500
+ while len(current_checkpoints) > num_checkpoints_to_keep:
501
+ self.remove_checkpoint(0, checkpoint_type)
502
+
503
+ barrier()
504
+
505
+ if remote_checkpoint_dir is not None:
506
+ return remote_checkpoint_dir, checkpoint_dir
507
+ else:
508
+ return checkpoint_dir, None
509
+
510
+ def save_sharded_checkpoint(self) -> Tuple[PathOrStr, Optional[PathOrStr]]:
511
+ checkpointer = build_sharded_checkpointer(self.cfg)
512
+ result = self._save_checkpoint(checkpointer, CheckpointType.sharded)
513
+ self.last_sharded_checkpoint_step = self.global_step
514
+ return result
515
+
516
+ def save_ephemeral_checkpoint(self) -> Tuple[PathOrStr, Optional[PathOrStr]]:
517
+ checkpointer = build_sharded_checkpointer(self.cfg)
518
+ result = self._save_checkpoint(checkpointer, CheckpointType.sharded_ephemeral)
519
+ self.last_sharded_checkpoint_step = self.global_step
520
+ return result
521
+
522
+ def _remove_sharded_checkpoint(self, idx: int, checkpoints: List[Path]):
523
+ oldest_checkpoint = checkpoints.pop(idx)
524
+ barrier()
525
+ if get_fs_local_rank() == 0 and oldest_checkpoint.is_dir():
526
+ shutil.rmtree(oldest_checkpoint, ignore_errors=True)
527
+ latest_path = Path(self.cfg.save_folder) / "latest"
528
+ if latest_path.resolve() == oldest_checkpoint.resolve():
529
+ latest_path.unlink()
530
+ barrier()
531
+
532
+ def remove_sharded_checkpoint(self, idx: int = 0):
533
+ self._remove_sharded_checkpoint(idx, self.checkpoints)
534
+
535
+ def remove_ephemeral_checkpoint(self, idx: int = 0):
536
+ self._remove_sharded_checkpoint(idx, self.ephemeral_checkpoints)
537
+
538
+ def restore_sharded_checkpoint(
539
+ self,
540
+ load_path: PathOrStr,
541
+ local_cache: Optional[PathOrStr] = None,
542
+ *,
543
+ load_optimizer_state: bool = True,
544
+ load_trainer_state: bool = True,
545
+ sharded_checkpointer: Optional[ShardedCheckpointerType] = None,
546
+ ):
547
+ # Zero-gradients to avoid gathering them.
548
+ self.optim.zero_grad(set_to_none=True)
549
+ checkpointer = build_sharded_checkpointer(self.cfg, name=sharded_checkpointer)
550
+ trainer_state = checkpointer.restore_checkpoint(
551
+ load_path,
552
+ self.dist_model,
553
+ self.optim,
554
+ local_cache=local_cache,
555
+ load_optimizer_state=load_optimizer_state,
556
+ )
557
+ if load_trainer_state:
558
+ self.load_trainer_state_dict(trainer_state)
559
+ barrier()
560
+
561
+ def save_unsharded_checkpoint(self) -> Tuple[PathOrStr, Optional[PathOrStr]]:
562
+ checkpointer = FullCheckpointer(self.cfg)
563
+ result = self._save_checkpoint(checkpointer, CheckpointType.unsharded)
564
+ self.last_unsharded_checkpoint_step = self.global_step
565
+ return result
566
+
567
+ def remove_unsharded_checkpoint(self, idx: int = 0):
568
+ barrier()
569
+ oldest_checkpoint = self.unsharded_checkpoints.pop(idx)
570
+ if get_global_rank() == 0 and oldest_checkpoint.is_dir():
571
+ shutil.rmtree(oldest_checkpoint, ignore_errors=True)
572
+ latest_path = Path(self.cfg.save_folder) / "latest-unsharded"
573
+ if latest_path.resolve() == oldest_checkpoint.resolve():
574
+ latest_path.unlink()
575
+ barrier()
576
+
577
+ def restore_unsharded_checkpoint(
578
+ self,
579
+ load_path: PathOrStr,
580
+ local_cache: Optional[PathOrStr] = None,
581
+ *,
582
+ load_optimizer_state: bool = True,
583
+ load_trainer_state: bool = True,
584
+ ):
585
+ # Zero-gradients to avoid gathering them.
586
+ self.optim.zero_grad(set_to_none=True)
587
+ checkpointer = FullCheckpointer(self.cfg)
588
+ trainer_state = checkpointer.restore_checkpoint(
589
+ load_path,
590
+ self.dist_model,
591
+ self.optim,
592
+ local_cache=local_cache,
593
+ load_optimizer_state=load_optimizer_state,
594
+ )
595
+ if load_trainer_state:
596
+ self.load_trainer_state_dict(trainer_state)
597
+ barrier()
598
+
599
+ def save_checkpoint(
600
+ self, checkpoint_type: CheckpointType = CheckpointType.sharded
601
+ ) -> Tuple[PathOrStr, Optional[PathOrStr]]:
602
+ result: Tuple[PathOrStr, Optional[PathOrStr]]
603
+ if checkpoint_type == CheckpointType.sharded:
604
+ result = self.save_sharded_checkpoint()
605
+ elif checkpoint_type == CheckpointType.unsharded:
606
+ result = self.save_unsharded_checkpoint()
607
+ elif checkpoint_type == CheckpointType.sharded_ephemeral:
608
+ result = self.save_ephemeral_checkpoint()
609
+ else:
610
+ raise NotImplementedError(checkpoint_type)
611
+
612
+ gc_cuda()
613
+ return result
614
+
615
+ def restore_checkpoint(
616
+ self,
617
+ load_path: PathOrStr,
618
+ *,
619
+ checkpoint_type: Optional[CheckpointType] = None,
620
+ local_cache: Optional[PathOrStr] = None,
621
+ load_optimizer_state: bool = True,
622
+ load_trainer_state: bool = True,
623
+ sharded_checkpointer: Optional[ShardedCheckpointerType] = None,
624
+ ):
625
+ if checkpoint_type == CheckpointType.unsharded or (
626
+ checkpoint_type is None and str(load_path).rstrip("/").endswith("-unsharded")
627
+ ):
628
+ self.restore_unsharded_checkpoint(
629
+ load_path,
630
+ local_cache=local_cache,
631
+ load_optimizer_state=load_optimizer_state,
632
+ load_trainer_state=load_trainer_state,
633
+ )
634
+ elif checkpoint_type == CheckpointType.sharded or checkpoint_type is None:
635
+ self.restore_sharded_checkpoint(
636
+ load_path,
637
+ local_cache=local_cache,
638
+ load_optimizer_state=load_optimizer_state,
639
+ load_trainer_state=load_trainer_state,
640
+ sharded_checkpointer=sharded_checkpointer,
641
+ )
642
+ elif checkpoint_type is not None:
643
+ raise NotImplementedError(checkpoint_type)
644
+
645
+ gc_cuda()
646
+
647
+ def remove_checkpoint(self, idx: int = 0, checkpoint_type: CheckpointType = CheckpointType.sharded):
648
+ if checkpoint_type == CheckpointType.sharded:
649
+ self.remove_sharded_checkpoint(idx=idx)
650
+ elif checkpoint_type == CheckpointType.unsharded:
651
+ self.remove_unsharded_checkpoint(idx=idx)
652
+ elif checkpoint_type == CheckpointType.sharded_ephemeral:
653
+ self.remove_ephemeral_checkpoint(idx=idx)
654
+ else:
655
+ raise NotImplementedError(checkpoint_type)
656
+
657
+ def _setup_module_output_save_hooks(self, micro_batch_idx: int) -> List[torch.utils.hooks.RemovableHandle]:
658
+ if (
659
+ self.cfg.module_outputs_save_steps is None
660
+ or self.global_step not in self.cfg.module_outputs_save_steps
661
+ ):
662
+ return []
663
+
664
+ if micro_batch_idx != 0 or get_global_rank() != 0:
665
+ # Hook is currently only used on the first microbatch of rank 0
666
+ return []
667
+
668
+ trace_save_folder = Path(self.cfg.save_folder) / f"traces/step{self.global_step}"
669
+ if trace_save_folder.exists():
670
+ if self.cfg.save_overwrite:
671
+ shutil.rmtree(trace_save_folder)
672
+ else:
673
+ raise OLMoConfigurationError(
674
+ f"Attempting to overwrite traces at step {self.global_step} without --save_overwrite"
675
+ )
676
+ trace_save_folder.mkdir(parents=True)
677
+
678
+ def trace_outputs_hook(
679
+ module_name: str, _: torch.nn.Module, args: Tuple[torch.Tensor, ...], output: torch.Tensor
680
+ ) -> None:
681
+ if len(args) == 0:
682
+ log.info("No input args for module %s, output %s", module_name, output)
683
+
684
+ module_input = args[0] if len(args) > 0 else torch.tensor(())
685
+ trace_save_folder = Path(self.cfg.save_folder) / f"traces/step{self.global_step}"
686
+ trace_save_folder.mkdir(parents=True, exist_ok=True)
687
+
688
+ module_occurence_num = 0
689
+ while (
690
+ module_input_filepath := trace_save_folder / f"{module_name}_{module_occurence_num}_input.pt"
691
+ ).exists():
692
+ module_occurence_num += 1
693
+ torch.save(module_input, module_input_filepath)
694
+
695
+ module_output_filepath = trace_save_folder / f"{module_name}_{module_occurence_num}_output.pt"
696
+ torch.save(output, module_output_filepath)
697
+
698
+ output_hooks = []
699
+ for module_name, module in self.model.named_modules(prefix="model"):
700
+ output_hooks.append(module.register_forward_hook(functools.partial(trace_outputs_hook, module_name)))
701
+
702
+ return output_hooks
703
+
704
+ def get_labels(self, batch: Dict[str, Any]) -> torch.Tensor:
705
+ # Labels are just input IDs shifted to the left (first item is ignored).
706
+ labels, label_mask, attention_mask, instance_mask = (
707
+ batch["input_ids"].clone(),
708
+ batch.get("label_mask"),
709
+ batch.get("attention_mask"),
710
+ batch.get("instance_mask"),
711
+ )
712
+ if label_mask is not None:
713
+ labels.masked_fill_(~label_mask, -100)
714
+ if attention_mask is not None:
715
+ labels.masked_fill_(attention_mask == 0.0, -100)
716
+ if instance_mask is not None:
717
+ labels.masked_fill_(~instance_mask.unsqueeze(-1), value=-100)
718
+ return labels[..., 1:].contiguous()
719
+
720
+ def model_forward(
721
+ self, batch: Dict[str, Any], loss_reduction: str = "mean", compute_z_loss: bool = False
722
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor]:
723
+ # shape: (batch_size, seq_len, vocab_size)
724
+ logits = self.dist_model(
725
+ input_ids=batch["input_ids"],
726
+ attention_mask=batch.get("attention_mask"),
727
+ attention_bias=batch.get("attention_bias"),
728
+ doc_lens=batch.get("doc_lens"),
729
+ max_doc_lens=batch.get("max_doc_lens"),
730
+ ).logits
731
+ logits_for_loss = logits[..., :-1, :].contiguous()
732
+ # shape: (batch_size * seq_len, vocab_size)
733
+ logits_for_loss = logits_for_loss.view(-1, logits_for_loss.size(-1))
734
+ # shape: (batch_size, seq_len)
735
+ labels = self.get_labels(batch)
736
+ # shape: (batch_size * seq_len,)
737
+ labels = labels.view(-1)
738
+ ce_loss, z_loss = self.loss_fn(
739
+ logits_for_loss, labels, ignore_index=-100, reduction=loss_reduction, compute_z_loss=compute_z_loss
740
+ )
741
+ if loss_reduction == "none":
742
+ # Reshape (batch_size * seq_len,) -> (batch_size, seq_len)
743
+ ce_loss = ce_loss.view(batch["input_ids"].shape[0], -1)
744
+ if z_loss is not None:
745
+ z_loss = z_loss.view(batch["input_ids"].shape[0], -1)
746
+ return ce_loss, z_loss, logits
747
+
748
+ def train_micro_batch(
749
+ self, micro_batch: Dict[str, Any], batch_size_in_tokens: int
750
+ ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
751
+ ce_loss, z_loss, logits = self.model_forward(
752
+ micro_batch, compute_z_loss=self.cfg.softmax_auxiliary_loss, loss_reduction="sum"
753
+ )
754
+ ce_loss = ce_loss / batch_size_in_tokens
755
+
756
+ # In case this helps with memory utilization.
757
+ del micro_batch
758
+
759
+ # Get loss to optimize for.
760
+ if self.cfg.softmax_auxiliary_loss:
761
+ assert z_loss is not None
762
+ z_loss = z_loss / batch_size_in_tokens
763
+ loss = ce_loss + z_loss
764
+ else:
765
+ loss = ce_loss
766
+
767
+ del logits
768
+
769
+ return loss, ce_loss, z_loss
770
+
771
+ def train_batch(self, batch: Dict[str, Any]) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
772
+ # Split into micro-batches.
773
+ micro_batches = self.split_batch(batch)
774
+ batch_size_in_tokens = batch["input_ids"].numel()
775
+
776
+ # In case this helps with memory utilization.
777
+ del batch
778
+
779
+ ce_batch_loss = torch.tensor(0.0, device=self.device)
780
+ z_batch_loss = None if not self.cfg.softmax_auxiliary_loss else torch.tensor(0.0, device=self.device)
781
+ num_micro_batches = len(micro_batches)
782
+
783
+ for micro_batch_idx, micro_batch in enumerate(micro_batches):
784
+ # setup sync context for DDP for all micro-batches except the last
785
+ grad_sync_context = nullcontext
786
+ if (
787
+ self.cfg.distributed_strategy == DistributedStrategy.ddp
788
+ and self.cfg.ddp is not None
789
+ and self.cfg.ddp.grad_sync_mode == DDPGradSyncMode.batch
790
+ ):
791
+ if micro_batch_idx != num_micro_batches - 1:
792
+ grad_sync_context = self.dist_model.no_sync
793
+
794
+ # Register output hooks
795
+ output_hooks: List[torch.utils.hooks.RemovableHandle] = []
796
+ output_hooks += self._setup_module_output_save_hooks(micro_batch_idx)
797
+
798
+ with grad_sync_context():
799
+ with torch.autocast("cuda", enabled=True, dtype=self.cfg.autocast_precision):
800
+ # Run forward pass.
801
+ loss, ce_loss, z_loss = self.train_micro_batch(micro_batch, batch_size_in_tokens)
802
+
803
+ # Update overall CE batch loss.
804
+ ce_batch_loss += ce_loss.detach()
805
+
806
+ # Update overall Z batch loss.
807
+ if z_loss is not None:
808
+ assert z_batch_loss is not None
809
+ z_batch_loss += z_loss.detach()
810
+
811
+ # Run backward pass.
812
+ loss.backward()
813
+
814
+ # Remove output hooks
815
+ for hook in output_hooks:
816
+ hook.remove()
817
+
818
+ return ce_batch_loss, z_batch_loss
819
+
820
+ def train_step(self, batch: Dict[str, Any], reduce_global_loss: bool = True) -> Dict[str, float]:
821
+ metrics: Dict[str, float] = {}
822
+
823
+ # Write data-indices to file.
824
+ if self.indices_file is not None and "index" in batch:
825
+ indices = "\t".join(str(int(i)) for i in batch["index"])
826
+ self.indices_file.write(f"{self.global_step}\t{indices}\n")
827
+
828
+ # Record how many instances are going to be skipped (masked out).
829
+ if (instance_mask := batch.get("instance_mask")) is not None:
830
+ metrics["train/masked_instances_local_rank"] = (~instance_mask).sum().item()
831
+
832
+ # Zero-gradients.
833
+ self.optim.zero_grad(set_to_none=True)
834
+
835
+ # Move tensors to the right device.
836
+ batch = move_to_device(batch, self.device)
837
+
838
+ # Run forward-backward pass.
839
+ ce_batch_loss, z_batch_loss = self.train_batch(batch)
840
+
841
+ # Collect loss, potentially reducing over all ranks.
842
+ if reduce_global_loss:
843
+ dist.reduce(ce_batch_loss, 0)
844
+ ce_batch_loss.div_(get_world_size())
845
+ if z_batch_loss is not None:
846
+ dist.reduce(z_batch_loss, 0)
847
+ z_batch_loss.div_(get_world_size())
848
+
849
+ # Clip gradient norms and collect param/gradient/optim metrics.
850
+ should_log_optim_metrics_this_step = self.should_log_optim_metrics_this_step()
851
+ optim_metrics = self.optim.clip_grads_and_collect_metrics(
852
+ self.global_step,
853
+ collect_param_metrics=should_log_optim_metrics_this_step,
854
+ # passing this process group here ensures metrics are reduced correctly when we're using
855
+ # HYBRID sharding.
856
+ process_group=self.dist_model.process_group,
857
+ )
858
+
859
+ # Adjust the learning rate.
860
+ for group in self.optim.param_groups:
861
+ # TODO (epwalsh): if we want to enable different LRs or gradient clipping settings per group
862
+ # we should pass `group["initial_lr"]` or `group["initial_max_grad_norm"]` here instead of
863
+ # the corresponding values from `self.cfg`.
864
+ group["lr"] = self.scheduler.get_lr(
865
+ self.cfg.optimizer.learning_rate, self.scheduler_current, self.scheduler_max
866
+ )
867
+ group["max_grad_norm"] = self.scheduler.get_max_grad_norm(
868
+ self.cfg.max_grad_norm, self.scheduler_current, self.scheduler_max
869
+ )
870
+ group["max_grad_norm_ratio"] = self.scheduler.get_max_grad_norm(
871
+ self.cfg.max_grad_norm_ratio, self.scheduler_current, self.scheduler_max
872
+ )
873
+
874
+ # Optimizer step.
875
+ self.optim.step()
876
+
877
+ # Collect metrics and check for NaN loss.
878
+ # NOTE: this involves a bunch of host-device syncs so we wait until the last moment to do this.
879
+ if torch.isnan(ce_batch_loss):
880
+ raise ValueError("nan loss encountered")
881
+ if z_batch_loss is not None and torch.isnan(z_batch_loss):
882
+ raise ValueError("nan loss encountered")
883
+ for key, value in optim_metrics.items():
884
+ metrics[f"optim/{key}"] = value.item()
885
+ self.cur_train_loss = ce_batch_loss.item()
886
+ self.min_train_loss = min(self.min_train_loss, self.cur_train_loss)
887
+ metrics["train/CrossEntropyLoss"] = self.cur_train_loss
888
+ metrics["train/Perplexity"] = math.exp(self.cur_train_loss)
889
+ if z_batch_loss is not None:
890
+ metrics["train/ZLoss"] = z_batch_loss.item()
891
+
892
+ # Maybe collect post-step optimizer-specific metrics.
893
+ if should_log_optim_metrics_this_step:
894
+ optim_metrics = self.optim.get_post_step_metrics(
895
+ self.dist_model, process_group=self.dist_model.process_group
896
+ )
897
+ for key, value in optim_metrics.items():
898
+ metrics[f"optim/{key}"] = value.item()
899
+
900
+ return metrics
901
+
902
+ def eval_batch(self, batch: Dict[str, Any]) -> Tuple[torch.Tensor, torch.Tensor]:
903
+ with torch.autocast("cuda", enabled=True, dtype=self.cfg.autocast_precision):
904
+ ce_loss, _, logits = self.model_forward(batch, loss_reduction="none")
905
+ return ce_loss.mean(dim=-1), logits
906
+
907
+ def eval_step(self, batch: Dict[str, Any], evaluator: Evaluator) -> None:
908
+ # Move tensors to the right device.
909
+ batch = move_to_device(batch, self.device)
910
+
911
+ # Run forward pass.
912
+ with torch.no_grad(): # NOTE: 'torch.inference_mode()' doesn't work with 'torch.compile()'.
913
+ ce_loss, logits = self.eval_batch(batch)
914
+
915
+ # Update metrics.
916
+ evaluator.update_metrics(
917
+ batch, ce_loss, logits
918
+ ) # batch includes all keys that the downstream evaluation needs
919
+
920
+ barrier()
921
+
922
+ def split_batch(self, batch: Dict[str, Any]) -> List[Dict[str, Any]]:
923
+ microbatch_size = self.cfg.device_train_microbatch_size
924
+ batch_size = batch["input_ids"].shape[0]
925
+ if batch_size <= microbatch_size:
926
+ return [batch]
927
+ else:
928
+ micro_batches = {}
929
+ for key, value in batch.items():
930
+ if isinstance(value, torch.Tensor):
931
+ micro_batches[key] = value.split(microbatch_size, dim=0)
932
+ elif isinstance(value, list):
933
+ micro_batches[key] = [
934
+ value[microbatch_size * i : microbatch_size * i + microbatch_size]
935
+ for i in range(math.ceil(batch_size / microbatch_size))
936
+ ]
937
+ else:
938
+ raise ValueError(f"unexpected item in batch: '{key}={value}'")
939
+ return [
940
+ {key: value[i] for key, value in micro_batches.items()} # type: ignore
941
+ for i in range(len(micro_batches["input_ids"]))
942
+ ]
943
+
944
+ def system_metrics(self) -> Dict[str, float]:
945
+ metrics = {}
946
+ if self.global_step < 3 or self.global_step % 10 == 0:
947
+ peak_gpu_mb = peak_gpu_memory()
948
+ if peak_gpu_mb is not None:
949
+ metrics["System/Peak GPU Memory (MB)"] = peak_gpu_mb
950
+ return metrics
951
+
952
+ def log_metrics_to_console(self, prefix: str, metrics: Dict[str, float]):
953
+ def format_float(value: float) -> str:
954
+ if value < 0.0001:
955
+ return str(value) # scientific notation
956
+ elif value > 1000:
957
+ return f"{int(value):,d}"
958
+ elif value > 100:
959
+ return f"{value:.1f}"
960
+ elif value > 10:
961
+ return f"{value:.2f}"
962
+ elif value > 1:
963
+ return f"{value:.3f}"
964
+ else:
965
+ return f"{value:.4f}"
966
+
967
+ log.info(
968
+ f"{prefix}\n"
969
+ + "\n".join(
970
+ [
971
+ f" {name}={format_float(value)}"
972
+ for name, value in metrics.items()
973
+ if name == "optim/total_grad_norm"
974
+ or not name.startswith("optim/") # there's too many optimizer metrics
975
+ ]
976
+ )
977
+ )
978
+
979
+ def should_log_optim_metrics_this_step(self) -> bool:
980
+ if self.cfg.wandb is None:
981
+ # We only log optimizer-specific metrics to W&B, since there are usually too many metrics
982
+ # to log to the console.
983
+ return False
984
+ optim_log_interval = self.cfg.optimizer.metrics_log_interval
985
+ if optim_log_interval is None:
986
+ optim_log_interval = self.cfg.wandb.log_interval
987
+ else:
988
+ optim_log_interval = max(optim_log_interval, self.cfg.wandb.log_interval)
989
+ return self.global_step % optim_log_interval == 0
990
+
991
+ def should_log_this_step(self) -> bool:
992
+ if self.global_step % self.cfg.console_log_interval == 0:
993
+ return True
994
+ elif self.cfg.wandb is not None and self.global_step % self.cfg.wandb.log_interval == 0:
995
+ return True
996
+ else:
997
+ return False
998
+
999
+ def eval(self) -> Dict[str, Any]:
1000
+ # Zero gradients and set model to 'eval' mode.
1001
+ self.optim.zero_grad(set_to_none=True)
1002
+ self.dist_model.eval()
1003
+
1004
+ eval_metrics = {}
1005
+ for evaluator in self.evaluators:
1006
+ log.info(f"Running evaluation for '{evaluator.label}'...")
1007
+
1008
+ # Reset metrics.
1009
+ evaluator.reset_metrics()
1010
+
1011
+ # Initialize data loader iterator.
1012
+ eval_batches = iter(evaluator.eval_loader)
1013
+
1014
+ # Adjust how many batches to evaluate on.
1015
+ num_eval_batches = (
1016
+ evaluator.subset_num_batches
1017
+ if evaluator.subset_num_batches is not None
1018
+ else self.cfg.eval_subset_num_batches
1019
+ )
1020
+ if num_eval_batches > 0:
1021
+ num_eval_batches = min(num_eval_batches, len(evaluator.eval_loader))
1022
+ eval_batches = islice(eval_batches, num_eval_batches)
1023
+
1024
+ # Run model over batches.
1025
+ for eval_step, eval_batch in enumerate(eval_batches):
1026
+ self.eval_step(eval_batch, evaluator)
1027
+
1028
+ # Log to console.
1029
+ if eval_step + 1 == num_eval_batches or (eval_step + 1) % self.cfg.console_log_interval == 0:
1030
+ log.info(f"[eval_step={eval_step + 1}/{num_eval_batches}]")
1031
+
1032
+ # Get final metrics.
1033
+ metrics = evaluator.compute_metrics()
1034
+ eval_metrics.update(metrics)
1035
+ self.log_metrics_to_console(f"{evaluator.label}", metrics)
1036
+
1037
+ del eval_batches
1038
+
1039
+ # Eval compiles a bunch more versions, and the result is terrible. This way we get back to zero.
1040
+ if self.cfg.compile is not None:
1041
+ torch.compiler.reset()
1042
+
1043
+ return eval_metrics
1044
+
1045
+ def check_if_cancelled(self) -> Tuple[bool, int]:
1046
+ should_cancel = False
1047
+ cancel_reason: Optional[str] = None
1048
+ extra_steps = 0
1049
+ if get_global_rank() == 0:
1050
+ if self.cfg.time_limit is not None and time.time() - self._start_time >= self.cfg.time_limit:
1051
+ # First check if we've reached the training time limit.
1052
+ should_cancel = True
1053
+ cancel_reason = "time limit reached"
1054
+ extra_steps = self.cfg.extra_steps_after_cancel
1055
+ elif (
1056
+ self.cfg.early_stopping_factor is not None
1057
+ and self.global_step > self.cfg.scheduler.t_warmup
1058
+ and self.cur_train_loss > self.cfg.early_stopping_factor * self.min_train_loss
1059
+ ):
1060
+ # Next check if early stopping loss criteria is met.
1061
+ should_cancel = True
1062
+ cancel_reason = "early stopping from loss increase"
1063
+ elif wandb.run is not None and (api_key := os.environ.get("WANDB_API_KEY")) is not None:
1064
+ # Finally, check if someone canceled the run from W&B by adding the 'cancel' / 'canceled' tag..
1065
+ # We won't see it in the run object. So we have to use the import/export API to check.
1066
+ from requests.exceptions import RequestException
1067
+ from wandb.errors import CommError
1068
+
1069
+ try:
1070
+ api = wandb.Api(api_key=api_key)
1071
+ run = api.run(wandb.run.path)
1072
+ for tag in run.tags or []:
1073
+ if tag.lower() in {"cancel", "canceled", "cancelled"}:
1074
+ should_cancel = True
1075
+ cancel_reason = "Weights & Biases tag"
1076
+ extra_steps = self.cfg.extra_steps_after_cancel
1077
+ break
1078
+ except (RequestException, CommError):
1079
+ log.info("Failed to check if W&B run is cancelled, continuing run.")
1080
+
1081
+ run_canceled = synchronize_flag(should_cancel, self.device)
1082
+ if run_canceled:
1083
+ extra_steps = synchronize_value(extra_steps, self.device)
1084
+ if cancel_reason is None:
1085
+ if extra_steps > 0:
1086
+ log.warning(f"Run canceled, stopping in {extra_steps} more steps...")
1087
+ else:
1088
+ log.warning("Run canceled")
1089
+ else:
1090
+ if extra_steps > 0:
1091
+ log.warning(f"Run canceled due to {cancel_reason}, stopping in {extra_steps} more steps...")
1092
+ else:
1093
+ log.warning(f"Run canceled due to {cancel_reason}")
1094
+
1095
+ return run_canceled, extra_steps
1096
+
1097
+ def fit(self):
1098
+ if self.cfg.stop_after is not None:
1099
+ if self.cfg.stop_at is None:
1100
+ self.cfg.stop_at = self.global_step + self.cfg.stop_after
1101
+ else:
1102
+ self.cfg.stop_at = min(self.cfg.stop_at, self.global_step + self.cfg.stop_after)
1103
+ if self.cfg.stop_at is None:
1104
+ self.cfg.stop_at = self.max_steps + 10
1105
+
1106
+ self._start_time = time.time()
1107
+ self._gc_init_state = gc.isenabled() # cache if garbage collection is enabled, reset on close.
1108
+
1109
+ # Disable automatic garbage collection, FSDP doesn't work well with it.
1110
+ if self.cfg.gen1_gc_interval is not None:
1111
+ gc.disable()
1112
+
1113
+ if self.cfg.load_path is not None and self.global_step > 0 and self.cfg.eval_on_load:
1114
+ eval_metrics = self.eval()
1115
+ if wandb.run is not None:
1116
+ wandb.log(eval_metrics, step=self.global_step)
1117
+
1118
+ # Set model to 'train' mode.
1119
+ self.dist_model.train()
1120
+
1121
+ # Initialize monitors.
1122
+ assert self.cfg.device_train_batch_size is not None
1123
+ speed_monitor = SpeedMonitor(self.cfg.speed_monitor)
1124
+ lr_monitor = LRMonitor(self.optim)
1125
+
1126
+ # Log system metrics at the start of training.
1127
+ sys_metrics = self.system_metrics()
1128
+ if sys_metrics:
1129
+ self.log_metrics_to_console("Pre-train system metrics", sys_metrics)
1130
+ if wandb.run is not None:
1131
+ wandb.log(sys_metrics, step=0)
1132
+
1133
+ # Python Profiler stuff
1134
+ if self.cfg.python_profiling:
1135
+ python_profiler = cProfile.Profile()
1136
+ else:
1137
+ python_profiler = None
1138
+
1139
+ # PyTorch Profiler stuff
1140
+ if self.cfg.torch_profiling and get_global_rank() == 0:
1141
+ from torch.profiler import schedule
1142
+
1143
+ profiling_schedule = schedule(wait=1, warmup=5, active=3, repeat=1)
1144
+
1145
+ def on_trace_ready(p):
1146
+ profiler_output_dir = Path(self.cfg.save_folder) / "profiler"
1147
+ profiler_output_dir.mkdir(exist_ok=True)
1148
+
1149
+ output = p.key_averages().table(sort_by="self_cuda_time_total", row_limit=32)
1150
+ log.info(f"Profile by total GPU time at step {p.step_num}:\n{output}")
1151
+ output = p.key_averages().table(sort_by="self_cpu_time_total", row_limit=32)
1152
+ log.info(f"Profile by total CPU time at step {p.step_num}:\n{output}")
1153
+
1154
+ p.export_chrome_trace(
1155
+ str(trace_path := (profiler_output_dir / f"{p.step_num}.chrome_trace.json.gz"))
1156
+ )
1157
+ if self.cfg.remote_save_folder is not None:
1158
+ upload_folder = f"{self.cfg.remote_save_folder.rstrip('/')}/profiler"
1159
+ log.info(f"Tracing complete, uploading results to '{upload_folder}'...")
1160
+ upload(trace_path, f"{upload_folder}/{trace_path.name}")
1161
+
1162
+ from torch.profiler import ProfilerActivity
1163
+
1164
+ torch_profiler = torch.profiler.profile(
1165
+ activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
1166
+ record_shapes=False,
1167
+ profile_memory=False,
1168
+ with_stack=True,
1169
+ schedule=profiling_schedule,
1170
+ on_trace_ready=on_trace_ready,
1171
+ )
1172
+ del profiling_schedule
1173
+ else:
1174
+ import contextlib
1175
+
1176
+ torch_profiler = contextlib.nullcontext()
1177
+
1178
+ # Train.
1179
+ first_batch: bool = True
1180
+ cancel_initiated: bool = False
1181
+ stop_at: int = self.cfg.stop_at
1182
+ save_checkpoints: bool = True
1183
+
1184
+ with torch_profiler as p:
1185
+ for epoch in range(self.epoch or 0, self.max_epochs):
1186
+ for batch in self.train_loader:
1187
+ # Bookkeeping.
1188
+ # NOTE: To track the global batch size / number of tokens per batch we make the assumption that all
1189
+ # batches see the same number of tokens, which should be the case for language model pre-training
1190
+ # (at least when drop_last=True).
1191
+ # Alternatively we'd have to use a distributed all reduce over seq_len here, but I don't want that
1192
+ # overhead. So for now I'm putting these assertions here so if the assumption is violated it will
1193
+ # fail loudly.
1194
+ batch_size, seq_len = batch["input_ids"].shape
1195
+ assert seq_len == self.cfg.model.max_sequence_length
1196
+ assert batch_size == self.cfg.device_train_batch_size
1197
+ global_batch_size = batch_size * get_world_size() # assumes batch size equal across ranks
1198
+ self.global_step += 1
1199
+ self.global_train_examples_seen_this_epoch += global_batch_size
1200
+ self.global_train_tokens_seen += global_batch_size * seq_len
1201
+ speed_monitor.batch_start(
1202
+ global_total_tokens=self.global_train_tokens_seen,
1203
+ device_batch_num_tokens=batch_size * seq_len, # num tokens in batch for this device
1204
+ # We start monitoring speed after the first batch since the first
1205
+ # batch might be an outlier due to compiling and other initialization overhead.
1206
+ num_fwd_flops=self.model.num_fwd_flops, # this is per token
1207
+ num_bck_flops=self.model.num_bck_flops, # this is per token
1208
+ record=not first_batch,
1209
+ )
1210
+
1211
+ should_log_this_step = self.should_log_this_step()
1212
+
1213
+ # Run train step on batch.
1214
+ metrics = self.train_step(batch, reduce_global_loss=should_log_this_step)
1215
+
1216
+ # Maybe collect other metrics.
1217
+ if should_log_this_step:
1218
+ # Speed metrics.
1219
+ metrics.update(speed_monitor.check())
1220
+ # System metrics.
1221
+ metrics.update(self.system_metrics())
1222
+ # Learning rate metrics.
1223
+ metrics.update(lr_monitor.check())
1224
+
1225
+ # Log metrics to console.
1226
+ if self.global_step % self.cfg.console_log_interval == 0:
1227
+ if get_global_rank() == 0:
1228
+ self.log_metrics_to_console(
1229
+ f"[step={self.global_step}/{self.max_steps},epoch={epoch}]",
1230
+ metrics,
1231
+ )
1232
+ else:
1233
+ log.info(f"[step={self.global_step}/{self.max_steps},epoch={epoch}]")
1234
+
1235
+ # Log metrics to W&B.
1236
+ if (
1237
+ wandb.run is not None
1238
+ and self.cfg.wandb is not None
1239
+ and self.global_step % self.cfg.wandb.log_interval == 0
1240
+ ):
1241
+ wandb.log(metrics, step=self.global_step)
1242
+
1243
+ # Check if/when run should be canceled.
1244
+ if not cancel_initiated and self.global_step % self.cfg.canceled_check_interval == 0:
1245
+ cancel_initiated, extra_steps = self.check_if_cancelled()
1246
+ if cancel_initiated:
1247
+ stop_at = min(stop_at, self.global_step + extra_steps)
1248
+
1249
+ # Maybe save sharded checkpoint.
1250
+ if self.cfg.distributed_strategy != DistributedStrategy.ddp:
1251
+ if save_checkpoints and (
1252
+ cancel_initiated
1253
+ or (
1254
+ self.cfg.save_interval is not None
1255
+ and self.global_step % self.cfg.save_interval == 0
1256
+ and self.cfg.save_num_checkpoints_to_keep != 0
1257
+ )
1258
+ ):
1259
+ log.info("Saving checkpoint...")
1260
+ checkpoint_path, _ = self.save_checkpoint(CheckpointType.sharded)
1261
+ log.info(f"Checkpoint saved to {checkpoint_path}")
1262
+
1263
+ # Remove any ephemeral checkpoints.
1264
+ while self.ephemeral_checkpoints:
1265
+ self.remove_ephemeral_checkpoint()
1266
+
1267
+ # Reset speed monitor so that we don't count the time taken to save checkpoints.
1268
+ speed_monitor.reset()
1269
+
1270
+ # If the run was just canceled this will be the final checkpoint.
1271
+ if cancel_initiated:
1272
+ save_checkpoints = False
1273
+ elif (
1274
+ self.cfg.save_interval_ephemeral is not None
1275
+ and self.global_step % self.cfg.save_interval_ephemeral == 0
1276
+ ):
1277
+ log.info("Saving ephemeral checkpoint...")
1278
+ checkpoint_path, _ = self.save_checkpoint(CheckpointType.sharded_ephemeral)
1279
+ log.info(f"Checkpoint saved to {checkpoint_path}")
1280
+
1281
+ # Reset speed monitor so that we don't count the time taken to save checkpoints.
1282
+ speed_monitor.reset()
1283
+
1284
+ # Maybe save unsharded checkpoint.
1285
+ # This code snippet should always execute when running DDP.
1286
+ if (
1287
+ save_checkpoints
1288
+ and self.cfg.save_interval_unsharded is not None
1289
+ and self.global_step % self.cfg.save_interval_unsharded == 0
1290
+ and self.cfg.save_num_unsharded_checkpoints_to_keep != 0
1291
+ ):
1292
+ log.info("Saving unsharded checkpoint...")
1293
+ checkpoint_path, _ = self.save_checkpoint(CheckpointType.unsharded)
1294
+ log.info(f"Unsharded checkpoint saved to {checkpoint_path}")
1295
+
1296
+ # Reset speed monitor so that we don't count the time taken to save checkpoints.
1297
+ speed_monitor.reset()
1298
+
1299
+ # Maybe run evaluations.
1300
+ if not cancel_initiated and (
1301
+ self.global_step % self.cfg.eval_interval == 0 or self.global_step >= stop_at
1302
+ ):
1303
+ eval_metrics = self.eval()
1304
+
1305
+ # Log metrics to W&B.
1306
+ if wandb.run is not None:
1307
+ wandb.log(eval_metrics, step=self.global_step)
1308
+
1309
+ # Reset speed monitor so that we don't count the time taken to run evaluations.
1310
+ speed_monitor.reset()
1311
+
1312
+ # Reset model to 'train' mode.
1313
+ self.dist_model.train()
1314
+
1315
+ # End of batch.
1316
+ first_batch = False
1317
+ if p is not None:
1318
+ p.step()
1319
+
1320
+ if self.global_step >= stop_at:
1321
+ break
1322
+
1323
+ # Run generation 1 garbage collection.
1324
+ if self.cfg.gen1_gc_interval is not None and self.global_step % self.cfg.gen1_gc_interval == 0:
1325
+ gc.collect(1)
1326
+
1327
+ # Python Profiler stuff
1328
+ # We do this now, at the bottom of this loop, so we capture the work of getting the next batch.
1329
+ if python_profiler is not None:
1330
+ if self.global_step == 5:
1331
+ python_profiler.enable()
1332
+ elif self.global_step == 8:
1333
+ python_profiler.disable()
1334
+ python_profiler.print_stats(sort=SortKey.CUMULATIVE)
1335
+ python_profiler = None
1336
+ else:
1337
+ log.info("Training epoch complete")
1338
+ self.epoch = epoch + 1
1339
+ self.global_train_examples_seen_this_epoch = 0
1340
+ self.dataset.start_index = 0
1341
+ if self.epoch < self.max_epochs:
1342
+ log.info(f"Reshuffling data loader for epoch {self.epoch}...")
1343
+ self.dataset.reshuffle(self.epoch)
1344
+ continue
1345
+
1346
+ break
1347
+
1348
+ # Save final checkpoint.
1349
+ if save_checkpoints:
1350
+ if (
1351
+ self.cfg.save_interval_unsharded is not None
1352
+ and self.last_unsharded_checkpoint_step != self.global_step
1353
+ ):
1354
+ log.info("Saving final unsharded model checkpoint...")
1355
+ checkpoint_path, _ = self.save_checkpoint(CheckpointType.unsharded)
1356
+ log.info(f"Unsharded checkpoint saved to {checkpoint_path}")
1357
+ elif (
1358
+ self.cfg.save_num_checkpoints_to_keep != 0
1359
+ and self.last_sharded_checkpoint_step != self.global_step
1360
+ and self.cfg.distributed_strategy == DistributedStrategy.fsdp
1361
+ ):
1362
+ log.info("Saving final checkpoint...")
1363
+ checkpoint_path, _ = self.save_checkpoint(CheckpointType.sharded)
1364
+ log.info(f"Checkpoint saved to {checkpoint_path}")
1365
+
1366
+ def close(self, exit_code: int = 0) -> None:
1367
+ gc_cuda()
1368
+
1369
+ if self.indices_file is not None:
1370
+ self.indices_file.flush()
1371
+ self.indices_file.close()
1372
+ if self._gc_init_state:
1373
+ gc.enable()
1374
+ else:
1375
+ gc.disable()
1376
+ if wandb.run is not None:
1377
+ wandb.finish(exit_code=exit_code, quiet=True)
1378
+
1379
+ def __enter__(self) -> Trainer:
1380
+ return self
1381
+
1382
+ def __exit__(self, exc_type, exc_val, exc_tb) -> None:
1383
+ del exc_val, exc_tb
1384
+ self.close(0 if exc_type is None else 1)
util.py ADDED
@@ -0,0 +1,929 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gzip
2
+ import io
3
+ import json
4
+ import logging
5
+ import os
6
+ import re
7
+ import socket
8
+ import sys
9
+ import time
10
+ import warnings
11
+ from datetime import datetime
12
+ from enum import Enum
13
+ from itertools import cycle, islice
14
+ from pathlib import Path
15
+ from queue import Queue
16
+ from threading import Thread
17
+ from typing import Any, Callable, Dict, MutableMapping, Optional, Tuple, Union
18
+
19
+ import boto3
20
+ import botocore.exceptions as boto_exceptions
21
+ import datasets
22
+ import requests
23
+ import rich
24
+ from botocore.config import Config
25
+ from cached_path.schemes import SchemeClient, add_scheme_client
26
+ from google.api_core.retry import Retry as GCSRetry
27
+ from google.api_core.retry import if_transient_error as gcs_is_transient_error
28
+ from rich.console import Console, ConsoleRenderable
29
+ from rich.highlighter import NullHighlighter
30
+ from rich.progress import Progress
31
+ from rich.text import Text
32
+ from rich.traceback import Traceback
33
+
34
+ from olmo_data.data import get_data_path
35
+
36
+ from .aliases import PathOrStr
37
+ from .exceptions import (
38
+ OLMoCliError,
39
+ OLMoEnvironmentError,
40
+ OLMoError,
41
+ OLMoNetworkError,
42
+ OLMoThreadError,
43
+ )
44
+ from .torch_util import get_global_rank, get_local_rank, get_node_rank, is_distributed
45
+
46
+ try:
47
+ from functools import cache
48
+ except ImportError:
49
+ from functools import lru_cache as cache
50
+
51
+
52
+ class StrEnum(str, Enum):
53
+ """
54
+ This is equivalent to Python's :class:`enum.StrEnum` since version 3.11.
55
+ We include this here for compatibility with older version of Python.
56
+ """
57
+
58
+ def __str__(self) -> str:
59
+ return self.value
60
+
61
+ def __repr__(self) -> str:
62
+ return f"'{str(self)}'"
63
+
64
+
65
+ _log_extra_fields: Dict[str, Any] = {}
66
+ log = logging.getLogger(__name__)
67
+
68
+
69
+ class LogFilterType(StrEnum):
70
+ rank0_only = "rank0_only"
71
+ local_rank0_only = "local_rank0_only"
72
+ all_ranks = "all_ranks"
73
+
74
+
75
+ def log_extra_field(field_name: str, field_value: Any) -> None:
76
+ global _log_extra_fields
77
+ if field_value is None:
78
+ if field_name in _log_extra_fields:
79
+ del _log_extra_fields[field_name]
80
+ else:
81
+ _log_extra_fields[field_name] = field_value
82
+
83
+
84
+ def setup_logging(log_filter_type: LogFilterType = LogFilterType.rank0_only) -> None:
85
+ """
86
+ :param rank0_only: INFO and below messages will only be emitted on the rank0 process.
87
+ """
88
+ log_extra_field("hostname", socket.gethostname())
89
+ if is_distributed():
90
+ log_extra_field("node_rank", get_node_rank())
91
+ log_extra_field("local_rank", get_local_rank())
92
+ log_extra_field("global_rank", get_global_rank())
93
+ else:
94
+ log_extra_field("node_rank", 0)
95
+ log_extra_field("local_rank", 0)
96
+ log_extra_field("global_rank", 0)
97
+
98
+ old_log_record_factory = logging.getLogRecordFactory()
99
+
100
+ def log_record_factory(*args, **kwargs) -> logging.LogRecord:
101
+ record = old_log_record_factory(*args, **kwargs)
102
+ for field_name, field_value in _log_extra_fields.items():
103
+ setattr(record, field_name, field_value)
104
+ return record
105
+
106
+ logging.setLogRecordFactory(log_record_factory)
107
+
108
+ handler: logging.Handler
109
+ if (
110
+ os.environ.get("OLMo_NONINTERACTIVE", False)
111
+ or os.environ.get("DEBIAN_FRONTEND", None) == "noninteractive"
112
+ or not sys.stdout.isatty()
113
+ ):
114
+ handler = logging.StreamHandler(sys.stdout)
115
+ formatter = logging.Formatter(
116
+ "%(asctime)s\t%(hostname)s:%(local_rank)s\t%(name)s:%(lineno)s\t%(levelname)s\t%(message)s"
117
+ )
118
+ formatter.default_time_format = "%Y-%m-%d %H:%M:%S"
119
+ formatter.default_msec_format = "%s.%03d"
120
+ handler.setFormatter(formatter)
121
+ else:
122
+ handler = RichHandler()
123
+
124
+ def rank0_filter(record: logging.LogRecord) -> int:
125
+ if record.levelno > logging.INFO:
126
+ return 1
127
+ if getattr(record, "global_rank", 0) == 0:
128
+ return 1
129
+ else:
130
+ return 0
131
+
132
+ def local_rank0_filter(record: logging.LogRecord) -> int:
133
+ if record.levelno > logging.INFO:
134
+ return 1
135
+ if getattr(record, "local_rank", 0) == 0:
136
+ return 1
137
+ else:
138
+ return 0
139
+
140
+ if log_filter_type == LogFilterType.rank0_only:
141
+ filter = rank0_filter
142
+ elif log_filter_type == LogFilterType.local_rank0_only:
143
+ filter = local_rank0_filter # type: ignore
144
+ elif log_filter_type == LogFilterType.all_ranks:
145
+ filter = None
146
+ else:
147
+ raise ValueError(log_filter_type)
148
+
149
+ if filter is not None:
150
+ handler.addFilter(filter) # type: ignore
151
+ logging.basicConfig(handlers=[handler], level=logging.INFO)
152
+
153
+ logging.captureWarnings(True)
154
+ logging.getLogger("urllib3").setLevel(logging.ERROR)
155
+
156
+
157
+ def excepthook(exctype, value, traceback):
158
+ """
159
+ Used to patch `sys.excepthook` in order to log exceptions.
160
+ """
161
+ if issubclass(exctype, KeyboardInterrupt):
162
+ sys.__excepthook__(exctype, value, traceback)
163
+ elif issubclass(exctype, OLMoCliError):
164
+ rich.get_console().print(f"[yellow]{value}[/]", highlight=False)
165
+ elif issubclass(exctype, OLMoError):
166
+ rich.get_console().print(Text(f"{exctype.__name__}:", style="red"), value, highlight=False)
167
+ else:
168
+ log.critical("Uncaught %s: %s", exctype.__name__, value, exc_info=(exctype, value, traceback))
169
+
170
+
171
+ def install_excepthook():
172
+ sys.excepthook = excepthook
173
+
174
+
175
+ def filter_warnings():
176
+ # Filter internal deprecation warnings from torch
177
+ warnings.filterwarnings(
178
+ action="ignore",
179
+ category=UserWarning,
180
+ message="torch.distributed.*_base is a private function and will be deprecated.*",
181
+ )
182
+ warnings.filterwarnings(
183
+ action="ignore",
184
+ category=UserWarning,
185
+ message="TypedStorage is deprecated.*",
186
+ )
187
+ warnings.filterwarnings(
188
+ action="ignore",
189
+ category=UserWarning,
190
+ message="Please use DTensor instead.*",
191
+ )
192
+ # Torchvision warnings. We don't actually use torchvision.
193
+ warnings.filterwarnings(
194
+ action="ignore",
195
+ message="failed to load.*",
196
+ module="torchvision.io.image",
197
+ )
198
+
199
+
200
+ def set_env_variables():
201
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
202
+
203
+
204
+ def prepare_cli_environment(log_filter_type: Optional[LogFilterType] = None):
205
+ if log_filter_type is None:
206
+ log_filter_type = LogFilterType(os.environ.get("LOG_FILTER_TYPE", "rank0_only"))
207
+ rich.reconfigure(width=max(rich.get_console().width, 180), soft_wrap=True)
208
+ setup_logging(log_filter_type=log_filter_type)
209
+ install_excepthook()
210
+ filter_warnings()
211
+ set_env_variables()
212
+
213
+
214
+ def clean_opt(arg: str) -> str:
215
+ if "=" not in arg:
216
+ arg = f"{arg}=True"
217
+ name, val = arg.split("=", 1)
218
+ name = name.strip("-").replace("-", "_")
219
+ return f"{name}={val}"
220
+
221
+
222
+ class RichHandler(logging.Handler):
223
+ """
224
+ A simplified version of rich.logging.RichHandler from
225
+ https://github.com/Textualize/rich/blob/master/rich/logging.py
226
+ """
227
+
228
+ def __init__(
229
+ self,
230
+ *,
231
+ level: Union[int, str] = logging.NOTSET,
232
+ console: Optional[Console] = None,
233
+ markup: bool = False,
234
+ ) -> None:
235
+ super().__init__(level=level)
236
+ self.console = console or rich.get_console()
237
+ self.highlighter = NullHighlighter()
238
+ self.markup = markup
239
+
240
+ def emit(self, record: logging.LogRecord) -> None:
241
+ try:
242
+ if hasattr(record.msg, "__rich__") or hasattr(record.msg, "__rich_console__"):
243
+ self.console.print(record.msg)
244
+ else:
245
+ msg: Any = record.msg
246
+ if isinstance(record.msg, str):
247
+ msg = self.render_message(record=record, message=record.getMessage())
248
+ renderables = [
249
+ self.get_time_text(record),
250
+ self.get_level_text(record),
251
+ self.get_location_text(record),
252
+ msg,
253
+ ]
254
+ if record.exc_info is not None:
255
+ tb = Traceback.from_exception(*record.exc_info) # type: ignore
256
+ renderables.append(tb)
257
+ self.console.print(*renderables)
258
+ except Exception:
259
+ self.handleError(record)
260
+
261
+ def render_message(self, *, record: logging.LogRecord, message: str) -> ConsoleRenderable:
262
+ use_markup = getattr(record, "markup", self.markup)
263
+ message_text = Text.from_markup(message) if use_markup else Text(message)
264
+
265
+ highlighter = getattr(record, "highlighter", self.highlighter)
266
+ if highlighter:
267
+ message_text = highlighter(message_text)
268
+
269
+ return message_text
270
+
271
+ def get_time_text(self, record: logging.LogRecord) -> Text:
272
+ log_time = datetime.fromtimestamp(record.created)
273
+ time_str = log_time.strftime("[%Y-%m-%d %X]")
274
+ return Text(time_str, style="log.time", end=" ")
275
+
276
+ def get_level_text(self, record: logging.LogRecord) -> Text:
277
+ level_name = record.levelname
278
+ level_text = Text.styled(level_name.ljust(8), f"logging.level.{level_name.lower()}")
279
+ level_text.style = "log.level"
280
+ level_text.end = " "
281
+ return level_text
282
+
283
+ def get_location_text(self, record: logging.LogRecord) -> Text:
284
+ name_and_line = f"{record.name}:{record.lineno}" if record.name != "root" else "root"
285
+ text = f"[{name_and_line}, rank={record.local_rank}]" # type: ignore
286
+ return Text(text, style="log.path")
287
+
288
+
289
+ def wait_for(condition: Callable[[], bool], description: str, timeout: float = 10.0):
290
+ """Wait for the condition function to return True."""
291
+ start_time = time.monotonic()
292
+ while not condition():
293
+ time.sleep(0.5)
294
+ if time.monotonic() - start_time > timeout:
295
+ raise TimeoutError(f"{description} timed out")
296
+
297
+
298
+ def is_url(path: PathOrStr) -> bool:
299
+ return re.match(r"[a-z0-9]+://.*", str(path)) is not None
300
+
301
+
302
+ def dir_is_empty(dir: PathOrStr) -> bool:
303
+ dir = Path(dir)
304
+ if not dir.is_dir():
305
+ return True
306
+ try:
307
+ next(dir.glob("*"))
308
+ return False
309
+ except StopIteration:
310
+ return True
311
+
312
+
313
+ def get_progress_bar() -> Progress:
314
+ from cached_path import get_download_progress
315
+
316
+ return get_download_progress()
317
+
318
+
319
+ def resource_path(
320
+ folder: PathOrStr, fname: str, local_cache: Optional[PathOrStr] = None, progress: Optional[Progress] = None
321
+ ) -> Path:
322
+ if local_cache is not None and (local_path := Path(local_cache) / fname).is_file():
323
+ log.info(f"Found local cache of {fname} at {local_path}")
324
+ return local_path
325
+ else:
326
+ from cached_path import cached_path
327
+
328
+ return cached_path(f"{str(folder).rstrip('/')}/{fname}", progress=progress)
329
+
330
+
331
+ def file_size(path: PathOrStr) -> int:
332
+ """
333
+ Get the size of a local or remote file in bytes.
334
+ """
335
+ if is_url(path):
336
+ from urllib.parse import urlparse
337
+
338
+ parsed = urlparse(str(path))
339
+ if parsed.scheme == "gs":
340
+ return _gcs_file_size(parsed.netloc, parsed.path.strip("/"))
341
+ elif parsed.scheme in ("s3", "r2", "weka"):
342
+ return _s3_file_size(parsed.scheme, parsed.netloc, parsed.path.strip("/"))
343
+ elif parsed.scheme in ("http", "https"):
344
+ return _http_file_size(parsed.scheme, parsed.netloc, parsed.path.strip("/"))
345
+ elif parsed.scheme == "file":
346
+ return file_size(str(path).replace("file://", "", 1))
347
+ else:
348
+ raise NotImplementedError(f"file size not implemented for '{parsed.scheme}' files")
349
+ else:
350
+ return os.stat(path).st_size
351
+
352
+
353
+ def upload(source: PathOrStr, target: str, save_overwrite: bool = False):
354
+ """Upload source file to a target location on GCS or S3."""
355
+ from urllib.parse import urlparse
356
+
357
+ source = Path(source)
358
+ assert source.is_file()
359
+ parsed = urlparse(target)
360
+ if parsed.scheme == "gs":
361
+ _gcs_upload(source, parsed.netloc, parsed.path.strip("/"), save_overwrite=save_overwrite)
362
+ elif parsed.scheme in ("s3", "r2", "weka"):
363
+ _s3_upload(source, parsed.scheme, parsed.netloc, parsed.path.strip("/"), save_overwrite=save_overwrite)
364
+ else:
365
+ raise NotImplementedError(f"Upload not implemented for '{parsed.scheme}' scheme")
366
+
367
+
368
+ def get_bytes_range(source: PathOrStr, bytes_start: int, num_bytes: int) -> bytes:
369
+ if is_url(source):
370
+ from urllib.parse import urlparse
371
+
372
+ parsed = urlparse(str(source))
373
+ if parsed.scheme == "gs":
374
+ return _gcs_get_bytes_range(parsed.netloc, parsed.path.strip("/"), bytes_start, num_bytes)
375
+ elif parsed.scheme in ("s3", "r2", "weka"):
376
+ return _s3_get_bytes_range(
377
+ parsed.scheme, parsed.netloc, parsed.path.strip("/"), bytes_start, num_bytes
378
+ )
379
+ elif parsed.scheme in ("http", "https"):
380
+ return _http_get_bytes_range(
381
+ parsed.scheme, parsed.netloc, parsed.path.strip("/"), bytes_start, num_bytes
382
+ )
383
+ elif parsed.scheme == "file":
384
+ return get_bytes_range(str(source).replace("file://", "", 1), bytes_start, num_bytes)
385
+ else:
386
+ raise NotImplementedError(f"get bytes range not implemented for '{parsed.scheme}' files")
387
+ else:
388
+ with open(source, "rb") as f:
389
+ f.seek(bytes_start)
390
+ return f.read(num_bytes)
391
+
392
+
393
+ def find_latest_checkpoint(dir: PathOrStr) -> Optional[PathOrStr]:
394
+ if is_url(dir):
395
+ from urllib.parse import urlparse
396
+
397
+ parsed = urlparse(str(dir))
398
+ if parsed.scheme == "gs":
399
+ return _gcs_find_latest_checkpoint(parsed.netloc, parsed.path.strip("/"))
400
+ elif parsed.scheme in ("s3", "r2", "weka"):
401
+ return _s3_find_latest_checkpoint(parsed.scheme, parsed.netloc, parsed.path.strip("/"))
402
+ elif parsed.scheme == "file":
403
+ return find_latest_checkpoint(str(dir).replace("file://", "", 1))
404
+ else:
405
+ raise NotImplementedError(f"find_latest_checkpoint not implemented for '{parsed.scheme}' files")
406
+ else:
407
+ latest_step = 0
408
+ latest_checkpoint: Optional[Path] = None
409
+ for path in Path(dir).glob("step*"):
410
+ if path.is_dir():
411
+ try:
412
+ step = int(path.name.replace("step", "").replace("-unsharded", ""))
413
+ except ValueError:
414
+ continue
415
+ # We prioritize sharded checkpoints over unsharded checkpoints.
416
+ if step > latest_step or (step == latest_step and not path.name.endswith("-unsharded")):
417
+ latest_step = step
418
+ latest_checkpoint = path
419
+ return latest_checkpoint
420
+
421
+
422
+ # Google Storage API is unhinged and requires you to specify the retry policy on every single call you make.
423
+ def _gcs_is_retriable(exception: Exception) -> bool:
424
+ if gcs_is_transient_error(exception):
425
+ return True
426
+ if isinstance(exception, requests.exceptions.ReadTimeout):
427
+ return True
428
+ return False
429
+
430
+
431
+ _gcs_retry = GCSRetry(predicate=_gcs_is_retriable, initial=1.0, maximum=10.0, multiplier=2.0, timeout=600.0)
432
+
433
+
434
+ def _gcs_upload(source: Path, bucket_name: str, key: str, save_overwrite: bool = False):
435
+ storage_client = _get_gcs_client()
436
+ bucket = storage_client.bucket(bucket_name)
437
+ blob = bucket.blob(key)
438
+ if not save_overwrite and blob.exists():
439
+ raise FileExistsError(f"gs://{bucket_name}/{key} already exists. Use save_overwrite to overwrite it.")
440
+ blob.upload_from_filename(source, retry=_gcs_retry)
441
+
442
+
443
+ def _gcs_file_size(bucket_name: str, key: str) -> int:
444
+ from google.api_core.exceptions import NotFound
445
+
446
+ storage_client = _get_gcs_client()
447
+ bucket = storage_client.bucket(bucket_name)
448
+ blob = bucket.blob(key)
449
+ try:
450
+ blob.reload(retry=_gcs_retry)
451
+ except NotFound:
452
+ raise FileNotFoundError(f"gs://{bucket_name}/{key}")
453
+ assert blob.size is not None
454
+ return blob.size
455
+
456
+
457
+ def _gcs_get_bytes_range(bucket_name: str, key: str, bytes_start: int, num_bytes: int) -> bytes:
458
+ from google.api_core.exceptions import NotFound
459
+
460
+ storage_client = _get_gcs_client()
461
+ bucket = storage_client.bucket(bucket_name)
462
+ blob = bucket.blob(key)
463
+ try:
464
+ return blob.download_as_bytes(start=bytes_start, end=bytes_start + num_bytes - 1, retry=_gcs_retry)
465
+ except NotFound:
466
+ raise FileNotFoundError(f"gs://{bucket_name}/{key}")
467
+
468
+
469
+ @cache
470
+ def _get_gcs_client():
471
+ from google.cloud import storage as gcs
472
+
473
+ return gcs.Client()
474
+
475
+
476
+ def _gcs_find_latest_checkpoint(bucket_name: str, prefix: str) -> Optional[str]:
477
+ if not prefix.endswith("/"):
478
+ prefix = f"{prefix}/"
479
+
480
+ storage_client = _get_gcs_client()
481
+ bucket = storage_client.bucket(bucket_name)
482
+ suffix = "/config.yaml"
483
+ latest_step: Optional[int] = None
484
+ latest_checkpoint: Optional[str] = None
485
+ for blob in bucket.list_blobs(prefix=prefix, match_glob=f"**{suffix}"):
486
+ # Disregard checkpoints that have an empty config file.
487
+ if blob.size <= 0:
488
+ continue
489
+
490
+ name = blob.name[len(prefix) : -len(suffix)]
491
+
492
+ if "/" in name:
493
+ # We're not considering checkpoints in subdirectories.
494
+ continue
495
+
496
+ if not name.startswith("step"):
497
+ continue
498
+ name = name[4:]
499
+
500
+ if name.endswith("-unsharded"):
501
+ name = name[: -len("-unsharded")]
502
+
503
+ try:
504
+ step = int(name)
505
+ except ValueError:
506
+ continue
507
+
508
+ # we prefer sharded checkpoints to unsharded ones
509
+ if (
510
+ latest_step is None
511
+ or step > latest_step
512
+ or (step == latest_step and latest_checkpoint is not None and latest_checkpoint.endswith("-unsharded"))
513
+ ):
514
+ latest_step = step
515
+ latest_checkpoint = f"gs://{bucket_name}/{blob.name[:-len(suffix)]}"
516
+
517
+ return latest_checkpoint
518
+
519
+
520
+ def _get_s3_profile_name(scheme: str) -> Optional[str]:
521
+ if scheme == "s3":
522
+ # For backwards compatibility, we assume S3 uses the default profile if S3_PROFILE is not set.
523
+ return os.environ.get("S3_PROFILE")
524
+ if scheme == "r2":
525
+ profile_name = os.environ.get("R2_PROFILE")
526
+ if profile_name is None:
527
+ raise OLMoEnvironmentError(
528
+ "R2 profile name is not set. Did you forget to set the 'R2_PROFILE' env var?"
529
+ )
530
+
531
+ return profile_name
532
+ if scheme == "weka":
533
+ profile_name = os.environ.get("WEKA_PROFILE")
534
+ if profile_name is None:
535
+ raise OLMoEnvironmentError(
536
+ "Weka profile name is not set. Did you forget to set the 'WEKA_PROFILE' env var?"
537
+ )
538
+
539
+ return profile_name
540
+
541
+ raise NotImplementedError(f"Cannot get profile name for scheme {scheme}")
542
+
543
+
544
+ def _get_s3_endpoint_url(scheme: str) -> Optional[str]:
545
+ if scheme == "s3":
546
+ return None
547
+ if scheme == "r2":
548
+ r2_endpoint_url = os.environ.get("R2_ENDPOINT_URL")
549
+ if r2_endpoint_url is None:
550
+ raise OLMoEnvironmentError(
551
+ "R2 endpoint url is not set. Did you forget to set the 'R2_ENDPOINT_URL' env var?"
552
+ )
553
+
554
+ return r2_endpoint_url
555
+ if scheme == "weka":
556
+ weka_endpoint_url = os.environ.get("WEKA_ENDPOINT_URL")
557
+ if weka_endpoint_url is None:
558
+ raise OLMoEnvironmentError(
559
+ "Weka endpoint url is not set. Did you forget to set the 'WEKA_ENDPOINT_URL' env var?"
560
+ )
561
+
562
+ return weka_endpoint_url
563
+
564
+ raise NotImplementedError(f"Cannot get endpoint url for scheme {scheme}")
565
+
566
+
567
+ @cache
568
+ def _get_s3_client(scheme: str):
569
+ session = boto3.Session(profile_name=_get_s3_profile_name(scheme))
570
+ return session.client(
571
+ "s3",
572
+ endpoint_url=_get_s3_endpoint_url(scheme),
573
+ config=Config(retries={"max_attempts": 10, "mode": "standard"}),
574
+ use_ssl=not int(os.environ.get("OLMO_NO_SSL", "0")),
575
+ )
576
+
577
+
578
+ def _wait_before_retry(attempt: int):
579
+ time.sleep(min(0.5 * 2**attempt, 3.0))
580
+
581
+
582
+ def _s3_upload(
583
+ source: Path, scheme: str, bucket_name: str, key: str, save_overwrite: bool = False, max_attempts: int = 3
584
+ ):
585
+ err: Optional[Exception] = None
586
+ if not save_overwrite:
587
+ for attempt in range(1, max_attempts + 1):
588
+ try:
589
+ _get_s3_client(scheme).head_object(Bucket=bucket_name, Key=key)
590
+ raise FileExistsError(
591
+ f"s3://{bucket_name}/{key} already exists. Use save_overwrite to overwrite it."
592
+ )
593
+ except boto_exceptions.ClientError as e:
594
+ if e.response["ResponseMetadata"]["HTTPStatusCode"] == 404:
595
+ err = None
596
+ break
597
+ err = e
598
+
599
+ if attempt < max_attempts:
600
+ log.warning("%s failed attempt %d with retriable error: %s", _s3_upload.__name__, attempt, err)
601
+ _wait_before_retry(attempt)
602
+
603
+ if err is not None:
604
+ raise OLMoNetworkError(f"Failed to check object existence during {scheme} upload") from err
605
+
606
+ try:
607
+ _get_s3_client(scheme).upload_file(source, bucket_name, key)
608
+ except boto_exceptions.ClientError as e:
609
+ raise OLMoNetworkError(f"Failed to upload to {scheme}") from e
610
+
611
+
612
+ def _s3_file_size(scheme: str, bucket_name: str, key: str, max_attempts: int = 3) -> int:
613
+ err: Optional[Exception] = None
614
+ for attempt in range(1, max_attempts + 1):
615
+ try:
616
+ return _get_s3_client(scheme).head_object(Bucket=bucket_name, Key=key)["ContentLength"]
617
+ except boto_exceptions.ClientError as e:
618
+ if e.response["ResponseMetadata"]["HTTPStatusCode"] == 404:
619
+ raise FileNotFoundError(f"s3://{bucket_name}/{key}") from e
620
+ err = e
621
+
622
+ if attempt < max_attempts:
623
+ log.warning("%s failed attempt %d with retriable error: %s", _s3_file_size.__name__, attempt, err)
624
+ _wait_before_retry(attempt)
625
+
626
+ raise OLMoNetworkError(f"Failed to get {scheme} file size") from err
627
+
628
+
629
+ def _s3_get_bytes_range(
630
+ scheme: str, bucket_name: str, key: str, bytes_start: int, num_bytes: int, max_attempts: int = 3
631
+ ) -> bytes:
632
+ err: Optional[Exception] = None
633
+ for attempt in range(1, max_attempts + 1):
634
+ try:
635
+ return (
636
+ _get_s3_client(scheme)
637
+ .get_object(
638
+ Bucket=bucket_name, Key=key, Range=f"bytes={bytes_start}-{bytes_start + num_bytes - 1}"
639
+ )["Body"]
640
+ .read()
641
+ )
642
+ except boto_exceptions.ClientError as e:
643
+ if e.response["ResponseMetadata"]["HTTPStatusCode"] == 404:
644
+ raise FileNotFoundError(f"{scheme}://{bucket_name}/{key}") from e
645
+ err = e
646
+ except (boto_exceptions.HTTPClientError, boto_exceptions.ConnectionError) as e:
647
+ # ResponseStreamingError (subclass of HTTPClientError) can happen as
648
+ # a result of a failed read from the stream (http.client.IncompleteRead).
649
+ # Retrying can help in this case.
650
+ err = e
651
+
652
+ if attempt < max_attempts:
653
+ log.warning(
654
+ "%s failed attempt %d with retriable error: %s", _s3_get_bytes_range.__name__, attempt, err
655
+ )
656
+ _wait_before_retry(attempt)
657
+
658
+ # When torch's DataLoader intercepts exceptions, it may try to re-raise them
659
+ # by recalling their constructor with a single message arg. Torch has some
660
+ # logic to deal with the absence of a single-parameter constructor, but it
661
+ # doesn't gracefully handle other possible failures in calling such a constructor
662
+ # This can cause an irrelevant exception (e.g. KeyError: 'error'), resulting
663
+ # in us losing the true exception info. To avoid this, we change the exception
664
+ # to a type that has a single-parameter constructor.
665
+ raise OLMoNetworkError(f"Failed to get bytes range from {scheme}") from err
666
+
667
+
668
+ def _s3_find_latest_checkpoint(scheme: str, bucket_name: str, prefix: str) -> Optional[str]:
669
+ if not prefix.endswith("/"):
670
+ prefix = f"{prefix}/"
671
+ response = _get_s3_client(scheme).list_objects(Bucket=bucket_name, Prefix=prefix, Delimiter="/")
672
+ assert not response["IsTruncated"] # need to handle this if it happens
673
+ latest_step = 0
674
+ latest_checkpoint: Optional[str] = None
675
+ for item in response.get("CommonPrefixes", []):
676
+ prefix = item["Prefix"].strip("/")
677
+ checkpoint_name = os.path.split(prefix)[-1]
678
+ if not checkpoint_name.startswith("step"):
679
+ continue
680
+ try:
681
+ step = int(checkpoint_name.replace("step", "").replace("-unsharded", ""))
682
+ except ValueError:
683
+ continue
684
+ # Make sure the checkpoint dir contains a config, otherwise the checkpoint is incomplete
685
+ # (upload might have have failed part way through).
686
+ try:
687
+ _s3_file_size(scheme, bucket_name, f"{prefix}/config.yaml")
688
+ except FileNotFoundError:
689
+ continue
690
+ # We prioritize sharded checkpoints over unsharded ones.
691
+ if step > latest_step or (step == latest_step and not checkpoint_name.endswith("-unsharded")):
692
+ latest_step = step
693
+ latest_checkpoint = f"{scheme}://{bucket_name}/{prefix}"
694
+ return latest_checkpoint
695
+
696
+
697
+ def _http_file_size(scheme: str, host_name: str, path: str) -> int:
698
+ import requests
699
+
700
+ response = requests.head(f"{scheme}://{host_name}/{path}", allow_redirects=True)
701
+ return int(response.headers.get("content-length"))
702
+
703
+
704
+ def _http_get_bytes_range(scheme: str, host_name: str, path: str, bytes_start: int, num_bytes: int) -> bytes:
705
+ import requests
706
+
707
+ response = requests.get(
708
+ f"{scheme}://{host_name}/{path}", headers={"Range": f"bytes={bytes_start}-{bytes_start+num_bytes-1}"}
709
+ )
710
+ result = response.content
711
+ assert (
712
+ len(result) == num_bytes
713
+ ), f"expected {num_bytes} bytes, got {len(result)}" # Some web servers silently ignore range requests and send everything
714
+ return result
715
+
716
+
717
+ def save_hf_dataset_to_disk(
718
+ dataset: datasets.DatasetDict | datasets.Dataset,
719
+ hf_path: str,
720
+ name: Optional[str],
721
+ split: str,
722
+ datasets_dir: PathOrStr,
723
+ ):
724
+ """
725
+ Saves a HF dataset to disk under the `datasets_dir`. It can be used to add a HF dataset
726
+ to `olmo_data` as follows:
727
+
728
+ ```
729
+ import datasets
730
+
731
+ from olmo.util import save_hf_dataset_to_disk
732
+
733
+ path, name, split = ...
734
+
735
+ dataset = datasets.load_dataset(path, name=name, split=split)
736
+ save_hf_dataset_to_disk(dataset, path, name, split, "olmo_data/hf_datasets")
737
+ ```
738
+ """
739
+ dataset_path = Path(datasets_dir) / hf_path / (name or "none") / split
740
+ return dataset.save_to_disk(str(dataset_path))
741
+
742
+
743
+ def load_hf_dataset(path: str, name: Optional[str], split: str):
744
+ """
745
+ Loads a HuggingFace dataset. The dataset is assumed to be saved using
746
+ `save_hf_dataset_to_disk` and located in `olmo_data/hf_datasets`.
747
+ """
748
+ dataset_rel_path = os.path.join("hf_datasets", path, name or "none", split)
749
+ with get_data_path(dataset_rel_path) as dataset_path:
750
+ if not dataset_path.is_dir():
751
+ raise NotADirectoryError(
752
+ f"HF dataset {path} name {name} split {split} not found in directory {dataset_rel_path}"
753
+ )
754
+ return datasets.load_from_disk(str(dataset_path))
755
+
756
+
757
+ def load_oe_eval_requests(path: str, name: Optional[str] = None, split: Optional[str] = None):
758
+ """
759
+ Loads an oe-eval request file from `olmo_data/oe_eval_tasks`.
760
+ TODO: Add support from loading from S3 instead?
761
+ """
762
+ dataset_rel_path = os.path.join("oe_eval_tasks", path)
763
+ if name is not None:
764
+ dataset_rel_path = os.path.join(dataset_rel_path, name)
765
+ with get_data_path(dataset_rel_path) as dataset_path:
766
+ if not dataset_path.is_dir():
767
+ raise NotADirectoryError(f"OE Eval dataset not found in directory {dataset_rel_path}")
768
+ data_file = dataset_path / "requests.jsonl.gz"
769
+ if not data_file.is_file():
770
+ data_file = dataset_path / "requests.jsonl"
771
+ if not data_file.is_file():
772
+ raise FileNotFoundError(
773
+ f"OE Eval dataset file requests-{split}.jsonl(.gz) missing in directory {dataset_rel_path}"
774
+ )
775
+ requests = []
776
+ if data_file.suffix == ".gz":
777
+ with gzip.open(data_file, "r") as file:
778
+ for line in file:
779
+ requests.append(json.loads(line.decode("utf-8").strip()))
780
+ else:
781
+ with open(data_file, "r") as file:
782
+ for line2 in file:
783
+ requests.append(json.loads(line2.strip()))
784
+ config = None
785
+ config_file = dataset_path / "config.json"
786
+ if config_file.is_file():
787
+ with open(config_file, "r") as file:
788
+ config = json.load(file)
789
+ return config, requests
790
+
791
+
792
+ def default_thread_count() -> int:
793
+ return int(os.environ.get("OLMO_NUM_THREADS") or min(32, (os.cpu_count() or 1) + 4))
794
+
795
+
796
+ def pass_through_fn(fn, *args, **kwargs):
797
+ return fn(*args, **kwargs)
798
+
799
+
800
+ def threaded_generator(g, maxsize: int = 16, thread_name: Optional[str] = None):
801
+ q: Queue = Queue(maxsize=maxsize)
802
+
803
+ sentinel = object()
804
+
805
+ def fill_queue():
806
+ try:
807
+ for value in g:
808
+ q.put(value)
809
+ except Exception as e:
810
+ q.put(e)
811
+ finally:
812
+ q.put(sentinel)
813
+
814
+ thread_name = thread_name or repr(g)
815
+ thread = Thread(name=thread_name, target=fill_queue, daemon=True)
816
+ thread.start()
817
+
818
+ for x in iter(q.get, sentinel):
819
+ if isinstance(x, Exception):
820
+ raise OLMoThreadError(f"generator thread {thread_name} failed") from x
821
+ else:
822
+ yield x
823
+
824
+
825
+ def roundrobin(*iterables):
826
+ """
827
+ Call the given iterables in a round-robin fashion. For example:
828
+ ``roundrobin('ABC', 'D', 'EF') --> A D E B F C``
829
+ """
830
+ # Adapted from https://docs.python.org/3/library/itertools.html#itertools-recipes
831
+ num_active = len(iterables)
832
+ nexts = cycle(iter(it).__next__ for it in iterables)
833
+ while num_active:
834
+ try:
835
+ for next in nexts:
836
+ yield next()
837
+ except StopIteration:
838
+ # Remove the iterator we just exhausted from the cycle.
839
+ num_active -= 1
840
+ nexts = cycle(islice(nexts, num_active))
841
+
842
+
843
+ def add_cached_path_clients():
844
+ add_scheme_client(WekaClient)
845
+
846
+
847
+ class WekaClient(SchemeClient):
848
+ recoverable_errors = SchemeClient.recoverable_errors + (
849
+ boto_exceptions.HTTPClientError,
850
+ boto_exceptions.ConnectionError,
851
+ )
852
+
853
+ scheme = "weka"
854
+
855
+ def __init__(self, resource: str) -> None:
856
+ SchemeClient.__init__(self, resource)
857
+ self.bucket_name, self.path = WekaClient._split_cloud_path(resource, "weka")
858
+ self.s3 = _get_s3_client("weka")
859
+ self.object_info = None
860
+
861
+ @staticmethod
862
+ def _split_cloud_path(url: str, provider: str) -> Tuple[str, str]:
863
+ """Split a full s3 path into the bucket name and path."""
864
+ from urllib.parse import urlparse
865
+
866
+ parsed = urlparse(url)
867
+ if not parsed.netloc or not parsed.path:
868
+ raise ValueError("bad {} path {}".format(provider, url))
869
+ bucket_name = parsed.netloc
870
+ provider_path = parsed.path
871
+ # Remove '/' at beginning of path.
872
+ if provider_path.startswith("/"):
873
+ provider_path = provider_path[1:]
874
+ return bucket_name, provider_path
875
+
876
+ def _ensure_object_info(self):
877
+ if self.object_info is None:
878
+ try:
879
+ self.object_info = self.s3.head_object(Bucket=self.bucket_name, Key=self.path)
880
+ except boto_exceptions.ClientError as e:
881
+ if e.response["ResponseMetadata"]["HTTPStatusCode"] == 404:
882
+ raise FileNotFoundError(f"weka://{self.bucket_name}/{self.path}") from e
883
+ raise e
884
+
885
+ def get_etag(self) -> Optional[str]:
886
+ self._ensure_object_info()
887
+ assert self.object_info is not None
888
+ return self.object_info.get("ETag")
889
+
890
+ def get_size(self) -> Optional[int]:
891
+ self._ensure_object_info()
892
+ assert self.object_info is not None
893
+ return self.object_info.get("ContentLength")
894
+
895
+ def get_resource(self, temp_file: io.BufferedWriter) -> None:
896
+ self.s3.download_fileobj(Fileobj=temp_file, Bucket=self.bucket_name, Key=self.path)
897
+
898
+ def get_bytes_range(self, index: int, length: int) -> bytes:
899
+ response = self.s3.get_object(
900
+ Bucket=self.bucket_name, Key=self.path, Range=f"bytes={index}-{index+length-1}"
901
+ )
902
+ return response["Body"].read()
903
+
904
+
905
+ def flatten_dict(dictionary, parent_key="", separator=".", include_lists=False):
906
+ """
907
+ Flatten a nested dictionary into a single-level dictionary.
908
+
909
+ Args:
910
+ dictionary (dict): The nested dictionary to be flattened.
911
+ parent_key (str, optional): The parent key to be prepended to the keys of the flattened dictionary. Defaults to "".
912
+ separator (str, optional): The separator to be used between the parent key and the keys of the flattened dictionary. Defaults to ".".
913
+ include_lists (bool, optional): Whether to convert lists to dictionaries with integer keys. Defaults to False.
914
+
915
+ Returns:
916
+ dict: The flattened dictionary.
917
+
918
+ """
919
+ d: Dict[str, Any] = {}
920
+ for key, value in dictionary.items():
921
+ new_key = parent_key + separator + key if parent_key else key
922
+ # convert lists to dict with key <int>
923
+ if isinstance(value, list) and include_lists:
924
+ value = {f"{i}": v for i, v in enumerate(value)}
925
+ if isinstance(value, MutableMapping):
926
+ d.update(**flatten_dict(value, new_key, separator=separator, include_lists=include_lists))
927
+ else:
928
+ d[new_key] = value
929
+ return d
version.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _MAJOR = "0"
2
+ _MINOR = "6"
3
+ # On main and in a nightly release the patch should be one ahead of the last
4
+ # released build.
5
+ _PATCH = "0"
6
+ # This is mainly for nightly builds which have the suffix ".dev$DATE". See
7
+ # https://semver.org/#is-v123-a-semantic-version for the semantics.
8
+ _SUFFIX = ""
9
+
10
+ VERSION_SHORT = "{0}.{1}".format(_MAJOR, _MINOR)
11
+ VERSION = "{0}.{1}.{2}{3}".format(_MAJOR, _MINOR, _PATCH, _SUFFIX)