Upload 15 files
Browse files- beam_search.py +1078 -0
- checkpoint.py +2023 -0
- config.json +2 -13
- config.py +1371 -0
- exceptions.py +50 -0
- initialization.py +22 -0
- model.py +1959 -0
- modeling_fan.py +271 -0
- optim.py +1040 -0
- safetensors_util.py +81 -0
- torch_util.py +158 -0
- train.py +1384 -0
- util.py +929 -0
- version.py +11 -0
beam_search.py
ADDED
@@ -0,0 +1,1078 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
This is a self-contained and flexible beam search implementation adapted from
|
3 |
+
AllenNLP's beam search: https://github.com/allenai/allennlp/blob/main/allennlp/nn/beam_search.py
|
4 |
+
"""
|
5 |
+
|
6 |
+
import copy
|
7 |
+
import warnings
|
8 |
+
from abc import abstractmethod
|
9 |
+
from inspect import signature
|
10 |
+
from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar, cast
|
11 |
+
|
12 |
+
import torch
|
13 |
+
|
14 |
+
__all__ = [
|
15 |
+
"Sampler",
|
16 |
+
"DeterministicSampler",
|
17 |
+
"MultinomialSampler",
|
18 |
+
"TopKSampler",
|
19 |
+
"TopPSampler",
|
20 |
+
"GumbelSampler",
|
21 |
+
"FinalSequenceScorer",
|
22 |
+
"SequenceLogProbabilityScorer",
|
23 |
+
"LengthNormalizedSequenceLogProbabilityScorer",
|
24 |
+
"Constraint",
|
25 |
+
"RepeatedNGramBlockingConstraint",
|
26 |
+
"BeamSearch",
|
27 |
+
]
|
28 |
+
|
29 |
+
StateType = Dict[str, torch.Tensor]
|
30 |
+
StepFunctionTypeWithTimestep = Callable[[torch.Tensor, StateType, int], Tuple[torch.Tensor, StateType]]
|
31 |
+
StepFunctionTypeNoTimestep = Callable[[torch.Tensor, StateType], Tuple[torch.Tensor, StateType]]
|
32 |
+
|
33 |
+
StepFunctionType = TypeVar("StepFunctionType", StepFunctionTypeWithTimestep, StepFunctionTypeNoTimestep)
|
34 |
+
"""
|
35 |
+
The type of step function that can be passed to [`BeamSearch.search`](#search).
|
36 |
+
|
37 |
+
This can either be [`StepFunctionTypeWithTimestep`](#stepfunctiontypewithtimestep)
|
38 |
+
or [`StepFunctionTypeNoTimestep`](#stepfunctiontypenotimestep).
|
39 |
+
"""
|
40 |
+
|
41 |
+
ConstraintStateType = List[List[Dict[str, Any]]]
|
42 |
+
|
43 |
+
|
44 |
+
class Sampler:
|
45 |
+
"""
|
46 |
+
An abstract class that can be used to sample candidates (either nodes or beams)
|
47 |
+
within `BeamSearch`.
|
48 |
+
|
49 |
+
A `Sampler` just has three methods, `init_state()`, `sample_nodes()` and `sample_beams()`.
|
50 |
+
|
51 |
+
`init_state()` takes three arguments:
|
52 |
+
|
53 |
+
- a tensor of starting log probs with shape `(batch_size,, num_classes)`,
|
54 |
+
- the batch size, an int,
|
55 |
+
- and the number of classes, also an int.
|
56 |
+
|
57 |
+
It returns a state dictionary with any state tensors needed for subsequent
|
58 |
+
calls to `sample_nodes()` and `sample_beams()`.
|
59 |
+
|
60 |
+
By default this method just returns an empty dictionary.
|
61 |
+
|
62 |
+
Both `sample_nodes()` and `sample_beams()` should take three arguments:
|
63 |
+
|
64 |
+
- tensor of normalized log probabilities with shape `(batch_size, num_examples)`,
|
65 |
+
- an integer representing the number of samples to take for each example in the batch,
|
66 |
+
- and a state dictionary which could contain any tensors needed for the `Sampler` to keep
|
67 |
+
track of state.
|
68 |
+
|
69 |
+
For `sample_nodes()`, `num_examples = num_classes`, but for `sample_beams`,
|
70 |
+
`num_examples = beam_size * per_node_beam_size`.
|
71 |
+
|
72 |
+
The return value should be a tuple containing:
|
73 |
+
|
74 |
+
- a tensor of log probabilities of the sampled examples with shape `(batch_size, num_samples)`,
|
75 |
+
- a tensor of indices of the sampled examples with shape `(batch_size, num_samples)`,
|
76 |
+
- and the updated state dictionary.
|
77 |
+
|
78 |
+
A default implementation of `sample_beams` is provided, which just deterministically
|
79 |
+
picks the `k` examples with highest log probability.
|
80 |
+
"""
|
81 |
+
|
82 |
+
def init_state(
|
83 |
+
self, start_class_log_probabilities: torch.Tensor, batch_size: int, num_classes: int
|
84 |
+
) -> StateType:
|
85 |
+
del start_class_log_probabilities, batch_size, num_classes
|
86 |
+
return {}
|
87 |
+
|
88 |
+
@abstractmethod
|
89 |
+
def sample_nodes(
|
90 |
+
self, log_probs: torch.Tensor, per_node_beam_size: int, state: StateType
|
91 |
+
) -> Tuple[torch.Tensor, torch.Tensor, StateType]:
|
92 |
+
raise NotImplementedError
|
93 |
+
|
94 |
+
def sample_beams(
|
95 |
+
self, log_probs: torch.Tensor, beam_size: int, state: StateType
|
96 |
+
) -> Tuple[torch.Tensor, torch.Tensor, StateType]:
|
97 |
+
del state
|
98 |
+
selected_log_probs, selected_indices = torch.topk(log_probs, beam_size, dim=-1)
|
99 |
+
return selected_log_probs, selected_indices, {}
|
100 |
+
|
101 |
+
|
102 |
+
class DeterministicSampler(Sampler):
|
103 |
+
"""
|
104 |
+
A `Sampler` that just deterministically returns the `k` nodes or beams with highest
|
105 |
+
log probability.
|
106 |
+
"""
|
107 |
+
|
108 |
+
def sample_nodes(
|
109 |
+
self, log_probs: torch.Tensor, per_node_beam_size: int, state: StateType
|
110 |
+
) -> Tuple[torch.Tensor, torch.Tensor, StateType]:
|
111 |
+
del state
|
112 |
+
selected_log_probs, selected_indices = torch.topk(log_probs, per_node_beam_size, dim=-1)
|
113 |
+
return selected_log_probs, selected_indices, {}
|
114 |
+
|
115 |
+
|
116 |
+
class MultinomialSampler(Sampler):
|
117 |
+
"""
|
118 |
+
A `Sampler` which samples nodes from the given multinomial distribution. Beams are sampled
|
119 |
+
in the default, non-deterministic way.
|
120 |
+
|
121 |
+
:param temperature: A `temperature` below 1.0 produces a sharper probability distribution and a `temperature`
|
122 |
+
above 1.0 produces a flatter probability distribution.
|
123 |
+
:param with_replacement: Whether to sample with replacement.
|
124 |
+
|
125 |
+
"""
|
126 |
+
|
127 |
+
def __init__(
|
128 |
+
self,
|
129 |
+
temperature: float = 1.0,
|
130 |
+
with_replacement: bool = False,
|
131 |
+
) -> None:
|
132 |
+
self.temperature = temperature
|
133 |
+
self.with_replacement = with_replacement
|
134 |
+
|
135 |
+
def sample_nodes(
|
136 |
+
self, log_probs: torch.Tensor, per_node_beam_size: int, state: StateType
|
137 |
+
) -> Tuple[torch.Tensor, torch.Tensor, StateType]:
|
138 |
+
if self.temperature != 1.0:
|
139 |
+
_probabilities = torch.nn.functional.softmax(log_probs / self.temperature, dim=-1)
|
140 |
+
else:
|
141 |
+
_probabilities = log_probs.exp()
|
142 |
+
|
143 |
+
selected_indices = torch.multinomial(_probabilities, per_node_beam_size, replacement=self.with_replacement)
|
144 |
+
|
145 |
+
return torch.gather(log_probs, 1, selected_indices), selected_indices, state
|
146 |
+
|
147 |
+
|
148 |
+
class TopKSampler(Sampler):
|
149 |
+
"""
|
150 |
+
A `Sampler` which redistributes the probability mass function for nodes among the
|
151 |
+
top `k` choices, then samples from that subset after re-normalizing the probabilities.
|
152 |
+
|
153 |
+
Beams are sampled in the default, deterministic way.
|
154 |
+
|
155 |
+
:param k: The number of top choices to be selected from.
|
156 |
+
:param temperature: A `temperature` below 1.0 produces a sharper probability distribution and a `temperature`
|
157 |
+
above 1.0 produces a flatter probability distribution.
|
158 |
+
:param with_replacement: If set to `True`, samples will be selected with replacement from the top k choices.
|
159 |
+
"""
|
160 |
+
|
161 |
+
def __init__(
|
162 |
+
self,
|
163 |
+
k: int = 1,
|
164 |
+
temperature: float = 1.0,
|
165 |
+
with_replacement: bool = False,
|
166 |
+
):
|
167 |
+
self.k = k
|
168 |
+
self.temperature = temperature or 1.0
|
169 |
+
self.with_replacement = with_replacement
|
170 |
+
|
171 |
+
def sample_nodes(
|
172 |
+
self, log_probs: torch.Tensor, per_node_beam_size: int, state: StateType
|
173 |
+
) -> Tuple[torch.Tensor, torch.Tensor, StateType]:
|
174 |
+
if not per_node_beam_size <= self.k <= log_probs.size()[1]:
|
175 |
+
raise ValueError(
|
176 |
+
"k must be a postive integer no less than per_node_beam_size and no greater than vocabulary size"
|
177 |
+
)
|
178 |
+
|
179 |
+
# shape (both): (batch_size, k)
|
180 |
+
top_k_log_probs, top_k_indices = log_probs.topk(self.k, dim=-1)
|
181 |
+
|
182 |
+
# Apply temperature if necessary.
|
183 |
+
# shape: (batch_size, k)
|
184 |
+
if self.temperature != 1.0:
|
185 |
+
top_k_log_probs = top_k_log_probs / self.temperature
|
186 |
+
|
187 |
+
# Re-normalize the subset.
|
188 |
+
# shape: (batch_size, k)
|
189 |
+
normalized_top_k_probs = torch.nn.functional.softmax(top_k_log_probs, dim=-1)
|
190 |
+
|
191 |
+
# Sample from the re-normalized subset.
|
192 |
+
# NOTE: These indices are not indices into `log_probs`, they are indices into `top_k_log_probs`.
|
193 |
+
# shape: (batch_size, per_node_beam_size)
|
194 |
+
sampled_indices = torch.multinomial(
|
195 |
+
normalized_top_k_probs, per_node_beam_size, replacement=self.with_replacement
|
196 |
+
)
|
197 |
+
|
198 |
+
# Convert `sampled_indices` back to indices in the original `log_probs` tensor.
|
199 |
+
# shape: (batch_size, per_node_beam_size)
|
200 |
+
indices = top_k_indices.gather(-1, sampled_indices)
|
201 |
+
|
202 |
+
return log_probs.gather(1, indices), indices, state
|
203 |
+
|
204 |
+
|
205 |
+
class TopPSampler(Sampler):
|
206 |
+
"""
|
207 |
+
A `Sampler` which redistributes the probability mass function for nodes among
|
208 |
+
the top choices with a cumulative probability of at least `p`, then samples from that subset
|
209 |
+
after re-normalizing the probabilities.
|
210 |
+
|
211 |
+
Beams are sampled in the default, deterministic way.
|
212 |
+
|
213 |
+
:param p:
|
214 |
+
The cumulative probability cutoff threshold. A higher value of `p` will result in more possible
|
215 |
+
examples to sample from. If `with_replacement` is `False` and the number of possible samples is
|
216 |
+
insufficient to sample without replacement from when calling `sample_nodes`, then the top
|
217 |
+
`per_node_beam_size` examples will be chosen.
|
218 |
+
:param temperature:
|
219 |
+
A `temperature` below 1.0 produces a sharper probability distribution and a `temperature`
|
220 |
+
above 1.0 produces a flatter probability distribution.
|
221 |
+
:param with_replacement:
|
222 |
+
If set to `True`, samples will be selected with replacement from the top choices.
|
223 |
+
|
224 |
+
"""
|
225 |
+
|
226 |
+
def __init__(
|
227 |
+
self,
|
228 |
+
p: float = 0.9,
|
229 |
+
temperature: float = 1.0,
|
230 |
+
with_replacement: bool = False,
|
231 |
+
):
|
232 |
+
if p < 0.0 or p > 1.0:
|
233 |
+
raise ValueError("p must be a positive float no greater than 1.0")
|
234 |
+
self.p = p
|
235 |
+
self.temperature = temperature or 1.0
|
236 |
+
self.with_replacement = with_replacement
|
237 |
+
|
238 |
+
def sample_nodes(
|
239 |
+
self, log_probs: torch.Tensor, per_node_beam_size: int, state: StateType
|
240 |
+
) -> Tuple[torch.Tensor, torch.Tensor, StateType]:
|
241 |
+
if not per_node_beam_size <= log_probs.size()[1]:
|
242 |
+
raise ValueError("per_node_beam_size cannot be greater than vocabulary size")
|
243 |
+
|
244 |
+
# First apply temperature coefficient:
|
245 |
+
if self.temperature != 1.0:
|
246 |
+
_log_probs = torch.nn.functional.log_softmax(log_probs / self.temperature, dim=-1)
|
247 |
+
else:
|
248 |
+
_log_probs = log_probs
|
249 |
+
|
250 |
+
# Sort the probabilities in descending order to then find cumulative sum
|
251 |
+
log_probs_descending, sorting_indices = torch.sort(_log_probs, descending=True)
|
252 |
+
|
253 |
+
# shape: (batch_size, num_classes)
|
254 |
+
probabilities_descending = log_probs_descending.exp()
|
255 |
+
probabilities_summed = torch.cumsum(probabilities_descending, dim=-1)
|
256 |
+
|
257 |
+
# Create a mask for filtering out probabilities that don't make the top `p`.
|
258 |
+
# shape: (batch_size, num_classes)
|
259 |
+
exclusion_mask = probabilities_summed >= self.p
|
260 |
+
|
261 |
+
# We want to include the first index where probabilities_summed >= p, so we shift over one.
|
262 |
+
exclusion_mask[..., 1:] = exclusion_mask[..., :-1].clone()
|
263 |
+
exclusion_mask[..., 0] = False
|
264 |
+
|
265 |
+
# Make sure there's at least `per_node_beam_size` options to be selected.
|
266 |
+
if not self.with_replacement:
|
267 |
+
exclusion_mask[..., :per_node_beam_size] = False
|
268 |
+
|
269 |
+
log_probs_descending[exclusion_mask] = torch.finfo(log_probs.dtype).min
|
270 |
+
|
271 |
+
# Now re-normalized the included log probs.
|
272 |
+
# shape: (batch_size, num_classes)
|
273 |
+
filtered_probabilities = torch.nn.functional.softmax(log_probs_descending, dim=-1)
|
274 |
+
|
275 |
+
# Sample from the re-normalized subset.
|
276 |
+
# NOTE: These indices are not indices into `log_probs`, they are indices into `log_probs_descending`.
|
277 |
+
# shape: (batch_size, per_node_beam_size)
|
278 |
+
sampled_indices = torch.multinomial(
|
279 |
+
filtered_probabilities, per_node_beam_size, replacement=self.with_replacement
|
280 |
+
)
|
281 |
+
|
282 |
+
# Convert `sampled_indices` back to indices in the original `log_probs` tensor.
|
283 |
+
# shape: (batch_size, per_node_beam_size)
|
284 |
+
selected_indices = sorting_indices.gather(-1, sampled_indices)
|
285 |
+
|
286 |
+
# Return (selected log probabilities, selected classes)
|
287 |
+
# shape: (len(log_probs),1) , (len(log_probs), 1)
|
288 |
+
return torch.gather(log_probs, 1, selected_indices), selected_indices, state
|
289 |
+
|
290 |
+
|
291 |
+
class GumbelSampler(Sampler):
|
292 |
+
"""
|
293 |
+
A `Sampler` which uses the Gumbel-Top-K trick to sample without replacement. See
|
294 |
+
[*Stochastic Beams and Where to Find Them: The Gumbel-Top-k Trick for Sampling
|
295 |
+
Sequences Without Replacement*, W Kool, H Van Hoof and M Welling, 2010]
|
296 |
+
(https://api.semanticscholar.org/CorpusID:76662039).
|
297 |
+
|
298 |
+
:param temperature: A `temperature` below 1.0 produces a sharper probability distribution and a `temperature`
|
299 |
+
above 1.0 produces a flatter probability distribution.
|
300 |
+
"""
|
301 |
+
|
302 |
+
def __init__(self, temperature: float = 1.0):
|
303 |
+
self.temperature = temperature
|
304 |
+
|
305 |
+
def init_state(
|
306 |
+
self, start_class_log_probabilities: torch.Tensor, batch_size: int, num_classes: int
|
307 |
+
) -> StateType:
|
308 |
+
# shape: (batch_size, num_classes)
|
309 |
+
zeros = start_class_log_probabilities.new_zeros((batch_size, num_classes))
|
310 |
+
|
311 |
+
# shape: (batch_size, num_classes)
|
312 |
+
G_phi_S = self.gumbel_with_max(start_class_log_probabilities, zeros)
|
313 |
+
|
314 |
+
return {"G_phi_S": G_phi_S}
|
315 |
+
|
316 |
+
def sample_nodes(
|
317 |
+
self,
|
318 |
+
log_probs: torch.Tensor,
|
319 |
+
per_node_beam_size: int,
|
320 |
+
state: StateType,
|
321 |
+
) -> Tuple[torch.Tensor, torch.Tensor, StateType]:
|
322 |
+
# First apply temperature coefficient:
|
323 |
+
# shape: (batch_size * beam_size, num_classes)
|
324 |
+
if self.temperature != 1.0:
|
325 |
+
_log_probs = torch.nn.functional.log_softmax(log_probs / self.temperature, dim=-1)
|
326 |
+
else:
|
327 |
+
_log_probs = log_probs
|
328 |
+
|
329 |
+
# shape: (group_size,)
|
330 |
+
phi_S = state["phi_S"]
|
331 |
+
|
332 |
+
# shape: (group_size, num_classes)
|
333 |
+
phi_S = phi_S.unsqueeze(-1).expand_as(_log_probs)
|
334 |
+
|
335 |
+
# shape: (group_size, num_classes)
|
336 |
+
phi_S_new = phi_S + _log_probs
|
337 |
+
|
338 |
+
# shape: (group_size, 1)
|
339 |
+
G_phi_S = state["G_phi_S"].unsqueeze(-1)
|
340 |
+
|
341 |
+
# shape: (group_size, num_classes)
|
342 |
+
G_phi_S_new = self.gumbel_with_max(phi_S_new, G_phi_S)
|
343 |
+
|
344 |
+
# Replace NaNs with very negative number.
|
345 |
+
# shape: (group_size, num_classes)
|
346 |
+
# G_phi_S_new[G_phi_S_new.isnan()] = torch.finfo(G_phi_S_new.dtype).min
|
347 |
+
|
348 |
+
# shape (both): (group_size, per_node_beam_size)
|
349 |
+
top_G_phi_S_new, top_indices = torch.topk(G_phi_S_new, per_node_beam_size, dim=-1)
|
350 |
+
|
351 |
+
# shape: (group_size, per_node_beam_size)
|
352 |
+
top_log_probs = log_probs.gather(1, top_indices)
|
353 |
+
|
354 |
+
return top_log_probs, top_indices, {"G_phi_S": top_G_phi_S_new}
|
355 |
+
|
356 |
+
def sample_beams(
|
357 |
+
self,
|
358 |
+
log_probs: torch.Tensor,
|
359 |
+
beam_size: int,
|
360 |
+
state: StateType,
|
361 |
+
) -> Tuple[torch.Tensor, torch.Tensor, StateType]:
|
362 |
+
"""
|
363 |
+
Returns the beams with the highest perturbed log probabilities.
|
364 |
+
"""
|
365 |
+
# shape (log_probs): (batch_size, beam_size * per_node_beam_size)
|
366 |
+
|
367 |
+
batch_size = log_probs.size()[0]
|
368 |
+
|
369 |
+
# shape: (batch_size * beam_size, per_node_beam_size)
|
370 |
+
G_phi_S = state["G_phi_S"]
|
371 |
+
|
372 |
+
# shape: (batch_size, beam_size * per_node_beam_size)
|
373 |
+
G_phi_S = G_phi_S.reshape_as(log_probs)
|
374 |
+
|
375 |
+
# shape (both): (batch_size, beam_size)
|
376 |
+
G_phi_S_new, selected_indices = torch.topk(G_phi_S, beam_size, dim=-1)
|
377 |
+
|
378 |
+
# shape: (batch_size, beam_size)
|
379 |
+
selected_log_probs = log_probs.gather(1, selected_indices)
|
380 |
+
|
381 |
+
# Now sort the selected beams by their true log prob.
|
382 |
+
# shape (all): (batch_size, beam_size)
|
383 |
+
selected_log_probs, sort_indices = selected_log_probs.sort(dim=-1, descending=True)
|
384 |
+
selected_indices = selected_indices.gather(1, sort_indices)
|
385 |
+
G_phi_S_new = G_phi_S_new.gather(1, sort_indices)
|
386 |
+
|
387 |
+
# shape: (batch_size * beam_size,)
|
388 |
+
G_phi_S_new = G_phi_S_new.reshape(batch_size * beam_size)
|
389 |
+
|
390 |
+
# shape: (batch_size * beam_size,)
|
391 |
+
phi_S = selected_log_probs.reshape(batch_size * beam_size)
|
392 |
+
|
393 |
+
return selected_log_probs, selected_indices, {"G_phi_S": G_phi_S_new, "phi_S": phi_S}
|
394 |
+
|
395 |
+
def gumbel(self, phi) -> torch.Tensor:
|
396 |
+
"""
|
397 |
+
Sample `Gumbel(phi)`.
|
398 |
+
|
399 |
+
`phi` should have shape `(batch_size, num_classes)`.
|
400 |
+
"""
|
401 |
+
return -torch.log(-torch.log(torch.rand_like(phi))) + phi
|
402 |
+
|
403 |
+
def gumbel_with_max(self, phi, T) -> torch.Tensor:
|
404 |
+
"""
|
405 |
+
Sample `Gumbel(phi)` conditioned on the maximum value being equal to `T`.
|
406 |
+
|
407 |
+
`phi` should have shape `(batch_size, num_classes)` and `T` should have
|
408 |
+
shape `(batch_size, 1)`.
|
409 |
+
"""
|
410 |
+
# Shape: (batch_size, num_classes)
|
411 |
+
G_phi = self.gumbel(phi)
|
412 |
+
|
413 |
+
# Now we find the maximum from these samples.
|
414 |
+
# Shape: (batch_size, )
|
415 |
+
Z, _ = G_phi.max(dim=-1)
|
416 |
+
|
417 |
+
# Shape: (batch_size, num_classes)
|
418 |
+
v = T - G_phi + torch.log1p(-torch.exp(G_phi - Z.unsqueeze(-1)))
|
419 |
+
|
420 |
+
# Shape: (batch_size, num_classes)
|
421 |
+
return T - torch.nn.functional.relu(v) - torch.log1p(torch.exp(-v.abs()))
|
422 |
+
|
423 |
+
|
424 |
+
class FinalSequenceScorer:
|
425 |
+
"""
|
426 |
+
An abstract class that can be used to score the final generated sequences found
|
427 |
+
by beam search. Given the predicted sequences and the corresponding log probabilities of
|
428 |
+
those sequences, the class calculates and returns the final score of the sequences.
|
429 |
+
|
430 |
+
The default implementation scores the sequences using the sum of the log probabilities of
|
431 |
+
the sequence, which is passed as input.
|
432 |
+
"""
|
433 |
+
|
434 |
+
@abstractmethod
|
435 |
+
def score(self, predictions: torch.Tensor, log_probabilities: torch.Tensor, end_index: int) -> torch.Tensor:
|
436 |
+
"""
|
437 |
+
Score the final predictions found by beam search.
|
438 |
+
Returns a tensor of the final sequence scores of shape `(batch_size, beam_size)`.
|
439 |
+
|
440 |
+
:param predictions: A tensor containing the initial predictions with shape `(batch_size, beam_size, max_steps)`.
|
441 |
+
:param log_probabilities: A tensor containing the log probabilities of the sequence, defined as the sum
|
442 |
+
of the log probabilities per token, with shape `(batch_size, beam_size)`.
|
443 |
+
:param end_index: The index of the end symbol.
|
444 |
+
|
445 |
+
"""
|
446 |
+
raise NotImplementedError
|
447 |
+
|
448 |
+
|
449 |
+
class SequenceLogProbabilityScorer(FinalSequenceScorer):
|
450 |
+
"""
|
451 |
+
A :class:`FinalSequenceScorer` which scores the sequences by the sum of the log probabilities
|
452 |
+
across the sequence's tokens.
|
453 |
+
"""
|
454 |
+
|
455 |
+
def score(self, predictions: torch.Tensor, log_probabilities: torch.Tensor, end_index: int) -> torch.Tensor:
|
456 |
+
del predictions, end_index
|
457 |
+
# The sum of the sequence log probabilities is the input parameter, so just
|
458 |
+
# return it.
|
459 |
+
return log_probabilities
|
460 |
+
|
461 |
+
|
462 |
+
class LengthNormalizedSequenceLogProbabilityScorer(FinalSequenceScorer):
|
463 |
+
"""
|
464 |
+
A :class:`FinalSequenceScorer` which scores the sequences by the average log probability of the
|
465 |
+
tokens in the sequence. It optionally includes a length penalty which promotes
|
466 |
+
or demotes sequences based on their lengths. The final score for a sequence will
|
467 |
+
be `(sequence_log_probability) / (sequence_length ** length_penalty)`. The sequence length
|
468 |
+
here includes the end token.
|
469 |
+
|
470 |
+
:param length_penalty: The length penalty to use. A value of 1.0 means no length penalty is used.
|
471 |
+
A value > 1.0 favors longer sequences, and < 1.0 favors shorter sequences.
|
472 |
+
"""
|
473 |
+
|
474 |
+
def __init__(self, length_penalty: float = 1.0):
|
475 |
+
super().__init__()
|
476 |
+
self.length_penalty = length_penalty
|
477 |
+
|
478 |
+
def score(self, predictions: torch.Tensor, log_probabilities: torch.Tensor, end_index: int) -> torch.Tensor:
|
479 |
+
# shape: (batch_size, beam_size)
|
480 |
+
lengths = (predictions != end_index).long().sum(dim=2)
|
481 |
+
|
482 |
+
# If the sequence ended during beam search, the `log_probabilities` will include
|
483 |
+
# the transition to the end token. Therefore, in such situations, `lengths` is
|
484 |
+
# actually off by 1. This corrects for that.
|
485 |
+
# shape: (batch_size, beam_size)
|
486 |
+
is_end_token = predictions[:, :, -1] == end_index
|
487 |
+
lengths += is_end_token.long()
|
488 |
+
|
489 |
+
# shape: (batch_size, beam_size)
|
490 |
+
average_log_probs = log_probabilities / (lengths**self.length_penalty)
|
491 |
+
return average_log_probs
|
492 |
+
|
493 |
+
|
494 |
+
class Constraint:
|
495 |
+
"""
|
496 |
+
An abstract class that can be used to enforce constraints on the output predictions
|
497 |
+
by manipulating the class log probabilities during beam search.
|
498 |
+
|
499 |
+
A `Constraint` just has three methods that need to be implemented by subclasses:
|
500 |
+
`init_state()`, `apply()` and `_update_state()`.
|
501 |
+
|
502 |
+
`init_state()` takes one argument:
|
503 |
+
|
504 |
+
- the batch size, an int
|
505 |
+
|
506 |
+
It returns a constraint state, which is a nested list of dictionaries, with any state needed for subsequent
|
507 |
+
calls to `apply()` and `update_state()`. The length of the outer list should be equal to `batch_size`.
|
508 |
+
Each inner list should be of length 1.
|
509 |
+
|
510 |
+
`apply()` takes two arguments:
|
511 |
+
|
512 |
+
- the constraint state, which is a nested list of dictionaries. The length of the outer list is `batch_size`
|
513 |
+
and the length of each inner list is `beam_size` except on the first time `apply()` is called when it is 1.
|
514 |
+
- `class_log_probabilities`, a tensor of shape `(batch_size, beam_size, num_classes)` that contains the
|
515 |
+
log probabilities for the classes during search. The first time `apply()` is called, `beam_size = 1`.
|
516 |
+
|
517 |
+
The `apply()` method should return new `class_log_probabilities` that enforce the constraint
|
518 |
+
for this step of beam search. For instance, it may prevent a specific class from being selected by setting
|
519 |
+
the corresponding log probability to a negligible value such as `float("-inf")` or
|
520 |
+
`torch.finfo(class_log_probabilities.dtype).min`.
|
521 |
+
|
522 |
+
`_update_state()` takes two arguments:
|
523 |
+
|
524 |
+
- the copied parent constraint state, which is a nested list of dictionaries. `state[i][j]` contains the
|
525 |
+
copied state for the parent of `last_prediction[i, j]`. It is unique to that batch and beam, so it can be
|
526 |
+
directly edited in-place without affecting the others.
|
527 |
+
- last_prediction, a tensor of shape `(batch_size, beam_size)` containing the predictions from the last
|
528 |
+
step of beam search.
|
529 |
+
|
530 |
+
The `_update_state()` function should return a new constraint state, a nested list of dictionaries of
|
531 |
+
length `batch_size` and inner list of length `beam_size`, one for each of the predictions in `last_prediction`.
|
532 |
+
|
533 |
+
"""
|
534 |
+
|
535 |
+
@abstractmethod
|
536 |
+
def init_state(
|
537 |
+
self,
|
538 |
+
batch_size: int,
|
539 |
+
) -> ConstraintStateType:
|
540 |
+
raise NotImplementedError
|
541 |
+
|
542 |
+
@abstractmethod
|
543 |
+
def apply(
|
544 |
+
self,
|
545 |
+
state: ConstraintStateType,
|
546 |
+
class_log_probabilities: torch.Tensor,
|
547 |
+
) -> torch.Tensor:
|
548 |
+
raise NotImplementedError
|
549 |
+
|
550 |
+
@staticmethod
|
551 |
+
def _copy_state(
|
552 |
+
state: ConstraintStateType,
|
553 |
+
batch_size: int,
|
554 |
+
beam_size: int,
|
555 |
+
last_backpointer: Optional[torch.Tensor] = None,
|
556 |
+
) -> ConstraintStateType:
|
557 |
+
"""
|
558 |
+
Copies the `state` . This method copies the data in `state` using `copy.deepcopy()`. If this
|
559 |
+
is not appropriate for your constraint, you will need to implement the copying yourself.
|
560 |
+
"""
|
561 |
+
new_state = []
|
562 |
+
for i in range(batch_size):
|
563 |
+
batch_state = []
|
564 |
+
for j in range(beam_size):
|
565 |
+
if last_backpointer is None:
|
566 |
+
# This is the first prediction, so the backpointer is 0
|
567 |
+
backpointer = 0
|
568 |
+
else:
|
569 |
+
backpointer = last_backpointer[i, j].item()
|
570 |
+
batch_state.append(copy.deepcopy(state[i][backpointer])) # type: ignore
|
571 |
+
new_state.append(batch_state)
|
572 |
+
return new_state
|
573 |
+
|
574 |
+
def update_state(
|
575 |
+
self,
|
576 |
+
state: ConstraintStateType,
|
577 |
+
last_prediction: torch.Tensor,
|
578 |
+
last_backpointer: Optional[torch.Tensor] = None,
|
579 |
+
) -> ConstraintStateType:
|
580 |
+
batch_size, beam_size = last_prediction.size()
|
581 |
+
new_state = self._copy_state(state, batch_size, beam_size, last_backpointer)
|
582 |
+
return self._update_state(new_state, last_prediction)
|
583 |
+
|
584 |
+
@abstractmethod
|
585 |
+
def _update_state(
|
586 |
+
self,
|
587 |
+
state: ConstraintStateType,
|
588 |
+
last_prediction: torch.Tensor,
|
589 |
+
) -> ConstraintStateType:
|
590 |
+
raise NotImplementedError
|
591 |
+
|
592 |
+
|
593 |
+
class RepeatedNGramBlockingConstraint(Constraint):
|
594 |
+
def __init__(self, ngram_size: int, **kwargs) -> None:
|
595 |
+
super().__init__(**kwargs)
|
596 |
+
self.ngram_size = ngram_size
|
597 |
+
|
598 |
+
def init_state(
|
599 |
+
self,
|
600 |
+
batch_size: int,
|
601 |
+
) -> ConstraintStateType:
|
602 |
+
return [[{"seen_ngrams": {}, "current_prefix": []}] for _ in range(batch_size)]
|
603 |
+
|
604 |
+
def apply(
|
605 |
+
self,
|
606 |
+
state: ConstraintStateType,
|
607 |
+
class_log_probabilities: torch.Tensor,
|
608 |
+
) -> torch.Tensor:
|
609 |
+
for i, batch in enumerate(state):
|
610 |
+
for j, beam in enumerate(batch):
|
611 |
+
current_prefix = tuple(beam["current_prefix"])
|
612 |
+
seen_ngrams = beam["seen_ngrams"]
|
613 |
+
try:
|
614 |
+
disallowed_indices = seen_ngrams[current_prefix]
|
615 |
+
class_log_probabilities[i, j, disallowed_indices] = torch.finfo(
|
616 |
+
class_log_probabilities.dtype
|
617 |
+
).min
|
618 |
+
except KeyError:
|
619 |
+
# We have not seen this prefix before, so there is no index
|
620 |
+
# that needs to be blocked
|
621 |
+
pass
|
622 |
+
return class_log_probabilities
|
623 |
+
|
624 |
+
def _update_state(
|
625 |
+
self,
|
626 |
+
state: ConstraintStateType,
|
627 |
+
last_prediction: torch.Tensor,
|
628 |
+
) -> ConstraintStateType:
|
629 |
+
for i, batch in enumerate(state):
|
630 |
+
for j, beam in enumerate(batch):
|
631 |
+
prediction = last_prediction[i, j].item()
|
632 |
+
prefix = beam["current_prefix"]
|
633 |
+
seen_ngrams = beam["seen_ngrams"]
|
634 |
+
|
635 |
+
if len(prefix) == self.ngram_size - 1:
|
636 |
+
# This is a new ngram that we have to remember
|
637 |
+
if tuple(prefix) not in seen_ngrams:
|
638 |
+
seen_ngrams[tuple(prefix)] = []
|
639 |
+
seen_ngrams[tuple(prefix)].append(prediction)
|
640 |
+
|
641 |
+
# Create the new prefix, removing the oldest index if the prefix
|
642 |
+
# is too long
|
643 |
+
prefix.append(prediction)
|
644 |
+
if len(prefix) == self.ngram_size:
|
645 |
+
prefix.pop(0)
|
646 |
+
return state
|
647 |
+
|
648 |
+
|
649 |
+
class BeamSearch:
|
650 |
+
"""
|
651 |
+
Implements the beam search algorithm for decoding the most likely sequences.
|
652 |
+
|
653 |
+
:param end_index: The index of the "stop" or "end" token in the vocabulary. Usually the EOS token ID.
|
654 |
+
|
655 |
+
:param max_steps: The maximum number of decoding steps to take, i.e. the maximum length
|
656 |
+
of the predicted sequences.
|
657 |
+
|
658 |
+
:param beam_size: The width of the beam used.
|
659 |
+
|
660 |
+
:param per_node_beam_size: The maximum number of candidates to consider per node, at each step in the search.
|
661 |
+
If not given, this just defaults to `beam_size`. Setting this parameter
|
662 |
+
to a number smaller than `beam_size` may give better results, as it can introduce
|
663 |
+
more diversity into the search. See
|
664 |
+
[*Beam Search Strategies for Neural Machine Translation*, Freitag and Al-Onaizan, 2017]
|
665 |
+
(https://api.semanticscholar.org/CorpusID:2229477).
|
666 |
+
|
667 |
+
:param sampler: An optional `Sampler` which is used to pick next candidate nodes and beams.
|
668 |
+
If not specified, `DeterministicSampler` will be used, which just takes the
|
669 |
+
`per_node_beam_size` most likely nodes and the `beam_size` most likely beams.
|
670 |
+
|
671 |
+
Using the [`GumbelSampler`](#gumbelsampler), on the other hand, will give you
|
672 |
+
[Stochastic Beam Search](https://api.semanticscholar.org/CorpusID:76662039).
|
673 |
+
|
674 |
+
:param min_steps: The minimum number of decoding steps to take, i.e. the minimum length of
|
675 |
+
the predicted sequences. This does not include the start or end tokens. If `None`,
|
676 |
+
no minimum is enforced.
|
677 |
+
|
678 |
+
:param final_sequence_scorer: An optional `FinalSequenceScorer` which is used to score the final generated sequences.
|
679 |
+
The output from this module is what is returned by the `search` method. If not
|
680 |
+
specified, `SequenceLogProbabilityScorer` will be used, which scores the sequences
|
681 |
+
by the sum of the token log probabilities.
|
682 |
+
|
683 |
+
:param constraints: An optional list of `Constraint`s which should be applied during beam search. If not
|
684 |
+
provided, no constraints will be enforced.
|
685 |
+
|
686 |
+
"""
|
687 |
+
|
688 |
+
def __init__(
|
689 |
+
self,
|
690 |
+
end_index: int,
|
691 |
+
*,
|
692 |
+
max_steps: int = 50,
|
693 |
+
beam_size: int = 10,
|
694 |
+
per_node_beam_size: Optional[int] = None,
|
695 |
+
sampler: Optional[Sampler] = None,
|
696 |
+
min_steps: Optional[int] = None,
|
697 |
+
final_sequence_scorer: Optional[FinalSequenceScorer] = None,
|
698 |
+
constraints: Optional[List[Constraint]] = None,
|
699 |
+
) -> None:
|
700 |
+
if not max_steps > 0:
|
701 |
+
raise ValueError("max_steps must be positive")
|
702 |
+
if not beam_size > 0:
|
703 |
+
raise ValueError("beam_size must be positive")
|
704 |
+
if per_node_beam_size is not None and not per_node_beam_size > 0:
|
705 |
+
raise ValueError("per_node_beam_size must be positive")
|
706 |
+
if min_steps is not None:
|
707 |
+
if not min_steps >= 0:
|
708 |
+
raise ValueError("min_steps must be non-negative")
|
709 |
+
if not min_steps <= max_steps:
|
710 |
+
raise ValueError("min_steps must be less than or equal to max_steps")
|
711 |
+
|
712 |
+
self._end_index = end_index
|
713 |
+
self.max_steps = max_steps
|
714 |
+
self.beam_size = beam_size
|
715 |
+
self.per_node_beam_size = per_node_beam_size or beam_size
|
716 |
+
self.sampler = sampler or DeterministicSampler()
|
717 |
+
self.min_steps = min_steps or 0
|
718 |
+
self.final_sequence_scorer = final_sequence_scorer or SequenceLogProbabilityScorer()
|
719 |
+
self.constraints = constraints or []
|
720 |
+
|
721 |
+
@staticmethod
|
722 |
+
def _reconstruct_sequences(predictions, backpointers):
|
723 |
+
# Reconstruct the sequences.
|
724 |
+
# shape: [(batch_size, beam_size, 1)]
|
725 |
+
reconstructed_predictions = [predictions[-1].unsqueeze(2)]
|
726 |
+
|
727 |
+
if not backpointers:
|
728 |
+
return reconstructed_predictions
|
729 |
+
|
730 |
+
# shape: (batch_size, beam_size)
|
731 |
+
cur_backpointers = backpointers[-1]
|
732 |
+
|
733 |
+
for timestep in range(len(predictions) - 2, 0, -1):
|
734 |
+
# shape: (batch_size, beam_size, 1)
|
735 |
+
cur_preds = predictions[timestep].gather(1, cur_backpointers).unsqueeze(2)
|
736 |
+
|
737 |
+
reconstructed_predictions.append(cur_preds)
|
738 |
+
|
739 |
+
# shape: (batch_size, beam_size)
|
740 |
+
cur_backpointers = backpointers[timestep - 1].gather(1, cur_backpointers)
|
741 |
+
|
742 |
+
# shape: (batch_size, beam_size, 1)
|
743 |
+
final_preds = predictions[0].gather(1, cur_backpointers).unsqueeze(2)
|
744 |
+
|
745 |
+
reconstructed_predictions.append(final_preds)
|
746 |
+
|
747 |
+
return reconstructed_predictions
|
748 |
+
|
749 |
+
def search(
|
750 |
+
self,
|
751 |
+
start_predictions: torch.Tensor,
|
752 |
+
start_state: StateType,
|
753 |
+
step: StepFunctionType,
|
754 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
755 |
+
"""
|
756 |
+
Given a starting state and a step function, apply beam search to find the
|
757 |
+
most likely target sequences.
|
758 |
+
|
759 |
+
Returns a tuple of `(predictions, final_scores)`, where `predictions`
|
760 |
+
has shape `(batch_size, beam_size, max_steps)` and `final_scores`
|
761 |
+
has shape `(batch_size, beam_size)`.
|
762 |
+
|
763 |
+
.. note::
|
764 |
+
If your step function returns `-inf` for some log probabilities
|
765 |
+
(like if you're using a masked log-softmax) then some of the "best"
|
766 |
+
sequences returned may also have `-inf` log probability. Specifically
|
767 |
+
this happens when the beam size is smaller than the number of actions
|
768 |
+
with finite log probability (non-zero probability) returned by the step function.
|
769 |
+
Therefore if you're using a mask you may want to check the results from `search`
|
770 |
+
and potentially discard sequences with non-finite log probability.
|
771 |
+
|
772 |
+
:param start_predictions: A tensor containing the initial predictions with shape `(batch_size,)`.
|
773 |
+
Usually the initial predictions are just the index of the "start" token
|
774 |
+
in the target vocabulary.
|
775 |
+
|
776 |
+
:param start_state: The initial state passed to the `step` function. Each value of the state dict
|
777 |
+
should be a tensor of shape `(batch_size, *)`, where `*` means any other
|
778 |
+
number of dimensions.
|
779 |
+
|
780 |
+
:param step: A function that is responsible for computing the next most likely tokens,
|
781 |
+
given the current state and the predictions from the last time step.
|
782 |
+
The function should accept two or three arguments:
|
783 |
+
|
784 |
+
- a tensor of shape `(group_size,)` or representing the index of the predicted
|
785 |
+
tokens from the last time step,
|
786 |
+
- the current state, a `StateType`, and
|
787 |
+
- optionally, the timestep, an `int`.
|
788 |
+
|
789 |
+
The `group_size` will be `batch_size * beam_size`, except in the initial
|
790 |
+
step, for which it will just be `batch_size`.
|
791 |
+
|
792 |
+
The function is expected to return a tuple, where the first element
|
793 |
+
is a tensor of shape `(group_size, vocab_size)` containing
|
794 |
+
the log probabilities of the tokens for the next step, and the second
|
795 |
+
element is the updated state. The tensor in the state should have shape
|
796 |
+
`(group_size, *)`, where `*` means any other number of dimensions.
|
797 |
+
|
798 |
+
"""
|
799 |
+
step_signature = signature(step)
|
800 |
+
if len(step_signature.parameters) < 3:
|
801 |
+
# If the step function we're given does not take the time step argument, wrap it
|
802 |
+
# in one that does.
|
803 |
+
old_step = cast(StepFunctionTypeNoTimestep, step)
|
804 |
+
|
805 |
+
def new_step(last_predictions: torch.Tensor, state: Dict[str, torch.Tensor], time_step: int):
|
806 |
+
del time_step
|
807 |
+
return old_step(last_predictions, state)
|
808 |
+
|
809 |
+
return self._search(start_predictions, start_state, new_step)
|
810 |
+
else:
|
811 |
+
return self._search(start_predictions, start_state, cast(StepFunctionTypeWithTimestep, step))
|
812 |
+
|
813 |
+
def _search(
|
814 |
+
self,
|
815 |
+
start_predictions: torch.Tensor,
|
816 |
+
start_state: StateType,
|
817 |
+
step: StepFunctionTypeWithTimestep,
|
818 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
819 |
+
batch_size = start_predictions.size()[0]
|
820 |
+
|
821 |
+
# List of (batch_size, beam_size) tensors. One for each time step. Does not
|
822 |
+
# include the start symbols, which are implicit.
|
823 |
+
predictions: List[torch.Tensor] = []
|
824 |
+
|
825 |
+
# List of (batch_size, beam_size) tensors. One for each time step. None for
|
826 |
+
# the first. Stores the index n for the parent prediction, i.e.
|
827 |
+
# predictions[t-1][i][n], that it came from.
|
828 |
+
backpointers: List[torch.Tensor] = []
|
829 |
+
|
830 |
+
constraint_states = [constraint.init_state(batch_size) for constraint in self.constraints]
|
831 |
+
|
832 |
+
# Calculate the first timestep. This is done outside the main loop
|
833 |
+
# because we are going from a single decoder input (the output from the
|
834 |
+
# encoder) to the top `beam_size` decoder outputs. On the other hand,
|
835 |
+
# within the main loop we are going from the `beam_size` elements of the
|
836 |
+
# beam to `beam_size`^2 candidates from which we will select the top
|
837 |
+
# `beam_size` elements for the next iteration.
|
838 |
+
# shape: (batch_size, num_classes)
|
839 |
+
start_class_log_probabilities, state = step(start_predictions, start_state, 0)
|
840 |
+
|
841 |
+
num_classes = start_class_log_probabilities.size()[1]
|
842 |
+
|
843 |
+
# Make sure `per_node_beam_size` is not larger than `num_classes`.
|
844 |
+
if self.per_node_beam_size > num_classes:
|
845 |
+
raise ValueError(
|
846 |
+
f"Vocab size ({num_classes:d}) too small "
|
847 |
+
f"relative to per_node_beam_size ({self.per_node_beam_size:d}).\n"
|
848 |
+
f"Please decrease beam_size or per_node_beam_size."
|
849 |
+
)
|
850 |
+
|
851 |
+
sampler_state = self.sampler.init_state(start_class_log_probabilities, batch_size, num_classes)
|
852 |
+
|
853 |
+
# Apply all constraints.
|
854 |
+
if self.constraints:
|
855 |
+
# shape: (batch_size, 1, num_classes)
|
856 |
+
expanded_start_class_log_probabilities = start_class_log_probabilities.unsqueeze(1)
|
857 |
+
for constraint, constraint_state in zip(self.constraints, constraint_states):
|
858 |
+
expanded_start_class_log_probabilities = constraint.apply(
|
859 |
+
constraint_state, expanded_start_class_log_probabilities
|
860 |
+
)
|
861 |
+
start_class_log_probabilities = expanded_start_class_log_probabilities.squeeze(1)
|
862 |
+
|
863 |
+
# Prevent selecting the end symbol if there is any min_steps constraint
|
864 |
+
if self.min_steps >= 1:
|
865 |
+
start_class_log_probabilities[:, self._end_index] = torch.finfo(
|
866 |
+
start_class_log_probabilities.dtype
|
867 |
+
).min
|
868 |
+
|
869 |
+
# Get the initial predicted classed and their log probabilities.
|
870 |
+
# shape: (batch_size, beam_size), (batch_size, beam_size)
|
871 |
+
(
|
872 |
+
start_top_log_probabilities,
|
873 |
+
start_predicted_classes,
|
874 |
+
sampler_state,
|
875 |
+
) = self.sampler.sample_beams(start_class_log_probabilities, self.beam_size, sampler_state)
|
876 |
+
|
877 |
+
if self.beam_size == 1 and (start_predicted_classes == self._end_index).all():
|
878 |
+
warnings.warn(
|
879 |
+
"Empty sequences predicted. You may want to increase the beam size or ensure "
|
880 |
+
"your step function is working properly.",
|
881 |
+
RuntimeWarning,
|
882 |
+
)
|
883 |
+
return start_predicted_classes.unsqueeze(-1), start_top_log_probabilities
|
884 |
+
|
885 |
+
# The log probabilities for the last time step.
|
886 |
+
# shape: (batch_size, beam_size)
|
887 |
+
last_log_probabilities = start_top_log_probabilities
|
888 |
+
|
889 |
+
# shape: [(batch_size, beam_size)]
|
890 |
+
predictions.append(start_predicted_classes)
|
891 |
+
|
892 |
+
# Log probability tensor that mandates that the end token is selected.
|
893 |
+
# shape: (batch_size * beam_size, num_classes)
|
894 |
+
log_probs_after_end = start_class_log_probabilities.new_full(
|
895 |
+
(batch_size * self.beam_size, num_classes),
|
896 |
+
torch.finfo(start_class_log_probabilities.dtype).min,
|
897 |
+
)
|
898 |
+
log_probs_after_end[:, self._end_index] = 0.0
|
899 |
+
|
900 |
+
# Set the same state for each element in the beam.
|
901 |
+
self._update_initial_state(state, batch_size)
|
902 |
+
|
903 |
+
for i, constraint in enumerate(self.constraints):
|
904 |
+
constraint_states[i] = constraint.update_state(constraint_states[i], start_predicted_classes)
|
905 |
+
|
906 |
+
for timestep in range(self.max_steps - 1):
|
907 |
+
# shape: (batch_size * beam_size,)
|
908 |
+
last_predictions = predictions[-1].reshape(batch_size * self.beam_size)
|
909 |
+
|
910 |
+
# If every predicted token from the last step is `self._end_index`,
|
911 |
+
# then we can stop early.
|
912 |
+
if (last_predictions == self._end_index).all():
|
913 |
+
break
|
914 |
+
# Take a step. This get the predicted log probs of the next classes
|
915 |
+
# and updates the state.
|
916 |
+
# shape: (batch_size * beam_size, num_classes)
|
917 |
+
class_log_probabilities, state = step(last_predictions, state, timestep + 1)
|
918 |
+
|
919 |
+
# Apply all constraints.
|
920 |
+
if self.constraints:
|
921 |
+
# shape: (batch_size, beam_size, num_classes)
|
922 |
+
reshaped_class_log_probabilities = class_log_probabilities.view(batch_size, self.beam_size, -1)
|
923 |
+
for constraint, constraint_state in zip(self.constraints, constraint_states):
|
924 |
+
reshaped_class_log_probabilities = constraint.apply(
|
925 |
+
constraint_state, reshaped_class_log_probabilities
|
926 |
+
)
|
927 |
+
# shape: (batch_size * beam_size, num_classes)
|
928 |
+
class_log_probabilities = reshaped_class_log_probabilities.view(batch_size * self.beam_size, -1)
|
929 |
+
|
930 |
+
# The `timestep`-th iteration of the for loop is generating the `timestep + 2`-th token
|
931 |
+
# of the sequence (because `timestep` is 0-indexed and we generated the first token
|
932 |
+
# before the for loop). Here we block the end index if the search is not allowed to
|
933 |
+
# terminate on this iteration.
|
934 |
+
if timestep + 2 <= self.min_steps:
|
935 |
+
class_log_probabilities[:, self._end_index] = torch.finfo(class_log_probabilities.dtype).min
|
936 |
+
|
937 |
+
# shape: (batch_size * beam_size, num_classes)
|
938 |
+
last_predictions_expanded = last_predictions.unsqueeze(-1).expand(
|
939 |
+
batch_size * self.beam_size, num_classes
|
940 |
+
)
|
941 |
+
|
942 |
+
# Here we are finding any beams where we predicted the end token in
|
943 |
+
# the previous timestep and replacing the distribution with a
|
944 |
+
# one-hot distribution, forcing the beam to predict the end token
|
945 |
+
# this timestep as well.
|
946 |
+
# shape: (batch_size * beam_size, num_classes)
|
947 |
+
cleaned_log_probabilities = torch.where(
|
948 |
+
last_predictions_expanded == self._end_index,
|
949 |
+
log_probs_after_end,
|
950 |
+
class_log_probabilities,
|
951 |
+
)
|
952 |
+
|
953 |
+
# shape (both): (batch_size * beam_size, per_node_beam_size)
|
954 |
+
top_log_probabilities, predicted_classes, sampler_state = self.sampler.sample_nodes(
|
955 |
+
cleaned_log_probabilities, self.per_node_beam_size, sampler_state
|
956 |
+
)
|
957 |
+
|
958 |
+
# Here we expand the last log probabilities to (batch_size * beam_size, per_node_beam_size)
|
959 |
+
# so that we can add them to the current log probs for this timestep.
|
960 |
+
# This lets us maintain the log probability of each element on the beam.
|
961 |
+
# shape: (batch_size * beam_size, per_node_beam_size)
|
962 |
+
expanded_last_log_probabilities = (
|
963 |
+
last_log_probabilities.unsqueeze(2)
|
964 |
+
.expand(batch_size, self.beam_size, self.per_node_beam_size)
|
965 |
+
.reshape(batch_size * self.beam_size, self.per_node_beam_size)
|
966 |
+
)
|
967 |
+
|
968 |
+
# shape: (batch_size * beam_size, per_node_beam_size)
|
969 |
+
summed_top_log_probabilities = top_log_probabilities + expanded_last_log_probabilities
|
970 |
+
|
971 |
+
# shape: (batch_size, beam_size * per_node_beam_size)
|
972 |
+
reshaped_summed = summed_top_log_probabilities.reshape(
|
973 |
+
batch_size, self.beam_size * self.per_node_beam_size
|
974 |
+
)
|
975 |
+
|
976 |
+
# shape: (batch_size, beam_size * per_node_beam_size)
|
977 |
+
reshaped_predicted_classes = predicted_classes.reshape(
|
978 |
+
batch_size, self.beam_size * self.per_node_beam_size
|
979 |
+
)
|
980 |
+
|
981 |
+
# Keep only the top `beam_size` beam indices.
|
982 |
+
# shape (both): (batch_size, beam_size)
|
983 |
+
(
|
984 |
+
restricted_beam_log_probs,
|
985 |
+
restricted_beam_indices,
|
986 |
+
sampler_state,
|
987 |
+
) = self.sampler.sample_beams(reshaped_summed, self.beam_size, sampler_state)
|
988 |
+
|
989 |
+
# Use the beam indices to extract the corresponding classes.
|
990 |
+
# shape: (batch_size, beam_size)
|
991 |
+
restricted_predicted_classes = reshaped_predicted_classes.gather(1, restricted_beam_indices)
|
992 |
+
|
993 |
+
predictions.append(restricted_predicted_classes)
|
994 |
+
|
995 |
+
# shape: (batch_size, beam_size)
|
996 |
+
last_log_probabilities = restricted_beam_log_probs
|
997 |
+
|
998 |
+
# The beam indices come from a `beam_size * per_node_beam_size` dimension where the
|
999 |
+
# indices with a common ancestor are grouped together. Hence
|
1000 |
+
# dividing by per_node_beam_size gives the ancestor. (Note that this is integer
|
1001 |
+
# division as the tensor is a LongTensor.)
|
1002 |
+
# shape: (batch_size, beam_size)
|
1003 |
+
backpointer = torch.divide(restricted_beam_indices, self.per_node_beam_size, rounding_mode="trunc")
|
1004 |
+
backpointers.append(backpointer)
|
1005 |
+
|
1006 |
+
# Keep only the pieces of the state tensors corresponding to the
|
1007 |
+
# ancestors created this iteration.
|
1008 |
+
self._update_state(state, backpointer)
|
1009 |
+
|
1010 |
+
for i, constraint in enumerate(self.constraints):
|
1011 |
+
constraint_states[i] = constraint.update_state(
|
1012 |
+
constraint_states[i], restricted_predicted_classes, last_backpointer=backpointer
|
1013 |
+
)
|
1014 |
+
|
1015 |
+
# Warn about "-inf" log probabilities if not using any constraints (negligible
|
1016 |
+
# log probabilities are expected when using constraints).
|
1017 |
+
if not self.constraints and (
|
1018 |
+
not torch.isfinite(last_log_probabilities).all()
|
1019 |
+
or (last_log_probabilities == torch.finfo(last_log_probabilities.dtype).min).any()
|
1020 |
+
):
|
1021 |
+
warnings.warn(
|
1022 |
+
"Negligible log probabilities encountered ('-inf' or equivalent). "
|
1023 |
+
"Some final sequences may not make sense. "
|
1024 |
+
"This can happen when the beam size is larger than the number of valid (non-zero "
|
1025 |
+
"probability) transitions that the step function produces.",
|
1026 |
+
RuntimeWarning,
|
1027 |
+
)
|
1028 |
+
|
1029 |
+
reconstructed_predictions = self._reconstruct_sequences(predictions, backpointers)
|
1030 |
+
|
1031 |
+
# shape: (batch_size, beam_size, max_steps)
|
1032 |
+
all_predictions = torch.cat(list(reversed(reconstructed_predictions)), 2)
|
1033 |
+
|
1034 |
+
# Calculate the final sequence scores
|
1035 |
+
# shape: (batch_size, beam_size)
|
1036 |
+
final_scores = self.final_sequence_scorer.score(all_predictions, last_log_probabilities, self._end_index)
|
1037 |
+
|
1038 |
+
# Sort the sequences based on the final scores so the best scoring
|
1039 |
+
# sequence is at index 0
|
1040 |
+
sorted_final_scores, sorted_indices = torch.sort(final_scores, dim=1, descending=True)
|
1041 |
+
sorted_all_predictions = torch.gather(
|
1042 |
+
all_predictions, 1, sorted_indices.unsqueeze(-1).expand_as(all_predictions)
|
1043 |
+
)
|
1044 |
+
|
1045 |
+
return sorted_all_predictions, sorted_final_scores
|
1046 |
+
|
1047 |
+
def _update_initial_state(self, state: StateType, batch_size: int):
|
1048 |
+
"""
|
1049 |
+
Expand tensors in a state dictionary from `(batch_size, *)` to `(batch_size * beam_size, *)`.
|
1050 |
+
"""
|
1051 |
+
for key, state_tensor in state.items():
|
1052 |
+
if state_tensor is None:
|
1053 |
+
continue
|
1054 |
+
# shape: (batch_size * beam_size, *)
|
1055 |
+
_, *last_dims = state_tensor.size()
|
1056 |
+
state[key] = (
|
1057 |
+
state_tensor.unsqueeze(1)
|
1058 |
+
.expand(batch_size, self.beam_size, *last_dims)
|
1059 |
+
.reshape(batch_size * self.beam_size, *last_dims)
|
1060 |
+
)
|
1061 |
+
|
1062 |
+
def _update_state(self, state: StateType, backpointer: torch.Tensor):
|
1063 |
+
batch_size = backpointer.size()[0]
|
1064 |
+
|
1065 |
+
for key, state_tensor in state.items():
|
1066 |
+
if state_tensor is None:
|
1067 |
+
continue
|
1068 |
+
_, *last_dims = state_tensor.size()
|
1069 |
+
# shape: (batch_size, beam_size, *)
|
1070 |
+
expanded_backpointer = backpointer.view(batch_size, self.beam_size, *([1] * len(last_dims))).expand(
|
1071 |
+
batch_size, self.beam_size, *last_dims
|
1072 |
+
)
|
1073 |
+
# shape: (batch_size * beam_size, *)
|
1074 |
+
state[key] = (
|
1075 |
+
state_tensor.reshape(batch_size, self.beam_size, *last_dims)
|
1076 |
+
.gather(1, expanded_backpointer)
|
1077 |
+
.reshape(batch_size * self.beam_size, *last_dims)
|
1078 |
+
)
|
checkpoint.py
ADDED
@@ -0,0 +1,2023 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gc
|
2 |
+
import io
|
3 |
+
import logging
|
4 |
+
import pickle
|
5 |
+
import shutil
|
6 |
+
import traceback
|
7 |
+
from abc import ABCMeta, abstractmethod
|
8 |
+
from collections import defaultdict
|
9 |
+
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor, as_completed
|
10 |
+
from contextlib import contextmanager
|
11 |
+
from copy import deepcopy
|
12 |
+
from dataclasses import dataclass, field, replace
|
13 |
+
from functools import reduce
|
14 |
+
from multiprocessing import shared_memory
|
15 |
+
from pathlib import Path
|
16 |
+
from typing import Any, Dict, Generator, List, Optional, Set, Tuple, cast
|
17 |
+
|
18 |
+
import numpy as np
|
19 |
+
import torch
|
20 |
+
import torch.distributed.checkpoint as dist_cp
|
21 |
+
import torch.multiprocessing as mp
|
22 |
+
import torch.nn as nn
|
23 |
+
from packaging import version
|
24 |
+
from torch.distributed import _remote_device
|
25 |
+
from torch.distributed._shard._utils import narrow_tensor_by_index
|
26 |
+
from torch.distributed._shard.metadata import ShardMetadata
|
27 |
+
from torch.distributed._shard.sharded_tensor import ShardedTensor
|
28 |
+
from torch.distributed.checkpoint.filesystem import WriteResult, _StorageInfo
|
29 |
+
from torch.distributed.checkpoint.metadata import Metadata, MetadataIndex
|
30 |
+
from torch.distributed.checkpoint.optimizer import load_sharded_optimizer_state_dict
|
31 |
+
from torch.distributed.checkpoint.planner import LoadItemType, ReadItem
|
32 |
+
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
33 |
+
from torch.distributed.fsdp import StateDictType
|
34 |
+
from torch.distributed.fsdp.api import (
|
35 |
+
FullOptimStateDictConfig,
|
36 |
+
FullStateDictConfig,
|
37 |
+
ShardedOptimStateDictConfig,
|
38 |
+
ShardedStateDictConfig,
|
39 |
+
)
|
40 |
+
from torch.futures import Future
|
41 |
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
42 |
+
|
43 |
+
try:
|
44 |
+
from torch.distributed.fsdp.flat_param import FlatParamHandle # type: ignore
|
45 |
+
except ModuleNotFoundError:
|
46 |
+
from torch.distributed.fsdp._flat_param import FlatParamHandle # type: ignore
|
47 |
+
|
48 |
+
from olmo import util
|
49 |
+
|
50 |
+
from .aliases import PathOrStr
|
51 |
+
from .config import BaseConfig, ShardedCheckpointerType, TrainConfig
|
52 |
+
from .exceptions import OLMoCheckpointError
|
53 |
+
from .optim import Optimizer, fix_optim_state_dict
|
54 |
+
from .safetensors_util import safetensors_file_to_state_dict
|
55 |
+
from .torch_util import (
|
56 |
+
barrier,
|
57 |
+
gc_cuda,
|
58 |
+
get_fs_local_rank,
|
59 |
+
get_global_rank,
|
60 |
+
get_local_rank,
|
61 |
+
get_local_world_size,
|
62 |
+
get_world_size,
|
63 |
+
)
|
64 |
+
from .util import (
|
65 |
+
_get_s3_client,
|
66 |
+
default_thread_count,
|
67 |
+
dir_is_empty,
|
68 |
+
get_bytes_range,
|
69 |
+
get_progress_bar,
|
70 |
+
resource_path,
|
71 |
+
upload,
|
72 |
+
wait_for,
|
73 |
+
)
|
74 |
+
|
75 |
+
__all__ = [
|
76 |
+
"save_fsdp_model_and_optim_state",
|
77 |
+
"load_fsdp_model_and_optim_state",
|
78 |
+
"load_fsdp_optim_state",
|
79 |
+
"save_state_dict",
|
80 |
+
"load_state_dict",
|
81 |
+
"load_model_state",
|
82 |
+
"RemoteFileSystemWriter",
|
83 |
+
"RemoteFileSystemReader",
|
84 |
+
"Checkpointer",
|
85 |
+
"FullCheckpointer",
|
86 |
+
"TorchNewStyleShardedCheckpointer",
|
87 |
+
"TorchLegacyShardedCheckpointer",
|
88 |
+
"LocalShardedCheckpointer",
|
89 |
+
"build_sharded_checkpointer",
|
90 |
+
]
|
91 |
+
|
92 |
+
|
93 |
+
log = logging.getLogger(__name__)
|
94 |
+
|
95 |
+
MODEL_AND_OPTIM_FOLDER = "model_and_optim"
|
96 |
+
|
97 |
+
|
98 |
+
def save_fsdp_model_and_optim_state(
|
99 |
+
checkpoint_dir: PathOrStr,
|
100 |
+
fsdp_model: FSDP,
|
101 |
+
optim: Optimizer,
|
102 |
+
*,
|
103 |
+
upload_to: Optional[str] = None,
|
104 |
+
save_overwrite: bool = False,
|
105 |
+
):
|
106 |
+
"""
|
107 |
+
Use this to save a state dict for an FSDP model and its optimizer via :module:`torch.distributed.checkpoint`
|
108 |
+
functions. This should be used during distributed training and should be called by all ranks.
|
109 |
+
|
110 |
+
:param checkpoint_dir: The directory to save to.
|
111 |
+
:param fsdp_model: The FSDP model.
|
112 |
+
:param optim: The FSDP model's optimizer.
|
113 |
+
:param upload_to: Optional, a remote "directory" to upload the checkpoint files to.
|
114 |
+
:param save_overwrite: Overwrite existing files.
|
115 |
+
|
116 |
+
:raises FileExistsError: If a model and optim checkpoint already exists in ``checkpoint_dir`` and ``save_overwrite=False``.
|
117 |
+
"""
|
118 |
+
checkpoint_dir = Path(checkpoint_dir)
|
119 |
+
target_dir = checkpoint_dir / MODEL_AND_OPTIM_FOLDER
|
120 |
+
if save_overwrite:
|
121 |
+
if get_fs_local_rank() == 0:
|
122 |
+
shutil.rmtree(target_dir, ignore_errors=True)
|
123 |
+
elif not dir_is_empty(target_dir):
|
124 |
+
raise FileExistsError(target_dir)
|
125 |
+
barrier()
|
126 |
+
if get_fs_local_rank() == 0:
|
127 |
+
target_dir.mkdir(exist_ok=True, parents=True)
|
128 |
+
barrier()
|
129 |
+
with FSDP.state_dict_type(
|
130 |
+
fsdp_model,
|
131 |
+
state_dict_type=StateDictType.SHARDED_STATE_DICT,
|
132 |
+
state_dict_config=ShardedStateDictConfig(offload_to_cpu=True),
|
133 |
+
optim_state_dict_config=ShardedOptimStateDictConfig(offload_to_cpu=True),
|
134 |
+
):
|
135 |
+
model_and_optim_state = {
|
136 |
+
"model": fsdp_model.state_dict(),
|
137 |
+
"optim": FSDP.optim_state_dict(fsdp_model, optim),
|
138 |
+
}
|
139 |
+
dist_cp.save_state_dict(
|
140 |
+
model_and_optim_state,
|
141 |
+
RemoteFileSystemWriter(
|
142 |
+
target_dir,
|
143 |
+
upload_to=None if upload_to is None else f"{upload_to.rstrip('/')}/{MODEL_AND_OPTIM_FOLDER}",
|
144 |
+
save_overwrite=save_overwrite,
|
145 |
+
),
|
146 |
+
)
|
147 |
+
|
148 |
+
|
149 |
+
def load_fsdp_model_and_optim_state(
|
150 |
+
checkpoint_dir: PathOrStr,
|
151 |
+
fsdp_model: FSDP,
|
152 |
+
optim: Optimizer,
|
153 |
+
*,
|
154 |
+
local_cache: Optional[PathOrStr] = None,
|
155 |
+
load_optimizer_state: bool = True,
|
156 |
+
):
|
157 |
+
"""
|
158 |
+
Use this to load a state dict for an FSDP model and its optimizer via :module:`torch.distributed.checkpoint`
|
159 |
+
functions. This should be used during distributed training and should be called by all ranks.
|
160 |
+
|
161 |
+
:param checkpoint_dir: The checkpoint directory to load from. This can be a local or remote directory.
|
162 |
+
:param fsdp_model: The FSDP model.
|
163 |
+
:param optim: The FSDP model's optimizer.
|
164 |
+
:param local_cache: A local cache of the checkpoint directory. Use this when the ``checkpoint_dir`` is a
|
165 |
+
remote "directory" but there might be a cached version of the same artifacts.
|
166 |
+
:param load_optimizer_state: Set to ``False`` to skip loading the optimizer state.
|
167 |
+
|
168 |
+
:raises FileNotFoundError: If the ``checkpoint_dir`` doesn't contain a model and optimizer checkpoint.
|
169 |
+
"""
|
170 |
+
load_path = str(checkpoint_dir).rstrip("/")
|
171 |
+
local_cache = None if local_cache is None else Path(local_cache)
|
172 |
+
with FSDP.state_dict_type(
|
173 |
+
fsdp_model,
|
174 |
+
state_dict_type=StateDictType.SHARDED_STATE_DICT,
|
175 |
+
state_dict_config=ShardedStateDictConfig(offload_to_cpu=True),
|
176 |
+
optim_state_dict_config=ShardedOptimStateDictConfig(offload_to_cpu=True),
|
177 |
+
):
|
178 |
+
# Load the model state dict in place.
|
179 |
+
log.info("Loading model state...")
|
180 |
+
model_state = {"model": fsdp_model.state_dict()}
|
181 |
+
dist_cp.load_state_dict(
|
182 |
+
model_state,
|
183 |
+
RemoteFileSystemReader(
|
184 |
+
f"{load_path}/{MODEL_AND_OPTIM_FOLDER}",
|
185 |
+
local_cache=None if local_cache is None else local_cache / MODEL_AND_OPTIM_FOLDER,
|
186 |
+
),
|
187 |
+
)
|
188 |
+
fsdp_model.load_state_dict(model_state["model"])
|
189 |
+
|
190 |
+
if not load_optimizer_state:
|
191 |
+
return
|
192 |
+
|
193 |
+
# Load optim state dict in place.
|
194 |
+
log.info("Loading sharded optimizer state...")
|
195 |
+
optim_state = load_sharded_optimizer_state_dict(
|
196 |
+
model_state_dict=model_state["model"],
|
197 |
+
optimizer_key="optim",
|
198 |
+
storage_reader=RemoteFileSystemReader(
|
199 |
+
f"{load_path}/{MODEL_AND_OPTIM_FOLDER}",
|
200 |
+
local_cache=None if local_cache is None else local_cache / MODEL_AND_OPTIM_FOLDER,
|
201 |
+
),
|
202 |
+
)
|
203 |
+
# optim_state["optim"] = {
|
204 |
+
# 'state': { fqn: { 'grad_norm_exp_avg': Tensor, 'step': Tensor, 'exp_avg': ShardedTensor, 'exp_avg_sq': ShardedTensor } },
|
205 |
+
# 'param_groups': [{ 'param_names': [ fsdp_fqn, ... ], 'params': [ fqn, ... ], ... }],
|
206 |
+
# }
|
207 |
+
del model_state
|
208 |
+
|
209 |
+
# Make sure tensors are on CPU! PyTorch puts them on GPU even though we have `offload_to_cpu=True`.
|
210 |
+
for state in optim_state["optim"]["state"].values():
|
211 |
+
for k in state.keys():
|
212 |
+
state[k] = state[k].cpu()
|
213 |
+
gc_cuda()
|
214 |
+
|
215 |
+
load_fsdp_optim_state(fsdp_model, optim, optim_state["optim"])
|
216 |
+
|
217 |
+
|
218 |
+
def load_fsdp_optim_state(fsdp_model: FSDP, optim: Optimizer, optim_state: Dict[str, Any]):
|
219 |
+
log.info("Flattening sharded optimizer state...")
|
220 |
+
# flattened_osd = {
|
221 |
+
# 'state': { id: { 'grad_norm_exp_avg': Tensor, 'step': Tensor, 'exp_avg': Tensor, 'exp_avg_sq': Tensor } },
|
222 |
+
# 'param_groups': [{ 'param_names': [ fsdp_fqn, ... ], 'params': [ id, ... ], ... }],
|
223 |
+
# }
|
224 |
+
# NOTE: Careful! The order of the these arguments has changed from 2.0 to 2.1... ¯\_(ツ)_/¯
|
225 |
+
if version.parse(torch.__version__) < version.parse("2.1.0"):
|
226 |
+
flattened_osd = FSDP.optim_state_dict_to_load(optim_state, fsdp_model, optim) # type: ignore
|
227 |
+
else:
|
228 |
+
flattened_osd = FSDP.optim_state_dict_to_load(fsdp_model, optim, optim_state) # type: ignore
|
229 |
+
|
230 |
+
del optim_state
|
231 |
+
gc_cuda()
|
232 |
+
|
233 |
+
log.info("Loading flattened optimizer state...")
|
234 |
+
|
235 |
+
# Put optim state on CPU since `Optimizer.load_state_dict()` will create a deepcopy of the whole state dict,
|
236 |
+
# which takes up unnecessary GPU memory.
|
237 |
+
for state in flattened_osd["state"].values():
|
238 |
+
for k in state.keys():
|
239 |
+
state[k] = state[k].cpu()
|
240 |
+
gc_cuda()
|
241 |
+
|
242 |
+
optim.load_state_dict(fix_optim_state_dict(optim, flattened_osd))
|
243 |
+
|
244 |
+
|
245 |
+
def save_state_dict(
|
246 |
+
checkpoint_dir: PathOrStr,
|
247 |
+
fname: str,
|
248 |
+
state_dict: Dict[str, Any],
|
249 |
+
*,
|
250 |
+
upload_to: Optional[str] = None,
|
251 |
+
save_overwrite: bool = False,
|
252 |
+
synchronize: bool = True,
|
253 |
+
):
|
254 |
+
"""
|
255 |
+
Save a regular state dict to the file ``fname`` within ``checkpoint_dir`` using :func:`torch.save()`.
|
256 |
+
This can be used during distributed training or not. If during distributed training the ``fname`` should be unique
|
257 |
+
for each rank.
|
258 |
+
|
259 |
+
:param checkpoint_dir: The directory to save to.
|
260 |
+
:param fname: The target file within ``checkpoint_dir`` to save to. This should be a path relative to the ``checkpoint_dir``.
|
261 |
+
:param state_dict: The state dict to save.
|
262 |
+
:param upload_to: Optional, a remote "directory" to upload the file to.
|
263 |
+
:param save_overwrite: Overwrite existing files.
|
264 |
+
:param synchronize: If ``False``, don't do any distributed synchronization. Use this when only calling
|
265 |
+
this function from a single rank.
|
266 |
+
|
267 |
+
:raises FileExistsError: If the ``fname`` already exists within ``checkpoint_dir`` and ``save_overwrite=False``.
|
268 |
+
"""
|
269 |
+
checkpoint_dir = Path(checkpoint_dir)
|
270 |
+
target_path = checkpoint_dir / fname
|
271 |
+
if save_overwrite:
|
272 |
+
target_path.unlink(missing_ok=True)
|
273 |
+
elif target_path.is_file():
|
274 |
+
raise FileExistsError(target_path)
|
275 |
+
if synchronize:
|
276 |
+
barrier()
|
277 |
+
target_path.parent.mkdir(exist_ok=True, parents=True)
|
278 |
+
if synchronize:
|
279 |
+
barrier()
|
280 |
+
torch.save(state_dict, target_path)
|
281 |
+
if upload_to is not None:
|
282 |
+
upload_target = f"{upload_to.rstrip('/')}/{fname}"
|
283 |
+
log.info(f"Uploading {target_path} to {upload_target}...")
|
284 |
+
upload(target_path, upload_target, save_overwrite=save_overwrite)
|
285 |
+
|
286 |
+
|
287 |
+
def load_state_dict(
|
288 |
+
checkpoint_dir: PathOrStr,
|
289 |
+
fname: str,
|
290 |
+
*,
|
291 |
+
local_cache: Optional[PathOrStr] = None,
|
292 |
+
map_location: Optional[str] = None,
|
293 |
+
):
|
294 |
+
"""
|
295 |
+
Load a regular state dict from the file ``fname`` within ``checkpoint_dir`` using :func:`torch.load()`.
|
296 |
+
This can be used during distributed training or not.
|
297 |
+
|
298 |
+
:param checkpoint_dir: A local or remote checkpoint directory.
|
299 |
+
:param fname: The target file within the ``checkpoint_dir``. This should be a path relative to the ``checkpoint_dir``.
|
300 |
+
:param local_cache: A local cache of the checkpoint directory. Use this when the ``checkpoint_dir`` is a
|
301 |
+
remote "directory" but there might be a cached version of the same artifacts.
|
302 |
+
|
303 |
+
:raises FileNotFoundError: If ``fname`` doesn't exist in the ``checkpoint_dir`` or the local cache.
|
304 |
+
"""
|
305 |
+
if fname.endswith(".pt"):
|
306 |
+
# Try safetensors version first.
|
307 |
+
try:
|
308 |
+
path = resource_path(
|
309 |
+
str(checkpoint_dir).rstrip("/"), fname[:-2] + "safetensors", local_cache=local_cache
|
310 |
+
)
|
311 |
+
return safetensors_file_to_state_dict(path, map_location=map_location)
|
312 |
+
except FileNotFoundError:
|
313 |
+
pass
|
314 |
+
|
315 |
+
path = resource_path(str(checkpoint_dir).rstrip("/"), fname, local_cache=local_cache)
|
316 |
+
return torch.load(path, map_location=map_location)
|
317 |
+
|
318 |
+
|
319 |
+
def load_model_state(checkpoint_dir: PathOrStr, model: torch.nn.Module):
|
320 |
+
"""
|
321 |
+
Load model state from a distributed FSDP model checkpoint created from :func:`save_fsdp_model_and_optim_state()`.
|
322 |
+
Note that ``model`` should not be wrapped with FSDP.
|
323 |
+
"""
|
324 |
+
state_dict = {"model": model.state_dict()}
|
325 |
+
dist_cp.load_state_dict(
|
326 |
+
state_dict,
|
327 |
+
RemoteFileSystemReader(f"{str(checkpoint_dir).rstrip('/')}/{MODEL_AND_OPTIM_FOLDER}"),
|
328 |
+
no_dist=True,
|
329 |
+
)
|
330 |
+
model.load_state_dict(state_dict["model"])
|
331 |
+
|
332 |
+
|
333 |
+
class RemoteFileSystemWriter(dist_cp.FileSystemWriter):
|
334 |
+
"""
|
335 |
+
A subclass of :class:`~torch.distributed.checkpoint.FileSystemWriter` that can upload files
|
336 |
+
directly to a cloud bucket when ``upload_to`` is specified.
|
337 |
+
"""
|
338 |
+
|
339 |
+
def __init__(
|
340 |
+
self,
|
341 |
+
path: PathOrStr,
|
342 |
+
single_file_per_rank: bool = True,
|
343 |
+
sync_files: bool = True,
|
344 |
+
thread_count: Optional[int] = None,
|
345 |
+
per_thread_copy_ahead: int = 10_000_000,
|
346 |
+
upload_to: Optional[str] = None,
|
347 |
+
save_overwrite: bool = False,
|
348 |
+
) -> None:
|
349 |
+
if thread_count is not None and thread_count <= 0:
|
350 |
+
raise ValueError("thread count must be at least 1")
|
351 |
+
super().__init__(
|
352 |
+
path,
|
353 |
+
single_file_per_rank=single_file_per_rank,
|
354 |
+
sync_files=sync_files,
|
355 |
+
# NOTE: we default to 1 thread here instead of whatever `default_thread_count()`
|
356 |
+
# returns because uploading big checkpoint files with multiple threads causes
|
357 |
+
# boto3 to fail in weird ways.
|
358 |
+
thread_count=thread_count or 1,
|
359 |
+
per_thread_copy_ahead=per_thread_copy_ahead,
|
360 |
+
)
|
361 |
+
self.upload_to = None if upload_to is None else upload_to.rstrip("/")
|
362 |
+
self.save_overwrite = save_overwrite
|
363 |
+
|
364 |
+
def write_data(
|
365 |
+
self,
|
366 |
+
plan: dist_cp.SavePlan,
|
367 |
+
planner: dist_cp.SavePlanner,
|
368 |
+
) -> Future[List[WriteResult]]:
|
369 |
+
fut = super().write_data(plan, planner)
|
370 |
+
if self.upload_to is not None:
|
371 |
+
files_to_upload = set()
|
372 |
+
for write_result in fut.wait():
|
373 |
+
files_to_upload.add(write_result.storage_data.relative_path)
|
374 |
+
|
375 |
+
# Create the global S3 client up front to work around a threading issue in boto.
|
376 |
+
if self.upload_to.startswith("s3://"):
|
377 |
+
_get_s3_client("s3")
|
378 |
+
elif self.upload_to.startswith("r2://"):
|
379 |
+
_get_s3_client("r2")
|
380 |
+
elif self.upload_to.startswith("weka://"):
|
381 |
+
_get_s3_client("weka")
|
382 |
+
|
383 |
+
with ThreadPoolExecutor(max_workers=self.thread_count) as executor:
|
384 |
+
futures = []
|
385 |
+
for fname in files_to_upload:
|
386 |
+
source = self.path / fname
|
387 |
+
target = f"{self.upload_to}/{fname}"
|
388 |
+
log.info(f"Uploading {source} to {target}...")
|
389 |
+
futures.append(executor.submit(upload, source, target, save_overwrite=self.save_overwrite))
|
390 |
+
for f in as_completed(futures):
|
391 |
+
try:
|
392 |
+
f.result()
|
393 |
+
except BaseException:
|
394 |
+
# NOTE: we might get an error here that can't be pickled, which causes a different failure
|
395 |
+
# later when PyTorch tries to reduce that error across ranks. So here we just make
|
396 |
+
# sure we're raising a simple error type that can be pickled.
|
397 |
+
raise OLMoCheckpointError(f"Original error:\n{traceback.format_exc()}")
|
398 |
+
return fut
|
399 |
+
|
400 |
+
def finish(self, metadata: Metadata, results: List[List[WriteResult]]) -> None:
|
401 |
+
super().finish(metadata, results)
|
402 |
+
if self.upload_to is not None:
|
403 |
+
source = self.path / ".metadata"
|
404 |
+
target = f"{self.upload_to}/.metadata"
|
405 |
+
log.info(f"Uploading {source} to {target}...")
|
406 |
+
upload(source, target, save_overwrite=self.save_overwrite)
|
407 |
+
|
408 |
+
|
409 |
+
class RemoteFileSystemReader(dist_cp.StorageReader):
|
410 |
+
"""
|
411 |
+
A :class:`~torch.distributed.checkpoint.StorageReader` based on :class:`~torch.distributed.checkpoint.FileSystemReader`
|
412 |
+
that can read data directly from cloud storage as well as a local directory.
|
413 |
+
"""
|
414 |
+
|
415 |
+
def __init__(
|
416 |
+
self, path: PathOrStr, *, local_cache: Optional[PathOrStr] = None, thread_count: Optional[int] = None
|
417 |
+
):
|
418 |
+
super().__init__()
|
419 |
+
if thread_count is not None and thread_count <= 0:
|
420 |
+
raise ValueError("thread count must be at least 1")
|
421 |
+
self.path = str(path).rstrip("/")
|
422 |
+
self.cache = None if local_cache is None else Path(local_cache)
|
423 |
+
self.thread_count = thread_count or default_thread_count()
|
424 |
+
self.storage_data: Dict[MetadataIndex, _StorageInfo] = dict()
|
425 |
+
self._metadata: Optional[Metadata] = None
|
426 |
+
|
427 |
+
def _get_bytes(self, relative_path: str, offset: int, length: int) -> bytes:
|
428 |
+
if self.cache is not None and (path := self.cache / relative_path).is_file():
|
429 |
+
return get_bytes_range(path, offset, length)
|
430 |
+
else:
|
431 |
+
return get_bytes_range(f"{self.path}/{relative_path}", offset, length)
|
432 |
+
|
433 |
+
def _get_content_for_read(self, read_item: ReadItem) -> Tuple[ReadItem, bytes]:
|
434 |
+
sinfo = self.storage_data[read_item.storage_index]
|
435 |
+
content = self._get_bytes(sinfo.relative_path, sinfo.offset, sinfo.length)
|
436 |
+
return (read_item, content)
|
437 |
+
|
438 |
+
def read_data(self, plan: dist_cp.LoadPlan, planner: dist_cp.LoadPlanner) -> Future[None]:
|
439 |
+
# Create the global S3 client up front to work around a threading issue in boto.
|
440 |
+
if isinstance(self.path, str):
|
441 |
+
if self.path.startswith("s3://"):
|
442 |
+
_get_s3_client("s3")
|
443 |
+
elif self.path.startswith("r2://"):
|
444 |
+
_get_s3_client("r2")
|
445 |
+
elif self.path.startswith("weka://"):
|
446 |
+
_get_s3_client("weka")
|
447 |
+
|
448 |
+
with ThreadPoolExecutor(max_workers=self.thread_count) as executor:
|
449 |
+
read_item_content_futures = []
|
450 |
+
for read_item in plan.items:
|
451 |
+
read_item_content_futures.append(executor.submit(self._get_content_for_read, read_item))
|
452 |
+
read_item_content_results = []
|
453 |
+
for f in as_completed(read_item_content_futures):
|
454 |
+
try:
|
455 |
+
read_item_content_results.append(f.result())
|
456 |
+
except BaseException:
|
457 |
+
# NOTE: we might get an error here that can't be pickled, which causes a different failure
|
458 |
+
# later when PyTorch tries to reduce that error across ranks. So here we just make
|
459 |
+
# sure we're raising a simple error type that can be pickled.
|
460 |
+
raise OLMoCheckpointError(f"Original error:\n{traceback.format_exc()}")
|
461 |
+
|
462 |
+
# Modified from `FileSystemReader.read_data()`
|
463 |
+
for read_item, content in read_item_content_results:
|
464 |
+
bytes = io.BytesIO(content)
|
465 |
+
bytes.seek(0)
|
466 |
+
if read_item.type == LoadItemType.BYTE_IO:
|
467 |
+
planner.load_bytes(read_item, bytes)
|
468 |
+
else:
|
469 |
+
tensor = cast(torch.Tensor, torch.load(bytes, map_location="cpu"))
|
470 |
+
tensor = narrow_tensor_by_index(tensor, read_item.storage_offsets, read_item.lengths)
|
471 |
+
target_tensor = planner.resolve_tensor(read_item).detach()
|
472 |
+
|
473 |
+
assert (
|
474 |
+
target_tensor.size() == tensor.size()
|
475 |
+
), f"req {read_item.storage_index} mismatch sizes {target_tensor.size()} vs {tensor.size()}"
|
476 |
+
target_tensor.copy_(tensor)
|
477 |
+
planner.commit_tensor(read_item, target_tensor)
|
478 |
+
|
479 |
+
fut: Future = Future()
|
480 |
+
fut.set_result(None)
|
481 |
+
return fut
|
482 |
+
|
483 |
+
def read_metadata(self) -> Metadata:
|
484 |
+
if self._metadata is None:
|
485 |
+
with resource_path(self.path, ".metadata", local_cache=self.cache).open("rb") as metadata_file:
|
486 |
+
self._metadata = pickle.load(metadata_file)
|
487 |
+
return self._metadata
|
488 |
+
|
489 |
+
def set_up_storage_reader(self, metadata: Metadata, is_coordinator: bool) -> None:
|
490 |
+
del is_coordinator
|
491 |
+
self.storage_data = metadata.storage_data
|
492 |
+
assert self.storage_data is not None
|
493 |
+
|
494 |
+
def prepare_local_plan(self, plan: dist_cp.LoadPlan) -> dist_cp.LoadPlan:
|
495 |
+
return plan
|
496 |
+
|
497 |
+
def prepare_global_plan(self, global_plan: List[dist_cp.LoadPlan]) -> List[dist_cp.LoadPlan]:
|
498 |
+
return global_plan
|
499 |
+
|
500 |
+
|
501 |
+
class Checkpointer(metaclass=ABCMeta):
|
502 |
+
def __init__(self, cfg: TrainConfig, thread_count: Optional[int] = None):
|
503 |
+
self.cfg = cfg
|
504 |
+
self.thread_count = thread_count or default_thread_count()
|
505 |
+
|
506 |
+
@abstractmethod
|
507 |
+
def save_checkpoint(
|
508 |
+
self,
|
509 |
+
dir: PathOrStr,
|
510 |
+
dist_model: nn.Module,
|
511 |
+
optim: Optimizer,
|
512 |
+
train_state: Dict[str, Any],
|
513 |
+
*,
|
514 |
+
upload_to: Optional[str] = None,
|
515 |
+
) -> None:
|
516 |
+
raise NotImplementedError
|
517 |
+
|
518 |
+
@abstractmethod
|
519 |
+
def restore_checkpoint(
|
520 |
+
self,
|
521 |
+
load_path: PathOrStr,
|
522 |
+
dist_model: nn.Module,
|
523 |
+
optim: Optimizer,
|
524 |
+
*,
|
525 |
+
local_cache: Optional[PathOrStr] = None,
|
526 |
+
load_optimizer_state: bool = True,
|
527 |
+
) -> Dict[str, Any]:
|
528 |
+
"""
|
529 |
+
Restores a checkpoint to the model and optimizer. Returns the remaining trainer state.
|
530 |
+
"""
|
531 |
+
raise NotImplementedError
|
532 |
+
|
533 |
+
def unshard_checkpoint(
|
534 |
+
self,
|
535 |
+
load_path: PathOrStr,
|
536 |
+
*,
|
537 |
+
local_cache: Optional[PathOrStr] = None,
|
538 |
+
load_optimizer_state: bool = True,
|
539 |
+
load_trainer_state: bool = True,
|
540 |
+
device: Optional[torch.device] = None,
|
541 |
+
) -> Tuple[Dict[str, torch.Tensor], Optional[Dict[str, Any]], Optional[Dict[str, Any]]]:
|
542 |
+
"""
|
543 |
+
Unshard a checkpoint.
|
544 |
+
|
545 |
+
Note this is not marked abstract because child classes are not required to implemented this.
|
546 |
+
"""
|
547 |
+
raise NotImplementedError
|
548 |
+
|
549 |
+
@contextmanager
|
550 |
+
def _temporary_wd(self, dir: PathOrStr) -> Generator[Path, None, None]:
|
551 |
+
# Make sure checkpoint directory doesn't exist unless it's okay to overwrite it.
|
552 |
+
checkpoint_dir = Path(dir)
|
553 |
+
if not dir_is_empty(checkpoint_dir):
|
554 |
+
if self.cfg.save_overwrite:
|
555 |
+
if get_fs_local_rank() == 0:
|
556 |
+
shutil.rmtree(checkpoint_dir, ignore_errors=True)
|
557 |
+
else:
|
558 |
+
raise FileExistsError(checkpoint_dir)
|
559 |
+
# No need to mkdir here since we'll directly replace the temporary directory with
|
560 |
+
# this directory below.
|
561 |
+
barrier()
|
562 |
+
|
563 |
+
# Prepare temporary directory. We don't have to be as careful here, we can
|
564 |
+
# just remove it if it already exists.
|
565 |
+
checkpoint_dir_tmp = checkpoint_dir.with_name(checkpoint_dir.name + "-tmp")
|
566 |
+
if get_fs_local_rank() == 0:
|
567 |
+
shutil.rmtree(checkpoint_dir_tmp, ignore_errors=True)
|
568 |
+
checkpoint_dir_tmp.mkdir(exist_ok=True, parents=True)
|
569 |
+
|
570 |
+
# In the cases where we're using a shared NFS drive between ranks to save checkpoints,
|
571 |
+
# creating the temp directory from rank 0 might not be immediately
|
572 |
+
# realized in the file systems of the other ranks.
|
573 |
+
# So we wait here across all ranks until that tmp checkpoint directory is visible.
|
574 |
+
wait_for(lambda: checkpoint_dir_tmp.exists(), "Waiting for checkpoint directory", timeout=10.0)
|
575 |
+
|
576 |
+
barrier()
|
577 |
+
|
578 |
+
# Yield temporary directory for `.save_checkpoint()` to use.
|
579 |
+
yield checkpoint_dir_tmp
|
580 |
+
|
581 |
+
barrier()
|
582 |
+
|
583 |
+
# Finally if all went well replace the temporary directory with the actual
|
584 |
+
# checkpoint directory.
|
585 |
+
if get_fs_local_rank() == 0:
|
586 |
+
# Replace temp directory with target checkpoint directory.
|
587 |
+
try:
|
588 |
+
checkpoint_dir_tmp.replace(checkpoint_dir)
|
589 |
+
except FileNotFoundError:
|
590 |
+
# Caught when another (file-system) local rank 0 has already replaced the tmp directory.
|
591 |
+
# This can happen when nodes are saving to a common NFS drive but otherwise have distinct
|
592 |
+
# file-systems.
|
593 |
+
if not checkpoint_dir.exists():
|
594 |
+
raise
|
595 |
+
|
596 |
+
# In the cases where we're using a shared NFS drive between ranks to save checkpoints,
|
597 |
+
# replacing the temp directory with the final directory from rank 0 might not be immediately
|
598 |
+
# realized in the file systems of the other ranks.
|
599 |
+
# So we wait here across all ranks until that final checkpoint directory is visible.
|
600 |
+
wait_for(lambda: checkpoint_dir.exists(), "Waiting for checkpoint directory", timeout=10.0)
|
601 |
+
|
602 |
+
barrier()
|
603 |
+
|
604 |
+
def _save_config(self, dir: PathOrStr, *, upload_to: Optional[str] = None) -> None:
|
605 |
+
if get_global_rank() == 0:
|
606 |
+
log.info("Saving config...")
|
607 |
+
self.cfg.save(config_path := Path(dir) / "config.yaml")
|
608 |
+
if upload_to is not None:
|
609 |
+
upload_target = f"{upload_to}/config.yaml"
|
610 |
+
log.info(f"Uploading {config_path} to {upload_target}")
|
611 |
+
upload(config_path, upload_target, save_overwrite=self.cfg.save_overwrite)
|
612 |
+
|
613 |
+
|
614 |
+
class FullCheckpointer(Checkpointer):
|
615 |
+
"""
|
616 |
+
A :class:`Checkpointer` that saves a single full model and optimizer state dictionary.
|
617 |
+
"""
|
618 |
+
|
619 |
+
def save_checkpoint(
|
620 |
+
self,
|
621 |
+
dir: PathOrStr,
|
622 |
+
dist_model: nn.Module,
|
623 |
+
optim: Optimizer,
|
624 |
+
trainer_state: Dict[str, Any],
|
625 |
+
*,
|
626 |
+
upload_to: Optional[str] = None,
|
627 |
+
) -> None:
|
628 |
+
with self._temporary_wd(dir) as checkpoint_dir:
|
629 |
+
if isinstance(dist_model, FSDP):
|
630 |
+
with FSDP.state_dict_type(
|
631 |
+
dist_model,
|
632 |
+
state_dict_type=StateDictType.FULL_STATE_DICT,
|
633 |
+
state_dict_config=FullStateDictConfig(rank0_only=True, offload_to_cpu=True),
|
634 |
+
optim_state_dict_config=FullOptimStateDictConfig(rank0_only=True, offload_to_cpu=True),
|
635 |
+
):
|
636 |
+
# We'll write the model and optimizer state dicts individually to reduce (CPU) memory consumption.
|
637 |
+
# First the model state.
|
638 |
+
model_state_dict = dist_model.state_dict()
|
639 |
+
self._write_model_dict(
|
640 |
+
model_state_dict, checkpoint_dir, upload_to, save_overwrite=self.cfg.save_overwrite
|
641 |
+
)
|
642 |
+
|
643 |
+
# Then the optimizer state.
|
644 |
+
optim_state_dict = FSDP.optim_state_dict(dist_model, optim)
|
645 |
+
self._write_optim_dict(
|
646 |
+
optim_state_dict, checkpoint_dir, upload_to, save_overwrite=self.cfg.save_overwrite
|
647 |
+
)
|
648 |
+
elif isinstance(dist_model, DDP):
|
649 |
+
# _write_model_dict and _write_optim_dict only write checkpoints for rank 0
|
650 |
+
# First, get the model state dict from DDP wrapped model
|
651 |
+
model_state_dict = dist_model.module.state_dict()
|
652 |
+
self._write_model_dict(
|
653 |
+
model_state_dict, checkpoint_dir, upload_to, save_overwrite=self.cfg.save_overwrite
|
654 |
+
)
|
655 |
+
|
656 |
+
# Then get the optimizer state dict
|
657 |
+
optim_state_dict = optim.state_dict()
|
658 |
+
self._write_optim_dict(
|
659 |
+
optim_state_dict, checkpoint_dir, upload_to, save_overwrite=self.cfg.save_overwrite
|
660 |
+
)
|
661 |
+
else:
|
662 |
+
log.info(
|
663 |
+
"`FullCheckpointer.save_checkpoint` only supported for FSDP and DDP distributed strategies!"
|
664 |
+
)
|
665 |
+
|
666 |
+
# Save trainer state.
|
667 |
+
if get_global_rank() == 0:
|
668 |
+
log.info("Saving trainer state...")
|
669 |
+
save_state_dict(
|
670 |
+
checkpoint_dir,
|
671 |
+
"train.pt",
|
672 |
+
trainer_state,
|
673 |
+
upload_to=upload_to,
|
674 |
+
save_overwrite=self.cfg.save_overwrite,
|
675 |
+
synchronize=False,
|
676 |
+
)
|
677 |
+
# Save config.
|
678 |
+
self._save_config(checkpoint_dir, upload_to=upload_to)
|
679 |
+
|
680 |
+
def restore_checkpoint(
|
681 |
+
self,
|
682 |
+
load_path: PathOrStr,
|
683 |
+
dist_model: nn.Module,
|
684 |
+
optim: Optimizer,
|
685 |
+
*,
|
686 |
+
local_cache: Optional[PathOrStr] = None,
|
687 |
+
load_optimizer_state: bool = True,
|
688 |
+
) -> Dict[str, Any]:
|
689 |
+
if isinstance(dist_model, FSDP):
|
690 |
+
with FSDP.state_dict_type(
|
691 |
+
dist_model,
|
692 |
+
state_dict_type=StateDictType.FULL_STATE_DICT,
|
693 |
+
state_dict_config=FullStateDictConfig(rank0_only=False, offload_to_cpu=True),
|
694 |
+
optim_state_dict_config=FullOptimStateDictConfig(rank0_only=False, offload_to_cpu=True),
|
695 |
+
):
|
696 |
+
with torch.no_grad():
|
697 |
+
# fill everything with NaN, so we can check afterwards that every parameter has been restored
|
698 |
+
for module_name, module in dist_model.named_modules():
|
699 |
+
if not isinstance(module, FSDP):
|
700 |
+
continue
|
701 |
+
for param in module.params:
|
702 |
+
param.fill_(torch.nan)
|
703 |
+
|
704 |
+
# restore params from checkpoint
|
705 |
+
state_dict_to_load = load_state_dict(
|
706 |
+
load_path, "model.pt", local_cache=local_cache, map_location="cpu"
|
707 |
+
)
|
708 |
+
(
|
709 |
+
state_dict_to_load,
|
710 |
+
og_keys_to_new,
|
711 |
+
) = dist_model._fsdp_wrapped_module._make_state_dict_compatible(state_dict_to_load)
|
712 |
+
|
713 |
+
for module_name, module in dist_model.named_modules():
|
714 |
+
if not isinstance(module, FSDP):
|
715 |
+
continue
|
716 |
+
for param in module.params:
|
717 |
+
assert param._is_flat_param
|
718 |
+
for fqn, spi in zip(param._fqns, param._shard_param_infos):
|
719 |
+
if not spi.in_shard:
|
720 |
+
continue
|
721 |
+
key = f"{module_name}.{fqn}"
|
722 |
+
key = key.replace("_fsdp_wrapped_module.", "")
|
723 |
+
key = key.lstrip(".")
|
724 |
+
t = state_dict_to_load[key]
|
725 |
+
t = t.flatten()
|
726 |
+
param[spi.offset_in_shard : spi.offset_in_shard + spi.numel_in_shard].copy_(
|
727 |
+
t[spi.intra_param_start_idx : spi.intra_param_end_idx + 1]
|
728 |
+
)
|
729 |
+
|
730 |
+
# make sure that every parameter has been restored
|
731 |
+
for module_name, module in dist_model.named_modules():
|
732 |
+
if not isinstance(module, FSDP):
|
733 |
+
continue
|
734 |
+
for param in module.params:
|
735 |
+
if torch.isnan(param).any():
|
736 |
+
raise ValueError(
|
737 |
+
f"Module '{module_name}' contains NaNs, this is likely a bug restoring from full checkpoints"
|
738 |
+
)
|
739 |
+
|
740 |
+
# Load optimizer state.
|
741 |
+
if load_optimizer_state:
|
742 |
+
optim_state_dict_to_load = load_state_dict(
|
743 |
+
load_path, "optim.pt", local_cache=local_cache, map_location="cpu"
|
744 |
+
)
|
745 |
+
optim_state_dict_to_load = self._make_optim_state_dict_compatible(
|
746 |
+
optim_state_dict_to_load,
|
747 |
+
og_keys_to_new,
|
748 |
+
)
|
749 |
+
gc.collect()
|
750 |
+
torch.cuda.empty_cache()
|
751 |
+
barrier()
|
752 |
+
for turn in range(get_local_world_size()):
|
753 |
+
log.info("Loading optimizer state turn %d ...", turn)
|
754 |
+
if turn == get_local_rank():
|
755 |
+
load_fsdp_optim_state(dist_model, optim, optim_state_dict_to_load)
|
756 |
+
gc.collect()
|
757 |
+
torch.cuda.empty_cache()
|
758 |
+
barrier()
|
759 |
+
del optim_state_dict_to_load
|
760 |
+
elif isinstance(dist_model, DDP):
|
761 |
+
# Load model state.
|
762 |
+
with torch.no_grad():
|
763 |
+
state_dict_to_load = load_state_dict(
|
764 |
+
load_path, "model.pt", local_cache=local_cache, map_location="cpu"
|
765 |
+
)
|
766 |
+
dist_model.module.load_state_dict(state_dict_to_load, strict=True)
|
767 |
+
|
768 |
+
# Load optimizer state.
|
769 |
+
if load_optimizer_state:
|
770 |
+
optim_state_dict_to_load = load_state_dict(
|
771 |
+
load_path, "optim.pt", local_cache=local_cache, map_location="cpu"
|
772 |
+
)
|
773 |
+
optim.load_state_dict(optim_state_dict_to_load)
|
774 |
+
|
775 |
+
gc.collect()
|
776 |
+
torch.cuda.empty_cache()
|
777 |
+
barrier()
|
778 |
+
else:
|
779 |
+
raise NotImplementedError(
|
780 |
+
"`FullCheckpointer.restore_checkpoint` only supported for FSDP and DDP distributed strategies!"
|
781 |
+
)
|
782 |
+
|
783 |
+
# Load other state.
|
784 |
+
try:
|
785 |
+
trainer_state = load_state_dict(load_path, "train.pt", local_cache=local_cache)
|
786 |
+
except FileNotFoundError:
|
787 |
+
# for backwards compatibility
|
788 |
+
trainer_state = load_state_dict(load_path, "other.pt", local_cache=local_cache)
|
789 |
+
barrier()
|
790 |
+
return trainer_state
|
791 |
+
|
792 |
+
def _write_model_dict(self, model_state_dict, checkpoint_dir, upload_to, save_overwrite):
|
793 |
+
if get_global_rank() == 0:
|
794 |
+
log.info("Saving model state...")
|
795 |
+
save_state_dict(
|
796 |
+
checkpoint_dir,
|
797 |
+
"model.pt",
|
798 |
+
model_state_dict,
|
799 |
+
upload_to=upload_to,
|
800 |
+
save_overwrite=save_overwrite,
|
801 |
+
synchronize=False,
|
802 |
+
)
|
803 |
+
|
804 |
+
del model_state_dict
|
805 |
+
barrier()
|
806 |
+
|
807 |
+
def _write_optim_dict(self, optim_state_dict, checkpoint_dir, upload_to, save_overwrite):
|
808 |
+
if get_global_rank() == 0:
|
809 |
+
log.info("Saving optim state...")
|
810 |
+
save_state_dict(
|
811 |
+
checkpoint_dir,
|
812 |
+
"optim.pt",
|
813 |
+
optim_state_dict,
|
814 |
+
upload_to=upload_to,
|
815 |
+
save_overwrite=save_overwrite,
|
816 |
+
synchronize=False,
|
817 |
+
)
|
818 |
+
|
819 |
+
del optim_state_dict
|
820 |
+
barrier()
|
821 |
+
|
822 |
+
def _make_optim_state_dict_compatible(
|
823 |
+
self, optim_state_dict: Dict[str, Any], og_keys_to_new: Dict[str, Set[str]]
|
824 |
+
) -> Dict[str, Any]:
|
825 |
+
# This state dict comes in two forms: one where the state keys are integers and one where the
|
826 |
+
# keys are fully qualified parameter names. The latter case is easier to deal with here so we
|
827 |
+
# first transform the integer key form into the FQN key form.
|
828 |
+
if isinstance(optim_state_dict["param_groups"][0]["params"][0], int):
|
829 |
+
id_to_fqn: Dict[int, str] = {}
|
830 |
+
for group in optim_state_dict["param_groups"]:
|
831 |
+
new_param_names = []
|
832 |
+
for fqn, id in zip(group["param_names"], group["params"]):
|
833 |
+
fqn = fqn.replace("_fsdp_wrapped_module.", "")
|
834 |
+
id_to_fqn[id] = fqn
|
835 |
+
new_param_names.append(fqn)
|
836 |
+
group["param_names"] = new_param_names
|
837 |
+
group["params"] = new_param_names
|
838 |
+
for id in list(optim_state_dict["state"].keys()):
|
839 |
+
optim_state_dict["state"][id_to_fqn[id]] = optim_state_dict["state"].pop(id)
|
840 |
+
else:
|
841 |
+
# Otherwise we still want to clean up the param names to remove the "_fsdp_wrapped_module." prefix.
|
842 |
+
for group in optim_state_dict["param_groups"]:
|
843 |
+
group["param_names"] = [fqn.replace("_fsdp_wrapped_module.", "") for fqn in group["param_names"]]
|
844 |
+
group["params"] = [fqn.replace("_fsdp_wrapped_module.", "") for fqn in group["params"]]
|
845 |
+
assert group["param_names"] == group["params"]
|
846 |
+
for key in list(optim_state_dict["state"].keys()):
|
847 |
+
optim_state_dict["state"][key.replace("_fsdp_wrapped_module.", "")] = optim_state_dict[
|
848 |
+
"state"
|
849 |
+
].pop(key)
|
850 |
+
|
851 |
+
# Now we can transform the state dict by renaming parameters according to `og_keys_to_new`.
|
852 |
+
# First fix param names in the state.
|
853 |
+
for og_key, new_keys in og_keys_to_new.items():
|
854 |
+
og_state = optim_state_dict["state"].pop(og_key, None)
|
855 |
+
if og_state is None:
|
856 |
+
continue
|
857 |
+
for i, new_key in enumerate(new_keys):
|
858 |
+
if i == len(new_keys) - 1:
|
859 |
+
optim_state_dict["state"][new_key] = og_state
|
860 |
+
else:
|
861 |
+
optim_state_dict["state"][new_key] = deepcopy(og_state)
|
862 |
+
# Now fix param names in the param groups.
|
863 |
+
for group in optim_state_dict["param_groups"]:
|
864 |
+
og_names = group["params"]
|
865 |
+
new_names = []
|
866 |
+
for og_key in og_names:
|
867 |
+
for new_key in og_keys_to_new[og_key]:
|
868 |
+
new_names.append(new_key)
|
869 |
+
group["params"] = new_names
|
870 |
+
group["param_names"] = new_names
|
871 |
+
|
872 |
+
return optim_state_dict
|
873 |
+
|
874 |
+
def load_checkpoint(
|
875 |
+
self,
|
876 |
+
load_path: PathOrStr,
|
877 |
+
*,
|
878 |
+
local_cache: Optional[PathOrStr] = None,
|
879 |
+
load_optimizer_state: bool = True,
|
880 |
+
device: Optional[torch.device] = None,
|
881 |
+
) -> Tuple[Dict[str, torch.Tensor], Optional[Dict[str, Any]]]:
|
882 |
+
device = device if device is not None else torch.device("cpu")
|
883 |
+
model_state = load_state_dict(load_path, "model.pt", local_cache=local_cache, map_location=device) # type: ignore
|
884 |
+
optim_state = None
|
885 |
+
if load_optimizer_state:
|
886 |
+
optim_state = load_state_dict(load_path, "optim.pt", local_cache=local_cache, map_location=device) # type: ignore
|
887 |
+
return model_state, optim_state
|
888 |
+
|
889 |
+
|
890 |
+
class TorchNewStyleShardedCheckpointer(Checkpointer):
|
891 |
+
"""
|
892 |
+
A sharded :class:`Checkpointer` that uses PyTorch's new distributed checkpointing functionality.
|
893 |
+
"""
|
894 |
+
|
895 |
+
def save_checkpoint(
|
896 |
+
self,
|
897 |
+
dir: PathOrStr,
|
898 |
+
dist_model: nn.Module,
|
899 |
+
optim: Optimizer,
|
900 |
+
trainer_state: Dict[str, Any],
|
901 |
+
*,
|
902 |
+
upload_to: Optional[str] = None,
|
903 |
+
) -> None:
|
904 |
+
assert isinstance(
|
905 |
+
dist_model, FSDP
|
906 |
+
), f"{self.__class__.__name__} is being called to save a model where `distributed_strategy` is not FSDP."
|
907 |
+
with self._temporary_wd(dir) as checkpoint_dir:
|
908 |
+
# Save model and optim state.
|
909 |
+
save_fsdp_model_and_optim_state(
|
910 |
+
checkpoint_dir,
|
911 |
+
dist_model,
|
912 |
+
optim,
|
913 |
+
upload_to=upload_to,
|
914 |
+
save_overwrite=self.cfg.save_overwrite,
|
915 |
+
)
|
916 |
+
|
917 |
+
# Save trainer state.
|
918 |
+
log.info("Saving trainer state...")
|
919 |
+
save_state_dict(
|
920 |
+
checkpoint_dir,
|
921 |
+
f"train/rank{get_global_rank()}.pt",
|
922 |
+
trainer_state,
|
923 |
+
upload_to=upload_to,
|
924 |
+
save_overwrite=self.cfg.save_overwrite,
|
925 |
+
)
|
926 |
+
|
927 |
+
# Save config.
|
928 |
+
self._save_config(checkpoint_dir, upload_to=upload_to)
|
929 |
+
|
930 |
+
def restore_checkpoint(
|
931 |
+
self,
|
932 |
+
load_path: PathOrStr,
|
933 |
+
dist_model: nn.Module,
|
934 |
+
optim: Optimizer,
|
935 |
+
*,
|
936 |
+
local_cache: Optional[PathOrStr] = None,
|
937 |
+
load_optimizer_state: bool = True,
|
938 |
+
) -> Dict[str, Any]:
|
939 |
+
# Load model and optimizer state in place.
|
940 |
+
log.info("Loading model and optimizer state...")
|
941 |
+
assert isinstance(
|
942 |
+
dist_model, FSDP
|
943 |
+
), f"{self.__class__.__name__} is being called to load a model where `distributed_strategy` is not FSDP."
|
944 |
+
|
945 |
+
load_fsdp_model_and_optim_state(
|
946 |
+
load_path,
|
947 |
+
dist_model,
|
948 |
+
optim,
|
949 |
+
local_cache=local_cache,
|
950 |
+
load_optimizer_state=load_optimizer_state,
|
951 |
+
)
|
952 |
+
|
953 |
+
# Load trainer state dict.
|
954 |
+
log.info("Loading trainer state...")
|
955 |
+
try:
|
956 |
+
trainer_state = load_state_dict(
|
957 |
+
load_path, f"train/rank{get_global_rank()}.pt", local_cache=local_cache
|
958 |
+
)
|
959 |
+
except FileNotFoundError:
|
960 |
+
# Fall back to rank 0 train state.
|
961 |
+
# This can happen when we're restoring a checkpoint with a different world size.
|
962 |
+
trainer_state = load_state_dict(load_path, "train/rank0.pt", local_cache=local_cache)
|
963 |
+
barrier()
|
964 |
+
return trainer_state
|
965 |
+
|
966 |
+
|
967 |
+
class TorchLegacyShardedCheckpointer(Checkpointer):
|
968 |
+
"""
|
969 |
+
A sharded :class:`Checkpointer` that just uses `torch.save()` with extra logic for handling FSDP model
|
970 |
+
and optim state.
|
971 |
+
|
972 |
+
The world size must be kept consistent when using this checkpointer.
|
973 |
+
"""
|
974 |
+
|
975 |
+
def __init__(self, cfg: TrainConfig, thread_count: Optional[int] = None, use_shared_mem_impl: bool = False):
|
976 |
+
super().__init__(cfg, thread_count)
|
977 |
+
self.use_shared_mem_impl = use_shared_mem_impl
|
978 |
+
|
979 |
+
def save_checkpoint(
|
980 |
+
self,
|
981 |
+
dir: PathOrStr,
|
982 |
+
dist_model: nn.Module,
|
983 |
+
optim: Optimizer,
|
984 |
+
trainer_state: Dict[str, Any],
|
985 |
+
*,
|
986 |
+
upload_to: Optional[str] = None,
|
987 |
+
) -> None:
|
988 |
+
assert isinstance(
|
989 |
+
dist_model, FSDP
|
990 |
+
), f"{self.__class__.__name__} is being called to save a model where `distributed_strategy` is not FSDP."
|
991 |
+
with self._temporary_wd(dir) as checkpoint_dir:
|
992 |
+
with FSDP.state_dict_type(
|
993 |
+
dist_model,
|
994 |
+
state_dict_type=StateDictType.SHARDED_STATE_DICT,
|
995 |
+
state_dict_config=ShardedStateDictConfig(offload_to_cpu=True),
|
996 |
+
optim_state_dict_config=ShardedOptimStateDictConfig(offload_to_cpu=True),
|
997 |
+
):
|
998 |
+
state_dict = {
|
999 |
+
"model": dist_model.state_dict(),
|
1000 |
+
"optim": FSDP.optim_state_dict(dist_model, optim),
|
1001 |
+
**trainer_state,
|
1002 |
+
}
|
1003 |
+
save_state_dict(
|
1004 |
+
checkpoint_dir,
|
1005 |
+
f"rank{get_global_rank()}.pt",
|
1006 |
+
state_dict,
|
1007 |
+
upload_to=upload_to,
|
1008 |
+
save_overwrite=self.cfg.save_overwrite,
|
1009 |
+
)
|
1010 |
+
|
1011 |
+
# Save config.
|
1012 |
+
self._save_config(checkpoint_dir, upload_to=upload_to)
|
1013 |
+
|
1014 |
+
def restore_checkpoint(
|
1015 |
+
self,
|
1016 |
+
load_path: PathOrStr,
|
1017 |
+
dist_model: nn.Module,
|
1018 |
+
optim: Optimizer,
|
1019 |
+
*,
|
1020 |
+
local_cache: Optional[PathOrStr] = None,
|
1021 |
+
load_optimizer_state: bool = True,
|
1022 |
+
) -> Dict[str, Any]:
|
1023 |
+
assert isinstance(
|
1024 |
+
dist_model, FSDP
|
1025 |
+
), f"{self.__class__.__name__} is being called to load a model where `distributed_strategy` is not FSDP."
|
1026 |
+
with FSDP.state_dict_type(
|
1027 |
+
dist_model,
|
1028 |
+
state_dict_type=StateDictType.SHARDED_STATE_DICT,
|
1029 |
+
state_dict_config=ShardedStateDictConfig(offload_to_cpu=True),
|
1030 |
+
optim_state_dict_config=ShardedOptimStateDictConfig(offload_to_cpu=True),
|
1031 |
+
):
|
1032 |
+
# Deserialize state dict.
|
1033 |
+
state_dict = load_state_dict(
|
1034 |
+
load_path, f"rank{get_global_rank()}.pt", local_cache=local_cache, map_location="cpu"
|
1035 |
+
)
|
1036 |
+
|
1037 |
+
# Load model and optimizer state.
|
1038 |
+
log.info("Loading model state...")
|
1039 |
+
dist_model.load_state_dict(state_dict["model"])
|
1040 |
+
del state_dict["model"]
|
1041 |
+
if load_optimizer_state:
|
1042 |
+
log.info("Loading optimizer state...")
|
1043 |
+
load_fsdp_optim_state(dist_model, optim, state_dict["optim"])
|
1044 |
+
del state_dict["optim"]
|
1045 |
+
|
1046 |
+
barrier()
|
1047 |
+
return state_dict
|
1048 |
+
|
1049 |
+
def unshard_checkpoint(
|
1050 |
+
self,
|
1051 |
+
load_path: PathOrStr,
|
1052 |
+
*,
|
1053 |
+
local_cache: Optional[PathOrStr] = None,
|
1054 |
+
load_optimizer_state: bool = True,
|
1055 |
+
load_trainer_state: bool = True,
|
1056 |
+
device: Optional[torch.device] = None,
|
1057 |
+
) -> Tuple[Dict[str, torch.Tensor], Optional[Dict[str, Any]], Optional[Dict[str, Any]]]:
|
1058 |
+
assert local_cache is None, "this method currently only supports local files"
|
1059 |
+
full_state_dict = self._unshard(load_path, device or torch.device("cpu"), skip_keys={"rng"})
|
1060 |
+
model_state = full_state_dict.pop("model")
|
1061 |
+
optim_state = full_state_dict.pop("optim")
|
1062 |
+
return (
|
1063 |
+
model_state,
|
1064 |
+
optim_state if load_optimizer_state else None,
|
1065 |
+
full_state_dict if load_trainer_state else None,
|
1066 |
+
)
|
1067 |
+
|
1068 |
+
def _copy_sharded_tensors_to_shared_mem(self, state: Dict, world_size: int, rank: int, key: Tuple):
|
1069 |
+
key = tuple() if key is None else key
|
1070 |
+
if isinstance(state, (list, tuple, set)):
|
1071 |
+
for i, sub_state in enumerate(state):
|
1072 |
+
self._copy_sharded_tensors_to_shared_mem(sub_state, world_size, rank, key + (i,))
|
1073 |
+
elif isinstance(state, dict):
|
1074 |
+
for name in state.keys():
|
1075 |
+
self._copy_sharded_tensors_to_shared_mem(state[name], world_size, rank, key + (name,))
|
1076 |
+
elif isinstance(state, ShardedTensor):
|
1077 |
+
self._copy_sharded_tensor_to_shared_mem(state, world_size, rank, key)
|
1078 |
+
return
|
1079 |
+
else:
|
1080 |
+
return
|
1081 |
+
|
1082 |
+
def _get_shard_placement_and_rank_sizes(
|
1083 |
+
self, shards_metadata: List[ShardMetadata], world_size: int
|
1084 |
+
) -> Tuple[Dict[ShardMetadata, Tuple[int, int]], List[int]]:
|
1085 |
+
def shard_size(shard_md):
|
1086 |
+
return reduce((lambda x, y: x * y), shard_md.shard_sizes) # type: ignore[attr-defined]
|
1087 |
+
|
1088 |
+
rank_sizes = [0 for _ in range(world_size)]
|
1089 |
+
shard_placement: Dict[ShardMetadata, Tuple[int, int]] = {}
|
1090 |
+
for shard_md in shards_metadata:
|
1091 |
+
shard_rank = cast(_remote_device, shard_md.placement).rank()
|
1092 |
+
assert shard_rank is not None
|
1093 |
+
if shard_rank >= world_size:
|
1094 |
+
raise RuntimeError(f"Shard rank {shard_rank} exceeds world size {world_size}")
|
1095 |
+
|
1096 |
+
shard_placement[shard_md] = (shard_rank, rank_sizes[shard_rank])
|
1097 |
+
rank_sizes[shard_rank] += shard_size(shard_md)
|
1098 |
+
|
1099 |
+
return shard_placement, rank_sizes
|
1100 |
+
|
1101 |
+
def _copy_sharded_tensor_to_shared_mem(
|
1102 |
+
self, sharded_tensor: ShardedTensor, world_size: int, rank: int, key: Tuple
|
1103 |
+
) -> Any:
|
1104 |
+
shard0_md = sharded_tensor.metadata()
|
1105 |
+
shard_placement, rank_sizes = self._get_shard_placement_and_rank_sizes(
|
1106 |
+
shard0_md.shards_metadata, world_size
|
1107 |
+
)
|
1108 |
+
|
1109 |
+
rank_size = rank_sizes[rank]
|
1110 |
+
assert rank_size >= 0
|
1111 |
+
if rank_size == 0:
|
1112 |
+
return
|
1113 |
+
|
1114 |
+
assert shard0_md.tensor_properties.dtype == torch.float32, "Expected sharded tensor to be fp32"
|
1115 |
+
numpy_type = np.float32
|
1116 |
+
|
1117 |
+
sharded_memory_name = "-".join(key + (str(rank),))
|
1118 |
+
|
1119 |
+
shm = shared_memory.SharedMemory(
|
1120 |
+
create=True, size=rank_size * np.dtype(numpy_type).itemsize, name=sharded_memory_name
|
1121 |
+
)
|
1122 |
+
np_arr = np.ndarray((rank_size,), dtype=numpy_type, buffer=shm.buf)
|
1123 |
+
|
1124 |
+
for local_shard in sharded_tensor.local_shards():
|
1125 |
+
shard_rank = cast(_remote_device, local_shard.metadata.placement).rank()
|
1126 |
+
assert shard_rank == rank
|
1127 |
+
|
1128 |
+
src = local_shard.tensor.flatten()
|
1129 |
+
shard_offset = shard_placement[local_shard.metadata][1]
|
1130 |
+
|
1131 |
+
np_arr[shard_offset : shard_offset + src.numel()] = src.numpy()
|
1132 |
+
|
1133 |
+
shm.close()
|
1134 |
+
|
1135 |
+
def _copy_sharded_data_to_shared_mem(self, world_size: int, shard_filepath: Path):
|
1136 |
+
shard_number = int(shard_filepath.name[4:-3])
|
1137 |
+
log.info("Starting unsharding shard number %d to shared memory", shard_number)
|
1138 |
+
|
1139 |
+
with self._patch_sharded_tensor_load():
|
1140 |
+
shard = torch.load(shard_filepath, map_location="cpu")
|
1141 |
+
log.debug("Done loading shard number %d", shard_number)
|
1142 |
+
|
1143 |
+
self._copy_sharded_tensors_to_shared_mem(
|
1144 |
+
shard, world_size, shard_number, (str(shard_filepath.parent).replace("/", "_"),)
|
1145 |
+
)
|
1146 |
+
log.info("Done unsharding shard number %d to shared memory", shard_number)
|
1147 |
+
|
1148 |
+
def _unshard_using_sharded_mem(
|
1149 |
+
self, state: Any, world_size: int, device: torch.device, shard_dir: PathOrStr
|
1150 |
+
) -> Any:
|
1151 |
+
return self._unshard_state_using_shared_mem(state, world_size, device, (str(shard_dir).replace("/", "_"),))
|
1152 |
+
|
1153 |
+
def _unshard_state_using_shared_mem(
|
1154 |
+
self, state: Any, world_size: int, device: torch.device, key: Tuple
|
1155 |
+
) -> Any:
|
1156 |
+
if isinstance(state, (list, tuple, set)):
|
1157 |
+
return state.__class__(
|
1158 |
+
self._unshard_state_using_shared_mem(sub_state, world_size, device, key + (i,))
|
1159 |
+
for i, sub_state in enumerate(state)
|
1160 |
+
)
|
1161 |
+
elif isinstance(state, dict):
|
1162 |
+
return {
|
1163 |
+
name: self._unshard_state_using_shared_mem(state[name], world_size, device, key + (name,))
|
1164 |
+
for name in state.keys()
|
1165 |
+
}
|
1166 |
+
elif isinstance(state, ShardedTensor):
|
1167 |
+
return self._unshard_tensor_using_shared_mem(state, world_size, device, key)
|
1168 |
+
elif isinstance(state, torch.Tensor):
|
1169 |
+
return state.to(device=device)
|
1170 |
+
else:
|
1171 |
+
return state
|
1172 |
+
|
1173 |
+
def _unshard_tensor_using_shared_mem(
|
1174 |
+
self, sharded_tensor: ShardedTensor, world_size: int, device: torch.device, key: Tuple
|
1175 |
+
) -> torch.Tensor:
|
1176 |
+
shard0_md = sharded_tensor.metadata()
|
1177 |
+
|
1178 |
+
def shard_size(shard_md):
|
1179 |
+
return reduce((lambda x, y: x * y), shard_md.shard_sizes) # type: ignore[attr-defined]
|
1180 |
+
|
1181 |
+
shard_placement, rank_sizes = self._get_shard_placement_and_rank_sizes(
|
1182 |
+
shard0_md.shards_metadata, world_size
|
1183 |
+
)
|
1184 |
+
|
1185 |
+
assert shard0_md.tensor_properties.dtype == torch.float32, "Expected sharded tensor to be fp32"
|
1186 |
+
numpy_type = np.float32
|
1187 |
+
|
1188 |
+
out = torch.empty(
|
1189 |
+
*sharded_tensor.metadata().size, dtype=sharded_tensor.metadata().tensor_properties.dtype, device=device
|
1190 |
+
)
|
1191 |
+
dims = len(sharded_tensor.metadata().size)
|
1192 |
+
for shard_md, (rank, rank_offset) in shard_placement.items():
|
1193 |
+
if rank >= world_size:
|
1194 |
+
raise RuntimeError(f"Shard rank {rank} exceeds world size {world_size}")
|
1195 |
+
|
1196 |
+
sharded_memory_name = "-".join(key + (str(rank),))
|
1197 |
+
shm = shared_memory.SharedMemory(name=sharded_memory_name)
|
1198 |
+
|
1199 |
+
rank_size = rank_sizes[rank]
|
1200 |
+
assert rank_size >= 0
|
1201 |
+
if rank_size == 0:
|
1202 |
+
continue
|
1203 |
+
|
1204 |
+
np_arr = np.ndarray((rank_size,), dtype=numpy_type, buffer=shm.buf)
|
1205 |
+
|
1206 |
+
tensor = torch.from_numpy(np_arr)[rank_offset : rank_offset + shard_size(shard_md)]
|
1207 |
+
tensor = tensor.view(shard_md.shard_sizes)
|
1208 |
+
|
1209 |
+
out_narrow_view = out
|
1210 |
+
for dim in range(dims):
|
1211 |
+
out_narrow_view = out_narrow_view.narrow(
|
1212 |
+
dim,
|
1213 |
+
shard_md.shard_offsets[dim],
|
1214 |
+
shard_md.shard_sizes[dim],
|
1215 |
+
)
|
1216 |
+
|
1217 |
+
out_narrow_view.copy_(tensor)
|
1218 |
+
|
1219 |
+
shm.close()
|
1220 |
+
shm.unlink()
|
1221 |
+
|
1222 |
+
return out
|
1223 |
+
|
1224 |
+
@contextmanager
|
1225 |
+
def _patch_sharded_tensor_load(self):
|
1226 |
+
"""
|
1227 |
+
Monkeypatch for torch's ShardedTensor, so we can unpickle without having torch.distributed set up.
|
1228 |
+
"""
|
1229 |
+
|
1230 |
+
def _rebuild_from_type_v2_monkey(func, new_type, args, state):
|
1231 |
+
ret = func(*args)
|
1232 |
+
if type(ret) is not new_type:
|
1233 |
+
ret = ret.as_subclass(new_type)
|
1234 |
+
|
1235 |
+
# Shortcut the construction of ShardedTensor
|
1236 |
+
# This is in the top 5 of my worst hacks.
|
1237 |
+
if isinstance(ret, ShardedTensor):
|
1238 |
+
ret._local_shards, ret._metadata, _, ret._sharding_spec, ret._init_rrefs = state
|
1239 |
+
return ret
|
1240 |
+
|
1241 |
+
# The rest of this function ought to be in the top 5 of somebody else's worst hacks.
|
1242 |
+
# Tensor does define __setstate__ even though it doesn't define
|
1243 |
+
# __getstate__. So only use __setstate__ if it is NOT the one defined
|
1244 |
+
# on Tensor
|
1245 |
+
if getattr(ret.__class__, "__setstate__", torch.Tensor.__setstate__) is not torch.Tensor.__setstate__:
|
1246 |
+
ret.__setstate__(state)
|
1247 |
+
else:
|
1248 |
+
ret = torch._utils._set_obj_state(ret, state)
|
1249 |
+
return ret
|
1250 |
+
|
1251 |
+
original_rebuild_from_type_v2 = torch._tensor._rebuild_from_type_v2
|
1252 |
+
try:
|
1253 |
+
torch._tensor._rebuild_from_type_v2 = _rebuild_from_type_v2_monkey
|
1254 |
+
yield
|
1255 |
+
finally:
|
1256 |
+
torch._tensor._rebuild_from_type_v2 = original_rebuild_from_type_v2
|
1257 |
+
|
1258 |
+
def _unshard_using_shared_memory(
|
1259 |
+
self, input_dir: PathOrStr, device: torch.device, skip_keys: Optional[Set[str]] = None
|
1260 |
+
):
|
1261 |
+
"""
|
1262 |
+
This unsharding implementation consists of:
|
1263 |
+
|
1264 |
+
1. Loading each shard on a separate process and copying their sharded tensors to shared memory.
|
1265 |
+
2. Loading 1 shard on the main process as a base unsharded object.
|
1266 |
+
3. Using the sharded tensors in shared memory to populate the base unsharded object.
|
1267 |
+
|
1268 |
+
This implementation is an alternative to a prior implementation that instead loaded
|
1269 |
+
all shards using threads, because that implementation turned out to
|
1270 |
+
be extremely slow (e.g. 6+ hours) sometimes when the world size was 1024.
|
1271 |
+
The current implementation is slower than the old one in many scenarios,
|
1272 |
+
but is significantly faster in the above mentioned case (e.g. 30 minutes)
|
1273 |
+
if there are enough CPUs.
|
1274 |
+
|
1275 |
+
We keep the other implementation since this once can be more unreliable,
|
1276 |
+
likely due to its dependence on a large amount of shared memory.
|
1277 |
+
"""
|
1278 |
+
|
1279 |
+
input_dir = Path(input_dir)
|
1280 |
+
skip_keys = skip_keys or set()
|
1281 |
+
|
1282 |
+
shard_filepaths = list(input_dir.glob("rank*.pt"))
|
1283 |
+
world_size = len(shard_filepaths)
|
1284 |
+
if world_size == 0:
|
1285 |
+
raise RuntimeError("No shards found for unsharding")
|
1286 |
+
|
1287 |
+
log.info("Number of shards: %d", world_size)
|
1288 |
+
shard_size_gb = shard_filepaths[0].stat().st_size / (1024 * 1024 * 1024)
|
1289 |
+
min_ram_required_estimate_gb = shard_size_gb * world_size
|
1290 |
+
log.info(
|
1291 |
+
"Shards are %.2fGB each, at least %.2fGB RAM is required", shard_size_gb, min_ram_required_estimate_gb
|
1292 |
+
)
|
1293 |
+
|
1294 |
+
log.info("Copying sharded tensors to shared memory using multiple processes")
|
1295 |
+
# Copy sharded data to shared memory using multiple processes, so this process can load
|
1296 |
+
# from memory rather than disk. We spawn a new process instead of forking since shared memory
|
1297 |
+
# appears to get deleted when forked processes end for some reason.
|
1298 |
+
executor = ProcessPoolExecutor(
|
1299 |
+
mp_context=mp.get_context("spawn"), initializer=util.prepare_cli_environment
|
1300 |
+
)
|
1301 |
+
futures = []
|
1302 |
+
for shard_filepath in shard_filepaths:
|
1303 |
+
shard_rank = int(shard_filepath.name[4:-3])
|
1304 |
+
|
1305 |
+
if shard_rank >= world_size:
|
1306 |
+
raise RuntimeError(
|
1307 |
+
f"Shard rank {shard_rank} of file {shard_filepath} exceeds world size {world_size}"
|
1308 |
+
)
|
1309 |
+
|
1310 |
+
futures.append(executor.submit(self._copy_sharded_data_to_shared_mem, world_size, shard_filepath))
|
1311 |
+
|
1312 |
+
for f in as_completed(futures):
|
1313 |
+
f.result()
|
1314 |
+
executor.shutdown()
|
1315 |
+
|
1316 |
+
log.info("Loading a shard on the main process to be unsharded state")
|
1317 |
+
with self._patch_sharded_tensor_load():
|
1318 |
+
state = torch.load(shard_filepaths[0], map_location="cpu")
|
1319 |
+
|
1320 |
+
for key in skip_keys:
|
1321 |
+
if key in state:
|
1322 |
+
del state[key]
|
1323 |
+
|
1324 |
+
log.info("Unsharding from %d shards ...", world_size)
|
1325 |
+
return self._unshard_using_sharded_mem(state, world_size, device, input_dir)
|
1326 |
+
|
1327 |
+
def _unshard(self, input_dir: PathOrStr, device: torch.device, skip_keys: Optional[Set[str]] = None):
|
1328 |
+
if self.use_shared_mem_impl:
|
1329 |
+
return self._unshard_using_shared_memory(input_dir, device, skip_keys)
|
1330 |
+
|
1331 |
+
input_dir = Path(input_dir)
|
1332 |
+
skip_keys = skip_keys or set()
|
1333 |
+
|
1334 |
+
with self._patch_sharded_tensor_load():
|
1335 |
+
# We load in threads because it's faster.
|
1336 |
+
executor = ThreadPoolExecutor()
|
1337 |
+
shards_dict = {}
|
1338 |
+
for shard_name in input_dir.glob("rank*.pt"):
|
1339 |
+
log.info("Loading %s ...", shard_name)
|
1340 |
+
shard_number = int(shard_name.name[4:-3]) # shard names look like "rankXX.pt"
|
1341 |
+
shards_dict[shard_number] = executor.submit(torch.load, shard_name, map_location="cpu")
|
1342 |
+
shards = [None] * len(shards_dict)
|
1343 |
+
for rank, shard_future in shards_dict.items():
|
1344 |
+
shard = shard_future.result()
|
1345 |
+
for key in skip_keys:
|
1346 |
+
if key in shard:
|
1347 |
+
del shard[key]
|
1348 |
+
shards[rank] = shard
|
1349 |
+
assert all(shard is not None for shard in shards)
|
1350 |
+
executor.shutdown()
|
1351 |
+
del shards_dict
|
1352 |
+
|
1353 |
+
log.info("Unsharding from %d shards ...", len(shards))
|
1354 |
+
|
1355 |
+
unsharded_state_dict = self._unshard_object(shards, device=device)
|
1356 |
+
# At this point in time we need 2x memory :-(
|
1357 |
+
del shards
|
1358 |
+
|
1359 |
+
return unsharded_state_dict
|
1360 |
+
|
1361 |
+
def _unshard_object(self, os: List[Any], device: torch.device) -> Any:
|
1362 |
+
rank0_item = os[0]
|
1363 |
+
assert all(type(o) is type(rank0_item) for o in os)
|
1364 |
+
if isinstance(rank0_item, str):
|
1365 |
+
assert all(o == rank0_item for o in os)
|
1366 |
+
return rank0_item
|
1367 |
+
elif isinstance(rank0_item, (list, tuple, set)):
|
1368 |
+
assert all(len(o) == len(rank0_item) for o in os)
|
1369 |
+
return rank0_item.__class__(self._unshard_object(o, device=device) for o in zip(*os))
|
1370 |
+
elif isinstance(rank0_item, dict):
|
1371 |
+
assert all(o.keys() == rank0_item.keys() for o in os)
|
1372 |
+
return {key: self._unshard_object([o[key] for o in os], device=device) for key in rank0_item.keys()}
|
1373 |
+
elif isinstance(rank0_item, ShardedTensor):
|
1374 |
+
return self._gather(os, device=device)
|
1375 |
+
else:
|
1376 |
+
assert all(self._objects_are_equal(o, rank0_item) for o in os)
|
1377 |
+
return rank0_item
|
1378 |
+
|
1379 |
+
def _gather(self, shards: List[ShardedTensor], device: torch.device) -> torch.Tensor:
|
1380 |
+
world_size = len(shards)
|
1381 |
+
shard0_md = shards[0].metadata()
|
1382 |
+
# Make sure all shards agree on the metadata
|
1383 |
+
assert all(shard.metadata() == shard0_md for shard in shards)
|
1384 |
+
# Make sure the nth shard expects to be the nth shard.
|
1385 |
+
assert all(
|
1386 |
+
shard_md.placement.rank() == rank # type: ignore
|
1387 |
+
for rank, shard_md in enumerate(shard0_md.shards_metadata)
|
1388 |
+
)
|
1389 |
+
|
1390 |
+
def shard_size(shard_md):
|
1391 |
+
return reduce((lambda x, y: x * y), shard_md.shard_sizes) # type: ignore[attr-defined]
|
1392 |
+
|
1393 |
+
rank_sizes = [0 for _ in range(world_size)]
|
1394 |
+
max_rank_size = 0
|
1395 |
+
shard_placement: Dict[ShardMetadata, Tuple[int, int]] = {}
|
1396 |
+
for shard_md in shard0_md.shards_metadata:
|
1397 |
+
shard_rank = cast(_remote_device, shard_md.placement).rank()
|
1398 |
+
assert shard_rank is not None
|
1399 |
+
|
1400 |
+
shard_placement[shard_md] = (shard_rank, rank_sizes[shard_rank])
|
1401 |
+
rank_sizes[shard_rank] += shard_size(shard_md)
|
1402 |
+
max_rank_size = max(max_rank_size, rank_sizes[shard_rank])
|
1403 |
+
|
1404 |
+
gather_list: List[torch.Tensor] = [torch.empty((max_rank_size,)) for _ in range(world_size)]
|
1405 |
+
|
1406 |
+
datas = []
|
1407 |
+
with torch.no_grad():
|
1408 |
+
for shard in shards:
|
1409 |
+
data = torch.empty(max_rank_size)
|
1410 |
+
|
1411 |
+
for local_shard in shard.local_shards():
|
1412 |
+
src = local_shard.tensor.flatten()
|
1413 |
+
shard_offset = shard_placement[local_shard.metadata][1]
|
1414 |
+
data[shard_offset : shard_offset + src.numel()].copy_(src)
|
1415 |
+
|
1416 |
+
datas.append(data)
|
1417 |
+
|
1418 |
+
# torch.gather in a nutshell
|
1419 |
+
for rank, data in enumerate(datas):
|
1420 |
+
gather_list[rank].copy_(data)
|
1421 |
+
|
1422 |
+
full_size = shard0_md.size
|
1423 |
+
out = torch.empty(*full_size, dtype=shard0_md.tensor_properties.dtype, device=device)
|
1424 |
+
dims = len(full_size)
|
1425 |
+
for shard_md in shard0_md.shards_metadata:
|
1426 |
+
rank, rank_offset = shard_placement[shard_md]
|
1427 |
+
tensor = gather_list[rank]
|
1428 |
+
tensor = tensor[rank_offset : rank_offset + shard_size(shard_md)]
|
1429 |
+
tensor = tensor.view(shard_md.shard_sizes)
|
1430 |
+
|
1431 |
+
out_narrow_view = out
|
1432 |
+
for dim in range(dims):
|
1433 |
+
out_narrow_view = out_narrow_view.narrow(
|
1434 |
+
dim,
|
1435 |
+
shard_md.shard_offsets[dim],
|
1436 |
+
shard_md.shard_sizes[dim],
|
1437 |
+
)
|
1438 |
+
|
1439 |
+
out_narrow_view.copy_(tensor)
|
1440 |
+
|
1441 |
+
return out
|
1442 |
+
|
1443 |
+
def _objects_are_equal(self, a: Any, b: Any) -> bool:
|
1444 |
+
if type(a) is not type(b):
|
1445 |
+
return False
|
1446 |
+
if isinstance(a, np.ndarray):
|
1447 |
+
return np.array_equal(a, b)
|
1448 |
+
elif isinstance(a, torch.Tensor):
|
1449 |
+
return torch.equal(a, b)
|
1450 |
+
else:
|
1451 |
+
return a == b
|
1452 |
+
|
1453 |
+
|
1454 |
+
@dataclass
|
1455 |
+
class _LocalShardedCheckpointerMetadata(BaseConfig):
|
1456 |
+
world_size: int = field(default_factory=get_world_size)
|
1457 |
+
|
1458 |
+
|
1459 |
+
@dataclass
|
1460 |
+
class _FlatParamShard:
|
1461 |
+
full_shape: torch.Size
|
1462 |
+
shard_offsets: Tuple[int, int]
|
1463 |
+
shard_data: Optional[torch.Tensor]
|
1464 |
+
|
1465 |
+
def copy_into(self, full_tensor: torch.Tensor) -> None:
|
1466 |
+
assert self.shard_data is not None
|
1467 |
+
full_tensor_shard_view = full_tensor.view(-1)[self.shard_offsets[0] : self.shard_offsets[1] + 1]
|
1468 |
+
assert self.shard_data.shape == full_tensor_shard_view.shape
|
1469 |
+
full_tensor_shard_view.copy_(self.shard_data)
|
1470 |
+
|
1471 |
+
|
1472 |
+
class LocalShardedCheckpointer(Checkpointer):
|
1473 |
+
"""
|
1474 |
+
A sharded :class:`Checkpointer` that directly saves the local FSDP flat params data.
|
1475 |
+
The optimizer state is saved directly with `torch.save()` without reformatting via FSDP methods.
|
1476 |
+
|
1477 |
+
The world size must be kept consistent when using this checkpointer. However, you can easily
|
1478 |
+
reconstruct a full unsharded model and/or optimizer state dictionary from a single Python process
|
1479 |
+
using :meth:`unshard_checkpoint()` (no distributed initialization required).
|
1480 |
+
"""
|
1481 |
+
|
1482 |
+
# These correspond to metadata attributes on `torch.distributed.fsdp.flat_param.FlatParameter`.
|
1483 |
+
_FLAT_PARAM_METADATA_TO_SAVE = (
|
1484 |
+
"_fqns",
|
1485 |
+
"_shard_param_offsets",
|
1486 |
+
"_shard_indices",
|
1487 |
+
"_numels",
|
1488 |
+
"_numels_with_padding",
|
1489 |
+
"_shapes",
|
1490 |
+
"_shard_numel_padded",
|
1491 |
+
"_shard_param_infos",
|
1492 |
+
)
|
1493 |
+
|
1494 |
+
def _fsdp_modules(self, fsdp_model: FSDP) -> List[Tuple[str, FSDP]]:
|
1495 |
+
"""
|
1496 |
+
Returns a list of FSDP modules with their FQN.
|
1497 |
+
"""
|
1498 |
+
modules = []
|
1499 |
+
for name, module in fsdp_model.named_modules():
|
1500 |
+
if isinstance(module, FSDP):
|
1501 |
+
modules.append((name, module))
|
1502 |
+
return modules
|
1503 |
+
|
1504 |
+
def _prepare_fsdp_model(self, fsdp_model: FSDP) -> None:
|
1505 |
+
from torch.distributed.fsdp._runtime_utils import _lazy_init
|
1506 |
+
|
1507 |
+
# TODO (epwalsh): I'm not sure if this is necessary, but this is what PyTorch does before saving/loading
|
1508 |
+
# an FSDP state dict through the built-in methods.
|
1509 |
+
if torch.cuda.is_available():
|
1510 |
+
torch.cuda.synchronize()
|
1511 |
+
_lazy_init(fsdp_model, fsdp_model)
|
1512 |
+
|
1513 |
+
def _fsdp_handles(self, fsdp_model: FSDP) -> List[FlatParamHandle]:
|
1514 |
+
if version.parse(torch.__version__) < version.parse("2.1.0"):
|
1515 |
+
return fsdp_model._handles # type: ignore
|
1516 |
+
elif version.parse(torch.__version__) < version.parse("2.3.0"):
|
1517 |
+
# Handle could be None if the FSDP wrapper doesn't manage any parameters.
|
1518 |
+
if hasattr(fsdp_model, "_handle") and fsdp_model._handle is not None:
|
1519 |
+
return [fsdp_model._handle] # type: ignore
|
1520 |
+
else:
|
1521 |
+
return []
|
1522 |
+
else:
|
1523 |
+
# Need to verify FSDP internals with newer versions.
|
1524 |
+
raise NotImplementedError
|
1525 |
+
|
1526 |
+
@torch.no_grad()
|
1527 |
+
def _get_flat_param_state_to_save(self, fsdp_model: FSDP) -> Dict[str, Any]:
|
1528 |
+
self._prepare_fsdp_model(fsdp_model)
|
1529 |
+
module_data = []
|
1530 |
+
for module_fqn, fsdp_module in self._fsdp_modules(fsdp_model):
|
1531 |
+
handle_data = []
|
1532 |
+
for handle in self._fsdp_handles(fsdp_module):
|
1533 |
+
data: Dict[str, Any] = {}
|
1534 |
+
# This is a `FlatParameter` instance.
|
1535 |
+
# See `torch.distributed.fsdp.flat_param` for the API.
|
1536 |
+
flat_param = handle.flat_param
|
1537 |
+
data["flat_param.data"] = flat_param.detach()
|
1538 |
+
for key in self._FLAT_PARAM_METADATA_TO_SAVE:
|
1539 |
+
if hasattr(flat_param, key):
|
1540 |
+
data[f"flat_param.{key}"] = getattr(flat_param, key)
|
1541 |
+
handle_data.append(data)
|
1542 |
+
module_data.append({"handles": handle_data, "name": module_fqn})
|
1543 |
+
return {"modules": module_data}
|
1544 |
+
|
1545 |
+
@torch.no_grad()
|
1546 |
+
def _load_flat_param_state(self, fsdp_model: FSDP, model_state: Dict[str, Any]):
|
1547 |
+
"""Load the state produced from `self._get_flat_param_state_to_save()`."""
|
1548 |
+
self._prepare_fsdp_model(fsdp_model)
|
1549 |
+
fsdp_modules = self._fsdp_modules(fsdp_model)
|
1550 |
+
assert len(model_state["modules"]) == len(fsdp_modules)
|
1551 |
+
for (_, fsdp_module), module_data in zip(fsdp_modules, model_state["modules"]):
|
1552 |
+
handles = self._fsdp_handles(fsdp_module)
|
1553 |
+
assert len(handles) == len(module_data["handles"])
|
1554 |
+
for handle, data in zip(handles, module_data["handles"]):
|
1555 |
+
flat_param = handle.flat_param
|
1556 |
+
# Make sure metadata matches.
|
1557 |
+
for key in self._FLAT_PARAM_METADATA_TO_SAVE:
|
1558 |
+
if hasattr(flat_param, key):
|
1559 |
+
assert getattr(flat_param, key) == data[f"flat_param.{key}"]
|
1560 |
+
# Load the flat sharded data.
|
1561 |
+
flat_param.copy_(data["flat_param.data"])
|
1562 |
+
|
1563 |
+
def _save_metadata(self, dir: PathOrStr, *, upload_to: Optional[str] = None) -> None:
|
1564 |
+
if get_fs_local_rank() == 0:
|
1565 |
+
log.info("Saving metadata...")
|
1566 |
+
metadata = _LocalShardedCheckpointerMetadata()
|
1567 |
+
metadata.save(metadata_path := Path(dir) / "metadata.yaml")
|
1568 |
+
if upload_to is not None and get_global_rank() == 0:
|
1569 |
+
upload_target = f"{upload_to}/metadata.yaml"
|
1570 |
+
log.info(f"Uploading {metadata_path} to {upload_target}")
|
1571 |
+
upload(metadata_path, upload_target, save_overwrite=self.cfg.save_overwrite)
|
1572 |
+
|
1573 |
+
def _load_metadata(
|
1574 |
+
self, load_path: PathOrStr, *, local_cache: Optional[PathOrStr] = None
|
1575 |
+
) -> _LocalShardedCheckpointerMetadata:
|
1576 |
+
metadata_path = resource_path(load_path, "metadata.yaml", local_cache=local_cache)
|
1577 |
+
return _LocalShardedCheckpointerMetadata.load(metadata_path)
|
1578 |
+
|
1579 |
+
def save_checkpoint(
|
1580 |
+
self,
|
1581 |
+
dir: PathOrStr,
|
1582 |
+
dist_model: nn.Module,
|
1583 |
+
optim: Optimizer,
|
1584 |
+
trainer_state: Dict[str, Any],
|
1585 |
+
*,
|
1586 |
+
upload_to: Optional[str] = None,
|
1587 |
+
) -> None:
|
1588 |
+
assert isinstance(
|
1589 |
+
dist_model, FSDP
|
1590 |
+
), f"{self.__class__.__name__} is being called to save a model where `distributed_strategy` is not FSDP."
|
1591 |
+
|
1592 |
+
with self._temporary_wd(dir) as checkpoint_dir:
|
1593 |
+
# Gather local FSDP flat params data to save.
|
1594 |
+
# We also save some flat param metadata like the corresponding fully qualified names (fqns)
|
1595 |
+
# of each original parameter so we can validate that the sharding is the same when loading
|
1596 |
+
# one of these checkpoints.
|
1597 |
+
log.info("Saving local FSDP flat params data...")
|
1598 |
+
save_state_dict(
|
1599 |
+
checkpoint_dir,
|
1600 |
+
f"model/rank{get_global_rank()}.pt",
|
1601 |
+
self._get_flat_param_state_to_save(dist_model),
|
1602 |
+
upload_to=upload_to,
|
1603 |
+
save_overwrite=self.cfg.save_overwrite,
|
1604 |
+
)
|
1605 |
+
|
1606 |
+
# Save optimizer state.
|
1607 |
+
log.info("Saving local optimizer state...")
|
1608 |
+
save_state_dict(
|
1609 |
+
checkpoint_dir,
|
1610 |
+
f"optim/rank{get_global_rank()}.pt",
|
1611 |
+
optim.state_dict(),
|
1612 |
+
upload_to=upload_to,
|
1613 |
+
save_overwrite=self.cfg.save_overwrite,
|
1614 |
+
)
|
1615 |
+
|
1616 |
+
# Save trainer state.
|
1617 |
+
log.info("Saving trainer state...")
|
1618 |
+
save_state_dict(
|
1619 |
+
checkpoint_dir,
|
1620 |
+
f"train/rank{get_global_rank()}.pt",
|
1621 |
+
trainer_state,
|
1622 |
+
upload_to=upload_to,
|
1623 |
+
save_overwrite=self.cfg.save_overwrite,
|
1624 |
+
)
|
1625 |
+
|
1626 |
+
# Save metadata.
|
1627 |
+
self._save_metadata(checkpoint_dir, upload_to=upload_to)
|
1628 |
+
|
1629 |
+
# Save config. We do this last b/c the presence of a config in a remote checkpoint
|
1630 |
+
# "directory" indicates that the folder is valid, as a opposed to a partially
|
1631 |
+
# uploaded checkpoint directory that failed before completing.
|
1632 |
+
self._save_config(checkpoint_dir, upload_to=upload_to)
|
1633 |
+
|
1634 |
+
def restore_checkpoint(
|
1635 |
+
self,
|
1636 |
+
load_path: PathOrStr,
|
1637 |
+
dist_model: nn.Module,
|
1638 |
+
optim: Optimizer,
|
1639 |
+
*,
|
1640 |
+
local_cache: Optional[PathOrStr] = None,
|
1641 |
+
load_optimizer_state: bool = True,
|
1642 |
+
) -> Dict[str, Any]:
|
1643 |
+
# Load metadata and make sure checkpoint is compatible.
|
1644 |
+
metadata = self._load_metadata(load_path, local_cache=local_cache)
|
1645 |
+
assert metadata.world_size == get_world_size()
|
1646 |
+
|
1647 |
+
# Load local FSDP flat param data.
|
1648 |
+
log.info("Loading local FSDP flat params data...")
|
1649 |
+
assert isinstance(
|
1650 |
+
dist_model, FSDP
|
1651 |
+
), f"{self.__class__.__name__} is being called to load a model where `distributed_strategy` is not FSDP."
|
1652 |
+
|
1653 |
+
model_state = load_state_dict(
|
1654 |
+
load_path, f"model/rank{get_global_rank()}.pt", local_cache=local_cache, map_location="cpu"
|
1655 |
+
)
|
1656 |
+
self._load_flat_param_state(dist_model, model_state)
|
1657 |
+
del model_state
|
1658 |
+
|
1659 |
+
# Load local optim state.
|
1660 |
+
if load_optimizer_state:
|
1661 |
+
log.info("Loading local optimizer state...")
|
1662 |
+
optim_state = load_state_dict(
|
1663 |
+
load_path, f"optim/rank{get_global_rank()}.pt", local_cache=local_cache, map_location="cpu"
|
1664 |
+
)
|
1665 |
+
# HACK/TODO (epwalsh): When we use adaptive clipping we track the 'grad_norm_exp_avg' for every param
|
1666 |
+
# in every rank, and keep this in the optimizer state. But this causes issues when loading the
|
1667 |
+
# state since torch sees the state is non-empty for some params which would normally be empty,
|
1668 |
+
# and then assumes it should have all of the other state tensors for that param, which is doesn't.
|
1669 |
+
# So for now we just remove 'grad_norm_exp_avg' everywhere from the state, which resets that metric.
|
1670 |
+
# Not the end of the world but there's probably a better way around this without resetting
|
1671 |
+
# the metric.
|
1672 |
+
for param_id in list(optim_state["state"].keys()):
|
1673 |
+
state = optim_state["state"][param_id]
|
1674 |
+
if "grad_norm_exp_avg" in state:
|
1675 |
+
del state["grad_norm_exp_avg"]
|
1676 |
+
if len(state) == 0:
|
1677 |
+
del optim_state["state"][param_id]
|
1678 |
+
optim.load_state_dict(optim_state)
|
1679 |
+
del optim_state
|
1680 |
+
|
1681 |
+
# Load local trainer state.
|
1682 |
+
log.info("Loading local trainer state...")
|
1683 |
+
trainer_state = load_state_dict(load_path, f"train/rank{get_global_rank()}.pt", local_cache=local_cache)
|
1684 |
+
barrier()
|
1685 |
+
return trainer_state
|
1686 |
+
|
1687 |
+
def _iter_flat_param_shards(
|
1688 |
+
self, model_state: Dict[str, Any]
|
1689 |
+
) -> Generator[Tuple[str, _FlatParamShard], None, None]:
|
1690 |
+
for module_data in model_state["modules"]:
|
1691 |
+
module_prefix = module_data["name"].replace("_fsdp_wrapped_module.", "")
|
1692 |
+
for handle in module_data["handles"]:
|
1693 |
+
flat_data = handle["flat_param.data"]
|
1694 |
+
if (num_padding := handle["flat_param._shard_numel_padded"]) > 0:
|
1695 |
+
# If there's padding in the flat param it should be on the right.
|
1696 |
+
assert (flat_data[-num_padding:] == 0).all()
|
1697 |
+
# NOTE: this changes depending on the torch version, but we don't do a version
|
1698 |
+
# check since we might be trying to unshard an old checkpoint that was stored
|
1699 |
+
# with a different torch version than we're currently running with.
|
1700 |
+
if "flat_param._shard_indices" in handle:
|
1701 |
+
# torch <=2.0.1
|
1702 |
+
param_start = handle["flat_param._shard_indices"][0]
|
1703 |
+
current_flat_index = 0
|
1704 |
+
for relative_fqn, full_shape, (offset_start, offset_end) in zip(
|
1705 |
+
handle["flat_param._fqns"][param_start:],
|
1706 |
+
handle["flat_param._shapes"][param_start:],
|
1707 |
+
handle["flat_param._shard_param_offsets"],
|
1708 |
+
):
|
1709 |
+
root_fqn = relative_fqn if not module_prefix else f"{module_prefix}.{relative_fqn}"
|
1710 |
+
numel_shard = offset_end - offset_start + 1
|
1711 |
+
flat_param_shard = _FlatParamShard(
|
1712 |
+
full_shape=full_shape,
|
1713 |
+
shard_offsets=(offset_start, offset_end),
|
1714 |
+
shard_data=flat_data[current_flat_index : current_flat_index + numel_shard],
|
1715 |
+
)
|
1716 |
+
current_flat_index += numel_shard
|
1717 |
+
yield root_fqn, flat_param_shard
|
1718 |
+
else:
|
1719 |
+
# torch >=2.1.0
|
1720 |
+
for relative_fqn, full_shape, shard_param_info in zip(
|
1721 |
+
handle["flat_param._fqns"],
|
1722 |
+
handle["flat_param._shapes"],
|
1723 |
+
handle["flat_param._shard_param_infos"],
|
1724 |
+
):
|
1725 |
+
if not shard_param_info.in_shard:
|
1726 |
+
continue
|
1727 |
+
root_fqn = relative_fqn if not module_prefix else f"{module_prefix}.{relative_fqn}"
|
1728 |
+
flat_param_shard = _FlatParamShard(
|
1729 |
+
full_shape=full_shape,
|
1730 |
+
shard_offsets=(
|
1731 |
+
shard_param_info.intra_param_start_idx,
|
1732 |
+
shard_param_info.intra_param_end_idx,
|
1733 |
+
),
|
1734 |
+
shard_data=flat_data[
|
1735 |
+
shard_param_info.offset_in_shard : shard_param_info.offset_in_shard
|
1736 |
+
+ shard_param_info.numel_in_shard
|
1737 |
+
],
|
1738 |
+
)
|
1739 |
+
yield root_fqn, flat_param_shard
|
1740 |
+
|
1741 |
+
def unshard_checkpoint(
|
1742 |
+
self,
|
1743 |
+
load_path: PathOrStr,
|
1744 |
+
*,
|
1745 |
+
local_cache: Optional[PathOrStr] = None,
|
1746 |
+
load_optimizer_state: bool = True,
|
1747 |
+
load_trainer_state: bool = True,
|
1748 |
+
device: Optional[torch.device] = None,
|
1749 |
+
) -> Tuple[Dict[str, torch.Tensor], Optional[Dict[str, Any]], Optional[Dict[str, Any]]]:
|
1750 |
+
device = device or torch.device("cpu")
|
1751 |
+
metadata = self._load_metadata(load_path, local_cache=local_cache)
|
1752 |
+
|
1753 |
+
# Gather paths model state, potentially downloading them.
|
1754 |
+
log.info("Gathering model state dicts...")
|
1755 |
+
model_state_paths = self._gather_state_dict_paths(
|
1756 |
+
load_path, "model", metadata.world_size, local_cache=local_cache
|
1757 |
+
)
|
1758 |
+
|
1759 |
+
# Load model state dicts one-by-one, materializing and populating the full parameters as we go.
|
1760 |
+
log.info("Materializing full parameters...")
|
1761 |
+
full_model_state: Dict[str, torch.Tensor] = {}
|
1762 |
+
# We keep a copy of the flat param metadata minus the actual tensors so we can reconstruct
|
1763 |
+
# the full optimizer state below without having to reload the model state dicts.
|
1764 |
+
flat_params_data: Dict[int, Dict[str, _FlatParamShard]] = defaultdict(dict)
|
1765 |
+
for rank, path in enumerate(model_state_paths):
|
1766 |
+
log.info(f"Loading shards from rank {rank}...")
|
1767 |
+
model_state = torch.load(path, map_location="cpu")
|
1768 |
+
for root_fqn, flat_param_shard in self._iter_flat_param_shards(model_state):
|
1769 |
+
if root_fqn not in full_model_state:
|
1770 |
+
log.info(
|
1771 |
+
f"Materializing full parameter '{root_fqn}' with shape {flat_param_shard.full_shape}..."
|
1772 |
+
)
|
1773 |
+
assert flat_param_shard.shard_data is not None
|
1774 |
+
full_model_state[root_fqn] = torch.empty(
|
1775 |
+
flat_param_shard.full_shape, dtype=flat_param_shard.shard_data.dtype, device=device
|
1776 |
+
)
|
1777 |
+
# Fill with NaNs so we can validate that the whole parameter has been populated
|
1778 |
+
# afterwards.
|
1779 |
+
full_model_state[root_fqn].fill_(torch.nan)
|
1780 |
+
# Copy over the local shard to the relevant part of the full parameter.
|
1781 |
+
full_param = full_model_state[root_fqn]
|
1782 |
+
log.info(f"Loading rank {rank} shard for '{root_fqn}'...")
|
1783 |
+
flat_param_shard.copy_into(full_param)
|
1784 |
+
flat_params_data[rank][root_fqn] = replace(flat_param_shard, shard_data=None)
|
1785 |
+
|
1786 |
+
log.info("Validating full parameters...")
|
1787 |
+
for key, tensor in full_model_state.items():
|
1788 |
+
if torch.isnan(tensor).any():
|
1789 |
+
raise ValueError(f"Parameter '{key}' contains NaNs, this is likely a bug with the unsharder")
|
1790 |
+
|
1791 |
+
trainer_state: Optional[Dict[str, Any]] = None
|
1792 |
+
if load_trainer_state:
|
1793 |
+
trainer_state = load_state_dict(load_path, "train/rank0.pt", local_cache=local_cache)
|
1794 |
+
|
1795 |
+
if not load_optimizer_state:
|
1796 |
+
return full_model_state, None, trainer_state
|
1797 |
+
|
1798 |
+
log.info("Gathering optim state dicts...")
|
1799 |
+
optim_state_paths = self._gather_state_dict_paths(
|
1800 |
+
load_path, "optim", metadata.world_size, local_cache=local_cache
|
1801 |
+
)
|
1802 |
+
|
1803 |
+
log.info("Materializing full optim state...")
|
1804 |
+
full_optim_state: Dict[str, Any] = {"state": defaultdict(dict)}
|
1805 |
+
fqn_to_id: Dict[str, int] = {}
|
1806 |
+
id_to_fqn: Dict[int, str] = {}
|
1807 |
+
for rank, path in enumerate(optim_state_paths):
|
1808 |
+
log.info(f"Loading sharded optim state from rank {rank}...")
|
1809 |
+
optim_state = torch.load(path, map_location="cpu")
|
1810 |
+
|
1811 |
+
# Initialize param groups.
|
1812 |
+
# We assume parameter groups are the same across all ranks.
|
1813 |
+
# The only thing that differs across ranks is the state for each local sharded param.
|
1814 |
+
if "param_groups" not in full_optim_state:
|
1815 |
+
full_optim_state["param_groups"] = optim_state["param_groups"]
|
1816 |
+
else:
|
1817 |
+
assert full_optim_state["param_groups"] == optim_state["param_groups"]
|
1818 |
+
|
1819 |
+
# Generate mapping of parameter FQNs to optimizer param IDs and vice-versa.
|
1820 |
+
if not fqn_to_id or not id_to_fqn:
|
1821 |
+
for group in full_optim_state["param_groups"]:
|
1822 |
+
for fqn, id in zip(group["param_names"], group["params"]):
|
1823 |
+
fqn = fqn.replace("_fsdp_wrapped_module.", "")
|
1824 |
+
fqn_to_id[fqn] = id
|
1825 |
+
id_to_fqn[id] = fqn
|
1826 |
+
|
1827 |
+
# Iterate over local shard state and copy into the full state.
|
1828 |
+
for id, shard_state in optim_state["state"].items():
|
1829 |
+
fqn = id_to_fqn[id]
|
1830 |
+
flat_param_shard = flat_params_data[rank].get(fqn) # type: ignore[assignment]
|
1831 |
+
full_state = full_optim_state["state"][id]
|
1832 |
+
for key, shard_value in shard_state.items():
|
1833 |
+
assert isinstance(shard_value, torch.Tensor)
|
1834 |
+
if shard_value.shape == torch.Size([]):
|
1835 |
+
# Add singleton tensors directly to full state. These should be the same across
|
1836 |
+
# all ranks.
|
1837 |
+
assert key in ("step", "grad_norm_exp_avg") # sanity check
|
1838 |
+
if key not in full_state:
|
1839 |
+
full_state[key] = shard_value.to(device)
|
1840 |
+
else:
|
1841 |
+
assert full_state[key] == shard_value
|
1842 |
+
else:
|
1843 |
+
# Otherwise we have a sharded param state.
|
1844 |
+
# If the corresponding full param state hasn't been materialized yet, do so now.
|
1845 |
+
assert flat_param_shard is not None, f"missing flat_params_data for {fqn} from rank {rank}"
|
1846 |
+
if key not in full_state:
|
1847 |
+
log.info(
|
1848 |
+
f"Materializing full state '{key}' for '{fqn}' with shape {flat_param_shard.full_shape}..."
|
1849 |
+
)
|
1850 |
+
full_state[key] = torch.empty(
|
1851 |
+
flat_param_shard.full_shape, dtype=shard_value.dtype, device=device
|
1852 |
+
)
|
1853 |
+
full_state_value = full_state[key]
|
1854 |
+
|
1855 |
+
# Copy over the local shard state to the relevant part of the full parameter state.
|
1856 |
+
log.info(f"Loading rank {rank} shard state of '{key}' for '{fqn}'...")
|
1857 |
+
replace(flat_param_shard, shard_data=shard_value).copy_into(full_state_value)
|
1858 |
+
|
1859 |
+
# Lastly, clean up the parameter names in param groups.
|
1860 |
+
for group in full_optim_state["param_groups"]:
|
1861 |
+
group["param_names"] = [n.replace("_fsdp_wrapped_module.", "") for n in group["param_names"]]
|
1862 |
+
|
1863 |
+
return full_model_state, full_optim_state, trainer_state
|
1864 |
+
|
1865 |
+
def _get_state_dict_path(
|
1866 |
+
self,
|
1867 |
+
load_path: PathOrStr,
|
1868 |
+
state_dict_type: str,
|
1869 |
+
rank: int,
|
1870 |
+
*,
|
1871 |
+
local_cache: Optional[PathOrStr] = None,
|
1872 |
+
progress=None,
|
1873 |
+
) -> Tuple[int, Path]:
|
1874 |
+
fname = f"{state_dict_type}/rank{rank}.pt"
|
1875 |
+
return rank, resource_path(str(load_path).rstrip("/"), fname, local_cache=local_cache, progress=progress)
|
1876 |
+
|
1877 |
+
def _gather_state_dict_paths(
|
1878 |
+
self,
|
1879 |
+
load_path: PathOrStr,
|
1880 |
+
state_dict_type: str,
|
1881 |
+
world_size: int,
|
1882 |
+
*,
|
1883 |
+
local_cache: Optional[PathOrStr] = None,
|
1884 |
+
) -> List[Path]:
|
1885 |
+
progress = get_progress_bar()
|
1886 |
+
with ThreadPoolExecutor(max_workers=self.thread_count) as executor:
|
1887 |
+
futures = []
|
1888 |
+
for rank in range(world_size):
|
1889 |
+
future = executor.submit(
|
1890 |
+
self._get_state_dict_path,
|
1891 |
+
load_path,
|
1892 |
+
state_dict_type,
|
1893 |
+
rank,
|
1894 |
+
local_cache=local_cache,
|
1895 |
+
progress=progress,
|
1896 |
+
)
|
1897 |
+
futures.append(future)
|
1898 |
+
|
1899 |
+
results: Dict[int, Path] = {}
|
1900 |
+
for future in as_completed(futures):
|
1901 |
+
rank, path = future.result()
|
1902 |
+
results[rank] = path
|
1903 |
+
|
1904 |
+
return [results[rank] for rank in range(world_size)]
|
1905 |
+
|
1906 |
+
|
1907 |
+
class OlmoCoreCheckpointer(Checkpointer):
|
1908 |
+
def save_checkpoint(
|
1909 |
+
self,
|
1910 |
+
dir: PathOrStr,
|
1911 |
+
dist_model: nn.Module,
|
1912 |
+
optim: Optimizer,
|
1913 |
+
trainer_state: Dict[str, Any],
|
1914 |
+
*,
|
1915 |
+
upload_to: Optional[str] = None,
|
1916 |
+
) -> None:
|
1917 |
+
from olmo_core.distributed.checkpoint import ( # type: ignore
|
1918 |
+
save_model_and_optim_state,
|
1919 |
+
)
|
1920 |
+
|
1921 |
+
with self._temporary_wd(dir) as checkpoint_dir:
|
1922 |
+
log.info("Saving model and optim state...")
|
1923 |
+
if get_fs_local_rank() == 0:
|
1924 |
+
(checkpoint_dir / "model").mkdir(exist_ok=True, parents=True)
|
1925 |
+
(checkpoint_dir / "optim").mkdir(exist_ok=True, parents=True)
|
1926 |
+
(checkpoint_dir / "train").mkdir(exist_ok=True, parents=True)
|
1927 |
+
|
1928 |
+
wait_for(
|
1929 |
+
lambda: (checkpoint_dir / "model").exists(), "Waiting for checkpoint model directory", timeout=10.0
|
1930 |
+
)
|
1931 |
+
wait_for(
|
1932 |
+
lambda: (checkpoint_dir / "optim").exists(), "Waiting for checkpoint optim directory", timeout=10.0
|
1933 |
+
)
|
1934 |
+
wait_for(
|
1935 |
+
lambda: (checkpoint_dir / "train").exists(), "Waiting for checkpoint train directory", timeout=10.0
|
1936 |
+
)
|
1937 |
+
|
1938 |
+
local_files_created = save_model_and_optim_state(checkpoint_dir, dist_model, optim)
|
1939 |
+
if upload_to is not None:
|
1940 |
+
for path in local_files_created:
|
1941 |
+
path = Path(path)
|
1942 |
+
upload_target = f"{upload_to.rstrip('/')}/{path.relative_to(checkpoint_dir)}"
|
1943 |
+
log.info(f"Uploading {path} to {upload_target}...")
|
1944 |
+
upload(path, upload_target, save_overwrite=self.cfg.save_overwrite)
|
1945 |
+
|
1946 |
+
log.info("Saving trainer state...")
|
1947 |
+
save_state_dict(
|
1948 |
+
checkpoint_dir,
|
1949 |
+
f"train/rank{get_global_rank()}.pt",
|
1950 |
+
trainer_state,
|
1951 |
+
upload_to=upload_to,
|
1952 |
+
save_overwrite=self.cfg.save_overwrite,
|
1953 |
+
)
|
1954 |
+
|
1955 |
+
self._save_config(checkpoint_dir, upload_to=upload_to)
|
1956 |
+
|
1957 |
+
def restore_checkpoint(
|
1958 |
+
self,
|
1959 |
+
load_path: PathOrStr,
|
1960 |
+
dist_model: nn.Module,
|
1961 |
+
optim: Optimizer,
|
1962 |
+
*,
|
1963 |
+
local_cache: Optional[PathOrStr] = None,
|
1964 |
+
load_optimizer_state: bool = True,
|
1965 |
+
) -> Dict[str, Any]:
|
1966 |
+
from olmo_core.distributed.checkpoint import ( # type: ignore
|
1967 |
+
load_model_and_optim_state,
|
1968 |
+
)
|
1969 |
+
|
1970 |
+
log.info("Loading model and optim state...")
|
1971 |
+
load_model_and_optim_state(load_path, dist_model, optim if load_optimizer_state else None)
|
1972 |
+
|
1973 |
+
log.info("Loading trainer state...")
|
1974 |
+
try:
|
1975 |
+
trainer_state = load_state_dict(
|
1976 |
+
load_path, f"train/rank{get_global_rank()}.pt", local_cache=local_cache
|
1977 |
+
)
|
1978 |
+
except FileNotFoundError:
|
1979 |
+
# Fall back to rank 0 train state.
|
1980 |
+
# This can happen when we're restoring a checkpoint with a different world size.
|
1981 |
+
trainer_state = load_state_dict(load_path, "train/rank0.pt", local_cache=local_cache)
|
1982 |
+
|
1983 |
+
barrier()
|
1984 |
+
return trainer_state
|
1985 |
+
|
1986 |
+
def unshard_checkpoint(
|
1987 |
+
self,
|
1988 |
+
load_path: PathOrStr,
|
1989 |
+
*,
|
1990 |
+
local_cache: Optional[PathOrStr] = None,
|
1991 |
+
load_optimizer_state: bool = True,
|
1992 |
+
load_trainer_state: bool = True,
|
1993 |
+
device: Optional[torch.device] = None,
|
1994 |
+
) -> Tuple[Dict[str, torch.Tensor], Optional[Dict[str, Any]], Optional[Dict[str, Any]]]:
|
1995 |
+
from olmo_core.distributed.checkpoint import ( # type: ignore
|
1996 |
+
unshard_model_state,
|
1997 |
+
unshard_optim_state,
|
1998 |
+
)
|
1999 |
+
|
2000 |
+
model_state = unshard_model_state(load_path, device=device)
|
2001 |
+
optim_state: Optional[Dict[str, Any]] = None
|
2002 |
+
train_state: Optional[Dict[str, Any]] = None
|
2003 |
+
if load_optimizer_state:
|
2004 |
+
optim_state = cast(Dict[str, Any], unshard_optim_state(load_path, device=device))
|
2005 |
+
if load_trainer_state:
|
2006 |
+
train_state = load_state_dict(load_path, "train/rank0.pt", local_cache=local_cache)
|
2007 |
+
return model_state, optim_state, train_state
|
2008 |
+
|
2009 |
+
|
2010 |
+
def build_sharded_checkpointer(
|
2011 |
+
cfg: TrainConfig, *, name: Optional[ShardedCheckpointerType] = None, use_shared_mem_impl: bool = False
|
2012 |
+
) -> Checkpointer:
|
2013 |
+
name = name or cfg.sharded_checkpointer
|
2014 |
+
if name == ShardedCheckpointerType.torch_new:
|
2015 |
+
return TorchNewStyleShardedCheckpointer(cfg)
|
2016 |
+
elif name == ShardedCheckpointerType.torch_legacy:
|
2017 |
+
return TorchLegacyShardedCheckpointer(cfg, use_shared_mem_impl=use_shared_mem_impl)
|
2018 |
+
elif name == ShardedCheckpointerType.local:
|
2019 |
+
return LocalShardedCheckpointer(cfg)
|
2020 |
+
elif name == ShardedCheckpointerType.olmo_core:
|
2021 |
+
return OlmoCoreCheckpointer(cfg)
|
2022 |
+
else:
|
2023 |
+
raise NotImplementedError(name)
|
config.json
CHANGED
@@ -1,12 +1,10 @@
|
|
1 |
{
|
2 |
-
"_name_or_path": "/nfs100/dongyh/FANformer-1B",
|
3 |
"activation_type": "swiglu",
|
4 |
"alibi": false,
|
5 |
"alibi_bias_max": 8.0,
|
6 |
"architectures": [
|
7 |
"OLMoForCausalLM"
|
8 |
],
|
9 |
-
"att_nolinear": false,
|
10 |
"attention_activation": null,
|
11 |
"attention_dropout": 0.0,
|
12 |
"attention_layer_norm": false,
|
@@ -25,7 +23,6 @@
|
|
25 |
"embedding_layer_norm": false,
|
26 |
"embedding_size": 50304,
|
27 |
"eos_token_id": 50279,
|
28 |
-
"ffn_activation": null,
|
29 |
"flash_attention": true,
|
30 |
"include_bias": false,
|
31 |
"init_cutoff_factor": null,
|
@@ -55,17 +52,9 @@
|
|
55 |
"rope_theta": 10000,
|
56 |
"scale_emb_init": false,
|
57 |
"scale_logits": false,
|
58 |
-
"
|
59 |
-
"transformers_version": "4.49.0",
|
60 |
-
"use_A": false,
|
61 |
-
"use_ATF": true,
|
62 |
"use_cache": true,
|
63 |
-
"
|
64 |
-
"use_fpneq": false,
|
65 |
-
"use_fpnnow": false,
|
66 |
-
"use_fpnpn": false,
|
67 |
-
"use_mod": false,
|
68 |
-
"use_mod_ffn": 0,
|
69 |
"vocab_size": 50280,
|
70 |
"weight_tying": true
|
71 |
}
|
|
|
1 |
{
|
|
|
2 |
"activation_type": "swiglu",
|
3 |
"alibi": false,
|
4 |
"alibi_bias_max": 8.0,
|
5 |
"architectures": [
|
6 |
"OLMoForCausalLM"
|
7 |
],
|
|
|
8 |
"attention_activation": null,
|
9 |
"attention_dropout": 0.0,
|
10 |
"attention_layer_norm": false,
|
|
|
23 |
"embedding_layer_norm": false,
|
24 |
"embedding_size": 50304,
|
25 |
"eos_token_id": 50279,
|
|
|
26 |
"flash_attention": true,
|
27 |
"include_bias": false,
|
28 |
"init_cutoff_factor": null,
|
|
|
52 |
"rope_theta": 10000,
|
53 |
"scale_emb_init": false,
|
54 |
"scale_logits": false,
|
55 |
+
"transformers_version": "4.46.0",
|
|
|
|
|
|
|
56 |
"use_cache": true,
|
57 |
+
"use_ATF": true,
|
|
|
|
|
|
|
|
|
|
|
58 |
"vocab_size": 50280,
|
59 |
"weight_tying": true
|
60 |
}
|
config.py
ADDED
@@ -0,0 +1,1371 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
from copy import deepcopy
|
4 |
+
from dataclasses import asdict, dataclass, field
|
5 |
+
from glob import glob
|
6 |
+
from pathlib import Path
|
7 |
+
from typing import (
|
8 |
+
Any,
|
9 |
+
Dict,
|
10 |
+
Iterable,
|
11 |
+
List,
|
12 |
+
Optional,
|
13 |
+
Tuple,
|
14 |
+
Type,
|
15 |
+
TypeVar,
|
16 |
+
Union,
|
17 |
+
cast,
|
18 |
+
)
|
19 |
+
|
20 |
+
import numpy as np
|
21 |
+
import torch
|
22 |
+
from omegaconf import DictConfig, ListConfig
|
23 |
+
from omegaconf import OmegaConf as om
|
24 |
+
from omegaconf.errors import OmegaConfBaseException
|
25 |
+
from torch.distributed.fsdp import MixedPrecision, ShardingStrategy
|
26 |
+
|
27 |
+
from .aliases import PathOrStr
|
28 |
+
from .exceptions import OLMoConfigurationError
|
29 |
+
from .util import StrEnum
|
30 |
+
|
31 |
+
__all__ = [
|
32 |
+
"ActivationType",
|
33 |
+
"ActivationCheckpointingStrategy",
|
34 |
+
"BlockType",
|
35 |
+
"LayerNormType",
|
36 |
+
"InitFnType",
|
37 |
+
"ModelConfig",
|
38 |
+
"OptimizerType",
|
39 |
+
"OptimizerConfig",
|
40 |
+
"SchedulerType",
|
41 |
+
"SchedulerConfig",
|
42 |
+
"DataConfig",
|
43 |
+
"InstanceFilterConfig",
|
44 |
+
"EvaluatorConfig",
|
45 |
+
"TokenizerConfig",
|
46 |
+
"TrainConfig",
|
47 |
+
"PaddingDirection",
|
48 |
+
"TruncationDirection",
|
49 |
+
"SpeedMonitorConfig",
|
50 |
+
"WandbConfig",
|
51 |
+
"CompilerConfig",
|
52 |
+
"WandbConfig",
|
53 |
+
"DDPConfig",
|
54 |
+
"DistributedStrategy",
|
55 |
+
"DDPGradSyncMode",
|
56 |
+
"FSDPPrecision",
|
57 |
+
"FSDPWrapStrategy",
|
58 |
+
"FSDPConfig",
|
59 |
+
"SingleGPUConfig",
|
60 |
+
"CheckpointType",
|
61 |
+
]
|
62 |
+
|
63 |
+
C = TypeVar("C", bound="BaseConfig")
|
64 |
+
D = TypeVar("D", bound="DictConfig|ListConfig")
|
65 |
+
|
66 |
+
|
67 |
+
class BaseConfig:
|
68 |
+
@classmethod
|
69 |
+
def _register_resolvers(cls, validate_paths: bool = True):
|
70 |
+
# Expands path globs into a list.
|
71 |
+
def path_glob(*paths) -> List[str]:
|
72 |
+
out = []
|
73 |
+
for path in paths:
|
74 |
+
matches = sorted(glob(path))
|
75 |
+
if not matches and validate_paths:
|
76 |
+
raise FileNotFoundError(f"{path} does not match any files or dirs")
|
77 |
+
out.extend(matches)
|
78 |
+
return out
|
79 |
+
|
80 |
+
# Chooses the first path in the arguments that exists.
|
81 |
+
def path_choose(*paths) -> str:
|
82 |
+
from .util import is_url
|
83 |
+
|
84 |
+
for path in paths:
|
85 |
+
if is_url(path) or Path(path).exists():
|
86 |
+
return path
|
87 |
+
if validate_paths:
|
88 |
+
raise FileNotFoundError(", ".join(paths))
|
89 |
+
else:
|
90 |
+
return ""
|
91 |
+
|
92 |
+
# Finds the latest checkpoint in a folder.
|
93 |
+
def path_last_checkpoint(path) -> str:
|
94 |
+
from .util import find_latest_checkpoint
|
95 |
+
|
96 |
+
latest_checkpoint = find_latest_checkpoint(path)
|
97 |
+
if latest_checkpoint is None:
|
98 |
+
if validate_paths:
|
99 |
+
raise FileNotFoundError(f"Could not find a latest checkpoint at {path}")
|
100 |
+
else:
|
101 |
+
return ""
|
102 |
+
else:
|
103 |
+
return str(latest_checkpoint)
|
104 |
+
|
105 |
+
om.register_new_resolver("path.glob", path_glob, replace=True)
|
106 |
+
om.register_new_resolver("path.choose", path_choose, replace=True)
|
107 |
+
om.register_new_resolver("path.last_checkpoint", path_last_checkpoint, replace=True)
|
108 |
+
|
109 |
+
@classmethod
|
110 |
+
def update_legacy_settings(cls, config: D) -> D:
|
111 |
+
"""
|
112 |
+
Update the legacy config settings whose schemas have undergone backwards-incompatible changes.
|
113 |
+
"""
|
114 |
+
return config
|
115 |
+
|
116 |
+
@classmethod
|
117 |
+
def new(cls: Type[C], **kwargs) -> C:
|
118 |
+
cls._register_resolvers()
|
119 |
+
conf = om.structured(cls)
|
120 |
+
try:
|
121 |
+
if kwargs:
|
122 |
+
conf = om.merge(conf, kwargs)
|
123 |
+
return cast(C, om.to_object(conf))
|
124 |
+
except OmegaConfBaseException as e:
|
125 |
+
raise OLMoConfigurationError(str(e))
|
126 |
+
|
127 |
+
@classmethod
|
128 |
+
def load(
|
129 |
+
cls: Type[C],
|
130 |
+
path: PathOrStr,
|
131 |
+
overrides: Optional[List[str]] = None,
|
132 |
+
key: Optional[str] = None,
|
133 |
+
validate_paths: bool = True,
|
134 |
+
) -> C:
|
135 |
+
"""Load from a YAML file."""
|
136 |
+
cls._register_resolvers(validate_paths=validate_paths)
|
137 |
+
schema = om.structured(cls)
|
138 |
+
try:
|
139 |
+
raw = om.load(str(path))
|
140 |
+
if key is not None:
|
141 |
+
raw = raw[key] # type: ignore
|
142 |
+
raw = cls.update_legacy_settings(raw)
|
143 |
+
conf = om.merge(schema, raw)
|
144 |
+
if overrides:
|
145 |
+
conf = om.merge(conf, om.from_dotlist(overrides))
|
146 |
+
return cast(C, om.to_object(conf))
|
147 |
+
except OmegaConfBaseException as e:
|
148 |
+
raise OLMoConfigurationError(str(e))
|
149 |
+
|
150 |
+
def save(self, path: PathOrStr) -> None:
|
151 |
+
"""Save to a YAML file."""
|
152 |
+
om.save(config=self, f=str(path))
|
153 |
+
|
154 |
+
def asdict(self, exclude: Optional[Iterable[str]] = None) -> Dict[str, Any]:
|
155 |
+
out = asdict(self) # type: ignore
|
156 |
+
if exclude is not None:
|
157 |
+
for name in exclude:
|
158 |
+
if name in out:
|
159 |
+
del out[name]
|
160 |
+
return out
|
161 |
+
|
162 |
+
def update_with(self, **kwargs):
|
163 |
+
result = deepcopy(self)
|
164 |
+
for key, value in kwargs.items():
|
165 |
+
setattr(result, key, value)
|
166 |
+
return result
|
167 |
+
|
168 |
+
|
169 |
+
class LayerNormType(StrEnum):
|
170 |
+
default = "default"
|
171 |
+
"""
|
172 |
+
The default LayerNorm implementation, equivalent to PyTorch's built-in version.
|
173 |
+
"""
|
174 |
+
|
175 |
+
low_precision = "low_precision"
|
176 |
+
"""
|
177 |
+
A low-precision version of the default LayerNorm.
|
178 |
+
"""
|
179 |
+
|
180 |
+
rms = "rms"
|
181 |
+
"""
|
182 |
+
An RMSNorm implementation. When using ``torch.compile`` this is
|
183 |
+
probably the fastest implementation.
|
184 |
+
"""
|
185 |
+
|
186 |
+
|
187 |
+
class ActivationType(StrEnum):
|
188 |
+
gelu = "gelu"
|
189 |
+
relu = "relu"
|
190 |
+
swiglu = "swiglu"
|
191 |
+
|
192 |
+
|
193 |
+
class BlockType(StrEnum):
|
194 |
+
sequential = "sequential"
|
195 |
+
|
196 |
+
llama = "llama"
|
197 |
+
"""
|
198 |
+
A block similar to the sequential block with slightly different
|
199 |
+
implementations of operations like attention to imitate the behavior of Llama.
|
200 |
+
"""
|
201 |
+
|
202 |
+
|
203 |
+
class InitFnType(StrEnum):
|
204 |
+
mitchell = "mitchell"
|
205 |
+
"""
|
206 |
+
The strategy suggested to us by Mitchell Wortsman from UW.
|
207 |
+
This uses a truncated normal distribution with an adaptive standard deviation that depends
|
208 |
+
on the size of the weights as well as the depth of the layer.
|
209 |
+
"""
|
210 |
+
|
211 |
+
normal = "normal"
|
212 |
+
"""
|
213 |
+
All weights are initialized from the same normal distribution.
|
214 |
+
"""
|
215 |
+
|
216 |
+
kaiming_normal = "kaiming_normal"
|
217 |
+
"""
|
218 |
+
All weights are initialized with the Kaiming method from a normal distribution.
|
219 |
+
Note this currently won't work with FSDP.
|
220 |
+
"""
|
221 |
+
|
222 |
+
fan_in = "fan_in"
|
223 |
+
"""
|
224 |
+
"Fan-in variance scaling", i.e. normal with a standard deviation of ``1/sqrt(d_in)`` where ``d_in``
|
225 |
+
is the input dimensionality of the kernel.
|
226 |
+
"""
|
227 |
+
|
228 |
+
full_megatron = "full_megatron"
|
229 |
+
"""
|
230 |
+
This is what metaseq calls "full megatron init". It is the init used for Llama 2.
|
231 |
+
"""
|
232 |
+
|
233 |
+
|
234 |
+
@dataclass
|
235 |
+
class ModelConfig(BaseConfig):
|
236 |
+
"""
|
237 |
+
OLMo (model) configuration.
|
238 |
+
"""
|
239 |
+
|
240 |
+
# Note that the defaults for these attributes are equivalent to the base GPT2 model.
|
241 |
+
|
242 |
+
d_model: int = 768
|
243 |
+
"""
|
244 |
+
The hidden size of the model.
|
245 |
+
"""
|
246 |
+
|
247 |
+
n_heads: int = 12
|
248 |
+
"""
|
249 |
+
The number of self-attention heads.
|
250 |
+
"""
|
251 |
+
|
252 |
+
n_kv_heads: Optional[int] = None
|
253 |
+
"""
|
254 |
+
The number of heads to use for keys and values. Defaults to `n_heads`.
|
255 |
+
Set this to ``None`` or ``n_heads`` for normal multi-head attention.
|
256 |
+
Set this to 1 for multi-query attention.
|
257 |
+
Set it to some in-between value for Llama2-style grouped query attention.
|
258 |
+
"""
|
259 |
+
|
260 |
+
clip_qkv: Optional[float] = None
|
261 |
+
"""
|
262 |
+
Clip QKV to this value when set.
|
263 |
+
"""
|
264 |
+
|
265 |
+
n_layers: int = 12
|
266 |
+
"""
|
267 |
+
The number of layers/blocks.
|
268 |
+
"""
|
269 |
+
|
270 |
+
mlp_ratio: int = 4
|
271 |
+
"""
|
272 |
+
The ratio of the inner MLP dimensionality to ``d_model``.
|
273 |
+
This is only used when ``mlp_hidden_size`` is not set.
|
274 |
+
"""
|
275 |
+
|
276 |
+
mlp_hidden_size: Optional[int] = None
|
277 |
+
"""
|
278 |
+
Set the exact hidden size for the MLP. Otherwise the inner MLP hidden size will be set to `mlp_ratio * d_model`.
|
279 |
+
"""
|
280 |
+
|
281 |
+
activation_type: ActivationType = ActivationType.swiglu
|
282 |
+
"""
|
283 |
+
The activation function to use within the MLP layers.
|
284 |
+
"""
|
285 |
+
|
286 |
+
block_type: BlockType = BlockType.sequential
|
287 |
+
"""
|
288 |
+
The transformer block implementation.
|
289 |
+
"""
|
290 |
+
|
291 |
+
block_group_size: int = 1
|
292 |
+
"""
|
293 |
+
The number of blocks to group together into a single parent block.
|
294 |
+
This has no affect on the number of parameters in the model and is only used to wrap groups
|
295 |
+
of blocks together with a single FSDP wrapper during training.
|
296 |
+
"""
|
297 |
+
|
298 |
+
alibi: bool = False
|
299 |
+
"""
|
300 |
+
If ``True``, use ALiBi embeddings. Mutually exclusive with ``rope``.
|
301 |
+
"""
|
302 |
+
|
303 |
+
alibi_bias_max: float = 8.0
|
304 |
+
"""
|
305 |
+
Maximum absolute value of ALiBi bias.
|
306 |
+
"""
|
307 |
+
|
308 |
+
rope: bool = False
|
309 |
+
"""
|
310 |
+
Use rotary positional embeddings (RoPE). Mutually exclusive with ``alibi``.
|
311 |
+
"""
|
312 |
+
|
313 |
+
rope_full_precision: bool = True
|
314 |
+
"""
|
315 |
+
If ``True``, apply RoPE embeddings at full precision regardless of the input type. Otherwise,
|
316 |
+
apply RoPE at the precision of the input.
|
317 |
+
"""
|
318 |
+
|
319 |
+
rope_theta: int = 10_000
|
320 |
+
"""
|
321 |
+
The theta setting for RoPE.
|
322 |
+
"""
|
323 |
+
|
324 |
+
flash_attention: bool = False
|
325 |
+
"""
|
326 |
+
If ``True``, use ``FlashAttention``.
|
327 |
+
"""
|
328 |
+
|
329 |
+
attention_dropout: float = 0.1
|
330 |
+
"""
|
331 |
+
The dropout probability within the attention modules.
|
332 |
+
"""
|
333 |
+
|
334 |
+
multi_query_attention: Optional[bool] = None
|
335 |
+
"""
|
336 |
+
Deprecated. Use n_kv_heads instead.
|
337 |
+
"""
|
338 |
+
|
339 |
+
attention_layer_norm: bool = False
|
340 |
+
"""
|
341 |
+
Apply layer norm to the keys and queries within the attention mechanism.
|
342 |
+
This can help stabilize training.
|
343 |
+
"""
|
344 |
+
|
345 |
+
residual_dropout: float = 0.1
|
346 |
+
"""
|
347 |
+
The dropout probability for the MLP and attention output within each block.
|
348 |
+
"""
|
349 |
+
|
350 |
+
embedding_dropout: float = 0.1
|
351 |
+
"""
|
352 |
+
The dropout probability for embeddings.
|
353 |
+
"""
|
354 |
+
|
355 |
+
embedding_layer_norm: bool = False
|
356 |
+
"""
|
357 |
+
Apply layer norm directly to the embeddings.
|
358 |
+
"""
|
359 |
+
|
360 |
+
layer_norm_type: LayerNormType = LayerNormType.default
|
361 |
+
"""
|
362 |
+
The layernorm implementation to use.
|
363 |
+
"""
|
364 |
+
|
365 |
+
layer_norm_with_affine: bool = True
|
366 |
+
"""
|
367 |
+
Whether to include bias and weight parameters for the layer norms.
|
368 |
+
This only affects layer norms that are immediately followed by a linear layer in the forward pass,
|
369 |
+
so everything except QK-norms. To turn off affines for QK norms as well, set :attr:`attention_layer_norm_with_affine`
|
370 |
+
to ``False``.
|
371 |
+
"""
|
372 |
+
|
373 |
+
layer_norm_eps: float = 1e-05
|
374 |
+
|
375 |
+
attention_layer_norm_with_affine: bool = True
|
376 |
+
"""
|
377 |
+
Toggle affine transform for the QK norms.
|
378 |
+
"""
|
379 |
+
|
380 |
+
max_sequence_length: int = 1024
|
381 |
+
"""
|
382 |
+
The maximum input sequence length supported by the model.
|
383 |
+
"""
|
384 |
+
|
385 |
+
include_bias: bool = True
|
386 |
+
"""
|
387 |
+
Whether or not to include bias parameters in linear layers.
|
388 |
+
In PaLM, they got rid of all bias terms because they found that large
|
389 |
+
models tend to have near 0 bias terms anyway.
|
390 |
+
"""
|
391 |
+
|
392 |
+
bias_for_layer_norm: Optional[bool] = None
|
393 |
+
"""
|
394 |
+
Whether or not to include bias parameters in layer norm.
|
395 |
+
This is separate from the include_bias parameter, because of a ROCm crash when biases are disabled in
|
396 |
+
layer norm.
|
397 |
+
When this is None (the default), it inherits the setting from include_bias.
|
398 |
+
"""
|
399 |
+
|
400 |
+
scale_logits: bool = False
|
401 |
+
"""
|
402 |
+
If ``True``, scale the output logits by ``1 / sqrt(d_model)``.
|
403 |
+
"""
|
404 |
+
|
405 |
+
vocab_size: int = 50257
|
406 |
+
"""
|
407 |
+
Vocabulary size of the model.
|
408 |
+
"""
|
409 |
+
|
410 |
+
embedding_size: Optional[int] = 50304
|
411 |
+
"""
|
412 |
+
The number of embeddings, i.e. the number of tokens. If set to ``None`` it will default
|
413 |
+
to ``vocab_size``. If ``vocab_size`` is not a multiple of 128, setting this to the
|
414 |
+
next multiple of 128 that's greater than ``vocab_size`` can improve throughput
|
415 |
+
substantially.
|
416 |
+
"""
|
417 |
+
|
418 |
+
weight_tying: bool = True
|
419 |
+
"""
|
420 |
+
Whether to tie output linear weights to the input embedding.
|
421 |
+
"""
|
422 |
+
|
423 |
+
eos_token_id: int = 50256
|
424 |
+
"""
|
425 |
+
The ID of the end-of-sentence special token.
|
426 |
+
"""
|
427 |
+
|
428 |
+
pad_token_id: int = 50256
|
429 |
+
"""
|
430 |
+
The ID of the token to use for padding. Defaults to the ID of the EOS token.
|
431 |
+
"""
|
432 |
+
|
433 |
+
init_device: Optional[str] = None
|
434 |
+
"""
|
435 |
+
The torch device to use when initializing the model parameters, e.g. "cpu", "cuda:0", "meta".
|
436 |
+
"""
|
437 |
+
|
438 |
+
init_fn: InitFnType = InitFnType.normal
|
439 |
+
"""
|
440 |
+
The weight initialization strategy.
|
441 |
+
"""
|
442 |
+
|
443 |
+
init_std: float = 0.02
|
444 |
+
"""
|
445 |
+
The standard deviation to use when initializing weights with a "fixed distribution" ``init_fn``, such
|
446 |
+
as "normal".
|
447 |
+
"""
|
448 |
+
|
449 |
+
init_cutoff_factor: Optional[float] = None
|
450 |
+
"""
|
451 |
+
A positive factor used to scale the cutoff values when initializing weights with a "fixed distribution" ``init_fn``, such
|
452 |
+
as "normal". Setting this to None means values are not cutoff.
|
453 |
+
"""
|
454 |
+
|
455 |
+
precision: Optional[str] = None
|
456 |
+
"""
|
457 |
+
Precision used to train/evaluate with. You shouldn't set this directly.
|
458 |
+
See :data:`TrainConfig.precision` instead.
|
459 |
+
"""
|
460 |
+
|
461 |
+
scale_emb_init: bool = False
|
462 |
+
"""
|
463 |
+
If ``True``, embeddings are scaled up by ``sqrt(d_model)`` during initialization.
|
464 |
+
Currently this is only used with `full_megatron` init when ``emb_init_std`` is unset.
|
465 |
+
"""
|
466 |
+
|
467 |
+
emb_init_std: Optional[float] = None
|
468 |
+
"""
|
469 |
+
Override the standard deviation to use when initializing the embedding weights.
|
470 |
+
"""
|
471 |
+
|
472 |
+
norm_after: bool = False
|
473 |
+
"""
|
474 |
+
Apply norm after the attention/feedforward layers rather than before, as introduced in the Swin transformer paper (Liu et al).
|
475 |
+
"""
|
476 |
+
|
477 |
+
use_ATF: Optional[bool] = False
|
478 |
+
|
479 |
+
p_ratio: float = 0.25
|
480 |
+
|
481 |
+
attention_activation: Optional[str] = None
|
482 |
+
|
483 |
+
@property
|
484 |
+
def effective_n_kv_heads(self) -> int:
|
485 |
+
if self.n_kv_heads is None:
|
486 |
+
if self.multi_query_attention is True:
|
487 |
+
return 1
|
488 |
+
else:
|
489 |
+
return self.n_heads
|
490 |
+
else:
|
491 |
+
if self.multi_query_attention is None:
|
492 |
+
return self.n_kv_heads
|
493 |
+
if self.multi_query_attention:
|
494 |
+
n_kv_heads_should_be = 1
|
495 |
+
else:
|
496 |
+
n_kv_heads_should_be = self.n_heads
|
497 |
+
if self.n_kv_heads == n_kv_heads_should_be:
|
498 |
+
return n_kv_heads_should_be
|
499 |
+
else:
|
500 |
+
raise OLMoConfigurationError(
|
501 |
+
"You can't set `multi_query_attention` and `n_kv_heads` at the same time."
|
502 |
+
)
|
503 |
+
|
504 |
+
|
505 |
+
class OptimizerType(StrEnum):
|
506 |
+
lionw = "lionw"
|
507 |
+
adamw = "adamw"
|
508 |
+
|
509 |
+
|
510 |
+
@dataclass
|
511 |
+
class OptimizerConfig(BaseConfig):
|
512 |
+
name: OptimizerType = OptimizerType.lionw
|
513 |
+
learning_rate: float = 1.0e-4
|
514 |
+
weight_decay: float = 0.01
|
515 |
+
betas: Tuple[float, float] = (0.9, 0.95)
|
516 |
+
eps: float = 1e-5
|
517 |
+
|
518 |
+
no_decay_norm_and_bias: Optional[bool] = None
|
519 |
+
"""
|
520 |
+
Deprecated. Use ``decay_norm_and_bias`` and ``decay_embeddings`` instead.
|
521 |
+
"""
|
522 |
+
|
523 |
+
selective_updates: bool = False
|
524 |
+
"""
|
525 |
+
If ``True``, optimizer parameter and state updates are skipped when the corresponding gradient is 0.
|
526 |
+
"""
|
527 |
+
|
528 |
+
decay_norm_and_bias: bool = False
|
529 |
+
decay_embeddings: bool = False
|
530 |
+
metrics_log_interval: Optional[int] = None
|
531 |
+
"""
|
532 |
+
The interval with which to collect and log detailed parameter-specific metrics.
|
533 |
+
This only applies when logging to W&B, since these metrics won't be logged to the console.
|
534 |
+
If not set, defaults to the wandb `log_interval`.
|
535 |
+
"""
|
536 |
+
|
537 |
+
record_update_metrics: bool = False
|
538 |
+
"""
|
539 |
+
Whether to record detailed metrics about the optimizer's parameter updates, like the norm and max
|
540 |
+
of the update with AdamW.
|
541 |
+
"""
|
542 |
+
|
543 |
+
def __post_init__(self):
|
544 |
+
self.betas = tuple(self.betas) # type: ignore[assignment]
|
545 |
+
|
546 |
+
@classmethod
|
547 |
+
def update_legacy_settings(cls, config: D) -> D:
|
548 |
+
new_config = config.copy()
|
549 |
+
if om.is_dict(new_config):
|
550 |
+
assert isinstance(new_config, DictConfig)
|
551 |
+
|
552 |
+
if hasattr(new_config, "name") and new_config.name == "decoupled_lionw":
|
553 |
+
new_config.name = "lionw"
|
554 |
+
if hasattr(new_config, "eps"):
|
555 |
+
del new_config.eps
|
556 |
+
|
557 |
+
return new_config
|
558 |
+
|
559 |
+
|
560 |
+
class SchedulerType(StrEnum):
|
561 |
+
cosine_with_warmup = "cosine_with_warmup"
|
562 |
+
linear_with_warmup = "linear_with_warmup"
|
563 |
+
inverse_sqrt_with_warmup = "inverse_sqrt_with_warmup"
|
564 |
+
max_scheduler = "max_scheduler"
|
565 |
+
constant = "constant"
|
566 |
+
cosine_linear_envelope = "cosine_linear_envelope"
|
567 |
+
constant_with_warmup = "constant_with_warmup"
|
568 |
+
|
569 |
+
|
570 |
+
class SchedulerUnits(StrEnum):
|
571 |
+
steps = "steps"
|
572 |
+
tokens = "tokens"
|
573 |
+
|
574 |
+
|
575 |
+
@dataclass
|
576 |
+
class SchedulerConfig(BaseConfig):
|
577 |
+
name: SchedulerType = SchedulerType.cosine_with_warmup
|
578 |
+
units: SchedulerUnits = SchedulerUnits.steps
|
579 |
+
t_warmup: Union[int, float] = 100
|
580 |
+
t_max: Optional[Union[int, float]] = None
|
581 |
+
alpha_f: float = 0.1
|
582 |
+
|
583 |
+
grad_clip_warmup_steps: Optional[Union[int, float]] = None
|
584 |
+
"""
|
585 |
+
The warmup period for which the max grad norm (or norm ratio) will be set to its
|
586 |
+
warmup value of `max_grad_norm * grad_clip_warmup_factor`.
|
587 |
+
"""
|
588 |
+
|
589 |
+
grad_clip_warmup_factor: Optional[float] = None
|
590 |
+
"""
|
591 |
+
The ratio of the max allowed gradient norm (or norm ratio) for clipping during the warmup period
|
592 |
+
vs after the warmup period.
|
593 |
+
"""
|
594 |
+
|
595 |
+
warmup_min_lr: Optional[float] = None
|
596 |
+
"""
|
597 |
+
The starting LR during the warmup period. If not set this defaults to 10% of
|
598 |
+
the target LR.
|
599 |
+
"""
|
600 |
+
|
601 |
+
|
602 |
+
class PaddingDirection(StrEnum):
|
603 |
+
right = "right"
|
604 |
+
left = "left"
|
605 |
+
|
606 |
+
|
607 |
+
@dataclass
|
608 |
+
class InstanceFilterConfig(BaseConfig):
|
609 |
+
repetition_max_period: int = 13
|
610 |
+
repetition_min_period: int = 1
|
611 |
+
repetition_max_count: int = 32
|
612 |
+
|
613 |
+
|
614 |
+
@dataclass
|
615 |
+
class DataConfig(BaseConfig):
|
616 |
+
paths: Optional[List[str]] = None
|
617 |
+
memmap_dtype: str = "uint16"
|
618 |
+
datasets: Optional[Dict[str, List[str]]] = None
|
619 |
+
label_mask_paths: Optional[List[str]] = None
|
620 |
+
pad_direction: PaddingDirection = PaddingDirection.right
|
621 |
+
generate_attention_mask: bool = False
|
622 |
+
generate_doc_lengths: bool = False
|
623 |
+
num_workers: int = 0
|
624 |
+
drop_last: bool = False
|
625 |
+
pin_memory: bool = False
|
626 |
+
prefetch_factor: Optional[int] = None
|
627 |
+
persistent_workers: bool = False
|
628 |
+
timeout: int = 0
|
629 |
+
seed: Optional[int] = None
|
630 |
+
instance_filter: Optional[InstanceFilterConfig] = None
|
631 |
+
custom_dataset: Optional[CustomDatasetConfig] = None
|
632 |
+
|
633 |
+
@property
|
634 |
+
def effective_memmap_dtype(self):
|
635 |
+
try:
|
636 |
+
# getattr will check this is part of numpy module, while np.dtype will check
|
637 |
+
# if this is a valid numpy dtype.
|
638 |
+
np.dtype(dtype := getattr(np, self.memmap_dtype))
|
639 |
+
except (AttributeError, TypeError) as e:
|
640 |
+
raise TypeError(f"Value {self.memmap_dtype} is not a valid numpy type") from e
|
641 |
+
return dtype
|
642 |
+
|
643 |
+
|
644 |
+
@dataclass
|
645 |
+
class CustomDatasetCollatorConfig(BaseConfig):
|
646 |
+
input_id_field: str = "input_ids" #: The field in the dataset items that contains the input token IDs.
|
647 |
+
attention_mask_field: Optional[str] = None #: The field in the dataset items that contains the attention mask.
|
648 |
+
attention_bias_field: Optional[str] = None #: The field in the dataset items that contains the attention bias.
|
649 |
+
label_mask_field: Optional[str] = None #: The field in the dataset items that contains the label mask.
|
650 |
+
index_field: Optional[str] = None #: The field in the dataset items that contains the index of the item.
|
651 |
+
instance_mask_field: Optional[str] = None #: The field in the dataset items that contains the instance mask.
|
652 |
+
doc_lens_field: Optional[str] = None #: The field in the dataset items that contains the document lengths.
|
653 |
+
metadata_field: Optional[str] = None #: The field in the dataset items that contains the metadata.
|
654 |
+
|
655 |
+
|
656 |
+
@dataclass
|
657 |
+
class CustomDatasetConfig(BaseConfig):
|
658 |
+
name: str #: The name of the custom dataset class or function that will be used to load the dataset.
|
659 |
+
module: Optional[
|
660 |
+
str
|
661 |
+
] = None #: The module where the custom dataset class is defined. If not set, the module will be inferred from the class name.
|
662 |
+
args: Optional[Dict[str, Any]] = None #: The arguments to pass to the custom dataset class or function
|
663 |
+
collate_fn: Optional[
|
664 |
+
str
|
665 |
+
] = None #: The name of the collate function to use for the custom dataset. Assumes the collate function is defined in the same module as the custom dataset class unless specified otherwise using the full object path.
|
666 |
+
token_field: Optional[str] = None #: The field in the dataset items that contains the tokenized text.
|
667 |
+
collate_config: Optional[CustomDatasetCollatorConfig] = field(
|
668 |
+
default_factory=CustomDatasetCollatorConfig
|
669 |
+
) #: The configuration for the collate function to use for the custom dataset.
|
670 |
+
|
671 |
+
|
672 |
+
class EvaluatorType(StrEnum):
|
673 |
+
downstream = "downstream"
|
674 |
+
lm = "lm"
|
675 |
+
|
676 |
+
|
677 |
+
@dataclass
|
678 |
+
class EvaluatorConfig(BaseConfig):
|
679 |
+
label: str
|
680 |
+
type: EvaluatorType = EvaluatorType.lm
|
681 |
+
data: DataConfig = field(default_factory=DataConfig)
|
682 |
+
device_eval_batch_size: Optional[int] = None
|
683 |
+
subset_num_batches: Optional[int] = None
|
684 |
+
|
685 |
+
|
686 |
+
class TruncationDirection(StrEnum):
|
687 |
+
right = "right"
|
688 |
+
left = "left"
|
689 |
+
|
690 |
+
|
691 |
+
@dataclass
|
692 |
+
class TokenizerConfig(BaseConfig):
|
693 |
+
identifier: str = "gpt2"
|
694 |
+
truncate_direction: TruncationDirection = TruncationDirection.right
|
695 |
+
|
696 |
+
|
697 |
+
@dataclass
|
698 |
+
class WandbConfig(BaseConfig):
|
699 |
+
project: Optional[str] = None
|
700 |
+
entity: Optional[str] = "ai2-llm"
|
701 |
+
group: Optional[str] = None
|
702 |
+
name: Optional[str] = None
|
703 |
+
tags: Optional[List[str]] = field(default_factory=lambda: ["watching"])
|
704 |
+
log_artifacts: bool = False
|
705 |
+
rank_zero_only: bool = True
|
706 |
+
log_interval: int = 1
|
707 |
+
|
708 |
+
|
709 |
+
@dataclass
|
710 |
+
class SpeedMonitorConfig(BaseConfig):
|
711 |
+
window_size: int = 100
|
712 |
+
gpu_flops_available: Optional[Union[float, int]] = None
|
713 |
+
|
714 |
+
|
715 |
+
@dataclass
|
716 |
+
class CompilerConfig(BaseConfig):
|
717 |
+
mode: Optional[str] = None
|
718 |
+
"""
|
719 |
+
The mode to compile the model in. At the moment this can be "default",
|
720 |
+
"reduce-overhead" (useful for smaller models/batches), or "max-autotune"
|
721 |
+
(the fastest for larger models, but takes a long time to compile).
|
722 |
+
"""
|
723 |
+
|
724 |
+
fullgraph: bool = False
|
725 |
+
"""
|
726 |
+
Whether it is OK to break model into several subgraphs when compiling.
|
727 |
+
Note that this is not compatible with FSDP.
|
728 |
+
"""
|
729 |
+
|
730 |
+
backend: str = "inductor"
|
731 |
+
"""
|
732 |
+
The backend to use.
|
733 |
+
"""
|
734 |
+
|
735 |
+
dynamic: Optional[bool] = None
|
736 |
+
"""
|
737 |
+
From the torch docs:
|
738 |
+
|
739 |
+
Use dynamic shape tracing. When this is True, we will up-front attempt to generate a kernel that is as dynamic
|
740 |
+
as possible to avoid recompilations when sizes change. This may not always work as some
|
741 |
+
operations/optimizations will force specialization; use TORCH_LOGS=dynamic to debug overspecialization. When
|
742 |
+
this is False, we will NEVER generate dynamic kernels, we will always specialize. By default (None), we
|
743 |
+
automatically detect if dynamism has occurred and compile a more dynamic kernel upon recompile.
|
744 |
+
"""
|
745 |
+
|
746 |
+
|
747 |
+
class DistributedStrategy(StrEnum):
|
748 |
+
ddp = "ddp"
|
749 |
+
"""
|
750 |
+
Wrap OLMo in torch.nn.parallel.DistributedDataParallel to train across ranks.
|
751 |
+
"""
|
752 |
+
|
753 |
+
fsdp = "fsdp"
|
754 |
+
"""
|
755 |
+
Wrap OLMo in torch.distributed.fsdp.FullyShardedDataParallel to train across ranks.
|
756 |
+
"""
|
757 |
+
|
758 |
+
single = "single"
|
759 |
+
"""
|
760 |
+
Train on a single device, i.e., do not distribute training. For development and debugging.
|
761 |
+
"""
|
762 |
+
|
763 |
+
|
764 |
+
class DDPGradSyncMode(StrEnum):
|
765 |
+
batch = "batch"
|
766 |
+
"""
|
767 |
+
Synchronize gradients after computation at each bucket only at the last micro-batch.
|
768 |
+
This is slightly faster than gradient syncs across each micro-batch but will consume more memory.
|
769 |
+
Can use this mode only when `find_unused_params` is set to False.
|
770 |
+
"""
|
771 |
+
|
772 |
+
micro_batch = "micro_batch"
|
773 |
+
"""
|
774 |
+
Synchronize gradients after computation at each bucket per micro-batch.
|
775 |
+
This will be slightly slower than gradient sync at the last micro-batch, but will consume less memory.
|
776 |
+
Can use this mode with both option of `find_unused_params` but specifically recommended to use with `find_unused_params`
|
777 |
+
set to True, to prevent errors.
|
778 |
+
"""
|
779 |
+
|
780 |
+
|
781 |
+
@dataclass
|
782 |
+
class DDPConfig(BaseConfig):
|
783 |
+
grad_sync_mode: DDPGradSyncMode = DDPGradSyncMode.batch
|
784 |
+
"""
|
785 |
+
Gradient sync mode for DDP
|
786 |
+
|
787 |
+
Note: When `find_unused_params` is set, set `grad_sync_mode` to `micro_batch` as different micro-batches might activate
|
788 |
+
different parts of the model, ex- MOEs.
|
789 |
+
"""
|
790 |
+
|
791 |
+
find_unused_params: bool = False
|
792 |
+
"""
|
793 |
+
(from torch documentation)
|
794 |
+
|
795 |
+
This mode allows running backward on a subgraph of the model, and DDP finds out which parameters
|
796 |
+
are involved in the backward pass by traversing the autograd graph from the model output and marking
|
797 |
+
all unused parameters as ready for reduction. Note that traversing the autograd graph introduces extra overheads,
|
798 |
+
so applications should only set find_unused_parameters to True when necessary.
|
799 |
+
"""
|
800 |
+
|
801 |
+
|
802 |
+
class FSDPWrapStrategy(StrEnum):
|
803 |
+
by_block = "by_block"
|
804 |
+
"""
|
805 |
+
Wrap each OLMo block with its own FSDP instance.
|
806 |
+
"""
|
807 |
+
|
808 |
+
by_block_and_size = "by_block_and_size"
|
809 |
+
"""
|
810 |
+
Like 'by_block' but `wte` and `ff_out` will be wrapped separately as well.
|
811 |
+
"""
|
812 |
+
|
813 |
+
by_block_group = "by_block_group"
|
814 |
+
"""
|
815 |
+
Wrap each block group together into its own FSDP instance.
|
816 |
+
This requires :attr:`~ModelConfig.block_group_size` to be bigger than 1.
|
817 |
+
"""
|
818 |
+
|
819 |
+
by_block_group_and_size = "by_block_group_and_size"
|
820 |
+
"""
|
821 |
+
Like 'by_block_group' but `wte` and `ff_out` will be wrapped separately as well.
|
822 |
+
"""
|
823 |
+
|
824 |
+
size_based = "size_based"
|
825 |
+
"""
|
826 |
+
Used PyTorch's default size-based auto wrap policy.
|
827 |
+
"""
|
828 |
+
|
829 |
+
one_in_two = "one_in_two"
|
830 |
+
one_in_three = "one_in_three"
|
831 |
+
one_in_four = "one_in_four"
|
832 |
+
one_in_five = "one_in_five"
|
833 |
+
|
834 |
+
|
835 |
+
class FSDPPrecision(StrEnum):
|
836 |
+
pure = "pure"
|
837 |
+
"""
|
838 |
+
Equivalent to :class:`torch.distributed.fsdp.MixedPrecision` with ``param_dtype``, ``reduce_dtype``,
|
839 |
+
and ``buffer_dtype`` all set to the autocast precision data type.
|
840 |
+
"""
|
841 |
+
|
842 |
+
mixed = "mixed"
|
843 |
+
"""
|
844 |
+
Equivalent to :class:`torch.distributed.fsdp.MixedPrecision` with ``param_dtype``, and ``buffer_dtype``
|
845 |
+
set to the autocast precision data type, while ``reduce_dtype`` is set to fp32.
|
846 |
+
"""
|
847 |
+
|
848 |
+
|
849 |
+
@dataclass
|
850 |
+
class FSDPConfig(BaseConfig):
|
851 |
+
use_orig_params: bool = True
|
852 |
+
"""
|
853 |
+
This must be ``True`` if using ``compile`` or you want to track the parameter norm during training.
|
854 |
+
"""
|
855 |
+
|
856 |
+
sharding_strategy: ShardingStrategy = ShardingStrategy.FULL_SHARD
|
857 |
+
|
858 |
+
wrapping_strategy: Optional[FSDPWrapStrategy] = None
|
859 |
+
"""
|
860 |
+
The wrapping strategy to use. If ``None``, the default, the model is wrapped with a single top-level
|
861 |
+
FSDP instance.
|
862 |
+
"""
|
863 |
+
|
864 |
+
precision: Optional[FSDPPrecision] = FSDPPrecision.pure
|
865 |
+
|
866 |
+
hybrid_sharding_num_model_replicas: Optional[int] = None
|
867 |
+
"""
|
868 |
+
The number of model instances, when using a hybrid sharding strategy.
|
869 |
+
If not ``None``, this must divide the total number of nodes. If ``None``, the default,
|
870 |
+
a model instance is used per node (as determined by ``get_world_size() // get_local_world_size()``).
|
871 |
+
PyTorch's default HSDP behavior matches this default behavior.
|
872 |
+
"""
|
873 |
+
|
874 |
+
|
875 |
+
@dataclass
|
876 |
+
class SingleGPUConfig(BaseConfig):
|
877 |
+
device: str = "auto"
|
878 |
+
"""
|
879 |
+
Device to run single-device training.
|
880 |
+
"""
|
881 |
+
|
882 |
+
def get_device(self):
|
883 |
+
if self.device == "auto":
|
884 |
+
if torch.backends.mps.is_available():
|
885 |
+
return torch.device("mps")
|
886 |
+
elif torch.cuda.is_available():
|
887 |
+
return torch.device("cuda")
|
888 |
+
else:
|
889 |
+
return torch.device("cpu")
|
890 |
+
elif self.device == "mps" and not torch.backends.mps.is_available():
|
891 |
+
raise OLMoConfigurationError("MPS not available.")
|
892 |
+
elif self.device == "cuda" and not torch.cuda.is_available():
|
893 |
+
raise OLMoConfigurationError("CUDA not available.")
|
894 |
+
else:
|
895 |
+
return torch.device(self.device)
|
896 |
+
|
897 |
+
|
898 |
+
class CheckpointType(StrEnum):
|
899 |
+
sharded = "sharded"
|
900 |
+
unsharded = "unsharded"
|
901 |
+
sharded_ephemeral = "sharded_ephemeral"
|
902 |
+
|
903 |
+
|
904 |
+
class ShardedCheckpointerType(StrEnum):
|
905 |
+
torch_new = "torch_new"
|
906 |
+
torch_legacy = "torch_legacy"
|
907 |
+
local = "local"
|
908 |
+
olmo_core = "olmo_core"
|
909 |
+
|
910 |
+
|
911 |
+
class ActivationCheckpointingStrategy(StrEnum):
|
912 |
+
whole_layer = "whole_layer"
|
913 |
+
"""
|
914 |
+
Checkpoint every transformer layer.
|
915 |
+
"""
|
916 |
+
|
917 |
+
one_in_two = "one_in_two"
|
918 |
+
"""
|
919 |
+
Checkpoint one in two transformer layers.
|
920 |
+
"""
|
921 |
+
|
922 |
+
one_in_three = "one_in_three"
|
923 |
+
"""
|
924 |
+
Checkpoint one in three transformer layers.
|
925 |
+
"""
|
926 |
+
|
927 |
+
one_in_four = "one_in_four"
|
928 |
+
"""
|
929 |
+
Checkpoint one in four transformer layers.
|
930 |
+
"""
|
931 |
+
|
932 |
+
one_in_eight = "one_in_eight"
|
933 |
+
"""
|
934 |
+
Checkpoint one in eight transformer layers.
|
935 |
+
"""
|
936 |
+
|
937 |
+
two_in_three = "two_in_three"
|
938 |
+
"""
|
939 |
+
Checkpoint two out of every three transformer layers.
|
940 |
+
"""
|
941 |
+
|
942 |
+
three_in_four = "three_in_four"
|
943 |
+
"""
|
944 |
+
Checkpoint three out of four of every transformer layers.
|
945 |
+
"""
|
946 |
+
|
947 |
+
fine_grained = "fine_grained"
|
948 |
+
"""
|
949 |
+
Focus checkpointing on where it is cheap to recompute and saves most memory.
|
950 |
+
"""
|
951 |
+
|
952 |
+
|
953 |
+
@dataclass
|
954 |
+
class TrainConfig(BaseConfig):
|
955 |
+
"""
|
956 |
+
OLMo training configuration.
|
957 |
+
"""
|
958 |
+
|
959 |
+
run_name: Optional[str] = None
|
960 |
+
"""
|
961 |
+
The name of the run.
|
962 |
+
"""
|
963 |
+
|
964 |
+
seed: int = 6198
|
965 |
+
"""
|
966 |
+
Used to seed all initial RNG states.
|
967 |
+
"""
|
968 |
+
|
969 |
+
epoch: Optional[int] = None
|
970 |
+
"""
|
971 |
+
Increment this when starting a new epoch.
|
972 |
+
"""
|
973 |
+
|
974 |
+
dry_run: bool = False
|
975 |
+
"""
|
976 |
+
If ``True``, don't actually train.
|
977 |
+
"""
|
978 |
+
|
979 |
+
model: ModelConfig = field(default_factory=ModelConfig)
|
980 |
+
"""
|
981 |
+
OLMo Model configuration.
|
982 |
+
"""
|
983 |
+
|
984 |
+
optimizer: OptimizerConfig = field(default_factory=OptimizerConfig)
|
985 |
+
"""
|
986 |
+
Optimizer configuration.
|
987 |
+
"""
|
988 |
+
|
989 |
+
scheduler: SchedulerConfig = field(default_factory=SchedulerConfig)
|
990 |
+
"""
|
991 |
+
Learning rate scheduler configuration.
|
992 |
+
"""
|
993 |
+
|
994 |
+
data: DataConfig = field(default_factory=DataConfig)
|
995 |
+
"""
|
996 |
+
Training data configuration.
|
997 |
+
"""
|
998 |
+
|
999 |
+
restore_dataloader: bool = True
|
1000 |
+
"""
|
1001 |
+
When restarting, restore the data loader to where it left off.
|
1002 |
+
If you restarting in order to train on a different dataset, set this to ``False``.
|
1003 |
+
"""
|
1004 |
+
|
1005 |
+
fast_forward_batches: Optional[int] = None
|
1006 |
+
"""
|
1007 |
+
When restarting, use this to fast-forward the dataloader beyond the last checkpoint.
|
1008 |
+
This can be useful when restarting due to a loss spike in order to skip the data that
|
1009 |
+
corresponded to the spike.
|
1010 |
+
"""
|
1011 |
+
|
1012 |
+
evaluators: List[EvaluatorConfig] = field(default_factory=list)
|
1013 |
+
"""
|
1014 |
+
Evaluation configurations.
|
1015 |
+
"""
|
1016 |
+
|
1017 |
+
eval_interval: int = 1000
|
1018 |
+
"""
|
1019 |
+
How often (in terms of batches) to run evaluations.
|
1020 |
+
"""
|
1021 |
+
|
1022 |
+
tokenizer: TokenizerConfig = field(default_factory=TokenizerConfig)
|
1023 |
+
"""
|
1024 |
+
Tokenizer configuration.
|
1025 |
+
"""
|
1026 |
+
|
1027 |
+
save_folder: str = "./"
|
1028 |
+
"""
|
1029 |
+
The directory to save checkpoints to.
|
1030 |
+
"""
|
1031 |
+
|
1032 |
+
remote_save_folder: Optional[str] = None
|
1033 |
+
"""
|
1034 |
+
A folder in a cloud bucket to upload saved checkpoints to.
|
1035 |
+
"""
|
1036 |
+
|
1037 |
+
canceled_check_interval: int = 50
|
1038 |
+
"""
|
1039 |
+
How often (in batches) to check if the run has been canceled or reached its time limit.
|
1040 |
+
"""
|
1041 |
+
|
1042 |
+
save_interval: Optional[int] = 1000
|
1043 |
+
"""
|
1044 |
+
How often (in terms of steps) to save sharded training state checkpoints.
|
1045 |
+
"""
|
1046 |
+
|
1047 |
+
save_interval_unsharded: Optional[int] = None
|
1048 |
+
"""
|
1049 |
+
How often (if at all) to save unsharded training state checkpoint.
|
1050 |
+
For large models it can be costly to save these, so it usually makes sense to save
|
1051 |
+
these less often than regular (sharded) training checkpoints.
|
1052 |
+
"""
|
1053 |
+
|
1054 |
+
save_interval_ephemeral: Optional[int] = None
|
1055 |
+
"""
|
1056 |
+
How often (if at all) to save ephemeral sharded checkpoints. These checkpoints are the same
|
1057 |
+
as those saved every `save_interval` except that at most only the most recent one of these is kept.
|
1058 |
+
This is useful when you want to checkpoint often for restarts in case of failures, but don't
|
1059 |
+
want to keep the majority of these checkpoints.
|
1060 |
+
|
1061 |
+
For example, suppose you want to keep your checkpoints at every 1000 steps, but you also want to save
|
1062 |
+
a temporary checkpoint every 100 steps in case your job fails. In that case you would
|
1063 |
+
set `save_interval=1000` and `save_interval_ephemeral=100`.
|
1064 |
+
"""
|
1065 |
+
|
1066 |
+
save_num_checkpoints_to_keep: int = -1
|
1067 |
+
"""
|
1068 |
+
How many sharded checkpoints to keep.
|
1069 |
+
"""
|
1070 |
+
|
1071 |
+
save_num_unsharded_checkpoints_to_keep: int = -1
|
1072 |
+
"""
|
1073 |
+
How many unsharded checkpoints to keep.
|
1074 |
+
"""
|
1075 |
+
|
1076 |
+
save_overwrite: bool = False
|
1077 |
+
"""
|
1078 |
+
If ``True``, overwrite any conflicting checkpoint files.
|
1079 |
+
"""
|
1080 |
+
|
1081 |
+
force_save_unsharded: bool = False
|
1082 |
+
"""
|
1083 |
+
Save an unsharded checkpoint before training (even during a dry run).
|
1084 |
+
Use this option with `--load-path={PATH}` and `--dry_run` to convert a sharded
|
1085 |
+
checkpoint into an unsharded checkpoint.
|
1086 |
+
"""
|
1087 |
+
|
1088 |
+
no_pre_train_checkpoint: bool = False
|
1089 |
+
"""
|
1090 |
+
Skip saving pre-train checkpoint.
|
1091 |
+
"""
|
1092 |
+
|
1093 |
+
load_path: Optional[str] = None
|
1094 |
+
"""
|
1095 |
+
The path to a training checkpoint to restore/resume from. If not set, then training begins from scratch.
|
1096 |
+
|
1097 |
+
Note that you can make use of the "path.last_checkpoint" Omegaconfig YAML resolver here, which takes
|
1098 |
+
a local or remote directory and resolves to the latest checkpoint (sharded or unsharded) in that directory.
|
1099 |
+
For example,
|
1100 |
+
|
1101 |
+
```bash
|
1102 |
+
--load_path='${path.last_checkpoint:s3://ai2-llm/checkpoints/7b/v1_5-mix-run-001}'
|
1103 |
+
```
|
1104 |
+
|
1105 |
+
If `try_load_latest_save` is set and saved checkpoints exist, then `load_path` will be overriden
|
1106 |
+
by the latest saved checkpoint.
|
1107 |
+
"""
|
1108 |
+
|
1109 |
+
load_path_sharded_checkpointer: Optional[ShardedCheckpointerType] = None
|
1110 |
+
"""
|
1111 |
+
The sharded checkpointer type to use to load the initial checkpoint from ``load_path``.
|
1112 |
+
"""
|
1113 |
+
|
1114 |
+
try_load_latest_save: bool = False
|
1115 |
+
"""
|
1116 |
+
If set, then training will be resumed from the latest checkpoint in the local save folder, falling
|
1117 |
+
back to the latest checkpoint in the remote save folder if none exists. If there are no checkpoints
|
1118 |
+
in the local and remote save folders, then checkpoint loading will fall back to `load_path`.
|
1119 |
+
"""
|
1120 |
+
|
1121 |
+
reset_optimizer_state: bool = False
|
1122 |
+
"""
|
1123 |
+
When this is set, we restore the model from a checkpoint (if given), but we leave the optimizer uninitialized.
|
1124 |
+
We also set a new learning rate schedule that does a new warmup, such that it intercepts the original learning
|
1125 |
+
curve (according to the current learning rate schedule settings), and continues from there.
|
1126 |
+
"""
|
1127 |
+
|
1128 |
+
reset_trainer_state: bool = False
|
1129 |
+
"""
|
1130 |
+
When this is set we don't restore the trainer state from a checkpoint.
|
1131 |
+
"""
|
1132 |
+
|
1133 |
+
sharded_checkpointer: ShardedCheckpointerType = ShardedCheckpointerType.torch_legacy
|
1134 |
+
"""
|
1135 |
+
The name of the sharded checkpointer to use to save (sharded) checkpoints throughout training.
|
1136 |
+
"""
|
1137 |
+
|
1138 |
+
new_style_checkpoints: Optional[bool] = None
|
1139 |
+
"""
|
1140 |
+
Deprecated. Use ``sharded_checkpointer`` instead.
|
1141 |
+
"""
|
1142 |
+
|
1143 |
+
max_duration: Union[int, str] = 10000
|
1144 |
+
"""
|
1145 |
+
How long to train for.
|
1146 |
+
|
1147 |
+
If specified without a unit (the default), the units are assumed to be steps.
|
1148 |
+
You can also specify this in terms of tokens, for example: `max_duration="2e12T"` means train until
|
1149 |
+
2 trillion tokens.
|
1150 |
+
"""
|
1151 |
+
|
1152 |
+
global_train_batch_size: int = 512
|
1153 |
+
"""
|
1154 |
+
The effective global batch size.
|
1155 |
+
"""
|
1156 |
+
|
1157 |
+
device_train_batch_size: Optional[int] = None # calculated automatically
|
1158 |
+
"""
|
1159 |
+
Don't set this manually. This will be set to ``global_train_batch_size // world_size``.
|
1160 |
+
"""
|
1161 |
+
|
1162 |
+
device_train_microbatch_size: int = 16
|
1163 |
+
"""
|
1164 |
+
The number of instances passed to the model in a single forward-backward pass. You should set
|
1165 |
+
this as large as you can based on available GPU memory.
|
1166 |
+
"""
|
1167 |
+
|
1168 |
+
device_eval_batch_size: int = 16
|
1169 |
+
"""
|
1170 |
+
The number of evaluation instances passed to the model in a single forward pass on each device.
|
1171 |
+
"""
|
1172 |
+
|
1173 |
+
eval_subset_num_batches: int = -1
|
1174 |
+
"""
|
1175 |
+
The number of batches to use for downstream evaluation from each dataset.
|
1176 |
+
"""
|
1177 |
+
|
1178 |
+
eval_on_load: bool = False
|
1179 |
+
"""
|
1180 |
+
When resuming from a checkpoint, run the evaluation loop right away.
|
1181 |
+
"""
|
1182 |
+
|
1183 |
+
device_train_grad_accum: Optional[int] = None # calculated automatically
|
1184 |
+
"""
|
1185 |
+
Don't set this manually. This will be set to ``device_train_batch_size // device_train_microbatch_size``.
|
1186 |
+
"""
|
1187 |
+
|
1188 |
+
max_grad_norm: Optional[float] = None
|
1189 |
+
"""
|
1190 |
+
Clip gradient norms to this value if set.
|
1191 |
+
"""
|
1192 |
+
|
1193 |
+
max_grad_norm_ratio: Optional[float] = None
|
1194 |
+
"""
|
1195 |
+
If set, gradient norms will be clipped to `max_grad_norm_ratio * exp_avg(norm(grad))`.
|
1196 |
+
This takes priority over `max_grad_norm` when set.
|
1197 |
+
"""
|
1198 |
+
|
1199 |
+
precision: Optional[str] = None
|
1200 |
+
"""
|
1201 |
+
Precision to train with (e.g. "amp_bf16", "amp_fp16", or "fp32").
|
1202 |
+
"""
|
1203 |
+
|
1204 |
+
wandb: Optional[WandbConfig] = None
|
1205 |
+
"""
|
1206 |
+
Weights & Biases configuration.
|
1207 |
+
"""
|
1208 |
+
|
1209 |
+
speed_monitor: SpeedMonitorConfig = field(default_factory=SpeedMonitorConfig)
|
1210 |
+
"""
|
1211 |
+
Speed monitor configuration.
|
1212 |
+
"""
|
1213 |
+
|
1214 |
+
console_log_interval: int = 1
|
1215 |
+
"""
|
1216 |
+
How often to log to the console.
|
1217 |
+
"""
|
1218 |
+
|
1219 |
+
gen1_gc_interval: Optional[int] = 1
|
1220 |
+
"""
|
1221 |
+
How often (in steps) to run generation 1 garbage collection.
|
1222 |
+
Set to ``None`` to use automatic garbage collection (i.e. we don't mess with it).
|
1223 |
+
"""
|
1224 |
+
|
1225 |
+
compile: Optional[CompilerConfig] = None
|
1226 |
+
"""
|
1227 |
+
Settings for compiling the model with ``torch.compile()``.
|
1228 |
+
"""
|
1229 |
+
|
1230 |
+
distributed_strategy: Optional[DistributedStrategy] = DistributedStrategy.fsdp
|
1231 |
+
"""
|
1232 |
+
Distributed strategy for OLMo model (eg. single GPU, DDP, FSDP).
|
1233 |
+
"""
|
1234 |
+
|
1235 |
+
fsdp: Optional[FSDPConfig] = field(default_factory=FSDPConfig)
|
1236 |
+
"""
|
1237 |
+
Fully sharded data parallel settings.
|
1238 |
+
"""
|
1239 |
+
|
1240 |
+
ddp: Optional[DDPConfig] = None
|
1241 |
+
"""
|
1242 |
+
DDP settings.
|
1243 |
+
"""
|
1244 |
+
|
1245 |
+
single: SingleGPUConfig = field(default_factory=lambda: SingleGPUConfig(device="auto"))
|
1246 |
+
"""
|
1247 |
+
Single device settings for GPU/CPU/MPS. Defaults to auto-detect the best device.
|
1248 |
+
"""
|
1249 |
+
|
1250 |
+
softmax_auxiliary_loss: bool = False
|
1251 |
+
"""
|
1252 |
+
If ``True``, we add the auxiliary loss function from PaLM that encourages the softmax
|
1253 |
+
normalizing term to be close to 0.
|
1254 |
+
"""
|
1255 |
+
|
1256 |
+
auxiliary_loss_multiplier: Optional[float] = 1e-4
|
1257 |
+
"""
|
1258 |
+
Used with `softmax_auxiliary_loss`. PaLM uses 1e-4, Chameleon uses 1e-5.
|
1259 |
+
"""
|
1260 |
+
|
1261 |
+
time_limit: Optional[float] = None
|
1262 |
+
"""
|
1263 |
+
The maximum amount of time to train for before saving a checkpoint and ending early.
|
1264 |
+
"""
|
1265 |
+
|
1266 |
+
extra_steps_after_cancel: int = 10
|
1267 |
+
"""
|
1268 |
+
Under certain conditions when a run is canceled we train for a few extra steps after saving
|
1269 |
+
the final checkpoint so that when the run is restarted from the latest checkpoint we have some
|
1270 |
+
overlap in metrics.
|
1271 |
+
"""
|
1272 |
+
|
1273 |
+
early_stopping_factor: Optional[float] = None
|
1274 |
+
|
1275 |
+
save_data_indices: bool = True
|
1276 |
+
"""
|
1277 |
+
Save training data indices from each batch for each worker.
|
1278 |
+
"""
|
1279 |
+
|
1280 |
+
python_profiling: bool = False
|
1281 |
+
"""
|
1282 |
+
Whether to run the Python profiler on batches 6, 7, and 8.
|
1283 |
+
"""
|
1284 |
+
|
1285 |
+
torch_profiling: bool = False
|
1286 |
+
"""
|
1287 |
+
Whether to run the PyTorch profiler on batches 6, 7, and 8.
|
1288 |
+
"""
|
1289 |
+
|
1290 |
+
stop_at: Optional[int] = None
|
1291 |
+
"""
|
1292 |
+
Stop at a specific step.
|
1293 |
+
"""
|
1294 |
+
|
1295 |
+
stop_after: Optional[int] = None
|
1296 |
+
"""
|
1297 |
+
Stop after a specific number of steps.
|
1298 |
+
"""
|
1299 |
+
|
1300 |
+
activation_checkpointing: Optional[ActivationCheckpointingStrategy] = None
|
1301 |
+
"""
|
1302 |
+
The activation checkpointing strategy to use.
|
1303 |
+
"""
|
1304 |
+
|
1305 |
+
fused_loss: Optional[bool] = None
|
1306 |
+
"""
|
1307 |
+
Whether to use the fused CE loss function from `flash-attn`.
|
1308 |
+
"""
|
1309 |
+
|
1310 |
+
hf_datasets_cache_dir: Optional[str] = None
|
1311 |
+
"""
|
1312 |
+
Deprecated, HF datasets are now stored in `olmo_data.hf_datasets`.
|
1313 |
+
|
1314 |
+
Path to cache directory of HF datasets saved with `datasets.save_to_disk`.
|
1315 |
+
"""
|
1316 |
+
|
1317 |
+
module_outputs_save_steps: Optional[List[int]] = None
|
1318 |
+
"""
|
1319 |
+
Outputs of model submodules are saved during the provided steps. Submodule outputs
|
1320 |
+
can be compared using `scripts/compare_module_outputs.py`.
|
1321 |
+
"""
|
1322 |
+
|
1323 |
+
@property
|
1324 |
+
def autocast_precision(self) -> torch.dtype:
|
1325 |
+
if self.precision == "amp_bf16":
|
1326 |
+
return torch.bfloat16
|
1327 |
+
elif self.precision == "amp_fp16":
|
1328 |
+
return torch.float16
|
1329 |
+
elif self.precision == "fp32":
|
1330 |
+
return torch.float32
|
1331 |
+
else:
|
1332 |
+
raise ValueError(f"Unexpected precision type '{self.precision}'")
|
1333 |
+
|
1334 |
+
@property
|
1335 |
+
def fsdp_precision(self) -> Optional[MixedPrecision]:
|
1336 |
+
if self.fsdp is not None:
|
1337 |
+
if self.fsdp.precision is None:
|
1338 |
+
return None
|
1339 |
+
elif self.fsdp.precision == FSDPPrecision.pure:
|
1340 |
+
return MixedPrecision(
|
1341 |
+
param_dtype=self.autocast_precision,
|
1342 |
+
reduce_dtype=self.autocast_precision,
|
1343 |
+
buffer_dtype=self.autocast_precision,
|
1344 |
+
)
|
1345 |
+
elif self.fsdp.precision == FSDPPrecision.mixed:
|
1346 |
+
return MixedPrecision(
|
1347 |
+
param_dtype=self.autocast_precision,
|
1348 |
+
reduce_dtype=torch.float32,
|
1349 |
+
buffer_dtype=self.autocast_precision,
|
1350 |
+
)
|
1351 |
+
else:
|
1352 |
+
raise NotImplementedError(f"{self.fsdp.precision}")
|
1353 |
+
else:
|
1354 |
+
raise ValueError("self.fsdp is None!")
|
1355 |
+
|
1356 |
+
@classmethod
|
1357 |
+
def update_legacy_settings(cls, config: D) -> D:
|
1358 |
+
new_config = config.copy()
|
1359 |
+
if om.is_dict(new_config):
|
1360 |
+
assert isinstance(new_config, DictConfig)
|
1361 |
+
|
1362 |
+
if hasattr(new_config, "activation_checkpointing"):
|
1363 |
+
if new_config.activation_checkpointing is False:
|
1364 |
+
new_config.activation_checkpointing = None
|
1365 |
+
if new_config.activation_checkpointing is True:
|
1366 |
+
new_config.activation_checkpointing = ActivationCheckpointingStrategy.whole_layer
|
1367 |
+
|
1368 |
+
if hasattr(new_config, "optimizer"):
|
1369 |
+
new_config.optimizer = OptimizerConfig.update_legacy_settings(new_config.optimizer)
|
1370 |
+
|
1371 |
+
return new_config
|
exceptions.py
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
__all__ = [
|
2 |
+
"OLMoError",
|
3 |
+
"OLMoConfigurationError",
|
4 |
+
"OLMoCliError",
|
5 |
+
"OLMoEnvironmentError",
|
6 |
+
"OLMoNetworkError",
|
7 |
+
"OLMoCheckpointError",
|
8 |
+
]
|
9 |
+
|
10 |
+
|
11 |
+
class OLMoError(Exception):
|
12 |
+
"""
|
13 |
+
Base class for all custom OLMo exceptions.
|
14 |
+
"""
|
15 |
+
|
16 |
+
|
17 |
+
class OLMoConfigurationError(OLMoError):
|
18 |
+
"""
|
19 |
+
An error with a configuration file.
|
20 |
+
"""
|
21 |
+
|
22 |
+
|
23 |
+
class OLMoCliError(OLMoError):
|
24 |
+
"""
|
25 |
+
An error from incorrect CLI usage.
|
26 |
+
"""
|
27 |
+
|
28 |
+
|
29 |
+
class OLMoEnvironmentError(OLMoError):
|
30 |
+
"""
|
31 |
+
An error from incorrect environment variables.
|
32 |
+
"""
|
33 |
+
|
34 |
+
|
35 |
+
class OLMoNetworkError(OLMoError):
|
36 |
+
"""
|
37 |
+
An error with a network request.
|
38 |
+
"""
|
39 |
+
|
40 |
+
|
41 |
+
class OLMoCheckpointError(OLMoError):
|
42 |
+
"""
|
43 |
+
An error occurred reading or writing from a checkpoint.
|
44 |
+
"""
|
45 |
+
|
46 |
+
|
47 |
+
class OLMoThreadError(Exception):
|
48 |
+
"""
|
49 |
+
Raised when a thread fails.
|
50 |
+
"""
|
initialization.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional, Union
|
2 |
+
|
3 |
+
import torch.nn as nn
|
4 |
+
|
5 |
+
__all__ = ["init_normal"]
|
6 |
+
|
7 |
+
|
8 |
+
def init_normal(
|
9 |
+
module: Union[nn.Linear, nn.Embedding],
|
10 |
+
std: float,
|
11 |
+
init_cutoff_factor: Optional[float] = None,
|
12 |
+
):
|
13 |
+
# weights
|
14 |
+
if init_cutoff_factor is not None:
|
15 |
+
cutoff_value = init_cutoff_factor * std
|
16 |
+
nn.init.trunc_normal_(module.weight, mean=0.0, std=std, a=-cutoff_value, b=cutoff_value)
|
17 |
+
else:
|
18 |
+
nn.init.normal_(module.weight, mean=0.0, std=std)
|
19 |
+
|
20 |
+
# biases
|
21 |
+
if isinstance(module, nn.Linear) and module.bias is not None:
|
22 |
+
nn.init.zeros_(module.bias)
|
model.py
ADDED
@@ -0,0 +1,1959 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Adapted from
|
3 |
+
[MosaiclML](https://github.com/mosaicml/examples.git) and
|
4 |
+
[minGPT](https://github.com/karpathy/minGPT.git)
|
5 |
+
"""
|
6 |
+
|
7 |
+
from __future__ import annotations
|
8 |
+
|
9 |
+
import logging
|
10 |
+
import math
|
11 |
+
import sys
|
12 |
+
from abc import abstractmethod
|
13 |
+
from collections import defaultdict
|
14 |
+
from functools import partial
|
15 |
+
from typing import (
|
16 |
+
Callable,
|
17 |
+
Dict,
|
18 |
+
Iterable,
|
19 |
+
List,
|
20 |
+
NamedTuple,
|
21 |
+
Optional,
|
22 |
+
Sequence,
|
23 |
+
Set,
|
24 |
+
Tuple,
|
25 |
+
cast,
|
26 |
+
)
|
27 |
+
|
28 |
+
import torch
|
29 |
+
import torch.backends.cuda
|
30 |
+
import torch.nn as nn
|
31 |
+
import torch.nn.functional as F
|
32 |
+
from torch import einsum
|
33 |
+
|
34 |
+
from .aliases import PathOrStr
|
35 |
+
from .beam_search import BeamSearch, Constraint, FinalSequenceScorer, Sampler
|
36 |
+
from .config import (
|
37 |
+
ActivationCheckpointingStrategy,
|
38 |
+
ActivationType,
|
39 |
+
BlockType,
|
40 |
+
CheckpointType,
|
41 |
+
FSDPWrapStrategy,
|
42 |
+
InitFnType,
|
43 |
+
LayerNormType,
|
44 |
+
ModelConfig,
|
45 |
+
ShardedCheckpointerType,
|
46 |
+
TrainConfig,
|
47 |
+
)
|
48 |
+
from .exceptions import OLMoConfigurationError
|
49 |
+
from .initialization import init_normal
|
50 |
+
from .torch_util import ensure_finite_, get_cumulative_document_lengths
|
51 |
+
|
52 |
+
if sys.version_info.minor > 8:
|
53 |
+
from collections.abc import MutableMapping
|
54 |
+
elif sys.version_info.minor == 8:
|
55 |
+
from typing import MutableMapping
|
56 |
+
else:
|
57 |
+
raise SystemExit("This script supports Python 3.8 or higher")
|
58 |
+
|
59 |
+
__all__ = [
|
60 |
+
"LayerNormBase",
|
61 |
+
"LayerNorm",
|
62 |
+
"RMSLayerNorm",
|
63 |
+
"RotaryEmbedding",
|
64 |
+
"Activation",
|
65 |
+
"GELU",
|
66 |
+
"ReLU",
|
67 |
+
"SwiGLU",
|
68 |
+
"OLMoBlock",
|
69 |
+
"OLMoSequentialBlock",
|
70 |
+
"OLMo",
|
71 |
+
"OLMoOutput",
|
72 |
+
"OLMoGenerateOutput",
|
73 |
+
]
|
74 |
+
|
75 |
+
log = logging.getLogger(__name__)
|
76 |
+
|
77 |
+
class FANLayer(nn.Module):
|
78 |
+
"""
|
79 |
+
FANLayer: The layer used in FAN (https://arxiv.org/abs/2410.02675).
|
80 |
+
|
81 |
+
Args:
|
82 |
+
input_dim (int): The number of input features.
|
83 |
+
output_dim (int): The number of output features.
|
84 |
+
p_ratio (float): The ratio of output dimensions used for cosine and sine parts (default: 0.25).
|
85 |
+
activation (str or callable): The activation function to apply to the g component. If a string is passed,
|
86 |
+
the corresponding activation from torch.nn.functional is used (default: 'gelu').
|
87 |
+
use_p_bias (bool): If True, include bias in the linear transformations of p component (default: True).
|
88 |
+
There is almost no difference between bias and non-bias in our experiments.
|
89 |
+
"""
|
90 |
+
|
91 |
+
def __init__(self, input_dim, output_dim, p_ratio=0.25, activation='gelu', use_p_bias=True):
|
92 |
+
super(FANLayer, self).__init__()
|
93 |
+
|
94 |
+
# Ensure the p_ratio is within a valid range
|
95 |
+
assert 0 <= p_ratio <= 0.5, "p_ratio must be between 0 and 0.5"
|
96 |
+
|
97 |
+
self.p_ratio = p_ratio
|
98 |
+
p_output_dim = int(output_dim * self.p_ratio)
|
99 |
+
g_output_dim = output_dim - p_output_dim * 2 # Account for cosine and sine terms
|
100 |
+
|
101 |
+
|
102 |
+
self.input_linear = nn.Linear(input_dim, p_output_dim+g_output_dim, bias=use_p_bias)
|
103 |
+
|
104 |
+
self.fused_dims = (p_output_dim, g_output_dim)
|
105 |
+
|
106 |
+
# Set the activation function
|
107 |
+
if isinstance(activation, str):
|
108 |
+
self.activation = getattr(F, activation)
|
109 |
+
else:
|
110 |
+
self.activation = activation if activation else lambda x: x
|
111 |
+
|
112 |
+
def forward(self, src, norm_g=None):
|
113 |
+
"""
|
114 |
+
Args:
|
115 |
+
src (Tensor): Input tensor of shape (batch_size, input_dim).
|
116 |
+
|
117 |
+
Returns:
|
118 |
+
Tensor: Output tensor of shape (batch_size, output_dim), after applying the FAN layer.
|
119 |
+
"""
|
120 |
+
pg = self.input_linear(src)
|
121 |
+
|
122 |
+
p, g = pg.split(self.fused_dims, dim=-1)
|
123 |
+
|
124 |
+
# Concatenate cos(p), sin(p), and activated g along the last dimension
|
125 |
+
output = torch.cat((torch.cos(p), torch.sin(p), self.activation(g)), dim=-1)
|
126 |
+
|
127 |
+
return output
|
128 |
+
|
129 |
+
class FAN(nn.Module):
|
130 |
+
def __init__(self, input_dim, output_dim, config, activation='gelu'):
|
131 |
+
super(FAN, self).__init__()
|
132 |
+
|
133 |
+
self.fanlayer = FANLayer(input_dim, input_dim, config.p_ratio, activation)
|
134 |
+
self.linear = nn.Linear(input_dim, output_dim, bias=config.include_bias, device=config.init_device)
|
135 |
+
|
136 |
+
def forward(self, src):
|
137 |
+
return self.linear(self.fanlayer(src))
|
138 |
+
|
139 |
+
|
140 |
+
def activation_checkpoint_function(cfg: ModelConfig):
|
141 |
+
preserve_rng_state = not (
|
142 |
+
(cfg.attention_dropout == 0.0) and (cfg.embedding_dropout == 0.0) and (cfg.residual_dropout == 0.0)
|
143 |
+
)
|
144 |
+
from torch.utils.checkpoint import checkpoint
|
145 |
+
|
146 |
+
return partial(
|
147 |
+
checkpoint,
|
148 |
+
preserve_rng_state=preserve_rng_state,
|
149 |
+
use_reentrant=False,
|
150 |
+
)
|
151 |
+
|
152 |
+
|
153 |
+
def should_checkpoint_block(strategy: Optional[ActivationCheckpointingStrategy], block_idx: int) -> bool:
|
154 |
+
if strategy is None:
|
155 |
+
return False
|
156 |
+
elif (
|
157 |
+
(strategy == ActivationCheckpointingStrategy.whole_layer)
|
158 |
+
or (strategy == ActivationCheckpointingStrategy.one_in_two and block_idx % 2 == 0)
|
159 |
+
or (strategy == ActivationCheckpointingStrategy.one_in_three and block_idx % 3 == 0)
|
160 |
+
or (strategy == ActivationCheckpointingStrategy.one_in_four and block_idx % 4 == 0)
|
161 |
+
or (strategy == ActivationCheckpointingStrategy.one_in_eight and block_idx % 8 == 0)
|
162 |
+
or (strategy == ActivationCheckpointingStrategy.two_in_three and block_idx % 3 != 0)
|
163 |
+
or (strategy == ActivationCheckpointingStrategy.three_in_four and block_idx % 4 != 0)
|
164 |
+
):
|
165 |
+
return True
|
166 |
+
else:
|
167 |
+
return False
|
168 |
+
|
169 |
+
|
170 |
+
class BufferCache(dict, MutableMapping[str, torch.Tensor]):
|
171 |
+
"""
|
172 |
+
Cache for attention biases and other things that would normally be stored as buffers.
|
173 |
+
We avoid using buffers because we've run into various issues doing so with FSDP.
|
174 |
+
In general it appears the way FSDP handles buffers is not well-defined.
|
175 |
+
It doesn't shard them but apparently it does synchronize them across processes, which we want to avoid
|
176 |
+
since (A) it isn't necessary, and (B) we sometimes have `-inf` in these biases which might get turned into
|
177 |
+
NaNs when they're synchronized due to casting or some other issue.
|
178 |
+
"""
|
179 |
+
|
180 |
+
|
181 |
+
def _non_meta_init_device(config: ModelConfig) -> torch.device:
|
182 |
+
if config.init_device is not None and config.init_device != "meta":
|
183 |
+
return torch.device(config.init_device)
|
184 |
+
else:
|
185 |
+
if torch.backends.mps.is_available():
|
186 |
+
return torch.device("mps")
|
187 |
+
elif torch.cuda.is_available():
|
188 |
+
return torch.device("cuda")
|
189 |
+
else:
|
190 |
+
return torch.device("cpu")
|
191 |
+
|
192 |
+
|
193 |
+
class Dropout(nn.Dropout):
|
194 |
+
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
195 |
+
if self.p == 0.0:
|
196 |
+
return input
|
197 |
+
else:
|
198 |
+
return F.dropout(input, self.p, self.training, self.inplace)
|
199 |
+
|
200 |
+
|
201 |
+
class LayerNormBase(nn.Module):
|
202 |
+
def __init__(
|
203 |
+
self,
|
204 |
+
config: ModelConfig,
|
205 |
+
*,
|
206 |
+
size: Optional[int] = None,
|
207 |
+
elementwise_affine: Optional[bool] = True,
|
208 |
+
):
|
209 |
+
super().__init__()
|
210 |
+
self.config = config
|
211 |
+
self.eps = config.layer_norm_eps
|
212 |
+
self.normalized_shape = (size or config.d_model,)
|
213 |
+
if elementwise_affine or (elementwise_affine is None and self.config.layer_norm_with_affine):
|
214 |
+
self.weight = nn.Parameter(torch.ones(self.normalized_shape, device=config.init_device))
|
215 |
+
use_bias = self.config.bias_for_layer_norm
|
216 |
+
if use_bias is None:
|
217 |
+
use_bias = self.config.include_bias
|
218 |
+
if use_bias:
|
219 |
+
self.bias = nn.Parameter(torch.zeros(self.normalized_shape, device=config.init_device))
|
220 |
+
else:
|
221 |
+
self.register_parameter("bias", None)
|
222 |
+
else:
|
223 |
+
self.register_parameter("bias", None)
|
224 |
+
self.register_parameter("weight", None)
|
225 |
+
|
226 |
+
@abstractmethod
|
227 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
228 |
+
raise NotImplementedError
|
229 |
+
|
230 |
+
@classmethod
|
231 |
+
def build(cls, config: ModelConfig, size: Optional[int] = None, **kwargs) -> LayerNormBase:
|
232 |
+
if config.layer_norm_type == LayerNormType.default:
|
233 |
+
return LayerNorm(config, size=size, low_precision=False, **kwargs)
|
234 |
+
elif config.layer_norm_type == LayerNormType.low_precision:
|
235 |
+
return LayerNorm(config, size=size, low_precision=True, **kwargs)
|
236 |
+
elif config.layer_norm_type == LayerNormType.rms:
|
237 |
+
return RMSLayerNorm(config, size=size, **kwargs)
|
238 |
+
else:
|
239 |
+
raise NotImplementedError(f"Unknown LayerNorm type: '{config.layer_norm_type}'")
|
240 |
+
|
241 |
+
def _cast_if_autocast_enabled(self, tensor: torch.Tensor, dtype: Optional[torch.dtype] = None) -> torch.Tensor:
|
242 |
+
# NOTE: `is_autocast_enabled()` only checks for CUDA autocast, so we use the separate function
|
243 |
+
# `is_autocast_cpu_enabled()` for CPU autocast.
|
244 |
+
# See https://github.com/pytorch/pytorch/issues/110966.
|
245 |
+
if tensor.device.type == "cuda" and torch.is_autocast_enabled():
|
246 |
+
return tensor.to(dtype=dtype if dtype is not None else torch.get_autocast_gpu_dtype())
|
247 |
+
elif tensor.device.type == "cpu" and torch.is_autocast_cpu_enabled():
|
248 |
+
return tensor.to(dtype=dtype if dtype is not None else torch.get_autocast_cpu_dtype())
|
249 |
+
else:
|
250 |
+
return tensor
|
251 |
+
|
252 |
+
def reset_parameters(self):
|
253 |
+
if self.weight is not None:
|
254 |
+
torch.nn.init.ones_(self.weight) # type: ignore
|
255 |
+
if self.bias is not None:
|
256 |
+
torch.nn.init.zeros_(self.bias) # type: ignore
|
257 |
+
|
258 |
+
|
259 |
+
class LayerNorm(LayerNormBase):
|
260 |
+
"""
|
261 |
+
The default :class:`LayerNorm` implementation which can optionally run in low precision.
|
262 |
+
"""
|
263 |
+
|
264 |
+
def __init__(
|
265 |
+
self,
|
266 |
+
config: ModelConfig,
|
267 |
+
size: Optional[int] = None,
|
268 |
+
low_precision: bool = False,
|
269 |
+
elementwise_affine: Optional[bool] = None,
|
270 |
+
):
|
271 |
+
super().__init__(config, size=size, elementwise_affine=elementwise_affine)
|
272 |
+
self.low_precision = low_precision
|
273 |
+
|
274 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
275 |
+
if self.low_precision:
|
276 |
+
module_device = x.device
|
277 |
+
downcast_x = self._cast_if_autocast_enabled(x)
|
278 |
+
downcast_weight = (
|
279 |
+
self._cast_if_autocast_enabled(self.weight) if self.weight is not None else self.weight
|
280 |
+
)
|
281 |
+
downcast_bias = self._cast_if_autocast_enabled(self.bias) if self.bias is not None else self.bias
|
282 |
+
with torch.autocast(enabled=False, device_type=module_device.type):
|
283 |
+
return F.layer_norm(
|
284 |
+
downcast_x, self.normalized_shape, weight=downcast_weight, bias=downcast_bias, eps=self.eps
|
285 |
+
)
|
286 |
+
else:
|
287 |
+
return F.layer_norm(x, self.normalized_shape, weight=self.weight, bias=self.bias, eps=self.eps)
|
288 |
+
|
289 |
+
|
290 |
+
class RMSLayerNorm(LayerNormBase):
|
291 |
+
"""
|
292 |
+
RMS layer norm, a simplified :class:`LayerNorm` implementation
|
293 |
+
"""
|
294 |
+
|
295 |
+
def __init__(
|
296 |
+
self,
|
297 |
+
config: ModelConfig,
|
298 |
+
size: Optional[int] = None,
|
299 |
+
elementwise_affine: Optional[bool] = None,
|
300 |
+
):
|
301 |
+
super().__init__(config, size=size, elementwise_affine=elementwise_affine)
|
302 |
+
|
303 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
304 |
+
with torch.autocast(enabled=False, device_type=x.device.type):
|
305 |
+
og_dtype = x.dtype
|
306 |
+
x = x.to(torch.float32)
|
307 |
+
variance = x.pow(2).mean(-1, keepdim=True)
|
308 |
+
x = x * torch.rsqrt(variance + self.eps)
|
309 |
+
x = x.to(og_dtype)
|
310 |
+
|
311 |
+
if self.weight is not None:
|
312 |
+
if self.bias is not None:
|
313 |
+
return self.weight * x + self.bias
|
314 |
+
else:
|
315 |
+
return self.weight * x
|
316 |
+
else:
|
317 |
+
return x
|
318 |
+
|
319 |
+
|
320 |
+
class RotaryEmbedding(nn.Module):
|
321 |
+
"""
|
322 |
+
[Rotary positional embeddings (RoPE)](https://arxiv.org/abs/2104.09864).
|
323 |
+
"""
|
324 |
+
|
325 |
+
def __init__(self, config: ModelConfig, cache: BufferCache):
|
326 |
+
super().__init__()
|
327 |
+
self.config = config
|
328 |
+
self.__cache = cache
|
329 |
+
# Warm up cache.
|
330 |
+
self.get_rotary_embedding(config.max_sequence_length, _non_meta_init_device(config))
|
331 |
+
|
332 |
+
def get_rotary_embedding(self, seq_len: int, device: torch.device) -> Tuple[torch.Tensor, torch.Tensor]:
|
333 |
+
if (
|
334 |
+
(pos_sin := self.__cache.get("rope_pos_sin")) is not None
|
335 |
+
and (pos_cos := self.__cache.get("rope_pos_cos")) is not None
|
336 |
+
and pos_sin.shape[-2] >= seq_len
|
337 |
+
and pos_cos.shape[-2] >= seq_len
|
338 |
+
):
|
339 |
+
if pos_sin.device != device:
|
340 |
+
pos_sin = pos_sin.to(device)
|
341 |
+
self.__cache["rope_pos_sin"] = pos_sin
|
342 |
+
if pos_cos.device != device:
|
343 |
+
pos_cos = pos_cos.to(device)
|
344 |
+
self.__cache["rope_pos_cos"] = pos_cos
|
345 |
+
return pos_sin[:, :, :seq_len, :], pos_cos[:, :, :seq_len, :]
|
346 |
+
|
347 |
+
with torch.autocast(device.type, enabled=False):
|
348 |
+
dim = self.config.d_model // self.config.n_heads
|
349 |
+
inv_freq = 1.0 / (
|
350 |
+
self.config.rope_theta ** (torch.arange(0, dim, 2, device=device, dtype=torch.float) / dim)
|
351 |
+
)
|
352 |
+
seq = torch.arange(seq_len, device=device, dtype=torch.float)
|
353 |
+
freqs = einsum("i , j -> i j", seq, inv_freq)
|
354 |
+
positions = torch.cat((freqs, freqs), dim=-1)
|
355 |
+
pos_sin, pos_cos = positions.sin()[None, None, :, :], positions.cos()[None, None, :, :]
|
356 |
+
self.__cache["rope_pos_sin"] = pos_sin
|
357 |
+
self.__cache["rope_pos_cos"] = pos_cos
|
358 |
+
return pos_sin, pos_cos
|
359 |
+
|
360 |
+
def rotate_half(self, x: torch.Tensor) -> torch.Tensor:
|
361 |
+
B, nh, T, hs = x.size()
|
362 |
+
x = x.view(B, nh, T, 2, hs // 2)
|
363 |
+
x1, x2 = x.unbind(dim=-2)
|
364 |
+
return torch.cat((-x2, x1), dim=-1)
|
365 |
+
|
366 |
+
def apply_rotary_pos_emb(self, pos_sin: torch.Tensor, pos_cos: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
|
367 |
+
return ((t * pos_cos) + (self.rotate_half(t) * pos_sin)).to(t.dtype)
|
368 |
+
|
369 |
+
def forward(self, q: torch.Tensor, k: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
370 |
+
if self.config.rope_full_precision:
|
371 |
+
q_, k_ = q.float(), k.float()
|
372 |
+
else:
|
373 |
+
q_, k_ = q, k
|
374 |
+
|
375 |
+
with torch.autocast(q.device.type, enabled=False):
|
376 |
+
query_len, key_len = q_.shape[-2], k_.shape[-2] # could be different if layer_past not None
|
377 |
+
pos_sin, pos_cos = self.get_rotary_embedding(key_len, q_.device)
|
378 |
+
pos_sin = pos_sin.type_as(q_)
|
379 |
+
pos_cos = pos_cos.type_as(q_)
|
380 |
+
q_ = self.apply_rotary_pos_emb(
|
381 |
+
pos_sin[:, :, key_len - query_len : key_len, :],
|
382 |
+
pos_cos[:, :, key_len - query_len : key_len, :],
|
383 |
+
q_,
|
384 |
+
)
|
385 |
+
k_ = self.apply_rotary_pos_emb(pos_sin, pos_cos, k_)
|
386 |
+
return q_.type_as(q), k_.type_as(k)
|
387 |
+
|
388 |
+
|
389 |
+
class Activation(nn.Module):
|
390 |
+
def __init__(self, config: ModelConfig):
|
391 |
+
super().__init__()
|
392 |
+
self.config = config
|
393 |
+
|
394 |
+
@abstractmethod
|
395 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
396 |
+
raise NotImplementedError
|
397 |
+
|
398 |
+
@property
|
399 |
+
@abstractmethod
|
400 |
+
def output_multiplier(self) -> float:
|
401 |
+
raise NotImplementedError
|
402 |
+
|
403 |
+
@classmethod
|
404 |
+
def build(cls, config: ModelConfig) -> Activation:
|
405 |
+
if config.activation_type == ActivationType.gelu:
|
406 |
+
return cast(Activation, GELU(approximate="none"))
|
407 |
+
elif config.activation_type == ActivationType.relu:
|
408 |
+
return cast(Activation, ReLU(inplace=False))
|
409 |
+
elif config.activation_type == ActivationType.swiglu:
|
410 |
+
return SwiGLU(config)
|
411 |
+
else:
|
412 |
+
raise NotImplementedError(f"Unknown activation: '{config.activation_type}'")
|
413 |
+
|
414 |
+
|
415 |
+
class GELU(nn.GELU):
|
416 |
+
@property
|
417 |
+
def output_multiplier(self) -> float:
|
418 |
+
return 1.0
|
419 |
+
|
420 |
+
|
421 |
+
class ReLU(nn.ReLU):
|
422 |
+
@property
|
423 |
+
def output_multiplier(self) -> float:
|
424 |
+
return 1.0
|
425 |
+
|
426 |
+
|
427 |
+
class SwiGLU(Activation):
|
428 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
429 |
+
x, gate = x.chunk(2, dim=-1)
|
430 |
+
return F.silu(gate) * x
|
431 |
+
|
432 |
+
@property
|
433 |
+
def output_multiplier(self) -> float:
|
434 |
+
return 0.5
|
435 |
+
|
436 |
+
|
437 |
+
def causal_attention_bias(seq_len: int, device: torch.device) -> torch.FloatTensor:
|
438 |
+
att_bias = torch.triu(
|
439 |
+
torch.ones(seq_len, seq_len, device=device, dtype=torch.float),
|
440 |
+
diagonal=1,
|
441 |
+
)
|
442 |
+
att_bias.masked_fill_(att_bias == 1, torch.finfo(att_bias.dtype).min)
|
443 |
+
return att_bias.view(1, 1, seq_len, seq_len) # type: ignore
|
444 |
+
|
445 |
+
|
446 |
+
def get_causal_attention_bias(cache: BufferCache, seq_len: int, device: torch.device) -> torch.Tensor:
|
447 |
+
if (causal_bias := cache.get("causal_attention_bias")) is not None and causal_bias.shape[-1] >= seq_len:
|
448 |
+
if causal_bias.device != device:
|
449 |
+
causal_bias = causal_bias.to(device)
|
450 |
+
cache["causal_attention_bias"] = causal_bias
|
451 |
+
return causal_bias
|
452 |
+
with torch.autocast(device.type, enabled=False):
|
453 |
+
causal_bias = causal_attention_bias(seq_len, device)
|
454 |
+
cache["causal_attention_bias"] = causal_bias
|
455 |
+
return causal_bias
|
456 |
+
|
457 |
+
|
458 |
+
def alibi_attention_bias(seq_len: int, config: ModelConfig, device: torch.device) -> torch.FloatTensor:
|
459 |
+
alibi_bias = torch.arange(1 - seq_len, 1, dtype=torch.float, device=device).view(1, 1, 1, seq_len)
|
460 |
+
|
461 |
+
# shape: (1, 1, seq_len, seq_len)
|
462 |
+
alibi_bias = alibi_bias - torch.arange(1 - seq_len, 1, dtype=torch.float, device=device).view(1, 1, seq_len, 1)
|
463 |
+
alibi_bias.abs_().mul_(-1)
|
464 |
+
|
465 |
+
# shape: (n_heads,)
|
466 |
+
m = torch.arange(1, config.n_heads + 1, dtype=torch.float, device=device)
|
467 |
+
m.mul_(config.alibi_bias_max / config.n_heads)
|
468 |
+
|
469 |
+
# shape: (1, n_heads, seq_len, seq_len)
|
470 |
+
return alibi_bias * (1.0 / (2 ** m.view(1, config.n_heads, 1, 1))) # type: ignore
|
471 |
+
|
472 |
+
|
473 |
+
class OLMoBlock(nn.Module):
|
474 |
+
"""
|
475 |
+
A base class for transformer block implementations.
|
476 |
+
"""
|
477 |
+
|
478 |
+
def __init__(self, layer_id: int, config: ModelConfig, cache: BufferCache):
|
479 |
+
super().__init__()
|
480 |
+
self.layer_id = layer_id
|
481 |
+
self.config = config
|
482 |
+
self.hidden_size = (
|
483 |
+
config.mlp_hidden_size if config.mlp_hidden_size is not None else config.mlp_ratio * config.d_model
|
484 |
+
)
|
485 |
+
self.__cache = cache
|
486 |
+
assert config.d_model % config.n_heads == 0
|
487 |
+
|
488 |
+
self._activation_checkpoint_fn: Optional[Callable] = None
|
489 |
+
|
490 |
+
# Dropout.
|
491 |
+
self.dropout = Dropout(config.residual_dropout)
|
492 |
+
|
493 |
+
# Layer norms.
|
494 |
+
self.k_norm: Optional[LayerNormBase] = None
|
495 |
+
self.q_norm: Optional[LayerNormBase] = None
|
496 |
+
if config.attention_layer_norm:
|
497 |
+
assert config.effective_n_kv_heads is not None
|
498 |
+
self.k_norm = LayerNormBase.build(
|
499 |
+
config,
|
500 |
+
size=(config.d_model // config.n_heads) * config.effective_n_kv_heads,
|
501 |
+
elementwise_affine=config.attention_layer_norm_with_affine,
|
502 |
+
)
|
503 |
+
self.q_norm = LayerNormBase.build(config, elementwise_affine=config.attention_layer_norm_with_affine)
|
504 |
+
|
505 |
+
# Make sure QKV clip coefficient is positive, otherwise it's not well-defined.
|
506 |
+
if config.clip_qkv is not None:
|
507 |
+
assert config.clip_qkv > 0
|
508 |
+
|
509 |
+
# Activation function.
|
510 |
+
self.act = Activation.build(config)
|
511 |
+
assert (self.act.output_multiplier * self.hidden_size) % 1 == 0
|
512 |
+
|
513 |
+
# Attention output projection.
|
514 |
+
self.attn_out = nn.Linear(
|
515 |
+
config.d_model, config.d_model, bias=config.include_bias, device=config.init_device
|
516 |
+
)
|
517 |
+
|
518 |
+
# Feed-forward output projection.
|
519 |
+
self.ff_out = nn.Linear(
|
520 |
+
int(self.act.output_multiplier * self.hidden_size),
|
521 |
+
config.d_model,
|
522 |
+
bias=config.include_bias,
|
523 |
+
device=config.init_device,
|
524 |
+
)
|
525 |
+
self.ff_out._is_residual = True # type: ignore
|
526 |
+
|
527 |
+
# Rotary embeddings.
|
528 |
+
if self.config.rope:
|
529 |
+
self.rotary_emb = RotaryEmbedding(config, self.__cache)
|
530 |
+
|
531 |
+
self.flash_attn_func = None
|
532 |
+
self.flash_attn_varlen_func = None
|
533 |
+
if config.flash_attention:
|
534 |
+
try:
|
535 |
+
from flash_attn import ( # type: ignore
|
536 |
+
flash_attn_func,
|
537 |
+
flash_attn_varlen_func,
|
538 |
+
)
|
539 |
+
|
540 |
+
self.flash_attn_func = flash_attn_func
|
541 |
+
self.flash_attn_varlen_func = flash_attn_varlen_func
|
542 |
+
except ModuleNotFoundError:
|
543 |
+
pass
|
544 |
+
|
545 |
+
def reset_parameters(self):
|
546 |
+
if self.k_norm is not None:
|
547 |
+
self.k_norm.reset_parameters()
|
548 |
+
if self.q_norm is not None:
|
549 |
+
self.q_norm.reset_parameters()
|
550 |
+
|
551 |
+
if self.config.init_fn == InitFnType.normal:
|
552 |
+
attn_out_std = ff_out_std = self.config.init_std
|
553 |
+
cutoff_factor = self.config.init_cutoff_factor
|
554 |
+
|
555 |
+
elif self.config.init_fn == InitFnType.mitchell:
|
556 |
+
attn_out_std = 1 / (math.sqrt(2 * self.config.d_model * (self.layer_id + 1)))
|
557 |
+
ff_out_std = 1 / (math.sqrt(2 * self.ff_out.in_features * (self.layer_id + 1)))
|
558 |
+
cutoff_factor = self.config.init_cutoff_factor or 3.0
|
559 |
+
|
560 |
+
elif self.config.init_fn == InitFnType.full_megatron:
|
561 |
+
attn_out_std = ff_out_std = self.config.init_std / math.sqrt(2.0 * self.config.n_layers)
|
562 |
+
cutoff_factor = self.config.init_cutoff_factor or 3.0
|
563 |
+
|
564 |
+
else:
|
565 |
+
raise NotImplementedError(self.config.init_fn)
|
566 |
+
|
567 |
+
init_normal(self.attn_out, std=attn_out_std, init_cutoff_factor=cutoff_factor)
|
568 |
+
init_normal(self.ff_out, std=ff_out_std, init_cutoff_factor=cutoff_factor)
|
569 |
+
|
570 |
+
def set_activation_checkpointing(
|
571 |
+
self, strategy: Optional[ActivationCheckpointingStrategy], checkpoint_func: Optional[Callable] = None
|
572 |
+
):
|
573 |
+
if strategy == ActivationCheckpointingStrategy.fine_grained:
|
574 |
+
self._activation_checkpoint_fn = checkpoint_func or activation_checkpoint_function(self.config)
|
575 |
+
else:
|
576 |
+
self._activation_checkpoint_fn = None
|
577 |
+
|
578 |
+
@classmethod
|
579 |
+
def _cast_attn_bias(cls, bias: torch.Tensor, input_dtype: torch.dtype) -> torch.Tensor:
|
580 |
+
target_dtype = input_dtype
|
581 |
+
# NOTE: `is_autocast_enabled()` only checks for CUDA autocast, so we use the separate function
|
582 |
+
# `is_autocast_cpu_enabled()` for CPU autocast.
|
583 |
+
# See https://github.com/pytorch/pytorch/issues/110966.
|
584 |
+
if bias.device.type == "cuda" and torch.is_autocast_enabled():
|
585 |
+
target_dtype = torch.get_autocast_gpu_dtype()
|
586 |
+
elif bias.device.type == "cpu" and torch.is_autocast_cpu_enabled():
|
587 |
+
target_dtype = torch.get_autocast_cpu_dtype()
|
588 |
+
elif bias.device.type == "mps":
|
589 |
+
target_dtype = torch.get_autocast_dtype("mps")
|
590 |
+
if bias.dtype != target_dtype:
|
591 |
+
bias = bias.to(target_dtype)
|
592 |
+
ensure_finite_(bias, check_neg_inf=True, check_pos_inf=False)
|
593 |
+
return bias
|
594 |
+
|
595 |
+
def _scaled_dot_product_attention(
|
596 |
+
self,
|
597 |
+
q: torch.Tensor,
|
598 |
+
k: torch.Tensor,
|
599 |
+
v: torch.Tensor,
|
600 |
+
attn_mask: Optional[torch.Tensor] = None,
|
601 |
+
dropout_p: float = 0.0,
|
602 |
+
is_causal: bool = False,
|
603 |
+
max_doc_len: Optional[int] = None,
|
604 |
+
cu_doc_lens: Optional[torch.Tensor] = None,
|
605 |
+
) -> torch.Tensor:
|
606 |
+
"""
|
607 |
+
Computes scaled dot product attention on query, key and value tensors, using an optional
|
608 |
+
attention mask if passed, and applying dropout if a probability greater than 0.0 is specified.
|
609 |
+
"""
|
610 |
+
if max_doc_len is not None and cu_doc_lens is not None:
|
611 |
+
assert self.flash_attn_varlen_func is not None, "flash-attn is required for document masking"
|
612 |
+
assert attn_mask is None, "attn-mask is currently not supported with document masking"
|
613 |
+
B, T, D = q.size(0), q.size(2), q.size(3)
|
614 |
+
r = self.flash_attn_varlen_func(
|
615 |
+
q.transpose(1, 2).view(B * T, -1, D),
|
616 |
+
k.transpose(1, 2).view(B * T, -1, D),
|
617 |
+
v.transpose(1, 2).view(B * T, -1, D),
|
618 |
+
cu_doc_lens,
|
619 |
+
cu_doc_lens,
|
620 |
+
max_doc_len,
|
621 |
+
max_doc_len,
|
622 |
+
dropout_p=dropout_p,
|
623 |
+
causal=is_causal,
|
624 |
+
)
|
625 |
+
return r.view(B, T, -1, D).transpose(1, 2)
|
626 |
+
elif self.flash_attn_func is not None and attn_mask is None:
|
627 |
+
r = self.flash_attn_func(
|
628 |
+
q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), dropout_p=dropout_p, causal=is_causal
|
629 |
+
)
|
630 |
+
return r.transpose(1, 2)
|
631 |
+
else:
|
632 |
+
# torch's sdpa doesn't support GQA, so we're doing this
|
633 |
+
assert k.size(1) == v.size(1)
|
634 |
+
num_kv_heads = k.size(1)
|
635 |
+
num_q_heads = q.size(1)
|
636 |
+
if num_q_heads != num_kv_heads:
|
637 |
+
assert num_q_heads % num_kv_heads == 0
|
638 |
+
k = k.repeat_interleave(num_q_heads // num_kv_heads, dim=1, output_size=num_q_heads)
|
639 |
+
v = v.repeat_interleave(num_q_heads // num_kv_heads, dim=1, output_size=num_q_heads)
|
640 |
+
|
641 |
+
return F.scaled_dot_product_attention(
|
642 |
+
q,
|
643 |
+
k,
|
644 |
+
v,
|
645 |
+
attn_mask=attn_mask,
|
646 |
+
dropout_p=dropout_p,
|
647 |
+
is_causal=is_causal,
|
648 |
+
)
|
649 |
+
|
650 |
+
def attention(
|
651 |
+
self,
|
652 |
+
q: torch.Tensor,
|
653 |
+
k: torch.Tensor,
|
654 |
+
v: torch.Tensor,
|
655 |
+
attention_bias: Optional[torch.Tensor] = None,
|
656 |
+
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
657 |
+
use_cache: bool = False,
|
658 |
+
max_doc_len: Optional[int] = None,
|
659 |
+
cu_doc_lens: Optional[torch.Tensor] = None,
|
660 |
+
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
|
661 |
+
B, T, C = q.size() # batch size, sequence length, d_model
|
662 |
+
dtype = k.dtype
|
663 |
+
|
664 |
+
# Optionally apply layer norm to keys and queries.
|
665 |
+
if self.q_norm is not None and self.k_norm is not None:
|
666 |
+
q = self.q_norm(q).to(dtype=dtype)
|
667 |
+
k = self.k_norm(k).to(dtype=dtype)
|
668 |
+
|
669 |
+
# Move head forward to be next to the batch dim.
|
670 |
+
# shape: (B, nh, T, hs)
|
671 |
+
q = q.view(B, T, self.config.n_heads, C // self.config.n_heads).transpose(1, 2)
|
672 |
+
# shape: (B, n_kv_h, T, hs)
|
673 |
+
k = k.view(B, T, self.config.effective_n_kv_heads, C // self.config.n_heads).transpose(1, 2)
|
674 |
+
# shape: (B, n_kv_h, T, hs)
|
675 |
+
v = v.view(B, T, self.config.effective_n_kv_heads, C // self.config.n_heads).transpose(1, 2)
|
676 |
+
|
677 |
+
if layer_past is not None:
|
678 |
+
past_key, past_value = layer_past
|
679 |
+
k = torch.cat((past_key, k), dim=-2)
|
680 |
+
v = torch.cat((past_value, v), dim=-2)
|
681 |
+
|
682 |
+
present = (k, v) if use_cache else None
|
683 |
+
query_len, key_len = q.shape[-2], k.shape[-2] # could be different if layer_past not None
|
684 |
+
|
685 |
+
if self.config.rope:
|
686 |
+
# Apply rotary embeddings.
|
687 |
+
q, k = self.rotary_emb(q, k)
|
688 |
+
|
689 |
+
if attention_bias is not None:
|
690 |
+
# Resize and cast attention bias.
|
691 |
+
# The current dtype of the attention bias might not match the dtype that the SDP attn function will
|
692 |
+
# run in if AMP is enabled, and this can be a problem if some tokens are masked out due to padding
|
693 |
+
# as down-casting the attention bias to the autocast precision will result in -infs, which will
|
694 |
+
# cause the SDP attn function to produce NaNs.
|
695 |
+
attention_bias = self._cast_attn_bias(
|
696 |
+
attention_bias[:, :, key_len - query_len : key_len, :key_len], dtype
|
697 |
+
)
|
698 |
+
|
699 |
+
# Get the attention scores.
|
700 |
+
# shape: (B, nh, T, hs)
|
701 |
+
att = self._scaled_dot_product_attention(
|
702 |
+
q,
|
703 |
+
k,
|
704 |
+
v,
|
705 |
+
attn_mask=attention_bias,
|
706 |
+
dropout_p=0.0 if not self.training else self.config.attention_dropout,
|
707 |
+
is_causal=attention_bias is None,
|
708 |
+
max_doc_len=max_doc_len,
|
709 |
+
cu_doc_lens=cu_doc_lens,
|
710 |
+
)
|
711 |
+
|
712 |
+
# Re-assemble all head outputs side-by-side.
|
713 |
+
att = att.transpose(1, 2).contiguous().view(B, T, C)
|
714 |
+
|
715 |
+
# Apply output projection.
|
716 |
+
return self.attn_out(att), present
|
717 |
+
|
718 |
+
@abstractmethod
|
719 |
+
def forward(
|
720 |
+
self,
|
721 |
+
x: torch.Tensor,
|
722 |
+
attention_bias: Optional[torch.FloatTensor] = None,
|
723 |
+
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
724 |
+
use_cache: bool = False,
|
725 |
+
max_doc_len: Optional[int] = None,
|
726 |
+
cu_doc_lens: Optional[torch.Tensor] = None,
|
727 |
+
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
|
728 |
+
raise NotImplementedError
|
729 |
+
|
730 |
+
@classmethod
|
731 |
+
def build(cls, layer_id: int, config: ModelConfig, cache: BufferCache) -> OLMoBlock:
|
732 |
+
if config.block_type == BlockType.sequential:
|
733 |
+
return OLMoSequentialBlock(layer_id, config, cache)
|
734 |
+
elif config.block_type == BlockType.llama:
|
735 |
+
return OLMoLlamaBlock(layer_id, config, cache)
|
736 |
+
else:
|
737 |
+
raise NotImplementedError(f"Unknown block type: '{config.block_type}'")
|
738 |
+
|
739 |
+
|
740 |
+
class OLMoSequentialBlock(OLMoBlock):
|
741 |
+
"""
|
742 |
+
This is a typical transformer block where the output is computed as ``MLP(LN(x + Attention(LN(x))))``
|
743 |
+
(plus another skip connection). To compute it as ``LN(MLP(x + LN(Attention(x))))``,
|
744 |
+
use the flag `norm_after`.
|
745 |
+
"""
|
746 |
+
|
747 |
+
def __init__(self, layer_id: int, config: ModelConfig, cache: BufferCache):
|
748 |
+
super().__init__(layer_id, config, cache)
|
749 |
+
# Attention input projection. Projects x -> (q, k, v)
|
750 |
+
self.use_ATF = config.use_ATF
|
751 |
+
|
752 |
+
head_dim = config.d_model // config.n_heads
|
753 |
+
self.fused_dims = (
|
754 |
+
config.d_model,
|
755 |
+
config.effective_n_kv_heads * head_dim,
|
756 |
+
config.effective_n_kv_heads * head_dim,
|
757 |
+
)
|
758 |
+
|
759 |
+
|
760 |
+
if self.use_ATF:
|
761 |
+
self.att_proj = FAN(config.d_model, sum(self.fused_dims), config, activation=config.attention_activation)
|
762 |
+
else:
|
763 |
+
self.att_proj = nn.Linear(
|
764 |
+
config.d_model, sum(self.fused_dims), bias=config.include_bias, device=config.init_device
|
765 |
+
)
|
766 |
+
|
767 |
+
# Feed-forward input projection.
|
768 |
+
self.ff_proj = nn.Linear(
|
769 |
+
config.d_model, self.hidden_size, bias=config.include_bias, device=config.init_device
|
770 |
+
)
|
771 |
+
|
772 |
+
# Layer norms.
|
773 |
+
self.attn_norm = LayerNorm.build(config, size=config.d_model)
|
774 |
+
self.ff_norm = LayerNorm.build(config, size=config.d_model)
|
775 |
+
|
776 |
+
def reset_parameters(self):
|
777 |
+
super().reset_parameters()
|
778 |
+
self.attn_norm.reset_parameters()
|
779 |
+
self.ff_norm.reset_parameters()
|
780 |
+
# NOTE: the standard deviation for these weights does not depend on the layer.
|
781 |
+
|
782 |
+
if self.config.init_fn == InitFnType.normal:
|
783 |
+
std = self.config.init_std
|
784 |
+
cutoff_factor = self.config.init_cutoff_factor
|
785 |
+
elif self.config.init_fn == InitFnType.mitchell:
|
786 |
+
std = 1 / math.sqrt(self.config.d_model)
|
787 |
+
cutoff_factor = self.config.init_cutoff_factor or 3.0
|
788 |
+
elif self.config.init_fn == InitFnType.full_megatron:
|
789 |
+
std = self.config.init_std
|
790 |
+
cutoff_factor = self.config.init_cutoff_factor or 3.0
|
791 |
+
else:
|
792 |
+
raise NotImplementedError(self.config.init_fn)
|
793 |
+
|
794 |
+
if self.use_ATF:
|
795 |
+
init_normal(self.att_proj.fanlayer.input_linear, std, cutoff_factor)
|
796 |
+
init_normal(self.att_proj.linear, std, cutoff_factor)
|
797 |
+
else:
|
798 |
+
init_normal(self.att_proj, std, cutoff_factor)
|
799 |
+
|
800 |
+
init_normal(self.ff_proj, std, cutoff_factor)
|
801 |
+
|
802 |
+
def forward(
|
803 |
+
self,
|
804 |
+
x: torch.Tensor,
|
805 |
+
attention_bias: Optional[torch.Tensor] = None,
|
806 |
+
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
807 |
+
use_cache: bool = False,
|
808 |
+
max_doc_len: Optional[int] = None,
|
809 |
+
cu_doc_lens: Optional[torch.Tensor] = None,
|
810 |
+
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
|
811 |
+
# Get query, key, value projections.
|
812 |
+
# shape:
|
813 |
+
# - for regular attn q, k, v: (batch_size, seq_len, d_model)
|
814 |
+
# - for multi-query attn q: (batch_size, seq_len, d_model)
|
815 |
+
# k, v: (batch_size, seq_len, d_model // n_heads)
|
816 |
+
# - for group query attn q: (batch_size, seq_len, d_model)
|
817 |
+
# k, v: (batch_size, seq_len, d_model // n_kv_heads)
|
818 |
+
|
819 |
+
# apply norm before
|
820 |
+
if not self.config.norm_after:
|
821 |
+
if self._activation_checkpoint_fn is not None:
|
822 |
+
h = self._activation_checkpoint_fn(self.attn_norm, x)
|
823 |
+
else:
|
824 |
+
h = self.attn_norm(x)
|
825 |
+
else:
|
826 |
+
h = x
|
827 |
+
|
828 |
+
qkv = self.att_proj(h)
|
829 |
+
|
830 |
+
if self.config.clip_qkv is not None:
|
831 |
+
qkv.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
|
832 |
+
|
833 |
+
q, k, v = qkv.split(self.fused_dims, dim=-1)
|
834 |
+
|
835 |
+
# Get attention scores.
|
836 |
+
if self._activation_checkpoint_fn is not None:
|
837 |
+
att, cache = self._activation_checkpoint_fn( # type: ignore
|
838 |
+
self.attention,
|
839 |
+
q,
|
840 |
+
k,
|
841 |
+
v,
|
842 |
+
attention_bias,
|
843 |
+
layer_past=layer_past,
|
844 |
+
use_cache=use_cache,
|
845 |
+
max_doc_len=max_doc_len,
|
846 |
+
cu_doc_lens=cu_doc_lens,
|
847 |
+
)
|
848 |
+
else:
|
849 |
+
att, cache = self.attention(
|
850 |
+
q,
|
851 |
+
k,
|
852 |
+
v,
|
853 |
+
attention_bias,
|
854 |
+
layer_past=layer_past,
|
855 |
+
use_cache=use_cache,
|
856 |
+
max_doc_len=max_doc_len,
|
857 |
+
cu_doc_lens=cu_doc_lens,
|
858 |
+
)
|
859 |
+
|
860 |
+
if self.config.norm_after:
|
861 |
+
if self._activation_checkpoint_fn is not None:
|
862 |
+
att = self._activation_checkpoint_fn(self.attn_norm, att)
|
863 |
+
else:
|
864 |
+
att = self.attn_norm(att)
|
865 |
+
|
866 |
+
# Add attention scores.
|
867 |
+
# shape: (B, T, C)
|
868 |
+
x = x + self.dropout(att)
|
869 |
+
|
870 |
+
# Add feed-forward projection.
|
871 |
+
# shape: (batch_size, seq_len, d_model)
|
872 |
+
og_x = x
|
873 |
+
|
874 |
+
if not self.config.norm_after:
|
875 |
+
if self._activation_checkpoint_fn is not None:
|
876 |
+
x = self._activation_checkpoint_fn(self.ff_norm, x) # type: ignore
|
877 |
+
else:
|
878 |
+
x = self.ff_norm(x)
|
879 |
+
|
880 |
+
x = self.ff_proj(x)
|
881 |
+
|
882 |
+
if self._activation_checkpoint_fn is not None:
|
883 |
+
x = self._activation_checkpoint_fn(self.act, x) # type: ignore
|
884 |
+
else:
|
885 |
+
x = self.act(x)
|
886 |
+
x = self.ff_out(x)
|
887 |
+
|
888 |
+
if self.config.norm_after:
|
889 |
+
if self._activation_checkpoint_fn is not None:
|
890 |
+
x = self._activation_checkpoint_fn(self.ff_norm, x) # type: ignore
|
891 |
+
else:
|
892 |
+
x = self.ff_norm(x)
|
893 |
+
|
894 |
+
x = self.dropout(x)
|
895 |
+
x = og_x + x
|
896 |
+
|
897 |
+
return x, cache
|
898 |
+
|
899 |
+
|
900 |
+
class OLMoLlamaBlock(OLMoBlock):
|
901 |
+
"""
|
902 |
+
This is a transformer block where the output is computed as ``MLP(LN(x + Attention(LN(x))))``
|
903 |
+
(plus another skip connection). This block is similar to `OLMoSequentialBlock`
|
904 |
+
but some operations have slightly different implementations to imitate the
|
905 |
+
behavior of Llama.
|
906 |
+
"""
|
907 |
+
|
908 |
+
def __init__(self, layer_id: int, config: ModelConfig, cache: BufferCache):
|
909 |
+
super().__init__(layer_id, config, cache)
|
910 |
+
# Layer norms.
|
911 |
+
self.use_ATF = config.use_ATF
|
912 |
+
self.attn_norm = LayerNorm.build(config)
|
913 |
+
self.ff_norm = LayerNorm.build(config)
|
914 |
+
self.__cache = cache
|
915 |
+
|
916 |
+
# Attention input projection. Projects x -> (q, k, v)
|
917 |
+
if config.multi_query_attention:
|
918 |
+
q_proj_out_dim = config.d_model
|
919 |
+
k_proj_out_dim = config.d_model // config.n_heads
|
920 |
+
v_proj_out_dim = config.d_model // config.n_heads
|
921 |
+
else:
|
922 |
+
q_proj_out_dim = config.d_model
|
923 |
+
k_proj_out_dim = config.d_model
|
924 |
+
v_proj_out_dim = config.d_model
|
925 |
+
|
926 |
+
if self.use_ATF:
|
927 |
+
self.q_proj = FAN(config.d_model, q_proj_out_dim, config)
|
928 |
+
self.k_proj = FAN(config.d_model, k_proj_out_dim, config)
|
929 |
+
self.v_proj = FAN(config.d_model, v_proj_out_dim, config)
|
930 |
+
else:
|
931 |
+
self.q_proj = nn.Linear(
|
932 |
+
config.d_model, q_proj_out_dim, bias=config.include_bias, device=config.init_device
|
933 |
+
)
|
934 |
+
self.k_proj = nn.Linear(
|
935 |
+
config.d_model, k_proj_out_dim, bias=config.include_bias, device=config.init_device
|
936 |
+
)
|
937 |
+
self.v_proj = nn.Linear(
|
938 |
+
config.d_model, v_proj_out_dim, bias=config.include_bias, device=config.init_device
|
939 |
+
)
|
940 |
+
|
941 |
+
# Feed-forward input projection.
|
942 |
+
self.ff_proj = nn.Linear(
|
943 |
+
config.d_model, self.hidden_size, bias=config.include_bias, device=config.init_device
|
944 |
+
)
|
945 |
+
|
946 |
+
def reset_parameters(self):
|
947 |
+
super().reset_parameters()
|
948 |
+
self.attn_norm.reset_parameters()
|
949 |
+
self.ff_norm.reset_parameters()
|
950 |
+
# NOTE: the standard deviation for these weights does not depend on the layer.
|
951 |
+
|
952 |
+
if self.config.init_fn == InitFnType.normal:
|
953 |
+
std = self.config.init_std
|
954 |
+
cutoff_factor = self.config.init_cutoff_factor
|
955 |
+
elif self.config.init_fn == InitFnType.mitchell:
|
956 |
+
std = 1 / math.sqrt(self.config.d_model)
|
957 |
+
cutoff_factor = self.config.init_cutoff_factor or 3.0
|
958 |
+
elif self.config.init_fn == InitFnType.full_megatron:
|
959 |
+
std = self.config.init_std
|
960 |
+
cutoff_factor = self.config.init_cutoff_factor or 3.0
|
961 |
+
else:
|
962 |
+
raise NotImplementedError(self.config.init_fn)
|
963 |
+
|
964 |
+
init_normal(self.q_proj, std, cutoff_factor)
|
965 |
+
init_normal(self.k_proj, std, cutoff_factor)
|
966 |
+
init_normal(self.v_proj, std, cutoff_factor)
|
967 |
+
init_normal(self.ff_proj, std, cutoff_factor)
|
968 |
+
|
969 |
+
def _scaled_dot_product_attention(
|
970 |
+
self,
|
971 |
+
q: torch.Tensor,
|
972 |
+
k: torch.Tensor,
|
973 |
+
v: torch.Tensor,
|
974 |
+
attn_mask: Optional[torch.Tensor] = None,
|
975 |
+
dropout_p: float = 0.0,
|
976 |
+
is_causal: bool = False,
|
977 |
+
max_doc_len: Optional[int] = None,
|
978 |
+
cu_doc_lens: Optional[torch.Tensor] = None,
|
979 |
+
) -> torch.Tensor:
|
980 |
+
if max_doc_len is not None or cu_doc_lens is not None:
|
981 |
+
raise NotImplementedError(
|
982 |
+
f"attention document masking is not implemented for {self.__class__.__name__}"
|
983 |
+
)
|
984 |
+
|
985 |
+
attn_weights = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(q.size(-1))
|
986 |
+
|
987 |
+
if is_causal:
|
988 |
+
assert attn_mask is None
|
989 |
+
|
990 |
+
query_len, key_len = q.shape[-2], k.shape[-2] # could be different if layer_past not None
|
991 |
+
attn_bias = get_causal_attention_bias(self.__cache, key_len, q.device)[:, :, :query_len, :key_len]
|
992 |
+
elif attn_mask is not None:
|
993 |
+
attn_bias = attn_mask.to(q.dtype)
|
994 |
+
else:
|
995 |
+
attn_bias = torch.zeros_like(attn_weights)
|
996 |
+
|
997 |
+
attn_weights += attn_bias
|
998 |
+
attn_weights = nn.functional.softmax(attn_weights, dim=-1).to(q.dtype)
|
999 |
+
attn_weights = nn.functional.dropout(attn_weights, p=dropout_p)
|
1000 |
+
return torch.matmul(attn_weights, v)
|
1001 |
+
|
1002 |
+
def forward(
|
1003 |
+
self,
|
1004 |
+
x: torch.Tensor,
|
1005 |
+
attention_bias: Optional[torch.Tensor] = None,
|
1006 |
+
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
1007 |
+
use_cache: bool = False,
|
1008 |
+
max_doc_len: Optional[int] = None,
|
1009 |
+
cu_doc_lens: Optional[torch.Tensor] = None,
|
1010 |
+
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
|
1011 |
+
# Get query, key, value projections.
|
1012 |
+
# shape:
|
1013 |
+
# - for regular attn q, k, v: (batch_size, seq_len, d_model)
|
1014 |
+
# - for multi-query attn q: (batch_size, seq_len, d_model)
|
1015 |
+
# k, v: (batch_size, seq_len, d_model // n_heads)
|
1016 |
+
x_normed = self.attn_norm(x)
|
1017 |
+
q = self.q_proj(x_normed)
|
1018 |
+
k = self.k_proj(x_normed)
|
1019 |
+
v = self.v_proj(x_normed)
|
1020 |
+
|
1021 |
+
if self.config.clip_qkv is not None:
|
1022 |
+
q.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
|
1023 |
+
k.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
|
1024 |
+
v.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
|
1025 |
+
|
1026 |
+
# Get attention scores.
|
1027 |
+
att, cache = self.attention(
|
1028 |
+
q,
|
1029 |
+
k,
|
1030 |
+
v,
|
1031 |
+
attention_bias,
|
1032 |
+
layer_past=layer_past,
|
1033 |
+
use_cache=use_cache,
|
1034 |
+
max_doc_len=max_doc_len,
|
1035 |
+
cu_doc_lens=cu_doc_lens,
|
1036 |
+
)
|
1037 |
+
|
1038 |
+
# Add attention scores.
|
1039 |
+
# shape: (B, T, C)
|
1040 |
+
x = x + self.dropout(att)
|
1041 |
+
|
1042 |
+
# Add feed-forward projection.
|
1043 |
+
# shape: (batch_size, seq_len, d_model)
|
1044 |
+
og_x = x
|
1045 |
+
if self._activation_checkpoint_fn is not None:
|
1046 |
+
x = self._activation_checkpoint_fn(self.ff_norm, x) # type: ignore
|
1047 |
+
else:
|
1048 |
+
x = self.ff_norm(x)
|
1049 |
+
x = self.ff_proj(x)
|
1050 |
+
if self._activation_checkpoint_fn is not None:
|
1051 |
+
x = self._activation_checkpoint_fn(self.act, x) # type: ignore
|
1052 |
+
else:
|
1053 |
+
x = self.act(x)
|
1054 |
+
x = self.ff_out(x)
|
1055 |
+
x = self.dropout(x)
|
1056 |
+
x = og_x + x
|
1057 |
+
|
1058 |
+
return x, cache
|
1059 |
+
|
1060 |
+
|
1061 |
+
class OLMoOutput(NamedTuple):
|
1062 |
+
logits: torch.FloatTensor
|
1063 |
+
"""
|
1064 |
+
A tensor of shape `(batch_size, seq_len, vocab_size)` representing the log probabilities
|
1065 |
+
for the next token *before* normalization via (log) softmax.
|
1066 |
+
"""
|
1067 |
+
|
1068 |
+
attn_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]]
|
1069 |
+
"""
|
1070 |
+
Attention keys and values from each block.
|
1071 |
+
"""
|
1072 |
+
|
1073 |
+
hidden_states: Optional[Tuple[torch.Tensor, ...]]
|
1074 |
+
"""
|
1075 |
+
Hidden states from each block.
|
1076 |
+
"""
|
1077 |
+
|
1078 |
+
|
1079 |
+
class OLMoGenerateOutput(NamedTuple):
|
1080 |
+
token_ids: torch.LongTensor
|
1081 |
+
"""
|
1082 |
+
The generated token IDs, a tensor of shape `(batch_size, beam_size, max_steps)`.
|
1083 |
+
These do *not* include the original input IDs.
|
1084 |
+
"""
|
1085 |
+
|
1086 |
+
scores: torch.FloatTensor
|
1087 |
+
"""
|
1088 |
+
The scores of the generated sequences, a tensor of shape `(batch_size, beam_size)`.
|
1089 |
+
"""
|
1090 |
+
|
1091 |
+
|
1092 |
+
class OLMoBlockGroup(nn.ModuleList):
|
1093 |
+
def __init__(self, config: ModelConfig, layer_offset: int, modules: Optional[Iterable[nn.Module]] = None):
|
1094 |
+
super().__init__(modules)
|
1095 |
+
self.config = config
|
1096 |
+
self.layer_offset = layer_offset
|
1097 |
+
self.activation_checkpointing_strategy: Optional[ActivationCheckpointingStrategy] = None
|
1098 |
+
self._activation_checkpoint_fn = activation_checkpoint_function(self.config)
|
1099 |
+
|
1100 |
+
def forward(
|
1101 |
+
self,
|
1102 |
+
x: torch.Tensor,
|
1103 |
+
attention_bias: Optional[torch.FloatTensor] = None,
|
1104 |
+
layers_past: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
|
1105 |
+
use_cache: bool = False,
|
1106 |
+
max_doc_len: Optional[int] = None,
|
1107 |
+
cu_doc_lens: Optional[torch.Tensor] = None,
|
1108 |
+
) -> Tuple[torch.Tensor, Optional[List[Tuple[torch.Tensor, torch.Tensor]]]]:
|
1109 |
+
attn_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = [] if use_cache else None
|
1110 |
+
for block_idx, block in enumerate(self):
|
1111 |
+
layer_past = None if layers_past is None else layers_past[block_idx]
|
1112 |
+
block_idx += self.layer_offset
|
1113 |
+
if should_checkpoint_block(self.activation_checkpointing_strategy, block_idx):
|
1114 |
+
# shape: (batch_size, seq_len, d_model)
|
1115 |
+
x, cache = self._activation_checkpoint_fn( # type: ignore
|
1116 |
+
block,
|
1117 |
+
x,
|
1118 |
+
attention_bias=attention_bias,
|
1119 |
+
layer_past=layer_past,
|
1120 |
+
use_cache=use_cache,
|
1121 |
+
max_doc_len=max_doc_len,
|
1122 |
+
cu_doc_lens=cu_doc_lens,
|
1123 |
+
)
|
1124 |
+
else:
|
1125 |
+
# shape: (batch_size, seq_len, d_model)
|
1126 |
+
x, cache = block(
|
1127 |
+
x,
|
1128 |
+
attention_bias=attention_bias,
|
1129 |
+
layer_past=layer_past,
|
1130 |
+
use_cache=use_cache,
|
1131 |
+
max_doc_len=max_doc_len,
|
1132 |
+
cu_doc_lens=cu_doc_lens,
|
1133 |
+
)
|
1134 |
+
if attn_key_values is not None:
|
1135 |
+
assert cache is not None
|
1136 |
+
attn_key_values.append(cache)
|
1137 |
+
return x, attn_key_values
|
1138 |
+
|
1139 |
+
def reset_parameters(self):
|
1140 |
+
for block in self:
|
1141 |
+
block.reset_parameters()
|
1142 |
+
|
1143 |
+
def set_activation_checkpointing(
|
1144 |
+
self, strategy: Optional[ActivationCheckpointingStrategy], checkpoint_func: Optional[Callable] = None
|
1145 |
+
):
|
1146 |
+
self.activation_checkpointing_strategy = strategy
|
1147 |
+
for block in self:
|
1148 |
+
block.set_activation_checkpointing(strategy, checkpoint_func=checkpoint_func)
|
1149 |
+
|
1150 |
+
|
1151 |
+
class OLMo(nn.Module):
|
1152 |
+
def __init__(self, config: ModelConfig, init_params: bool = True):
|
1153 |
+
super().__init__()
|
1154 |
+
self.config = config
|
1155 |
+
self.__cache = BufferCache()
|
1156 |
+
|
1157 |
+
# Validate config.
|
1158 |
+
if self.config.alibi and self.config.flash_attention:
|
1159 |
+
raise OLMoConfigurationError("ALiBi is currently not supported with FlashAttention")
|
1160 |
+
|
1161 |
+
if self.config.alibi and self.config.rope:
|
1162 |
+
raise OLMoConfigurationError("ALiBi and RoPE are mutually exclusive")
|
1163 |
+
|
1164 |
+
if self.config.embedding_size is not None and self.config.embedding_size != self.config.vocab_size:
|
1165 |
+
if self.config.embedding_size < self.config.vocab_size:
|
1166 |
+
raise OLMoConfigurationError("embedding size should be at least as big as vocab size")
|
1167 |
+
elif self.config.embedding_size % 128 != 0:
|
1168 |
+
import warnings
|
1169 |
+
|
1170 |
+
warnings.warn(
|
1171 |
+
"Embedding size is not a multiple of 128! This could hurt throughput performance.", UserWarning
|
1172 |
+
)
|
1173 |
+
|
1174 |
+
self.activation_checkpointing_strategy: Optional[ActivationCheckpointingStrategy] = None
|
1175 |
+
self._activation_checkpoint_fn: Callable = activation_checkpoint_function(self.config)
|
1176 |
+
|
1177 |
+
if not (
|
1178 |
+
0 < self.config.block_group_size <= self.config.n_layers
|
1179 |
+
and self.config.n_layers % self.config.block_group_size == 0
|
1180 |
+
):
|
1181 |
+
raise OLMoConfigurationError("n layers must be divisible by block group size")
|
1182 |
+
|
1183 |
+
torch.backends.cuda.enable_flash_sdp(True)
|
1184 |
+
torch.backends.cuda.enable_mem_efficient_sdp(False) # this is super slow so make sure torch won't use it
|
1185 |
+
|
1186 |
+
self.transformer = nn.ModuleDict(
|
1187 |
+
dict(
|
1188 |
+
wte=nn.Embedding(
|
1189 |
+
config.embedding_size or config.vocab_size, config.d_model, device=config.init_device
|
1190 |
+
),
|
1191 |
+
emb_drop=Dropout(config.embedding_dropout),
|
1192 |
+
ln_f=LayerNorm.build(config),
|
1193 |
+
)
|
1194 |
+
)
|
1195 |
+
|
1196 |
+
blocks = [OLMoBlock.build(i, config, self.__cache) for i in range(config.n_layers)]
|
1197 |
+
if self.config.block_group_size > 1:
|
1198 |
+
block_groups = [
|
1199 |
+
OLMoBlockGroup(config, i, blocks[i : i + config.block_group_size])
|
1200 |
+
for i in range(0, config.n_layers, config.block_group_size)
|
1201 |
+
]
|
1202 |
+
self.transformer.update({"block_groups": nn.ModuleList(block_groups)})
|
1203 |
+
else:
|
1204 |
+
self.transformer.update({"blocks": nn.ModuleList(blocks)})
|
1205 |
+
|
1206 |
+
if not (self.config.alibi or self.config.rope):
|
1207 |
+
self.transformer.update(
|
1208 |
+
{"wpe": nn.Embedding(config.max_sequence_length, config.d_model, device=config.init_device)}
|
1209 |
+
)
|
1210 |
+
if not config.weight_tying:
|
1211 |
+
self.transformer.update(
|
1212 |
+
{
|
1213 |
+
"ff_out": nn.Linear(
|
1214 |
+
config.d_model,
|
1215 |
+
config.embedding_size or config.vocab_size,
|
1216 |
+
bias=config.include_bias,
|
1217 |
+
device=config.init_device,
|
1218 |
+
)
|
1219 |
+
}
|
1220 |
+
)
|
1221 |
+
if config.embedding_layer_norm:
|
1222 |
+
self.transformer.update({"emb_norm": LayerNorm.build(config)})
|
1223 |
+
|
1224 |
+
# When `init_device="meta"` FSDP will call `reset_parameters()` to initialize weights.
|
1225 |
+
if init_params and self.config.init_device != "meta":
|
1226 |
+
self.reset_parameters()
|
1227 |
+
self.__num_fwd_flops: Optional[int] = None
|
1228 |
+
self.__num_bck_flops: Optional[int] = None
|
1229 |
+
|
1230 |
+
# Warm up cache.
|
1231 |
+
if self.config.alibi:
|
1232 |
+
get_causal_attention_bias(self.__cache, config.max_sequence_length, _non_meta_init_device(config))
|
1233 |
+
self.get_alibi_attention_bias(config.max_sequence_length, _non_meta_init_device(config))
|
1234 |
+
|
1235 |
+
def set_activation_checkpointing(
|
1236 |
+
self, strategy: Optional[ActivationCheckpointingStrategy], checkpoint_func: Optional[Callable] = None
|
1237 |
+
):
|
1238 |
+
self.activation_checkpointing_strategy = strategy
|
1239 |
+
if self.config.block_group_size != 1:
|
1240 |
+
for block_group in self.transformer.block_groups:
|
1241 |
+
block_group.set_activation_checkpointing(strategy, checkpoint_func=checkpoint_func)
|
1242 |
+
else:
|
1243 |
+
for block in self.transformer.blocks:
|
1244 |
+
block.set_activation_checkpointing(strategy, checkpoint_func=checkpoint_func)
|
1245 |
+
|
1246 |
+
@property
|
1247 |
+
def device(self) -> torch.device:
|
1248 |
+
device: torch.device = self.transformer.wte.weight.device # type: ignore
|
1249 |
+
if device.type == "meta":
|
1250 |
+
return _non_meta_init_device(self.config)
|
1251 |
+
else:
|
1252 |
+
return device
|
1253 |
+
|
1254 |
+
def reset_parameters(self):
|
1255 |
+
log.info("Initializing model parameters...")
|
1256 |
+
# Top-level embeddings / linear layers.
|
1257 |
+
|
1258 |
+
if self.config.init_fn == InitFnType.normal:
|
1259 |
+
# Note: We may potentially want to multiply the std by a factor of sqrt(d) in case of `scale_logits`
|
1260 |
+
# and `weight_tying`. However, we are currently not using either, and may need to rethink the init logic
|
1261 |
+
# if/when we do want it.
|
1262 |
+
wte_std = self.config.emb_init_std or self.config.init_std
|
1263 |
+
wte_cutoff_factor = self.config.init_cutoff_factor
|
1264 |
+
elif self.config.init_fn == InitFnType.mitchell:
|
1265 |
+
wte_std = self.config.emb_init_std or 1.0 / math.sqrt(self.config.d_model)
|
1266 |
+
wte_cutoff_factor = self.config.init_cutoff_factor or 3.0
|
1267 |
+
elif self.config.init_fn == InitFnType.full_megatron:
|
1268 |
+
wte_std = self.config.init_std
|
1269 |
+
if self.config.emb_init_std is not None:
|
1270 |
+
wte_std = self.config.emb_init_std
|
1271 |
+
elif self.config.scale_emb_init:
|
1272 |
+
wte_std *= math.sqrt(self.config.d_model)
|
1273 |
+
wte_cutoff_factor = self.config.init_cutoff_factor or 3.0
|
1274 |
+
else:
|
1275 |
+
raise NotImplementedError(self.config.init_fn)
|
1276 |
+
|
1277 |
+
init_normal(self.transformer.wte, std=wte_std, init_cutoff_factor=wte_cutoff_factor)
|
1278 |
+
|
1279 |
+
if hasattr(self.transformer, "wpe"):
|
1280 |
+
if self.config.init_fn == InitFnType.normal:
|
1281 |
+
wpe_std = self.config.init_std
|
1282 |
+
wpe_cutoff_factor = self.config.init_cutoff_factor
|
1283 |
+
elif self.config.init_fn == InitFnType.mitchell:
|
1284 |
+
wpe_std = 1 / math.sqrt(self.config.d_model)
|
1285 |
+
wpe_cutoff_factor = self.config.init_cutoff_factor or 3.0
|
1286 |
+
elif self.config.init_fn == InitFnType.full_megatron:
|
1287 |
+
wpe_std = self.config.init_std
|
1288 |
+
wpe_cutoff_factor = self.config.init_cutoff_factor or 3.0
|
1289 |
+
else:
|
1290 |
+
raise NotImplementedError(self.config.init_fn)
|
1291 |
+
|
1292 |
+
init_normal(self.transformer.wpe, std=wpe_std, init_cutoff_factor=wpe_cutoff_factor)
|
1293 |
+
|
1294 |
+
# Top-level layer norm.
|
1295 |
+
self.transformer.ln_f.reset_parameters() # type: ignore
|
1296 |
+
|
1297 |
+
# Output weights.
|
1298 |
+
if hasattr(self.transformer, "ff_out"):
|
1299 |
+
if self.config.init_fn == InitFnType.normal:
|
1300 |
+
ff_out_std = self.config.init_std
|
1301 |
+
ff_out_cutoff_factor = self.config.init_cutoff_factor
|
1302 |
+
elif self.config.init_fn == InitFnType.mitchell:
|
1303 |
+
ff_out_std = 1 / math.sqrt(self.config.d_model)
|
1304 |
+
ff_out_cutoff_factor = self.config.init_cutoff_factor or 3.0
|
1305 |
+
elif self.config.init_fn == InitFnType.full_megatron:
|
1306 |
+
ff_out_std = 1 / math.sqrt(self.config.d_model)
|
1307 |
+
ff_out_cutoff_factor = self.config.init_cutoff_factor or 3.0
|
1308 |
+
else:
|
1309 |
+
raise NotImplementedError(self.config.init_fn)
|
1310 |
+
|
1311 |
+
init_normal(self.transformer.ff_out, ff_out_std, ff_out_cutoff_factor)
|
1312 |
+
|
1313 |
+
# Let the blocks handle themselves.
|
1314 |
+
if self.config.block_group_size == 1:
|
1315 |
+
for block in self.transformer.blocks:
|
1316 |
+
block.reset_parameters()
|
1317 |
+
else:
|
1318 |
+
for block_group in self.transformer.block_groups:
|
1319 |
+
block_group.reset_parameters()
|
1320 |
+
|
1321 |
+
def get_alibi_attention_bias(self, seq_len: int, device: torch.device) -> torch.Tensor:
|
1322 |
+
if (alibi_bias := self.__cache.get("alibi_attention_bias")) is not None and alibi_bias.shape[
|
1323 |
+
-1
|
1324 |
+
] >= seq_len:
|
1325 |
+
if alibi_bias.device != device:
|
1326 |
+
alibi_bias = alibi_bias.to(device)
|
1327 |
+
self.__cache["alibi_attention_bias"] = alibi_bias
|
1328 |
+
return alibi_bias
|
1329 |
+
with torch.autocast(device.type, enabled=False):
|
1330 |
+
alibi_bias = alibi_attention_bias(seq_len, self.config, device)
|
1331 |
+
self.__cache["alibi_attention_bias"] = alibi_bias
|
1332 |
+
return alibi_bias
|
1333 |
+
|
1334 |
+
def forward(
|
1335 |
+
self,
|
1336 |
+
input_ids: torch.LongTensor,
|
1337 |
+
input_embeddings: Optional[torch.FloatTensor] = None,
|
1338 |
+
attention_mask: Optional[torch.Tensor] = None,
|
1339 |
+
attention_bias: Optional[torch.Tensor] = None,
|
1340 |
+
past_key_values: Optional[Sequence[Tuple[torch.Tensor, torch.Tensor]]] = None,
|
1341 |
+
use_cache: bool = False,
|
1342 |
+
last_logits_only: bool = False,
|
1343 |
+
output_hidden_states: Optional[bool] = None,
|
1344 |
+
doc_lens: Optional[torch.Tensor] = None,
|
1345 |
+
max_doc_lens: Optional[Sequence[int]] = None,
|
1346 |
+
) -> OLMoOutput:
|
1347 |
+
"""
|
1348 |
+
:param input_ids: A tensor of shape `(batch_size, seq_len)`.
|
1349 |
+
:param input_embeddings: A tensor of shape `(batch_size, seq_len, d_model)` with input
|
1350 |
+
embeddings. When provided, it is treated as the output of the input embedding layer.
|
1351 |
+
:param attention_mask: A tensor of shape `(batch_size, seq_len)` that indicates
|
1352 |
+
which input IDs are masked. A `1` value in the mask means that
|
1353 |
+
the corresponding input ID should *not* be ignored. A `0` means
|
1354 |
+
that the corresponding input ID is masked.
|
1355 |
+
|
1356 |
+
This has the same meaning as the `attention_mask` in HuggingFace's `transformers`
|
1357 |
+
library.
|
1358 |
+
:param attention_bias: A tensor of shape `(batch_size, 1, seq_len, seq_len)`,
|
1359 |
+
`(1, 1, seq_len, seq_len)`, or `(seq_len, seq_len)`. This is used
|
1360 |
+
to introduce causal or other biases.
|
1361 |
+
|
1362 |
+
If the tensor is a bool or byte tensor, a `True` or `1` at `attention_bias[:, :, i, j]`
|
1363 |
+
indicates that the i-th element in the sequence is allowed to attend to the j-th
|
1364 |
+
element in the sequence.
|
1365 |
+
|
1366 |
+
If the tensor is a float tensor, it will just be added to the attention
|
1367 |
+
scores before the softmax.
|
1368 |
+
|
1369 |
+
The default is causal, which corresponds to a lower-diagonal byte matrix of ones.
|
1370 |
+
:param past_key_values: Pre-computed keys and values for each attention block.
|
1371 |
+
Can be used to speed up sequential decoding. The `input_ids` which have
|
1372 |
+
their past given to this model should not be passed as `input_ids` as they have already been computed.
|
1373 |
+
:param use_cache: If `True`, return key and value tensors for each block.
|
1374 |
+
:param last_logits_only: If `True`, only compute the logits for the last token of each sequence.
|
1375 |
+
This can speed up decoding when you only care about the next token.
|
1376 |
+
:param doc_lens: Document lengths to use in attention for intra-document masking.
|
1377 |
+
Shape `(batch_size, max_docs)`.
|
1378 |
+
:param max_doc_lens: Maximum document length for each instance in the batch.
|
1379 |
+
"""
|
1380 |
+
output_hidden_states = output_hidden_states if output_hidden_states is not None else False
|
1381 |
+
|
1382 |
+
if past_key_values:
|
1383 |
+
assert len(past_key_values) == self.config.n_layers
|
1384 |
+
|
1385 |
+
batch_size, seq_len = input_ids.size() if input_embeddings is None else input_embeddings.size()[:2]
|
1386 |
+
if past_key_values is None:
|
1387 |
+
past_length = 0
|
1388 |
+
else:
|
1389 |
+
past_length = past_key_values[0][0].size(-2)
|
1390 |
+
|
1391 |
+
max_doc_len: Optional[int] = None
|
1392 |
+
cu_doc_lens: Optional[torch.Tensor] = None
|
1393 |
+
if doc_lens is not None and max_doc_lens is not None:
|
1394 |
+
max_doc_len = max(max_doc_lens)
|
1395 |
+
cu_doc_lens = get_cumulative_document_lengths(doc_lens)
|
1396 |
+
|
1397 |
+
# Get embeddings of input.
|
1398 |
+
# shape: (batch_size, seq_len, d_model)
|
1399 |
+
x = self.transformer.wte(input_ids) if input_embeddings is None else input_embeddings # type: ignore
|
1400 |
+
|
1401 |
+
# Apply embedding layer norm.
|
1402 |
+
if self.config.embedding_layer_norm:
|
1403 |
+
x = self.transformer.emb_norm(x)
|
1404 |
+
|
1405 |
+
if not (self.config.alibi or self.config.rope):
|
1406 |
+
# Get positional embeddings.
|
1407 |
+
# shape: (1, seq_len)
|
1408 |
+
pos = torch.arange(past_length, past_length + seq_len, dtype=torch.long, device=x.device).unsqueeze(0)
|
1409 |
+
# shape: (1, seq_len, d_model)
|
1410 |
+
pos_emb = self.transformer.wpe(pos) # type: ignore
|
1411 |
+
x = pos_emb + x
|
1412 |
+
|
1413 |
+
# Apply dropout.
|
1414 |
+
# shape: (batch_size, seq_len, d_model)
|
1415 |
+
x = self.transformer.emb_drop(x) # type: ignore
|
1416 |
+
|
1417 |
+
# Transform the attention mask into what the blocks expect.
|
1418 |
+
if attention_mask is not None:
|
1419 |
+
# shape: (batch_size, 1, 1, seq_len)
|
1420 |
+
attention_mask = attention_mask.to(dtype=torch.float).view(batch_size, -1)[:, None, None, :]
|
1421 |
+
attention_mask = (1.0 - attention_mask) * torch.finfo(attention_mask.dtype).min
|
1422 |
+
|
1423 |
+
# Merge attention mask with attention bias.
|
1424 |
+
if (
|
1425 |
+
attention_bias is not None
|
1426 |
+
or attention_mask is not None
|
1427 |
+
or self.config.alibi
|
1428 |
+
# NOTE (epwalsh): we need to initialize the attn bias in order for attn to work properly
|
1429 |
+
# with key+value cache. Otherwise `F.scaled_dot_product_attention()` doesn't seem to compute
|
1430 |
+
# scores correctly.
|
1431 |
+
or past_key_values is not None
|
1432 |
+
):
|
1433 |
+
if attention_bias is None and self.config.alibi:
|
1434 |
+
attention_bias = get_causal_attention_bias(
|
1435 |
+
self.__cache, past_length + seq_len, x.device
|
1436 |
+
) + self.get_alibi_attention_bias(past_length + seq_len, x.device)
|
1437 |
+
elif attention_bias is None:
|
1438 |
+
attention_bias = get_causal_attention_bias(self.__cache, past_length + seq_len, x.device)
|
1439 |
+
elif attention_bias.dtype in (torch.int8, torch.bool):
|
1440 |
+
attention_bias = attention_bias.to(dtype=torch.float)
|
1441 |
+
attention_bias.masked_fill_(attention_bias == 0.0, torch.finfo(attention_bias.dtype).min)
|
1442 |
+
|
1443 |
+
# Transform to the right shape and data type.
|
1444 |
+
mask_len = seq_len
|
1445 |
+
if attention_mask is not None:
|
1446 |
+
mask_len = attention_mask.shape[-1]
|
1447 |
+
elif past_key_values is not None:
|
1448 |
+
mask_len = past_key_values[0][0].shape[-2] + seq_len
|
1449 |
+
attention_bias = attention_bias[:, :, :mask_len, :mask_len].to(dtype=torch.float)
|
1450 |
+
|
1451 |
+
# Add in the masking bias.
|
1452 |
+
if attention_mask is not None:
|
1453 |
+
attention_bias = attention_bias + attention_mask
|
1454 |
+
# Might get -infs after adding attention mask, since dtype.min + dtype.min = -inf.
|
1455 |
+
# `F.scaled_dot_product_attention()` doesn't handle -inf like you'd expect, instead
|
1456 |
+
# it can produce NaNs.
|
1457 |
+
ensure_finite_(attention_bias, check_neg_inf=True, check_pos_inf=False)
|
1458 |
+
|
1459 |
+
attn_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = [] if use_cache else None
|
1460 |
+
|
1461 |
+
# decoder layers
|
1462 |
+
all_hidden_states = []
|
1463 |
+
|
1464 |
+
# Apply blocks one-by-one.
|
1465 |
+
if self.config.block_group_size == 1:
|
1466 |
+
for block_idx, block in enumerate(self.transformer.blocks):
|
1467 |
+
if output_hidden_states:
|
1468 |
+
# add hidden states
|
1469 |
+
all_hidden_states.append(x)
|
1470 |
+
|
1471 |
+
layer_past = None if past_key_values is None else past_key_values[block_idx]
|
1472 |
+
if should_checkpoint_block(self.activation_checkpointing_strategy, block_idx):
|
1473 |
+
# shape: (batch_size, seq_len, d_model)
|
1474 |
+
x, cache = self._activation_checkpoint_fn(
|
1475 |
+
block,
|
1476 |
+
x,
|
1477 |
+
attention_bias=attention_bias,
|
1478 |
+
layer_past=layer_past,
|
1479 |
+
use_cache=use_cache,
|
1480 |
+
max_doc_len=max_doc_len,
|
1481 |
+
cu_doc_lens=cu_doc_lens,
|
1482 |
+
)
|
1483 |
+
else:
|
1484 |
+
# shape: (batch_size, seq_len, d_model)
|
1485 |
+
x, cache = block(
|
1486 |
+
x,
|
1487 |
+
attention_bias=attention_bias,
|
1488 |
+
layer_past=layer_past,
|
1489 |
+
use_cache=use_cache,
|
1490 |
+
max_doc_len=max_doc_len,
|
1491 |
+
cu_doc_lens=cu_doc_lens,
|
1492 |
+
)
|
1493 |
+
|
1494 |
+
if attn_key_values is not None:
|
1495 |
+
assert cache is not None
|
1496 |
+
attn_key_values.append(cache)
|
1497 |
+
else:
|
1498 |
+
for group_idx, block_group in enumerate(self.transformer.block_groups):
|
1499 |
+
if output_hidden_states:
|
1500 |
+
# add hidden states
|
1501 |
+
all_hidden_states.append(x)
|
1502 |
+
|
1503 |
+
layers_past = (
|
1504 |
+
None
|
1505 |
+
if past_key_values is None
|
1506 |
+
else past_key_values[
|
1507 |
+
group_idx * self.config.block_group_size : (group_idx + 1) * self.config.block_group_size
|
1508 |
+
]
|
1509 |
+
)
|
1510 |
+
x, cache = block_group(
|
1511 |
+
x,
|
1512 |
+
attention_bias=attention_bias,
|
1513 |
+
layers_past=layers_past,
|
1514 |
+
use_cache=use_cache,
|
1515 |
+
max_doc_len=max_doc_len,
|
1516 |
+
cu_doc_lens=cu_doc_lens,
|
1517 |
+
)
|
1518 |
+
if attn_key_values is not None:
|
1519 |
+
assert cache is not None
|
1520 |
+
attn_key_values.extend(cache)
|
1521 |
+
|
1522 |
+
if last_logits_only:
|
1523 |
+
# shape: (batch_size, 1, d_model)
|
1524 |
+
x = x[:, -1, :].unsqueeze(1)
|
1525 |
+
|
1526 |
+
# Apply final layer norm.
|
1527 |
+
# shape: (batch_size, seq_len or 1, d_model)
|
1528 |
+
x = self.transformer.ln_f(x) # type: ignore
|
1529 |
+
if output_hidden_states:
|
1530 |
+
# add final hidden state post-final-layernorm, following HuggingFace's convention
|
1531 |
+
all_hidden_states.append(x)
|
1532 |
+
|
1533 |
+
# Get logits.
|
1534 |
+
# shape: (batch_size, seq_len or 1, vocab_size)
|
1535 |
+
if self.config.weight_tying:
|
1536 |
+
logits = F.linear(x, self.transformer.wte.weight, None) # type: ignore
|
1537 |
+
else:
|
1538 |
+
logits = self.transformer.ff_out(x) # type: ignore
|
1539 |
+
if self.config.scale_logits:
|
1540 |
+
logits.mul_(1 / math.sqrt(self.config.d_model))
|
1541 |
+
|
1542 |
+
return OLMoOutput(
|
1543 |
+
logits=logits,
|
1544 |
+
attn_key_values=attn_key_values,
|
1545 |
+
hidden_states=tuple(all_hidden_states) if output_hidden_states else None,
|
1546 |
+
)
|
1547 |
+
|
1548 |
+
def get_fsdp_wrap_policy(self, wrap_strategy: Optional[FSDPWrapStrategy] = None):
|
1549 |
+
if wrap_strategy is None:
|
1550 |
+
return None
|
1551 |
+
|
1552 |
+
# The 'recurse' mode for the wrap function does not behave like you'd expect.
|
1553 |
+
# Even if we return False, it may still recurse because PyTorch does what it wants,
|
1554 |
+
# not what you want. This causes issues when, for example, we want to wrap 'ff_out' (a linear layer)
|
1555 |
+
# but not other linear layers within a block.
|
1556 |
+
# So we have to explicitly tell PyTorch which linear layers to wrap, and we also just
|
1557 |
+
# return True in 'recurse' mode for simplicity.
|
1558 |
+
size_based_module_to_wrap = {self.transformer.wte}
|
1559 |
+
if hasattr(self.transformer, "ff_out"):
|
1560 |
+
size_based_module_to_wrap.add(self.transformer.ff_out)
|
1561 |
+
|
1562 |
+
if wrap_strategy == FSDPWrapStrategy.by_block:
|
1563 |
+
|
1564 |
+
def fsdp_wrap_fn(module, recurse: bool = True, nonwrapped_numel: int = 0):
|
1565 |
+
del nonwrapped_numel
|
1566 |
+
wrap = isinstance(module, OLMoBlock)
|
1567 |
+
if recurse:
|
1568 |
+
return True
|
1569 |
+
else:
|
1570 |
+
return wrap
|
1571 |
+
|
1572 |
+
return fsdp_wrap_fn
|
1573 |
+
elif wrap_strategy == FSDPWrapStrategy.by_block_and_size:
|
1574 |
+
|
1575 |
+
def fsdp_wrap_fn(module, recurse: bool = True, nonwrapped_numel: int = 0):
|
1576 |
+
del nonwrapped_numel
|
1577 |
+
wrap = isinstance(module, (OLMoBlock,)) or module in size_based_module_to_wrap
|
1578 |
+
if recurse:
|
1579 |
+
return True
|
1580 |
+
else:
|
1581 |
+
return wrap
|
1582 |
+
|
1583 |
+
return fsdp_wrap_fn
|
1584 |
+
elif wrap_strategy == FSDPWrapStrategy.by_block_group:
|
1585 |
+
if self.config.block_group_size <= 1:
|
1586 |
+
raise OLMoConfigurationError(
|
1587 |
+
"'by_block_group' FSDP wrapping strategy requires block group size greater than 1"
|
1588 |
+
)
|
1589 |
+
|
1590 |
+
def fsdp_wrap_fn(module, recurse: bool = True, nonwrapped_numel: int = 0):
|
1591 |
+
del nonwrapped_numel
|
1592 |
+
wrap = isinstance(module, OLMoBlockGroup)
|
1593 |
+
if recurse:
|
1594 |
+
return True
|
1595 |
+
else:
|
1596 |
+
return wrap
|
1597 |
+
|
1598 |
+
return fsdp_wrap_fn
|
1599 |
+
elif wrap_strategy == FSDPWrapStrategy.by_block_group_and_size:
|
1600 |
+
if self.config.block_group_size <= 1:
|
1601 |
+
raise OLMoConfigurationError(
|
1602 |
+
"'by_block_group_and_size' FSDP wrapping strategy requires block group size greater than 1"
|
1603 |
+
)
|
1604 |
+
|
1605 |
+
def fsdp_wrap_fn(module, recurse: bool = True, nonwrapped_numel: int = 0):
|
1606 |
+
del nonwrapped_numel
|
1607 |
+
wrap = isinstance(module, (OLMoBlockGroup,)) or module in size_based_module_to_wrap
|
1608 |
+
if recurse:
|
1609 |
+
return True
|
1610 |
+
else:
|
1611 |
+
return wrap
|
1612 |
+
|
1613 |
+
return fsdp_wrap_fn
|
1614 |
+
elif wrap_strategy == FSDPWrapStrategy.size_based:
|
1615 |
+
from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy
|
1616 |
+
|
1617 |
+
return size_based_auto_wrap_policy
|
1618 |
+
elif wrap_strategy in {
|
1619 |
+
FSDPWrapStrategy.one_in_two,
|
1620 |
+
FSDPWrapStrategy.one_in_three,
|
1621 |
+
FSDPWrapStrategy.one_in_four,
|
1622 |
+
FSDPWrapStrategy.one_in_five,
|
1623 |
+
}:
|
1624 |
+
c = {
|
1625 |
+
FSDPWrapStrategy.one_in_two: 2,
|
1626 |
+
FSDPWrapStrategy.one_in_three: 3,
|
1627 |
+
FSDPWrapStrategy.one_in_four: 4,
|
1628 |
+
FSDPWrapStrategy.one_in_five: 5,
|
1629 |
+
}[wrap_strategy]
|
1630 |
+
|
1631 |
+
def fsdp_wrap_fn(module, recurse: bool = True, nonwrapped_numel: int = 0):
|
1632 |
+
del nonwrapped_numel
|
1633 |
+
wrap = isinstance(module, OLMoBlock) and module.layer_id % c == 0
|
1634 |
+
if recurse:
|
1635 |
+
return True
|
1636 |
+
else:
|
1637 |
+
return wrap
|
1638 |
+
|
1639 |
+
return fsdp_wrap_fn
|
1640 |
+
else:
|
1641 |
+
raise NotImplementedError(wrap_strategy)
|
1642 |
+
|
1643 |
+
def num_params(self, include_embedding: bool = True) -> int:
|
1644 |
+
"""
|
1645 |
+
Get the total number of parameters.
|
1646 |
+
"""
|
1647 |
+
params = (np for np in self.named_parameters())
|
1648 |
+
if not include_embedding:
|
1649 |
+
params = filter( # type: ignore
|
1650 |
+
lambda np: ".wte." not in np[0] and ".wpe." not in np[0],
|
1651 |
+
params,
|
1652 |
+
)
|
1653 |
+
return sum(p.numel() for _, p in params)
|
1654 |
+
|
1655 |
+
@property
|
1656 |
+
def num_fwd_flops(self):
|
1657 |
+
if self.__num_fwd_flops:
|
1658 |
+
return self.__num_fwd_flops
|
1659 |
+
|
1660 |
+
# embedding table is just a lookup in the forward pass
|
1661 |
+
n_params = self.num_params(include_embedding=False)
|
1662 |
+
# the number of parameters is approximately the number of multiply-accumulates (MAC) in the network
|
1663 |
+
# each MAC has 2 FLOPs - we multiply by 2 ie 2 * n_param
|
1664 |
+
# this gets us FLOPs / token
|
1665 |
+
params_flops_per_token = 2 * n_params
|
1666 |
+
# there are 2 FLOPS per mac; there is A=Q*K^T and out=A*V ops (ie mult by 2)
|
1667 |
+
attn_flops_per_token = (
|
1668 |
+
self.config.n_layers * 2 * 2 * (self.config.d_model * self.config.max_sequence_length)
|
1669 |
+
)
|
1670 |
+
self.__num_fwd_flops = params_flops_per_token + attn_flops_per_token
|
1671 |
+
return self.__num_fwd_flops
|
1672 |
+
|
1673 |
+
@property
|
1674 |
+
def num_bck_flops(self):
|
1675 |
+
if self.__num_bck_flops:
|
1676 |
+
return self.__num_bck_flops
|
1677 |
+
|
1678 |
+
n_params = self.num_params()
|
1679 |
+
params_flops_per_token = 4 * n_params
|
1680 |
+
attn_flops_per_token = self.config.n_layers * 8 * (self.config.d_model * self.config.max_sequence_length)
|
1681 |
+
self.__num_bck_flops = params_flops_per_token + attn_flops_per_token
|
1682 |
+
return self.__num_bck_flops
|
1683 |
+
|
1684 |
+
def generate(
|
1685 |
+
self,
|
1686 |
+
input_ids: torch.LongTensor,
|
1687 |
+
attention_mask: Optional[torch.Tensor] = None,
|
1688 |
+
attention_bias: Optional[torch.Tensor] = None,
|
1689 |
+
max_steps: int = 10,
|
1690 |
+
beam_size: int = 1,
|
1691 |
+
per_node_beam_size: Optional[int] = None,
|
1692 |
+
sampler: Optional[Sampler] = None,
|
1693 |
+
min_steps: Optional[int] = None,
|
1694 |
+
final_sequence_scorer: Optional[FinalSequenceScorer] = None,
|
1695 |
+
constraints: Optional[List[Constraint]] = None,
|
1696 |
+
) -> OLMoGenerateOutput:
|
1697 |
+
"""
|
1698 |
+
Generate token IDs using beam search.
|
1699 |
+
|
1700 |
+
Note that by default ``beam_size`` is set to 1, which is greedy decoding.
|
1701 |
+
|
1702 |
+
:param input_ids: A tensor of shape `(batch_size, seq_len)`.
|
1703 |
+
:param attention_mask: A optional tensor of shape `(batch_size, seq_len)`, the same
|
1704 |
+
as for the forward method.
|
1705 |
+
:param attention_bias: A tensor of shape
|
1706 |
+
`(batch_size, 1, seq_len + tokens_to_generate, seq_len + tokens_to_generate)`,
|
1707 |
+
the same as for the forward method except only one shape is excepted here.
|
1708 |
+
|
1709 |
+
For an explanation of the other arguments, see :class:`BeamSearch`.
|
1710 |
+
"""
|
1711 |
+
beam_search = BeamSearch(
|
1712 |
+
self.config.eos_token_id,
|
1713 |
+
max_steps=max_steps,
|
1714 |
+
beam_size=beam_size,
|
1715 |
+
per_node_beam_size=per_node_beam_size,
|
1716 |
+
sampler=sampler,
|
1717 |
+
min_steps=min_steps,
|
1718 |
+
final_sequence_scorer=final_sequence_scorer,
|
1719 |
+
constraints=constraints,
|
1720 |
+
)
|
1721 |
+
|
1722 |
+
# Validate inputs.
|
1723 |
+
batch_size, seq_len = input_ids.shape
|
1724 |
+
if attention_mask is not None:
|
1725 |
+
assert attention_mask.shape == (batch_size, seq_len)
|
1726 |
+
if attention_bias is not None:
|
1727 |
+
assert len(attention_bias.shape) == 4
|
1728 |
+
assert attention_bias.shape[:2] == (batch_size, 1)
|
1729 |
+
assert (
|
1730 |
+
seq_len + beam_search.max_steps
|
1731 |
+
<= attention_bias.shape[2]
|
1732 |
+
== attention_bias.shape[3]
|
1733 |
+
<= self.config.max_sequence_length
|
1734 |
+
)
|
1735 |
+
|
1736 |
+
tokens_generated = 0
|
1737 |
+
|
1738 |
+
def flatten_past_key_values(
|
1739 |
+
past_key_values: List[Tuple[torch.Tensor, torch.Tensor]],
|
1740 |
+
) -> Dict[str, torch.Tensor]:
|
1741 |
+
out = {}
|
1742 |
+
for i, (key, value) in enumerate(past_key_values):
|
1743 |
+
out[f"past_key_{i}"] = key
|
1744 |
+
out[f"past_value_{i}"] = value
|
1745 |
+
return out
|
1746 |
+
|
1747 |
+
def unflatten_past_key_values(
|
1748 |
+
past_key_values: Dict[str, torch.Tensor],
|
1749 |
+
) -> List[Tuple[torch.Tensor, torch.Tensor]]:
|
1750 |
+
out = []
|
1751 |
+
for i in range(self.config.n_layers):
|
1752 |
+
past_key = past_key_values[f"past_key_{i}"]
|
1753 |
+
past_value = past_key_values[f"past_value_{i}"]
|
1754 |
+
out.append((past_key, past_value))
|
1755 |
+
return out
|
1756 |
+
|
1757 |
+
def step(
|
1758 |
+
last_predictions: torch.Tensor, state: dict[str, torch.Tensor]
|
1759 |
+
) -> tuple[torch.Tensor, dict[str, torch.Tensor]]:
|
1760 |
+
nonlocal tokens_generated
|
1761 |
+
|
1762 |
+
attention_mask = state.get("attention_mask")
|
1763 |
+
attention_bias = state.get("attention_bias")
|
1764 |
+
|
1765 |
+
if tokens_generated > 0:
|
1766 |
+
past_key_values = unflatten_past_key_values(state)
|
1767 |
+
input_ids = last_predictions.unsqueeze(1)
|
1768 |
+
if attention_mask is not None:
|
1769 |
+
group_size = input_ids.shape[0]
|
1770 |
+
attention_mask = torch.cat((attention_mask, attention_mask.new_ones((group_size, 1))), dim=-1)
|
1771 |
+
else:
|
1772 |
+
past_key_values = None
|
1773 |
+
input_ids = state["input_ids"]
|
1774 |
+
|
1775 |
+
tokens_generated += 1
|
1776 |
+
|
1777 |
+
# Run forward pass of model to get logits, then normalize to get log probs.
|
1778 |
+
output = self(
|
1779 |
+
input_ids,
|
1780 |
+
attention_mask=attention_mask,
|
1781 |
+
attention_bias=attention_bias,
|
1782 |
+
past_key_values=past_key_values,
|
1783 |
+
use_cache=True,
|
1784 |
+
last_logits_only=True,
|
1785 |
+
)
|
1786 |
+
log_probs = F.log_softmax(output.logits[:, -1, :], dim=-1)
|
1787 |
+
|
1788 |
+
# Create new state.
|
1789 |
+
state = flatten_past_key_values(output.attn_key_values)
|
1790 |
+
if attention_mask is not None:
|
1791 |
+
state["attention_mask"] = attention_mask
|
1792 |
+
if attention_bias is not None:
|
1793 |
+
state["attention_bias"] = attention_bias
|
1794 |
+
|
1795 |
+
return log_probs, state
|
1796 |
+
|
1797 |
+
initial_preds = input_ids.new_zeros((batch_size,)) # This is arbitrary, we won't use this.
|
1798 |
+
state: dict[str, torch.Tensor] = {"input_ids": input_ids}
|
1799 |
+
if attention_mask is not None:
|
1800 |
+
state["attention_mask"] = attention_mask
|
1801 |
+
if attention_bias is not None:
|
1802 |
+
state["attention_bias"] = attention_bias
|
1803 |
+
with torch.no_grad():
|
1804 |
+
token_ids, scores = beam_search.search(initial_preds, state, step)
|
1805 |
+
|
1806 |
+
return OLMoGenerateOutput(
|
1807 |
+
token_ids=token_ids, # type: ignore[arg-type]
|
1808 |
+
scores=scores, # type: ignore[arg-type]
|
1809 |
+
)
|
1810 |
+
|
1811 |
+
@classmethod
|
1812 |
+
def from_checkpoint(
|
1813 |
+
cls, checkpoint_dir: PathOrStr, device: str = "cpu", checkpoint_type: Optional[CheckpointType] = None
|
1814 |
+
) -> OLMo:
|
1815 |
+
"""
|
1816 |
+
Load an OLMo model from a checkpoint.
|
1817 |
+
"""
|
1818 |
+
from .util import resource_path
|
1819 |
+
|
1820 |
+
# Guess checkpoint type.
|
1821 |
+
if checkpoint_type is None:
|
1822 |
+
try:
|
1823 |
+
if resource_path(checkpoint_dir, "model.pt").is_file():
|
1824 |
+
checkpoint_type = CheckpointType.unsharded
|
1825 |
+
else:
|
1826 |
+
checkpoint_type = CheckpointType.sharded
|
1827 |
+
except FileNotFoundError:
|
1828 |
+
checkpoint_type = CheckpointType.sharded
|
1829 |
+
|
1830 |
+
# Load config.
|
1831 |
+
config_path = resource_path(checkpoint_dir, "config.yaml")
|
1832 |
+
model_config = ModelConfig.load(config_path, key="model", validate_paths=False)
|
1833 |
+
|
1834 |
+
if checkpoint_type == CheckpointType.unsharded:
|
1835 |
+
# Initialize model (always on CPU to start with so we don't run out of GPU memory).
|
1836 |
+
model_config.init_device = "cpu"
|
1837 |
+
model = OLMo(model_config)
|
1838 |
+
|
1839 |
+
# Load state dict directly to target device.
|
1840 |
+
state_dict_path = resource_path(checkpoint_dir, "model.pt")
|
1841 |
+
state_dict = torch.load(state_dict_path, map_location="cpu")
|
1842 |
+
model.load_state_dict(model._make_state_dict_compatible(state_dict)[0])
|
1843 |
+
model = model.to(torch.device(device))
|
1844 |
+
else:
|
1845 |
+
train_config = TrainConfig.load(config_path)
|
1846 |
+
if train_config.sharded_checkpointer == ShardedCheckpointerType.olmo_core:
|
1847 |
+
from olmo_core.distributed.checkpoint import ( # type: ignore
|
1848 |
+
load_model_and_optim_state,
|
1849 |
+
)
|
1850 |
+
|
1851 |
+
model_config.init_device = device
|
1852 |
+
model = OLMo(model_config)
|
1853 |
+
load_model_and_optim_state(checkpoint_dir, model)
|
1854 |
+
else:
|
1855 |
+
# train_config.sharded_checkpointer == ShardedCheckpointerType.torch_new
|
1856 |
+
from .checkpoint import load_model_state
|
1857 |
+
|
1858 |
+
# Initialize model on target device. In this case the state dict is loaded in-place
|
1859 |
+
# so it's not necessary to start on CPU if the target device is a GPU.
|
1860 |
+
model_config.init_device = device
|
1861 |
+
model = OLMo(model_config)
|
1862 |
+
|
1863 |
+
# Load state dict in place.
|
1864 |
+
load_model_state(checkpoint_dir, model)
|
1865 |
+
|
1866 |
+
return model.eval()
|
1867 |
+
|
1868 |
+
def _make_state_dict_compatible(
|
1869 |
+
self, state_dict: Dict[str, torch.Tensor]
|
1870 |
+
) -> Tuple[Dict[str, torch.Tensor], Dict[str, Set[str]]]:
|
1871 |
+
"""
|
1872 |
+
Handles some cases where the state dict is valid yet may need to be transformed in order to
|
1873 |
+
be loaded.
|
1874 |
+
|
1875 |
+
This modifies the state dict in-place and also returns it, along with a mapping of original key
|
1876 |
+
names to new key names in cases where the keys were simply renamed. That mapping can be used
|
1877 |
+
to make a corresponding optimizer state dict compatible as well.
|
1878 |
+
"""
|
1879 |
+
import re
|
1880 |
+
from fnmatch import fnmatch
|
1881 |
+
|
1882 |
+
new_keys_to_og_keys: Dict[str, str] = {}
|
1883 |
+
|
1884 |
+
# Remove "_fsdp_wrapped_module." prefix from all keys. We don't want this prefix when the model is
|
1885 |
+
# not wrapped in FSDP. And when the model is wrapped in FSDP, loading this state dict will still work
|
1886 |
+
# fine without the prefixes. This also simplifies the other steps below.
|
1887 |
+
for key in list(state_dict.keys()):
|
1888 |
+
state_dict[(new_key := key.replace("_fsdp_wrapped_module.", ""))] = state_dict.pop(key)
|
1889 |
+
new_keys_to_og_keys[new_key] = key
|
1890 |
+
|
1891 |
+
# For backwards compatibility prior to fixing https://github.com/allenai/LLM/issues/222
|
1892 |
+
if self.config.block_type == BlockType.sequential:
|
1893 |
+
for key in list(state_dict.keys()):
|
1894 |
+
if fnmatch(key, "transformer.*.norm.weight"):
|
1895 |
+
tensor = state_dict.pop(key)
|
1896 |
+
state_dict[(new_key := key.replace("norm.weight", "attn_norm.weight"))] = tensor
|
1897 |
+
new_keys_to_og_keys[new_key] = new_keys_to_og_keys[key]
|
1898 |
+
state_dict[(new_key := key.replace("norm.weight", "ff_norm.weight"))] = tensor.clone()
|
1899 |
+
new_keys_to_og_keys[new_key] = new_keys_to_og_keys[key]
|
1900 |
+
del new_keys_to_og_keys[key]
|
1901 |
+
elif fnmatch(key, "transformer.*.norm.bias"):
|
1902 |
+
tensor = state_dict.pop(key)
|
1903 |
+
state_dict[(new_key := key.replace("norm.bias", "attn_norm.bias"))] = tensor
|
1904 |
+
new_keys_to_og_keys[new_key] = new_keys_to_og_keys[key]
|
1905 |
+
state_dict[(new_key := key.replace("norm.bias", "ff_norm.bias"))] = tensor.clone()
|
1906 |
+
new_keys_to_og_keys[new_key] = new_keys_to_og_keys[key]
|
1907 |
+
del new_keys_to_og_keys[key]
|
1908 |
+
|
1909 |
+
# For loading a state dict that was saved with a different `block_group_size`.
|
1910 |
+
if "transformer.block_groups.0.0.attn_out.weight" in state_dict.keys():
|
1911 |
+
state_dict_block_group_size = len(
|
1912 |
+
[k for k in state_dict.keys() if fnmatch(k, "transformer.block_groups.0.*.attn_out.weight")]
|
1913 |
+
)
|
1914 |
+
else:
|
1915 |
+
state_dict_block_group_size = 1
|
1916 |
+
if self.config.block_group_size != state_dict_block_group_size:
|
1917 |
+
log.info(
|
1918 |
+
f"Regrouping state dict blocks from group size {state_dict_block_group_size} to "
|
1919 |
+
f"group size {self.config.block_group_size}"
|
1920 |
+
)
|
1921 |
+
# For simplicity we're first going to flatten out the block groups in the state dict (if necessary)
|
1922 |
+
# and then (re-)group them into the right block sizes.
|
1923 |
+
if state_dict_block_group_size > 1:
|
1924 |
+
for key in list(state_dict.keys()):
|
1925 |
+
if (m := re.match(r"transformer.block_groups\.(\d+)\.(\d+)\..*", key)) is not None:
|
1926 |
+
group_idx, group_block_idx = int(m.group(1)), int(m.group(2))
|
1927 |
+
block_idx = (group_idx * state_dict_block_group_size) + group_block_idx
|
1928 |
+
state_dict[
|
1929 |
+
(
|
1930 |
+
new_key := key.replace(
|
1931 |
+
f"block_groups.{group_idx}.{group_block_idx}.", f"blocks.{block_idx}."
|
1932 |
+
)
|
1933 |
+
)
|
1934 |
+
] = state_dict.pop(key)
|
1935 |
+
new_keys_to_og_keys[new_key] = new_keys_to_og_keys.pop(key)
|
1936 |
+
|
1937 |
+
if self.config.block_group_size > 1:
|
1938 |
+
# Group the state dict blocks into the right block size.
|
1939 |
+
for key in list(state_dict.keys()):
|
1940 |
+
if (m := re.match(r"transformer.blocks\.(\d+)\..*", key)) is not None:
|
1941 |
+
block_idx = int(m.group(1))
|
1942 |
+
group_idx, group_block_idx = (
|
1943 |
+
block_idx // self.config.block_group_size,
|
1944 |
+
block_idx % self.config.block_group_size,
|
1945 |
+
)
|
1946 |
+
state_dict[
|
1947 |
+
(
|
1948 |
+
new_key := key.replace(
|
1949 |
+
f"blocks.{block_idx}.", f"block_groups.{group_idx}.{group_block_idx}."
|
1950 |
+
)
|
1951 |
+
)
|
1952 |
+
] = state_dict.pop(key)
|
1953 |
+
new_keys_to_og_keys[new_key] = new_keys_to_og_keys.pop(key)
|
1954 |
+
|
1955 |
+
og_keys_to_new: Dict[str, Set[str]] = defaultdict(set)
|
1956 |
+
for new_key, og_key in new_keys_to_og_keys.items():
|
1957 |
+
og_keys_to_new[og_key].add(new_key)
|
1958 |
+
|
1959 |
+
return state_dict, og_keys_to_new
|
modeling_fan.py
ADDED
@@ -0,0 +1,271 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
from dataclasses import fields
|
3 |
+
from typing import Callable, List, Optional, Tuple, Union
|
4 |
+
|
5 |
+
import torch
|
6 |
+
from transformers import PreTrainedModel
|
7 |
+
from transformers.cache_utils import Cache
|
8 |
+
from transformers.modeling_outputs import CausalLMOutputWithPast
|
9 |
+
from transformers.models.auto import AutoModelForCausalLM
|
10 |
+
|
11 |
+
from .config import ActivationCheckpointingStrategy, ModelConfig
|
12 |
+
from .model import OLMo
|
13 |
+
|
14 |
+
from .configuration_olmo import OLMoConfig
|
15 |
+
from typing import (
|
16 |
+
Callable,
|
17 |
+
Dict,
|
18 |
+
Iterable,
|
19 |
+
List,
|
20 |
+
NamedTuple,
|
21 |
+
Optional,
|
22 |
+
Sequence,
|
23 |
+
Set,
|
24 |
+
Tuple,
|
25 |
+
cast,
|
26 |
+
)
|
27 |
+
log = logging.getLogger(__name__)
|
28 |
+
|
29 |
+
|
30 |
+
def create_model_config_from_pretrained_config(config: OLMoConfig):
|
31 |
+
"""
|
32 |
+
Utility function
|
33 |
+
"""
|
34 |
+
|
35 |
+
kwargs = {}
|
36 |
+
for field in fields(ModelConfig):
|
37 |
+
kwargs[field.name] = getattr(config, field.name)
|
38 |
+
|
39 |
+
model_config = ModelConfig(**kwargs)
|
40 |
+
|
41 |
+
# Handle flash attention settings
|
42 |
+
if config._attn_implementation == "flash_attention_2":
|
43 |
+
model_config.flash_attention = True
|
44 |
+
elif config._attn_implementation in ("eager", "sdpa"):
|
45 |
+
model_config.flash_attention = False
|
46 |
+
else:
|
47 |
+
raise ValueError(f"Unexpected _attn_implementation {config._attn_implementation}")
|
48 |
+
|
49 |
+
return model_config
|
50 |
+
|
51 |
+
|
52 |
+
class OLMoForCausalLM(PreTrainedModel):
|
53 |
+
"""
|
54 |
+
Extremely barebones HF model wrapper.
|
55 |
+
"""
|
56 |
+
|
57 |
+
config_class = OLMoConfig
|
58 |
+
base_model_prefix = "model"
|
59 |
+
_no_split_modules = ["OLMoBlock"]
|
60 |
+
_supports_flash_attn_2 = True
|
61 |
+
_supports_sdpa = True
|
62 |
+
supports_gradient_checkpointing = True
|
63 |
+
|
64 |
+
def __init__(self, config: OLMoConfig, model: Optional[OLMo] = None, init_params: bool = False):
|
65 |
+
super().__init__(config)
|
66 |
+
self._gradient_checkpointing_func: Optional[Callable] = None
|
67 |
+
self._gradient_checkpointing = False
|
68 |
+
|
69 |
+
if not model:
|
70 |
+
model_config = create_model_config_from_pretrained_config(config)
|
71 |
+
# Initialize model (always on CPU to start with so we don't run out of GPU memory).
|
72 |
+
model_config.init_device = "cpu"
|
73 |
+
self.model = OLMo(model_config, init_params=init_params)
|
74 |
+
else:
|
75 |
+
self.model = model
|
76 |
+
|
77 |
+
@property
|
78 |
+
def gradient_checkpointing(self) -> bool:
|
79 |
+
return self._gradient_checkpointing
|
80 |
+
|
81 |
+
@gradient_checkpointing.setter
|
82 |
+
def gradient_checkpointing(self, enabled: bool):
|
83 |
+
if self._gradient_checkpointing == enabled:
|
84 |
+
return
|
85 |
+
|
86 |
+
# HF does not specify a way to pass checkpointing strategies, so we pick
|
87 |
+
# whole layer as our strategy. We can make this configurable later if needed.
|
88 |
+
checkpointing_strategy = ActivationCheckpointingStrategy.whole_layer if enabled else None
|
89 |
+
self.model.set_activation_checkpointing(
|
90 |
+
checkpointing_strategy, checkpoint_func=self._gradient_checkpointing_func
|
91 |
+
)
|
92 |
+
self._gradient_checkpointing = enabled
|
93 |
+
|
94 |
+
def forward(
|
95 |
+
self,
|
96 |
+
input_ids: torch.LongTensor,
|
97 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
98 |
+
attention_mask: Optional[torch.Tensor] = None,
|
99 |
+
attention_bias: Optional[torch.Tensor] = None,
|
100 |
+
# past_key_values: Optional[List[torch.FloatTensor]] = None,
|
101 |
+
past_key_values: Optional[Sequence[Tuple[torch.Tensor, torch.Tensor]]] = None,
|
102 |
+
labels: Optional[torch.LongTensor] = None,
|
103 |
+
use_cache: Optional[bool] = None,
|
104 |
+
output_attentions: Optional[bool] = None,
|
105 |
+
output_hidden_states: Optional[bool] = None,
|
106 |
+
return_dict: Optional[bool] = None,
|
107 |
+
cache_position: Optional[
|
108 |
+
Cache
|
109 |
+
] = None, # This is a hack mitigation of an issue in transformers `4.39.x` https://github.com/huggingface/transformers/issues/29426
|
110 |
+
) -> Union[Tuple, CausalLMOutputWithPast]:
|
111 |
+
if use_cache is None:
|
112 |
+
use_cache = self.config.use_cache
|
113 |
+
|
114 |
+
if output_attentions:
|
115 |
+
raise ValueError("output_attentions is not yet supported in OLMo")
|
116 |
+
|
117 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
118 |
+
|
119 |
+
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
120 |
+
outputs = self.model.forward(
|
121 |
+
input_ids=input_ids,
|
122 |
+
input_embeddings=inputs_embeds,
|
123 |
+
attention_mask=attention_mask,
|
124 |
+
attention_bias=attention_bias,
|
125 |
+
past_key_values=past_key_values,
|
126 |
+
use_cache=use_cache,
|
127 |
+
output_hidden_states=output_hidden_states,
|
128 |
+
)
|
129 |
+
|
130 |
+
logits = outputs.logits
|
131 |
+
hidden_states = outputs.hidden_states
|
132 |
+
|
133 |
+
loss = None
|
134 |
+
if labels is not None:
|
135 |
+
# Shift so that tokens < n predict n
|
136 |
+
shift_logits = logits[..., :-1, :].contiguous()
|
137 |
+
shift_labels = labels[..., 1:].contiguous()
|
138 |
+
# Flatten the tokens
|
139 |
+
loss_fct = torch.nn.CrossEntropyLoss()
|
140 |
+
shift_logits = shift_logits.view(-1, self.config.embedding_size)
|
141 |
+
shift_labels = shift_labels.view(-1)
|
142 |
+
# Enable model parallelism
|
143 |
+
shift_labels = shift_labels.to(shift_logits.device)
|
144 |
+
loss = loss_fct(shift_logits, shift_labels)
|
145 |
+
|
146 |
+
if not return_dict:
|
147 |
+
output = (logits,) + outputs[1:]
|
148 |
+
return (loss,) + output if loss is not None else output
|
149 |
+
|
150 |
+
return CausalLMOutputWithPast(
|
151 |
+
loss=loss,
|
152 |
+
logits=logits,
|
153 |
+
past_key_values=outputs.attn_key_values,
|
154 |
+
hidden_states=hidden_states,
|
155 |
+
)
|
156 |
+
|
157 |
+
def can_generate(self) -> bool:
|
158 |
+
return True
|
159 |
+
|
160 |
+
def prepare_inputs_for_generation(
|
161 |
+
self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple]] = None, **kwargs
|
162 |
+
):
|
163 |
+
if past_key_values:
|
164 |
+
# This is because we want the model to only process the last generated token.
|
165 |
+
input_ids = input_ids[:, -1:]
|
166 |
+
model_inputs = {"input_ids": input_ids, "past_key_values": past_key_values}
|
167 |
+
|
168 |
+
model_inputs.update(kwargs)
|
169 |
+
model_inputs["use_cache"] = kwargs.pop("use_cache", self.config.use_cache)
|
170 |
+
return model_inputs
|
171 |
+
|
172 |
+
# TODO: these are required to make the implementation complete.
|
173 |
+
# def resize_position_embeddings(self, new_num_position_embeddings: int):
|
174 |
+
# pass
|
175 |
+
#
|
176 |
+
# def get_position_embeddings(self) -> Union[nn.Embedding, Tuple[nn.Embedding]]:
|
177 |
+
# pass
|
178 |
+
#
|
179 |
+
# def _reorder_cache(self, past_key_values, beam_idx):
|
180 |
+
# pass
|
181 |
+
|
182 |
+
def get_input_embeddings(self) -> torch.nn.Module:
|
183 |
+
return self.model.transformer.wte
|
184 |
+
|
185 |
+
def set_input_embeddings(self, value: torch.nn.Module):
|
186 |
+
self.model.transformer.wte = value
|
187 |
+
|
188 |
+
def get_output_embeddings(self):
|
189 |
+
if self.config.weight_tying:
|
190 |
+
return self.model.transformer.wte
|
191 |
+
else:
|
192 |
+
return self.model.transformer.ff_out
|
193 |
+
|
194 |
+
def set_output_embeddings(self, value: torch.nn.Module):
|
195 |
+
if self.config.weight_tying:
|
196 |
+
self.model.transformer.wte = value
|
197 |
+
else:
|
198 |
+
self.model.transformer.ff_out = value
|
199 |
+
|
200 |
+
def tie_weights(self):
|
201 |
+
"""
|
202 |
+
This function is intentionally left as a no-op.
|
203 |
+
|
204 |
+
Weight tying is handled as follows:
|
205 |
+
- When the model is initialized, the `ff_out` layer is conditionally defined based on the `weight_tying` configuration.
|
206 |
+
See: `if not config.weight_tying: self.transformer.update(...)` in `olmo/model.py`.
|
207 |
+
- When computing logits, the `wte` weights are used directly if `weight_tying` is enabled.
|
208 |
+
See: `if self.config.weight_tying: logits = F.linear(x, self.transformer.wte.weight, None)` in the `forward` method.
|
209 |
+
|
210 |
+
Therefore, there is no need to explicitly tie the weights in this function.
|
211 |
+
"""
|
212 |
+
pass
|
213 |
+
|
214 |
+
def resize_token_embeddings(
|
215 |
+
self, new_num_tokens: Optional[int] = None, pad_to_multiple_of: Optional[int] = None
|
216 |
+
) -> torch.nn.Embedding:
|
217 |
+
"""
|
218 |
+
Resizes input token embeddings matrix of the model if `new_num_tokens != config.embedding_size`.
|
219 |
+
|
220 |
+
Takes care of tying weights embeddings afterwards if the model class has a `tie_weights()` method.
|
221 |
+
|
222 |
+
Arguments:
|
223 |
+
new_num_tokens (`int`, *optional*):
|
224 |
+
The new number of tokens in the embedding matrix. Increasing the size will add newly initialized
|
225 |
+
vectors at the end. Reducing the size will remove vectors from the end. If not provided or `None`, just
|
226 |
+
returns a pointer to the input tokens `torch.nn.Embedding` module of the model without doing anything.
|
227 |
+
pad_to_multiple_of (`int`, *optional*):
|
228 |
+
If set will pad the embedding matrix to a multiple of the provided value. If `new_num_tokens` is set to
|
229 |
+
`None` will just pad the embedding to a multiple of `pad_to_multiple_of`.
|
230 |
+
|
231 |
+
This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability
|
232 |
+
`>= 7.5` (Volta), or on TPUs which benefit from having sequence lengths be a multiple of 128. For more
|
233 |
+
details about this, or help on choosing the correct value for resizing, refer to this guide:
|
234 |
+
https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html#requirements-tc
|
235 |
+
|
236 |
+
Return:
|
237 |
+
`torch.nn.Embedding`: Pointer to the input tokens Embeddings Module of the model.
|
238 |
+
|
239 |
+
Note:
|
240 |
+
This method differs from the base class implementation by resizing the `embedding_size` attribute of the
|
241 |
+
model configuration instead of the `vocab_size`. It also includes a warning if the resized `embedding_size`
|
242 |
+
is less than the `vocab_size`. In OLMo, `embedding_size` refers to the dimensionality of the model's token
|
243 |
+
embeddings, while `vocab_size` refers to the number of unique tokens in the vocabulary.
|
244 |
+
"""
|
245 |
+
model_embeds = self._resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
|
246 |
+
if new_num_tokens is None and pad_to_multiple_of is None:
|
247 |
+
return model_embeds
|
248 |
+
|
249 |
+
# Update base model and current model config
|
250 |
+
self.config.embedding_size = model_embeds.weight.shape[0]
|
251 |
+
self.model.config.embedding_size = model_embeds.weight.shape[0]
|
252 |
+
|
253 |
+
# Check if the embedding size is less than the vocab size
|
254 |
+
if self.config.embedding_size < self.config.vocab_size:
|
255 |
+
warning_message = (
|
256 |
+
f"Resizing token embeddings to size {self.config.embedding_size}, which is less than the vocab size "
|
257 |
+
f"{self.config.vocab_size} defined in the model configuration. Make sure your tokenizer's vocabulary "
|
258 |
+
"size is less than or equal to the new token embedding size."
|
259 |
+
)
|
260 |
+
log.warning(warning_message)
|
261 |
+
|
262 |
+
# Tie weights again if needed
|
263 |
+
self.tie_weights()
|
264 |
+
|
265 |
+
return model_embeds
|
266 |
+
|
267 |
+
|
268 |
+
# Register the model so that it is available for transformer pipelines, auto-loading, etc.
|
269 |
+
# OLMo is integrated directly in transformers from v4.40.0 onwards, but the version in transformers
|
270 |
+
# may not support the newest architectures we create.
|
271 |
+
AutoModelForCausalLM.register(OLMoConfig, OLMoForCausalLM)
|
optim.py
ADDED
@@ -0,0 +1,1040 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
from abc import ABCMeta, abstractmethod
|
3 |
+
from dataclasses import dataclass, replace
|
4 |
+
from math import cos, pi, sqrt
|
5 |
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import torch.distributed as dist
|
9 |
+
import torch.nn as nn
|
10 |
+
from torch.distributed.fsdp import FullyShardedDataParallel
|
11 |
+
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
12 |
+
from torch.optim.optimizer import Optimizer as OptimizerBase
|
13 |
+
|
14 |
+
from . import LayerNormBase
|
15 |
+
from .config import OptimizerType, SchedulerConfig, SchedulerType, TrainConfig
|
16 |
+
from .torch_util import get_default_device, is_distributed
|
17 |
+
|
18 |
+
__all__ = [
|
19 |
+
"Optimizer",
|
20 |
+
"LionW",
|
21 |
+
"AdamW",
|
22 |
+
"Scheduler",
|
23 |
+
"CosWithWarmup",
|
24 |
+
"LinearWithWarmup",
|
25 |
+
"InvSqrtWithWarmup",
|
26 |
+
"MaxScheduler",
|
27 |
+
"ConstantScheduler",
|
28 |
+
"CosLinearEnvelope",
|
29 |
+
"BoltOnWarmupScheduler",
|
30 |
+
"build_optimizer",
|
31 |
+
"build_scheduler",
|
32 |
+
]
|
33 |
+
|
34 |
+
|
35 |
+
log = logging.getLogger(__name__)
|
36 |
+
|
37 |
+
|
38 |
+
class Optimizer(OptimizerBase):
|
39 |
+
def __init__(self, *args, record_update_metrics: bool = False, selective_updates: bool = False, **kwargs):
|
40 |
+
super().__init__(*args, **kwargs)
|
41 |
+
self._record_update_metrics = record_update_metrics
|
42 |
+
self._collecting_metrics = False
|
43 |
+
self._selective_updates = selective_updates
|
44 |
+
|
45 |
+
def _clean_param_name(self, name: str) -> str:
|
46 |
+
return name.replace("_fsdp_wrapped_module.", "")
|
47 |
+
|
48 |
+
@torch.no_grad()
|
49 |
+
def clip_grads_and_collect_metrics(
|
50 |
+
self,
|
51 |
+
global_step: int,
|
52 |
+
collect_param_metrics: bool = True,
|
53 |
+
process_group: Optional[dist.ProcessGroup] = None,
|
54 |
+
device: Optional[torch.device] = None,
|
55 |
+
) -> Dict[str, torch.Tensor]:
|
56 |
+
"""
|
57 |
+
Clips gradients for every group that has the field `max_grad_norm`.
|
58 |
+
At the same time collect metrics for each parameter and its gradient.
|
59 |
+
"""
|
60 |
+
self._collecting_metrics = collect_param_metrics
|
61 |
+
device = get_default_device() if device is None else device
|
62 |
+
|
63 |
+
# NOTE (epwalsh): during distributed training we're making an assumption that the order of
|
64 |
+
# the param groups and the params within each group are the same across all ranks.
|
65 |
+
# This is justified since we initialize the parameter groups in every rank by iterating over
|
66 |
+
# `module.parameters()` or `module.named_modules()` / `module.named_parameters()`, each of which
|
67 |
+
# provides a consistent order.
|
68 |
+
# For each parameter (with a gradient) we'll collect:
|
69 |
+
# - min, max, avg, norm of the param itself
|
70 |
+
# - min, max, avg, norm of the param's gradient
|
71 |
+
# - min, max, avg, norm of any additional per-parameter optimizer state metrics returned from
|
72 |
+
# `self.get_state_for_param()`.
|
73 |
+
# Afterwards we'll reduce these all over all ranks.
|
74 |
+
per_param_min_metrics: List[torch.Tensor] = []
|
75 |
+
per_param_max_metrics: List[torch.Tensor] = []
|
76 |
+
per_param_sum_metrics: List[torch.Tensor] = []
|
77 |
+
per_param_norm_metrics: List[torch.Tensor] = []
|
78 |
+
per_param_numel_metrics: List[torch.Tensor] = []
|
79 |
+
|
80 |
+
per_param_min_metric_names: List[str] = []
|
81 |
+
per_param_max_metric_names: List[str] = []
|
82 |
+
per_param_avg_metric_names: List[str] = []
|
83 |
+
per_param_norm_metric_names: List[str] = []
|
84 |
+
|
85 |
+
dst_rank = 0
|
86 |
+
if process_group is not None:
|
87 |
+
dst_rank = dist.get_global_rank(process_group, 0)
|
88 |
+
|
89 |
+
#######################################################################
|
90 |
+
# part 1: collect metrics locally
|
91 |
+
#######################################################################
|
92 |
+
for group in self.param_groups:
|
93 |
+
for name, p in zip(group["param_names"], group["params"]):
|
94 |
+
name = self._clean_param_name(name)
|
95 |
+
# Always need to collect the norm of gradients for clipping, even if we're not collecting
|
96 |
+
# other metrics.
|
97 |
+
tensors: List[Optional[torch.Tensor]] = [p.grad]
|
98 |
+
prefixes: List[str] = [f"grad/{name}"]
|
99 |
+
if collect_param_metrics:
|
100 |
+
state = self.get_state_for_param(p)
|
101 |
+
sorted_state_keys = sorted([k for k in state.keys()])
|
102 |
+
tensors.extend([p] + [state[key] for key in sorted_state_keys])
|
103 |
+
prefixes.extend([f"param/{name}"] + [f"{key}/{name}" for key in sorted_state_keys])
|
104 |
+
assert len(tensors) == len(prefixes)
|
105 |
+
|
106 |
+
# Get min, max, avg, and norm for all `tensors` associated with the parameter.
|
107 |
+
for x, prefix in zip(tensors, prefixes):
|
108 |
+
# grad or state tensors could be none for params that have their shards completely on
|
109 |
+
# other ranks.
|
110 |
+
if x is not None and x.numel() > 0:
|
111 |
+
if collect_param_metrics:
|
112 |
+
x_abs = x.abs()
|
113 |
+
per_param_min_metrics.append(x_abs.min().unsqueeze(0).to(dtype=torch.float32))
|
114 |
+
per_param_max_metrics.append(x_abs.max().unsqueeze(0).to(dtype=torch.float32))
|
115 |
+
per_param_sum_metrics.append(x.sum().unsqueeze(0).to(dtype=torch.float32))
|
116 |
+
per_param_numel_metrics.append(
|
117 |
+
torch.tensor([x.numel()], device=device, dtype=torch.float32)
|
118 |
+
)
|
119 |
+
per_param_norm_metrics.append(
|
120 |
+
torch.linalg.vector_norm(x, 2.0, dtype=torch.float32).unsqueeze(0)
|
121 |
+
)
|
122 |
+
else:
|
123 |
+
if collect_param_metrics:
|
124 |
+
per_param_min_metrics.append(
|
125 |
+
torch.tensor([float("inf")], device=device, dtype=torch.float32)
|
126 |
+
)
|
127 |
+
per_param_max_metrics.append(torch.tensor([0.0], device=device, dtype=torch.float32))
|
128 |
+
per_param_sum_metrics.append(torch.tensor([0.0], device=device, dtype=torch.float32))
|
129 |
+
per_param_numel_metrics.append(torch.tensor([0.0], device=device, dtype=torch.float32))
|
130 |
+
per_param_norm_metrics.append(torch.tensor([0.0], device=device, dtype=torch.float32))
|
131 |
+
if collect_param_metrics:
|
132 |
+
per_param_min_metric_names.append(f"{prefix}.min")
|
133 |
+
per_param_max_metric_names.append(f"{prefix}.max")
|
134 |
+
per_param_avg_metric_names.append(f"{prefix}.avg")
|
135 |
+
per_param_norm_metric_names.append(f"{prefix}.norm")
|
136 |
+
|
137 |
+
assert (
|
138 |
+
len(per_param_min_metrics)
|
139 |
+
== len(per_param_min_metric_names)
|
140 |
+
== len(per_param_max_metrics)
|
141 |
+
== len(per_param_max_metric_names)
|
142 |
+
== len(per_param_sum_metrics)
|
143 |
+
== len(per_param_numel_metrics)
|
144 |
+
== len(per_param_avg_metric_names)
|
145 |
+
)
|
146 |
+
assert len(per_param_norm_metrics) == len(per_param_norm_metric_names)
|
147 |
+
|
148 |
+
def is_grad_norm_metric(metric_name: str) -> bool:
|
149 |
+
return metric_name.startswith("grad/") and metric_name.endswith(".norm")
|
150 |
+
|
151 |
+
#######################################################################
|
152 |
+
# part 2: reduce metrics over ranks
|
153 |
+
#######################################################################
|
154 |
+
param_group_sharded = False
|
155 |
+
for group in self.param_groups:
|
156 |
+
param_group_sharded = param_group_sharded or group.get("sharded", False)
|
157 |
+
|
158 |
+
total_grad_norm: torch.Tensor
|
159 |
+
per_param_avg_metrics: List[torch.Tensor] = []
|
160 |
+
if is_distributed() and param_group_sharded:
|
161 |
+
# Reduce metrics across all ranks. Note that we can use a `reduce` for most cases
|
162 |
+
# instead of an `all_reduce`, but we need `all_reduce` for norms so that all ranks
|
163 |
+
# get the right value for gradient norms so they can clip correctly.
|
164 |
+
# Reduce mins.
|
165 |
+
if per_param_min_metrics:
|
166 |
+
all_mins = torch.cat(per_param_min_metrics).to(device)
|
167 |
+
dist.reduce(all_mins, dst_rank, op=dist.ReduceOp.MIN, group=process_group)
|
168 |
+
per_param_min_metrics = all_mins.split(1)
|
169 |
+
# Reduce maxs.
|
170 |
+
if per_param_max_metrics:
|
171 |
+
all_maxs = torch.cat(per_param_max_metrics).to(device)
|
172 |
+
dist.reduce(all_maxs, dst_rank, op=dist.ReduceOp.MAX, group=process_group)
|
173 |
+
per_param_max_metrics = all_maxs.split(1)
|
174 |
+
# Reduce sums or just norms.
|
175 |
+
all_norms = torch.cat(per_param_norm_metrics).to(device) ** 2.0
|
176 |
+
if per_param_sum_metrics and per_param_numel_metrics:
|
177 |
+
all_sums = torch.cat(per_param_sum_metrics).to(device)
|
178 |
+
all_numels = torch.cat(per_param_numel_metrics).to(device)
|
179 |
+
all_sums_norms_numels = torch.cat(
|
180 |
+
[all_sums.unsqueeze(0), all_norms.unsqueeze(0), all_numels.unsqueeze(0)], dim=0
|
181 |
+
)
|
182 |
+
dist.all_reduce(all_sums_norms_numels, op=dist.ReduceOp.SUM, group=process_group)
|
183 |
+
all_sums, all_norms, all_numels = all_sums_norms_numels.split(1)
|
184 |
+
# Get averages.
|
185 |
+
# NOTE: could get infs for non-rank0 processes but that's okay.
|
186 |
+
per_param_avg_metrics = (all_sums / all_numels).squeeze(0).split(1)
|
187 |
+
else:
|
188 |
+
dist.all_reduce(all_norms, op=dist.ReduceOp.SUM, group=process_group)
|
189 |
+
grad_norm_metric_mask = torch.tensor(
|
190 |
+
[float(is_grad_norm_metric(n)) for n in per_param_norm_metric_names], device=all_norms.device
|
191 |
+
)
|
192 |
+
total_grad_norm = (all_norms * grad_norm_metric_mask).sum() ** 0.5
|
193 |
+
per_param_norm_metrics = (all_norms ** (0.5)).squeeze(0).split(1)
|
194 |
+
else:
|
195 |
+
total_grad_norm = (
|
196 |
+
torch.cat(
|
197 |
+
[
|
198 |
+
m
|
199 |
+
for m, n in zip(per_param_norm_metrics, per_param_norm_metric_names)
|
200 |
+
if is_grad_norm_metric(n)
|
201 |
+
]
|
202 |
+
)
|
203 |
+
** 2.0
|
204 |
+
).sum() ** 0.5
|
205 |
+
per_param_avg_metrics = [x / n for x, n in zip(per_param_sum_metrics, per_param_numel_metrics)]
|
206 |
+
|
207 |
+
assert len(per_param_avg_metrics) == len(per_param_avg_metric_names)
|
208 |
+
|
209 |
+
# Collect all metrics into a single dict.
|
210 |
+
all_metrics: Dict[str, torch.Tensor] = {}
|
211 |
+
if collect_param_metrics:
|
212 |
+
for metric_name, metric in zip(per_param_min_metric_names, per_param_min_metrics):
|
213 |
+
all_metrics[metric_name] = metric.squeeze(0)
|
214 |
+
for metric_name, metric in zip(per_param_max_metric_names, per_param_max_metrics):
|
215 |
+
all_metrics[metric_name] = metric.squeeze(0)
|
216 |
+
for metric_name, metric in zip(per_param_avg_metric_names, per_param_avg_metrics):
|
217 |
+
all_metrics[metric_name] = metric.squeeze(0)
|
218 |
+
|
219 |
+
for metric_name, metric in zip(per_param_norm_metric_names, per_param_norm_metrics):
|
220 |
+
all_metrics[metric_name] = metric.squeeze(0)
|
221 |
+
all_metrics["total_grad_norm"] = total_grad_norm
|
222 |
+
|
223 |
+
#######################################################################
|
224 |
+
# part 3: clip grads
|
225 |
+
#######################################################################
|
226 |
+
num_grads_clipped = 0
|
227 |
+
num_eligible_grads = 0
|
228 |
+
for group in self.param_groups:
|
229 |
+
if (max_norm_ratio := group.get("max_grad_norm_ratio")) is not None:
|
230 |
+
num_clipped = self._do_adaptive_clipping(
|
231 |
+
group, max_norm_ratio, global_step, all_metrics, collect_param_metrics=collect_param_metrics
|
232 |
+
)
|
233 |
+
elif (max_norm := group.get("max_grad_norm")) is not None:
|
234 |
+
num_clipped = self._do_global_fixed_clipping(
|
235 |
+
group, max_norm, all_metrics, collect_param_metrics=collect_param_metrics
|
236 |
+
)
|
237 |
+
else:
|
238 |
+
# No clipping needed.
|
239 |
+
continue
|
240 |
+
num_eligible_grads += len(group["params"])
|
241 |
+
if num_clipped is not None:
|
242 |
+
num_grads_clipped += num_clipped
|
243 |
+
|
244 |
+
if collect_param_metrics:
|
245 |
+
if num_eligible_grads > 0:
|
246 |
+
clipping_rate = torch.tensor(num_grads_clipped / num_eligible_grads, device="cpu")
|
247 |
+
else:
|
248 |
+
clipping_rate = torch.tensor(0.0, device="cpu")
|
249 |
+
all_metrics["clipping_rate"] = clipping_rate
|
250 |
+
|
251 |
+
# total_grad_norm is computed at all steps, even when collect_param_metrics is set to False
|
252 |
+
return all_metrics
|
253 |
+
|
254 |
+
@torch.no_grad()
|
255 |
+
def _do_adaptive_clipping(
|
256 |
+
self,
|
257 |
+
group: Dict[str, Any],
|
258 |
+
max_norm_ratio: float,
|
259 |
+
global_step: int,
|
260 |
+
all_metrics: Dict[str, torch.Tensor],
|
261 |
+
collect_param_metrics: bool = True,
|
262 |
+
device: Optional[torch.device] = None,
|
263 |
+
) -> Optional[int]:
|
264 |
+
"""
|
265 |
+
Do adaptive gradient clipping on a param group.
|
266 |
+
|
267 |
+
If ``collect_param_metrics`` is ``True`` this will return the total number of gradients clipped.
|
268 |
+
"""
|
269 |
+
device = get_default_device() if device is None else device
|
270 |
+
num_grads_clipped = 0
|
271 |
+
# We'll use the bigger of beta1 and beta2 to update the exponential average of the norm of
|
272 |
+
# the gradient (a scalar), not to be confused with the exponential average of the gradient.
|
273 |
+
# TODO (epwalsh): handle optimizers that don't have betas.
|
274 |
+
beta1, beta2 = group["betas"]
|
275 |
+
beta = max(beta1, beta2)
|
276 |
+
for name, p in zip(group["param_names"], group["params"]):
|
277 |
+
name = self._clean_param_name(name)
|
278 |
+
grad_norm = all_metrics.get(f"grad/{name}.norm")
|
279 |
+
if grad_norm is None:
|
280 |
+
continue
|
281 |
+
|
282 |
+
# Get or initialize the exponential average of grad norm.
|
283 |
+
# TODO: The way we have it right now, every rank tracks the `grad_norm_exp_avg` of every parameter,
|
284 |
+
# even parameters for which the corresponding local shard is empty. This has the potential to
|
285 |
+
# cause some issues with the optimizer, as we ran into with https://github.com/allenai/LLM/pull/372.
|
286 |
+
# So we should consider changing how we do this at some point so that we don't add any state
|
287 |
+
# to parameters for which the local shard is empty. That would probably add extra distributed
|
288 |
+
# communication, at least on steps where we have to log (i.e. when `collect_param_metrics=True`).
|
289 |
+
state = self.state[p]
|
290 |
+
grad_norm_exp_avg = state.get("grad_norm_exp_avg")
|
291 |
+
if grad_norm_exp_avg is None:
|
292 |
+
grad_norm_exp_avg = grad_norm.clone().to(device)
|
293 |
+
# We don't want to add anything to `state` until `state` has been initialized, otherwise
|
294 |
+
# this will crash some optimizers which rely on checking `len(state)`. The downside here
|
295 |
+
# is that we won't start tracking `grad_norm_exp_avg` until the 2nd training step.
|
296 |
+
if global_step > 1:
|
297 |
+
state["grad_norm_exp_avg"] = grad_norm_exp_avg
|
298 |
+
|
299 |
+
max_allowed_norm = max_norm_ratio * grad_norm_exp_avg
|
300 |
+
clip_coef = max_allowed_norm / (grad_norm + 1e-6)
|
301 |
+
|
302 |
+
# Clip the gradients and update the exponential average.
|
303 |
+
# Note that multiplying by the clamped coefficient is meaningless when it is
|
304 |
+
# equal to 1, but it avoids the host-device sync that would result from `if clip_coef_clamped < 1`.
|
305 |
+
clip_coef_clamped = torch.clamp(clip_coef, max=1.0)
|
306 |
+
if p.grad is not None:
|
307 |
+
# p.grad could be none for some ranks when using FSDP.
|
308 |
+
p.grad.detach().mul_(clip_coef_clamped.to(p.grad.device, p.grad.dtype))
|
309 |
+
|
310 |
+
# Update the exponential average of the norm of the gradient with the clipped norm of the gradient.
|
311 |
+
grad_norm_exp_avg.lerp_((grad_norm * clip_coef_clamped).to(grad_norm_exp_avg.device), 1 - beta)
|
312 |
+
# Alternative: update with the *unclipped* norm of the gradient.
|
313 |
+
# grad_norm_exp_avg.lerp_(grad_norm.to(grad_norm_exp_avg.device), 1 - beta)
|
314 |
+
|
315 |
+
if collect_param_metrics:
|
316 |
+
# Can't avoid host-device sync here.
|
317 |
+
if clip_coef_clamped < 1.0:
|
318 |
+
num_grads_clipped += 1
|
319 |
+
all_metrics[f"grad_norm_exp_avg/{name}"] = grad_norm_exp_avg
|
320 |
+
return num_grads_clipped if collect_param_metrics else None
|
321 |
+
|
322 |
+
@torch.no_grad()
|
323 |
+
def _do_global_fixed_clipping(
|
324 |
+
self,
|
325 |
+
group: Dict[str, Any],
|
326 |
+
max_norm: float,
|
327 |
+
all_metrics: Dict[str, torch.Tensor],
|
328 |
+
collect_param_metrics: bool = True,
|
329 |
+
device: Optional[torch.device] = None,
|
330 |
+
) -> Optional[int]:
|
331 |
+
"""
|
332 |
+
Do global fixed gradient clipping on a param group.
|
333 |
+
|
334 |
+
If ``collect_param_metrics`` is ``True`` this will return the total number of gradients clipped.
|
335 |
+
"""
|
336 |
+
device = get_default_device() if device is None else device
|
337 |
+
total_grad_norm = all_metrics["total_grad_norm"]
|
338 |
+
clip_coef = max_norm / (total_grad_norm.to(device) + 1e-6)
|
339 |
+
clip_coef_clamped = torch.clamp(clip_coef, max=1.0)
|
340 |
+
num_grads_clipped: Optional[int] = None
|
341 |
+
if collect_param_metrics:
|
342 |
+
# Can't avoid host-device sync here.
|
343 |
+
if clip_coef_clamped < 1.0:
|
344 |
+
num_grads_clipped = len(group["params"])
|
345 |
+
for p in group["params"]:
|
346 |
+
# Clip the gradients.
|
347 |
+
# Note that multiplying by the clamped coefficient is meaningless when it is
|
348 |
+
# equal to 1, but it avoids the host-device sync that would result from `if clip_coef_clamped < 1`.
|
349 |
+
if p.grad is not None:
|
350 |
+
# p.grad could be none for some ranks when using FSDP.
|
351 |
+
p.grad.detach().mul_(clip_coef_clamped.to(p.grad.device, p.grad.dtype))
|
352 |
+
return num_grads_clipped
|
353 |
+
|
354 |
+
def get_post_step_metrics(
|
355 |
+
self, module: nn.Module, process_group: Optional[dist.ProcessGroup] = None
|
356 |
+
) -> Dict[str, torch.Tensor]:
|
357 |
+
del module, process_group
|
358 |
+
return {}
|
359 |
+
|
360 |
+
def get_state_for_param(self, param: nn.Parameter) -> Dict[str, Optional[torch.Tensor]]:
|
361 |
+
del param
|
362 |
+
return {}
|
363 |
+
|
364 |
+
|
365 |
+
class LionW(Optimizer):
|
366 |
+
"""
|
367 |
+
Adapted from https://github.com/google/automl/blob/master/lion/lion_pytorch.py
|
368 |
+
"""
|
369 |
+
|
370 |
+
def __init__(
|
371 |
+
self,
|
372 |
+
params,
|
373 |
+
lr: float = 1e-4,
|
374 |
+
betas: Tuple[float, float] = (0.9, 0.99),
|
375 |
+
weight_decay: float = 0.0,
|
376 |
+
record_update_metrics: bool = False,
|
377 |
+
selective_updates: bool = False,
|
378 |
+
device: Optional[torch.device] = None,
|
379 |
+
):
|
380 |
+
assert lr > 0.0
|
381 |
+
assert all([0.0 <= beta <= 1.0 for beta in betas])
|
382 |
+
defaults = dict(lr=lr, betas=betas, weight_decay=weight_decay)
|
383 |
+
super().__init__(
|
384 |
+
params, defaults, record_update_metrics=record_update_metrics, selective_updates=selective_updates
|
385 |
+
)
|
386 |
+
for group in self.param_groups:
|
387 |
+
group["initial_lr"] = group["lr"]
|
388 |
+
self._update_total_dot_prod: Optional[torch.Tensor] = None
|
389 |
+
self._update_total_norm: Optional[torch.Tensor] = None
|
390 |
+
self._signed_update_total_norm: Optional[torch.Tensor] = None
|
391 |
+
self._device: Optional[torch.device] = device
|
392 |
+
|
393 |
+
def get_post_step_metrics(
|
394 |
+
self, module: nn.Module, process_group: Optional[dist.ProcessGroup] = None
|
395 |
+
) -> Dict[str, torch.Tensor]:
|
396 |
+
assert isinstance(
|
397 |
+
module, FSDP
|
398 |
+
), "`get_post_step_metrics` expects module to be FSDP and will not work with other `distributed_strategy`."
|
399 |
+
|
400 |
+
update_total_dot_prod = self._update_total_dot_prod
|
401 |
+
update_total_norm = self._update_total_norm
|
402 |
+
signed_update_total_norm = self._signed_update_total_norm
|
403 |
+
if update_total_dot_prod is None or update_total_norm is None or signed_update_total_norm is None:
|
404 |
+
return {}
|
405 |
+
|
406 |
+
self._update_total_dot_prod = None
|
407 |
+
self._update_total_norm = None
|
408 |
+
self._signed_update_total_norm = None
|
409 |
+
|
410 |
+
if is_distributed() and isinstance(module, FullyShardedDataParallel):
|
411 |
+
# Reduce total dot prod and norms across all ranks.
|
412 |
+
update_total_norm = update_total_norm**2.0
|
413 |
+
signed_update_total_norm = signed_update_total_norm**2.0
|
414 |
+
# Reduce all together to avoid multiple communication calls.
|
415 |
+
all_together = torch.stack([update_total_dot_prod, update_total_norm, signed_update_total_norm])
|
416 |
+
# Only need the final result on rank0, since that's where we log from.
|
417 |
+
dist.reduce(
|
418 |
+
all_together,
|
419 |
+
0 if process_group is None else dist.get_global_rank(process_group, 0),
|
420 |
+
group=process_group,
|
421 |
+
)
|
422 |
+
update_total_dot_prod, update_total_norm, signed_update_total_norm = all_together
|
423 |
+
update_total_norm = update_total_norm**0.5
|
424 |
+
signed_update_total_norm = signed_update_total_norm**0.5
|
425 |
+
|
426 |
+
update_cos_sim = update_total_dot_prod / torch.max(
|
427 |
+
update_total_norm * signed_update_total_norm,
|
428 |
+
torch.tensor(1e-8, device=get_default_device() if self._device is None else self._device),
|
429 |
+
)
|
430 |
+
return {"update_cos_sim": update_cos_sim}
|
431 |
+
|
432 |
+
@torch.no_grad()
|
433 |
+
def step(self, closure=None) -> None:
|
434 |
+
if closure is not None:
|
435 |
+
with torch.enable_grad():
|
436 |
+
closure()
|
437 |
+
|
438 |
+
update_total_dot_prod: Optional[torch.Tensor] = None
|
439 |
+
update_norms: Optional[List[torch.Tensor]] = None
|
440 |
+
signed_update_norms: Optional[List[torch.Tensor]] = None
|
441 |
+
if self._collecting_metrics and self._record_update_metrics:
|
442 |
+
update_total_dot_prod = torch.tensor(0.0, dtype=torch.float32)
|
443 |
+
update_norms = []
|
444 |
+
signed_update_norms = []
|
445 |
+
|
446 |
+
for group in self.param_groups:
|
447 |
+
for p in group["params"]:
|
448 |
+
grad = p.grad
|
449 |
+
if grad is None:
|
450 |
+
continue
|
451 |
+
|
452 |
+
state = self.state[p]
|
453 |
+
|
454 |
+
# Perform step weight decay
|
455 |
+
mask: Union[torch.Tensor, int] = grad != 0 if self._selective_updates else 1
|
456 |
+
p.data.mul_(1 - mask * (group["lr"] * group["weight_decay"]))
|
457 |
+
|
458 |
+
# State initialization
|
459 |
+
if len(state) == 0:
|
460 |
+
# Exponential moving average of gradient values
|
461 |
+
state["exp_avg"] = torch.zeros_like(p)
|
462 |
+
|
463 |
+
exp_avg = state["exp_avg"]
|
464 |
+
beta1, beta2 = group["betas"]
|
465 |
+
|
466 |
+
# Weight update
|
467 |
+
update = exp_avg * beta1 + grad * (1 - beta1)
|
468 |
+
if isinstance(mask, torch.Tensor):
|
469 |
+
# When mask isn't a tensor it's just a literal `1` (python int), so there's
|
470 |
+
# no point in calling this op.
|
471 |
+
update.mul_(mask)
|
472 |
+
signed_update = torch.sign(update)
|
473 |
+
p.add_(signed_update, alpha=-group["lr"])
|
474 |
+
|
475 |
+
# Decay the momentum running average coefficient
|
476 |
+
exp_avg.mul_(1 - mask * (1 - beta2)).add_(grad, alpha=1 - beta2)
|
477 |
+
|
478 |
+
# Track dot product and norms of update vs signed update in order to calculate
|
479 |
+
# their cosine similarity.
|
480 |
+
if (
|
481 |
+
update_total_dot_prod is not None
|
482 |
+
and update_norms is not None
|
483 |
+
and signed_update_norms is not None
|
484 |
+
):
|
485 |
+
update_total_dot_prod = update_total_dot_prod.to(update.device)
|
486 |
+
update_total_dot_prod += torch.tensordot(update, signed_update, dims=len(update.shape))
|
487 |
+
update_norms.append(torch.linalg.vector_norm(update, 2.0, dtype=torch.float32))
|
488 |
+
signed_update_norms.append(torch.linalg.vector_norm(signed_update, 2.0, dtype=torch.float32))
|
489 |
+
|
490 |
+
# Compute cosine similarity between update and signed update.
|
491 |
+
if update_total_dot_prod is not None and update_norms is not None and signed_update_norms is not None:
|
492 |
+
device = get_default_device() if self._device is None else self._device
|
493 |
+
self._update_total_dot_prod = update_total_dot_prod.to(device)
|
494 |
+
self._update_total_norm = torch.linalg.vector_norm(
|
495 |
+
torch.stack(update_norms),
|
496 |
+
2.0,
|
497 |
+
dtype=torch.float32,
|
498 |
+
).to(device)
|
499 |
+
self._signed_update_total_norm = torch.linalg.vector_norm(
|
500 |
+
torch.stack(signed_update_norms),
|
501 |
+
2.0,
|
502 |
+
dtype=torch.float32,
|
503 |
+
).to(device)
|
504 |
+
|
505 |
+
|
506 |
+
class AdamW(torch.optim.AdamW, Optimizer):
|
507 |
+
def __init__(self, *args, record_update_metrics: bool = False, selective_updates: bool = False, **kwargs):
|
508 |
+
super().__init__(*args, **kwargs)
|
509 |
+
|
510 |
+
# Need to set these here just like in our base `Optimizer` class since our `Optimizer.__init__`
|
511 |
+
# won't be called.
|
512 |
+
self._record_update_metrics = record_update_metrics
|
513 |
+
self._collecting_metrics = False
|
514 |
+
self._selective_updates = selective_updates
|
515 |
+
|
516 |
+
self._step_size_param_names: Optional[List[str]] = None
|
517 |
+
self._step_size_norms: Optional[List[torch.Tensor]] = None
|
518 |
+
self._step_size_maxs: Optional[List[torch.Tensor]] = None
|
519 |
+
|
520 |
+
@torch.no_grad()
|
521 |
+
def step(self, closure=None) -> None:
|
522 |
+
if not (self._record_update_metrics and self._collecting_metrics) and not self._selective_updates:
|
523 |
+
return super().step(closure=closure)
|
524 |
+
|
525 |
+
device = get_default_device()
|
526 |
+
param_names = []
|
527 |
+
step_size_norms = []
|
528 |
+
step_size_maxs = []
|
529 |
+
for group in self.param_groups:
|
530 |
+
beta1, beta2 = group["betas"]
|
531 |
+
lr = group["lr"]
|
532 |
+
weight_decay = group["weight_decay"]
|
533 |
+
eps = group["eps"]
|
534 |
+
amsgrad = group["amsgrad"]
|
535 |
+
for name, param in zip(group["param_names"], group["params"]):
|
536 |
+
name = self._clean_param_name(name)
|
537 |
+
param_names.append(name)
|
538 |
+
grad = param.grad
|
539 |
+
if grad is None:
|
540 |
+
step_size_norms.append(torch.tensor([0.0], device=device))
|
541 |
+
step_size_maxs.append(torch.tensor([0.0], device=device))
|
542 |
+
continue
|
543 |
+
|
544 |
+
state = self.state[param]
|
545 |
+
# init state if needed
|
546 |
+
if len(state) == 0:
|
547 |
+
state["step"] = (
|
548 |
+
torch.zeros((), dtype=torch.float32, device=param.device)
|
549 |
+
if group["capturable"] or group["fused"]
|
550 |
+
else torch.tensor(0.0, dtype=torch.float32)
|
551 |
+
)
|
552 |
+
# Exponential moving average of gradient values
|
553 |
+
state["exp_avg"] = torch.zeros_like(param, memory_format=torch.preserve_format)
|
554 |
+
# Exponential moving average of squared gradient values
|
555 |
+
state["exp_avg_sq"] = torch.zeros_like(param, memory_format=torch.preserve_format)
|
556 |
+
if amsgrad:
|
557 |
+
# Maintains max of all exp. moving avg. of sq. grad. values
|
558 |
+
state["max_exp_avg_sq"] = torch.zeros_like(param, memory_format=torch.preserve_format)
|
559 |
+
|
560 |
+
exp_avg = state["exp_avg"]
|
561 |
+
exp_avg_sq = state["exp_avg_sq"]
|
562 |
+
step_t = state["step"]
|
563 |
+
|
564 |
+
# Update step.
|
565 |
+
step_t += 1
|
566 |
+
|
567 |
+
# Perform step weight decay.
|
568 |
+
mask: Union[torch.Tensor, int] = grad != 0 if self._selective_updates else 1
|
569 |
+
param.mul_(1 - mask * (lr * weight_decay))
|
570 |
+
|
571 |
+
# Decay the first and second moment running average coefficient.
|
572 |
+
exp_avg.lerp_(grad, mask * (1 - beta1))
|
573 |
+
exp_avg_sq.mul_(1 - mask * (1 - beta2)).addcmul_(grad, grad, value=1 - beta2)
|
574 |
+
|
575 |
+
step = step_t.item()
|
576 |
+
|
577 |
+
bias_correction1 = 1 - beta1**step
|
578 |
+
bias_correction2 = 1 - beta2**step
|
579 |
+
|
580 |
+
step_size = lr / bias_correction1
|
581 |
+
|
582 |
+
bias_correction2_sqrt = sqrt(bias_correction2)
|
583 |
+
|
584 |
+
if amsgrad:
|
585 |
+
max_exp_avg_sq = state["max_exp_avg_sq"]
|
586 |
+
# Maintains the maximum of all 2nd moment running avg. till now
|
587 |
+
torch.maximum(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
|
588 |
+
|
589 |
+
# Use the max. for normalizing running avg. of gradient
|
590 |
+
denom = (max_exp_avg_sq.sqrt() / bias_correction2_sqrt).add_(eps)
|
591 |
+
else:
|
592 |
+
denom = (exp_avg_sq.sqrt() / bias_correction2_sqrt).add_(eps)
|
593 |
+
|
594 |
+
update = -step_size * torch.div(exp_avg, denom)
|
595 |
+
if isinstance(mask, torch.Tensor):
|
596 |
+
# When mask isn't a tensor it's just a literal `1` (python int), so there's
|
597 |
+
# no point in calling this op.
|
598 |
+
update.mul_(mask)
|
599 |
+
param.add_(update)
|
600 |
+
step_size_norms.append(torch.linalg.vector_norm(update, 2.0, dtype=torch.float32).unsqueeze(0))
|
601 |
+
step_size_maxs.append(update.abs().max().unsqueeze(0))
|
602 |
+
|
603 |
+
self._step_size_param_names = param_names
|
604 |
+
self._step_size_norms = step_size_norms
|
605 |
+
self._step_size_maxs = step_size_maxs
|
606 |
+
|
607 |
+
def get_state_for_param(self, param: nn.Parameter) -> Dict[str, Optional[torch.Tensor]]:
|
608 |
+
return {key: self.state[param].get(key) for key in ("exp_avg", "exp_avg_sq")} # type: ignore
|
609 |
+
|
610 |
+
def get_post_step_metrics(
|
611 |
+
self, module: nn.Module, process_group: Optional[dist.ProcessGroup] = None
|
612 |
+
) -> Dict[str, torch.Tensor]:
|
613 |
+
if not (self._record_update_metrics and self._collecting_metrics):
|
614 |
+
return {}
|
615 |
+
else:
|
616 |
+
device = get_default_device()
|
617 |
+
dst_rank = 0
|
618 |
+
if process_group is not None:
|
619 |
+
dst_rank = dist.get_global_rank(process_group, 0)
|
620 |
+
param_names = self._step_size_param_names
|
621 |
+
step_size_norms = self._step_size_norms
|
622 |
+
step_size_maxs = self._step_size_maxs
|
623 |
+
assert param_names is not None
|
624 |
+
assert step_size_norms is not None
|
625 |
+
assert step_size_maxs is not None
|
626 |
+
|
627 |
+
# Reduce metrics if needed.
|
628 |
+
if is_distributed() and isinstance(module, FullyShardedDataParallel):
|
629 |
+
# Reduce norms.
|
630 |
+
all_norms = torch.cat(step_size_norms).to(device) ** 2.0
|
631 |
+
dist.reduce(all_norms, dst_rank, op=dist.ReduceOp.SUM, group=process_group)
|
632 |
+
step_size_norms = (all_norms ** (0.5)).squeeze(0).split(1)
|
633 |
+
|
634 |
+
# Reduce maxs.
|
635 |
+
all_maxs = torch.cat(step_size_maxs).to(device)
|
636 |
+
dist.reduce(all_maxs, dst_rank, op=dist.ReduceOp.MAX, group=process_group)
|
637 |
+
step_size_maxs = all_maxs.split(1)
|
638 |
+
|
639 |
+
metrics = {}
|
640 |
+
for param_name, step_size_norm, step_size_max in zip(param_names, step_size_norms, step_size_maxs): # type: ignore[arg-type]
|
641 |
+
metrics[f"step/{param_name}.norm"] = step_size_norm.squeeze(0)
|
642 |
+
metrics[f"step/{param_name}.max"] = step_size_max.squeeze(0)
|
643 |
+
|
644 |
+
self._step_size_param_names = None
|
645 |
+
self._step_size_norms = None
|
646 |
+
self._step_size_maxs = None
|
647 |
+
return metrics
|
648 |
+
|
649 |
+
|
650 |
+
@dataclass
|
651 |
+
class Scheduler(metaclass=ABCMeta):
|
652 |
+
# NOTE: these fields are not given default values because otherwise dataclasses complains
|
653 |
+
# about how the scheduler subclasses are defined.
|
654 |
+
grad_clip_warmup_steps: Optional[int]
|
655 |
+
grad_clip_warmup_factor: Optional[float]
|
656 |
+
warmup_min_lr: Optional[float]
|
657 |
+
|
658 |
+
@abstractmethod
|
659 |
+
def get_lr(self, initial_lr: float, step: int, max_steps: int) -> float:
|
660 |
+
raise NotImplementedError
|
661 |
+
|
662 |
+
def _get_max_grad_norm_coeff(
|
663 |
+
self, initial_value: Optional[float], step: int, max_steps: int
|
664 |
+
) -> Optional[float]:
|
665 |
+
del max_steps # might need this in the future, but for now I just wanted to match the API of `get_lr()`.
|
666 |
+
if initial_value is None:
|
667 |
+
return None
|
668 |
+
elif (
|
669 |
+
self.grad_clip_warmup_steps is None
|
670 |
+
or self.grad_clip_warmup_factor is None
|
671 |
+
or step > self.grad_clip_warmup_steps
|
672 |
+
):
|
673 |
+
return initial_value
|
674 |
+
else:
|
675 |
+
return self.grad_clip_warmup_factor * initial_value
|
676 |
+
|
677 |
+
def get_max_grad_norm(
|
678 |
+
self, initial_max_grad_norm: Optional[float], step: int, max_steps: int
|
679 |
+
) -> Optional[float]:
|
680 |
+
return self._get_max_grad_norm_coeff(initial_max_grad_norm, step, max_steps)
|
681 |
+
|
682 |
+
def get_max_grad_norm_ratio(
|
683 |
+
self, initial_max_grad_norm_ratio: Optional[float], step: int, max_steps: int
|
684 |
+
) -> Optional[float]:
|
685 |
+
return self._get_max_grad_norm_coeff(initial_max_grad_norm_ratio, step, max_steps)
|
686 |
+
|
687 |
+
def _linear_warmup(self, initial_lr: float, step: int, warmup_steps: int = 2000) -> float:
|
688 |
+
warmup_min_lr = self.warmup_min_lr if self.warmup_min_lr is not None else initial_lr * 0.10
|
689 |
+
assert 0 <= warmup_min_lr < initial_lr
|
690 |
+
return warmup_min_lr + (initial_lr - warmup_min_lr) * min(step, warmup_steps) / warmup_steps
|
691 |
+
|
692 |
+
|
693 |
+
@dataclass
|
694 |
+
class CosWithWarmup(Scheduler):
|
695 |
+
warmup_steps: int
|
696 |
+
alpha_f: float = 0.1
|
697 |
+
t_max: Optional[int] = None
|
698 |
+
|
699 |
+
def get_lr(self, initial_lr: float, step: int, max_steps: int) -> float:
|
700 |
+
max_steps = max_steps if self.t_max is None else self.t_max
|
701 |
+
eta_min = initial_lr * self.alpha_f
|
702 |
+
if step < self.warmup_steps:
|
703 |
+
return self._linear_warmup(initial_lr, step, self.warmup_steps)
|
704 |
+
elif step >= max_steps:
|
705 |
+
return eta_min
|
706 |
+
else:
|
707 |
+
step = step - self.warmup_steps
|
708 |
+
max_steps = max_steps - self.warmup_steps
|
709 |
+
return eta_min + (initial_lr - eta_min) * (1 + cos(pi * step / max_steps)) / 2
|
710 |
+
|
711 |
+
|
712 |
+
@dataclass
|
713 |
+
class LinearWithWarmup(Scheduler):
|
714 |
+
warmup_steps: int
|
715 |
+
alpha_f: float = 0.1
|
716 |
+
t_max: Optional[int] = None
|
717 |
+
|
718 |
+
def get_lr(self, initial_lr: float, step: int, max_steps: int) -> float:
|
719 |
+
max_steps = max_steps if self.t_max is None else self.t_max
|
720 |
+
eta_min = initial_lr * self.alpha_f
|
721 |
+
if step < self.warmup_steps:
|
722 |
+
return self._linear_warmup(initial_lr, step, self.warmup_steps)
|
723 |
+
elif step >= max_steps:
|
724 |
+
return eta_min
|
725 |
+
else:
|
726 |
+
step = step - self.warmup_steps
|
727 |
+
max_steps = max_steps - self.warmup_steps
|
728 |
+
return initial_lr - (initial_lr - eta_min) * (step / max_steps)
|
729 |
+
|
730 |
+
|
731 |
+
@dataclass
|
732 |
+
class InvSqrtWithWarmup(Scheduler):
|
733 |
+
warmup_steps: int
|
734 |
+
|
735 |
+
def get_lr(self, initial_lr: float, step: int, max_steps: int) -> float:
|
736 |
+
if step < self.warmup_steps:
|
737 |
+
return self._linear_warmup(initial_lr, step, self.warmup_steps)
|
738 |
+
del max_steps
|
739 |
+
return initial_lr * sqrt(self.warmup_steps / max(self.warmup_steps, step))
|
740 |
+
|
741 |
+
|
742 |
+
@dataclass
|
743 |
+
class MaxScheduler(Scheduler):
|
744 |
+
sched1: Scheduler
|
745 |
+
sched2: Scheduler
|
746 |
+
|
747 |
+
def get_lr(self, initial_lr: float, step: int, max_steps: int) -> float:
|
748 |
+
return max(
|
749 |
+
self.sched1.get_lr(initial_lr, step, max_steps), self.sched2.get_lr(initial_lr, step, max_steps)
|
750 |
+
)
|
751 |
+
|
752 |
+
|
753 |
+
@dataclass
|
754 |
+
class BoltOnWarmupScheduler(Scheduler):
|
755 |
+
inner: Scheduler
|
756 |
+
warmup_start: int
|
757 |
+
warmup_end: int
|
758 |
+
|
759 |
+
@classmethod
|
760 |
+
def wrap(cls, scheduler: Scheduler, warmup_start: int, warmup_end: int) -> "BoltOnWarmupScheduler":
|
761 |
+
return cls(
|
762 |
+
grad_clip_warmup_steps=None,
|
763 |
+
grad_clip_warmup_factor=None,
|
764 |
+
inner=scheduler,
|
765 |
+
warmup_start=warmup_start,
|
766 |
+
warmup_end=warmup_end,
|
767 |
+
warmup_min_lr=None,
|
768 |
+
)
|
769 |
+
|
770 |
+
def get_lr(self, initial_lr: float, step: int, max_steps: int) -> float:
|
771 |
+
if step < self.warmup_start:
|
772 |
+
return 0.0
|
773 |
+
if step < self.warmup_end:
|
774 |
+
lr_at_intercept = self.inner.get_lr(initial_lr, self.warmup_end, max_steps)
|
775 |
+
return lr_at_intercept * (step - self.warmup_start) / (self.warmup_end - self.warmup_start)
|
776 |
+
else:
|
777 |
+
return self.inner.get_lr(initial_lr, step, max_steps)
|
778 |
+
|
779 |
+
def _get_max_grad_norm_coeff(
|
780 |
+
self, initial_value: Optional[float], step: int, max_steps: int
|
781 |
+
) -> Optional[float]:
|
782 |
+
return self.inner._get_max_grad_norm_coeff(initial_value, step, max_steps)
|
783 |
+
|
784 |
+
|
785 |
+
@dataclass
|
786 |
+
class ConstantScheduler(Scheduler):
|
787 |
+
def get_lr(self, initial_lr: float, step: int, max_steps: int) -> float:
|
788 |
+
del step, max_steps
|
789 |
+
return initial_lr
|
790 |
+
|
791 |
+
|
792 |
+
@dataclass
|
793 |
+
class CosLinearEnvelope(Scheduler):
|
794 |
+
"Pointwise product of cosine schedule and linear decay; useful during annealing."
|
795 |
+
warmup_steps: int
|
796 |
+
alpha_f: float = 0.1
|
797 |
+
t_max: Optional[int] = None
|
798 |
+
|
799 |
+
def get_lr(self, initial_lr: float, step: int, max_steps: int) -> float:
|
800 |
+
max_steps = max_steps if self.t_max is None else self.t_max
|
801 |
+
eta_min = initial_lr * self.alpha_f
|
802 |
+
|
803 |
+
if step < self.warmup_steps:
|
804 |
+
return self._linear_warmup(initial_lr, step, self.warmup_steps)
|
805 |
+
if step >= max_steps:
|
806 |
+
return eta_min
|
807 |
+
else:
|
808 |
+
step = step - self.warmup_steps
|
809 |
+
max_steps = max_steps - self.warmup_steps
|
810 |
+
linear_envelope = 1 - (step / max_steps)
|
811 |
+
cosine_schedule = (initial_lr - eta_min) * (1 + cos(pi * step / max_steps)) / 2
|
812 |
+
return eta_min + linear_envelope * cosine_schedule
|
813 |
+
|
814 |
+
|
815 |
+
@dataclass
|
816 |
+
class ConstantWithWarmupScheduler(Scheduler):
|
817 |
+
warmup_steps: int
|
818 |
+
|
819 |
+
def get_lr(self, initial_lr: float, step: int, max_steps: int) -> float:
|
820 |
+
if step < self.warmup_steps:
|
821 |
+
return self._linear_warmup(initial_lr, step, self.warmup_steps)
|
822 |
+
del max_steps
|
823 |
+
return initial_lr
|
824 |
+
|
825 |
+
|
826 |
+
PARAM_GROUP_FIELDS = ("sharded", "max_grad_norm", "max_grad_norm_ratio", "param_names")
|
827 |
+
|
828 |
+
|
829 |
+
def get_param_groups(cfg: TrainConfig, model: nn.Module) -> List[Dict[str, Any]]:
|
830 |
+
"""
|
831 |
+
Separate parameters into weight decay and non weight decay groups.
|
832 |
+
"""
|
833 |
+
param_groups: List[Dict[str, Any]]
|
834 |
+
param_group_defaults = {
|
835 |
+
"sharded": isinstance(model, FullyShardedDataParallel),
|
836 |
+
"max_grad_norm": cfg.max_grad_norm,
|
837 |
+
"max_grad_norm_ratio": cfg.max_grad_norm_ratio,
|
838 |
+
}
|
839 |
+
|
840 |
+
# Separate out parameters that we don't want to apply weight decay to, like norms and biases.
|
841 |
+
decay = set()
|
842 |
+
no_decay = set()
|
843 |
+
all_params = {}
|
844 |
+
for mn, m in model.named_modules():
|
845 |
+
for pn, p in m.named_parameters():
|
846 |
+
# NOTE: because named_modules and named_parameters are recursive
|
847 |
+
# we will see the same tensors p many many times, but doing it this way
|
848 |
+
# allows us to know which parent module any tensor p belongs to...
|
849 |
+
if not p.requires_grad:
|
850 |
+
continue
|
851 |
+
|
852 |
+
fpn = f"{mn}.{pn}" if mn else pn
|
853 |
+
all_params[fpn] = p
|
854 |
+
|
855 |
+
if pn.endswith("bias"):
|
856 |
+
if cfg.optimizer.decay_norm_and_bias:
|
857 |
+
decay.add(fpn)
|
858 |
+
else:
|
859 |
+
no_decay.add(fpn)
|
860 |
+
elif pn.endswith("weight") and isinstance(m, nn.Linear):
|
861 |
+
decay.add(fpn)
|
862 |
+
elif pn.endswith("weight") and isinstance(m, (LayerNormBase, nn.LayerNorm)):
|
863 |
+
if cfg.optimizer.decay_norm_and_bias:
|
864 |
+
decay.add(fpn)
|
865 |
+
else:
|
866 |
+
no_decay.add(fpn)
|
867 |
+
elif pn.endswith("weight") and isinstance(m, nn.Embedding):
|
868 |
+
if cfg.optimizer.decay_embeddings:
|
869 |
+
decay.add(fpn)
|
870 |
+
else:
|
871 |
+
no_decay.add(fpn)
|
872 |
+
|
873 |
+
# Validate that we've considered every parameter
|
874 |
+
inter_params = decay & no_decay
|
875 |
+
union_params = decay | no_decay
|
876 |
+
assert len(inter_params) == 0, f"parameters {inter_params} made it into both decay/no_decay sets!"
|
877 |
+
assert (
|
878 |
+
len(all_params.keys() - union_params) == 0
|
879 |
+
), f"parameters {all_params.keys() - union_params} were not separated into either decay/no_decay set!"
|
880 |
+
|
881 |
+
# Create the pytorch optimizer groups.
|
882 |
+
decay_sorted = sorted(list(decay))
|
883 |
+
no_decay_sorted = sorted(list(no_decay))
|
884 |
+
param_groups = []
|
885 |
+
if len(decay_sorted) > 0:
|
886 |
+
param_groups.append(
|
887 |
+
{
|
888 |
+
"params": [all_params[pn] for pn in decay_sorted],
|
889 |
+
"param_names": decay_sorted,
|
890 |
+
**param_group_defaults,
|
891 |
+
}
|
892 |
+
)
|
893 |
+
if len(no_decay_sorted) > 0:
|
894 |
+
param_groups.append(
|
895 |
+
{
|
896 |
+
"params": [all_params[pn] for pn in no_decay_sorted],
|
897 |
+
"param_names": no_decay_sorted,
|
898 |
+
"weight_decay": 0.0,
|
899 |
+
**param_group_defaults,
|
900 |
+
}
|
901 |
+
)
|
902 |
+
|
903 |
+
# Validate fields.
|
904 |
+
for group in param_groups:
|
905 |
+
for key in PARAM_GROUP_FIELDS:
|
906 |
+
assert key in group
|
907 |
+
|
908 |
+
return param_groups
|
909 |
+
|
910 |
+
|
911 |
+
def fix_optim_state_dict(optimizer: Optimizer, state_dict: Dict[str, Any]) -> Dict[str, Any]:
|
912 |
+
"""
|
913 |
+
Make sure old optim state dicts are compatible with new versions.
|
914 |
+
"""
|
915 |
+
if len(state_dict["param_groups"]) == 1 and len(optimizer.param_groups) == 2:
|
916 |
+
assert optimizer.param_groups[1]["weight_decay"] == 0.0
|
917 |
+
|
918 |
+
# Decay
|
919 |
+
decay_param_group = {k: v for k, v in state_dict["param_groups"][0].items() if k != "params"}
|
920 |
+
decay_param_group["params"] = optimizer.state_dict()["param_groups"][0]["params"]
|
921 |
+
|
922 |
+
# No decay.
|
923 |
+
no_decay_param_group = {k: v for k, v in state_dict["param_groups"][0].items() if k != "params"}
|
924 |
+
no_decay_param_group["weight_decay"] = 0.0
|
925 |
+
no_decay_param_group["params"] = optimizer.state_dict()["param_groups"][1]["params"]
|
926 |
+
|
927 |
+
state_dict["param_groups"] = [decay_param_group, no_decay_param_group]
|
928 |
+
|
929 |
+
assert len(optimizer.param_groups) == len(state_dict["param_groups"])
|
930 |
+
|
931 |
+
# Make sure:
|
932 |
+
# - All required fields are included in the state dict,
|
933 |
+
# - And that the values of those fields doesn't change from what's currently set in the optimizer,
|
934 |
+
# since we might have changed those fields on purpose after a restart.
|
935 |
+
for group, sd_group in zip(optimizer.param_groups, state_dict["param_groups"]):
|
936 |
+
for key in PARAM_GROUP_FIELDS:
|
937 |
+
sd_group[key] = group[key]
|
938 |
+
|
939 |
+
return state_dict
|
940 |
+
|
941 |
+
|
942 |
+
def build_optimizer(cfg: TrainConfig, model: nn.Module) -> Optimizer:
|
943 |
+
param_groups = get_param_groups(cfg, model)
|
944 |
+
log.info(f"Constructing optimizer with {len(param_groups)} param groups")
|
945 |
+
if cfg.optimizer.name == OptimizerType.lionw:
|
946 |
+
return LionW(
|
947 |
+
param_groups,
|
948 |
+
lr=cfg.optimizer.learning_rate,
|
949 |
+
betas=cfg.optimizer.betas,
|
950 |
+
weight_decay=cfg.optimizer.weight_decay,
|
951 |
+
record_update_metrics=cfg.optimizer.record_update_metrics,
|
952 |
+
selective_updates=cfg.optimizer.selective_updates,
|
953 |
+
)
|
954 |
+
elif cfg.optimizer.name == OptimizerType.adamw:
|
955 |
+
return AdamW(
|
956 |
+
param_groups,
|
957 |
+
lr=cfg.optimizer.learning_rate,
|
958 |
+
betas=cfg.optimizer.betas,
|
959 |
+
weight_decay=cfg.optimizer.weight_decay,
|
960 |
+
record_update_metrics=cfg.optimizer.record_update_metrics,
|
961 |
+
selective_updates=cfg.optimizer.selective_updates,
|
962 |
+
eps=cfg.optimizer.eps,
|
963 |
+
)
|
964 |
+
else:
|
965 |
+
raise NotImplementedError
|
966 |
+
|
967 |
+
|
968 |
+
def build_scheduler(cfg: TrainConfig, sched_cfg: Optional[SchedulerConfig] = None) -> Scheduler:
|
969 |
+
sched_cfg = sched_cfg if sched_cfg is not None else cfg.scheduler
|
970 |
+
if sched_cfg.name == SchedulerType.cosine_with_warmup:
|
971 |
+
return CosWithWarmup(
|
972 |
+
grad_clip_warmup_steps=(
|
973 |
+
None if sched_cfg.grad_clip_warmup_steps is None else int(sched_cfg.grad_clip_warmup_steps)
|
974 |
+
),
|
975 |
+
grad_clip_warmup_factor=sched_cfg.grad_clip_warmup_factor,
|
976 |
+
warmup_steps=int(sched_cfg.t_warmup),
|
977 |
+
alpha_f=sched_cfg.alpha_f,
|
978 |
+
t_max=None if sched_cfg.t_max is None else int(sched_cfg.t_max),
|
979 |
+
warmup_min_lr=sched_cfg.warmup_min_lr,
|
980 |
+
)
|
981 |
+
elif sched_cfg.name == SchedulerType.linear_with_warmup:
|
982 |
+
return LinearWithWarmup(
|
983 |
+
grad_clip_warmup_steps=(
|
984 |
+
None if sched_cfg.grad_clip_warmup_steps is None else int(sched_cfg.grad_clip_warmup_steps)
|
985 |
+
),
|
986 |
+
grad_clip_warmup_factor=sched_cfg.grad_clip_warmup_factor,
|
987 |
+
warmup_steps=int(sched_cfg.t_warmup),
|
988 |
+
alpha_f=sched_cfg.alpha_f,
|
989 |
+
t_max=None if sched_cfg.t_max is None else int(sched_cfg.t_max),
|
990 |
+
warmup_min_lr=sched_cfg.warmup_min_lr,
|
991 |
+
)
|
992 |
+
elif sched_cfg.name == SchedulerType.inverse_sqrt_with_warmup:
|
993 |
+
return InvSqrtWithWarmup(
|
994 |
+
grad_clip_warmup_steps=(
|
995 |
+
None if sched_cfg.grad_clip_warmup_steps is None else int(sched_cfg.grad_clip_warmup_steps)
|
996 |
+
),
|
997 |
+
grad_clip_warmup_factor=sched_cfg.grad_clip_warmup_factor,
|
998 |
+
warmup_steps=int(sched_cfg.t_warmup),
|
999 |
+
warmup_min_lr=sched_cfg.warmup_min_lr,
|
1000 |
+
)
|
1001 |
+
elif sched_cfg.name == SchedulerType.max_scheduler:
|
1002 |
+
return MaxScheduler(
|
1003 |
+
grad_clip_warmup_steps=(
|
1004 |
+
None if sched_cfg.grad_clip_warmup_steps is None else int(sched_cfg.grad_clip_warmup_steps)
|
1005 |
+
),
|
1006 |
+
grad_clip_warmup_factor=sched_cfg.grad_clip_warmup_factor,
|
1007 |
+
sched1=build_scheduler(cfg, replace(sched_cfg, name=SchedulerType.cosine_with_warmup)),
|
1008 |
+
sched2=build_scheduler(cfg, replace(sched_cfg, name=SchedulerType.inverse_sqrt_with_warmup)),
|
1009 |
+
warmup_min_lr=sched_cfg.warmup_min_lr,
|
1010 |
+
)
|
1011 |
+
elif sched_cfg.name == SchedulerType.constant:
|
1012 |
+
return ConstantScheduler(
|
1013 |
+
grad_clip_warmup_steps=(
|
1014 |
+
None if sched_cfg.grad_clip_warmup_steps is None else int(sched_cfg.grad_clip_warmup_steps)
|
1015 |
+
),
|
1016 |
+
grad_clip_warmup_factor=sched_cfg.grad_clip_warmup_factor,
|
1017 |
+
warmup_min_lr=sched_cfg.warmup_min_lr,
|
1018 |
+
)
|
1019 |
+
elif sched_cfg.name == SchedulerType.cosine_linear_envelope:
|
1020 |
+
return CosLinearEnvelope(
|
1021 |
+
grad_clip_warmup_steps=(
|
1022 |
+
None if sched_cfg.grad_clip_warmup_steps is None else int(sched_cfg.grad_clip_warmup_steps)
|
1023 |
+
),
|
1024 |
+
grad_clip_warmup_factor=sched_cfg.grad_clip_warmup_factor,
|
1025 |
+
warmup_steps=int(sched_cfg.t_warmup),
|
1026 |
+
alpha_f=sched_cfg.alpha_f,
|
1027 |
+
t_max=None if sched_cfg.t_max is None else int(sched_cfg.t_max),
|
1028 |
+
warmup_min_lr=sched_cfg.warmup_min_lr,
|
1029 |
+
)
|
1030 |
+
elif sched_cfg.name == SchedulerType.constant_with_warmup:
|
1031 |
+
return ConstantWithWarmupScheduler(
|
1032 |
+
grad_clip_warmup_steps=(
|
1033 |
+
None if sched_cfg.grad_clip_warmup_steps is None else int(sched_cfg.grad_clip_warmup_steps)
|
1034 |
+
),
|
1035 |
+
grad_clip_warmup_factor=sched_cfg.grad_clip_warmup_factor,
|
1036 |
+
warmup_min_lr=sched_cfg.warmup_min_lr,
|
1037 |
+
warmup_steps=int(sched_cfg.t_warmup),
|
1038 |
+
)
|
1039 |
+
else:
|
1040 |
+
raise NotImplementedError
|
safetensors_util.py
ADDED
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import base64
|
2 |
+
import pickle
|
3 |
+
from dataclasses import dataclass
|
4 |
+
from typing import Dict, Optional, Tuple
|
5 |
+
|
6 |
+
import safetensors.torch
|
7 |
+
import torch
|
8 |
+
|
9 |
+
from olmo.aliases import PathOrStr
|
10 |
+
|
11 |
+
__all__ = [
|
12 |
+
"state_dict_to_safetensors_file",
|
13 |
+
"safetensors_file_to_state_dict",
|
14 |
+
]
|
15 |
+
|
16 |
+
|
17 |
+
@dataclass(eq=True, frozen=True)
|
18 |
+
class STKey:
|
19 |
+
keys: Tuple
|
20 |
+
value_is_pickled: bool
|
21 |
+
|
22 |
+
|
23 |
+
def encode_key(key: STKey) -> str:
|
24 |
+
b = pickle.dumps((key.keys, key.value_is_pickled))
|
25 |
+
b = base64.urlsafe_b64encode(b)
|
26 |
+
return str(b, "ASCII")
|
27 |
+
|
28 |
+
|
29 |
+
def decode_key(key: str) -> STKey:
|
30 |
+
b = base64.urlsafe_b64decode(key)
|
31 |
+
keys, value_is_pickled = pickle.loads(b)
|
32 |
+
return STKey(keys, value_is_pickled)
|
33 |
+
|
34 |
+
|
35 |
+
def flatten_dict(d: Dict) -> Dict[STKey, torch.Tensor]:
|
36 |
+
result = {}
|
37 |
+
for key, value in d.items():
|
38 |
+
if isinstance(value, torch.Tensor):
|
39 |
+
result[STKey((key,), False)] = value
|
40 |
+
elif isinstance(value, dict):
|
41 |
+
value = flatten_dict(value)
|
42 |
+
for inner_key, inner_value in value.items():
|
43 |
+
result[STKey((key,) + inner_key.keys, inner_key.value_is_pickled)] = inner_value
|
44 |
+
else:
|
45 |
+
pickled = bytearray(pickle.dumps(value))
|
46 |
+
pickled_tensor = torch.frombuffer(pickled, dtype=torch.uint8)
|
47 |
+
result[STKey((key,), True)] = pickled_tensor
|
48 |
+
return result
|
49 |
+
|
50 |
+
|
51 |
+
def unflatten_dict(d: Dict[STKey, torch.Tensor]) -> Dict:
|
52 |
+
result: Dict = {}
|
53 |
+
|
54 |
+
for key, value in d.items():
|
55 |
+
if key.value_is_pickled:
|
56 |
+
value = pickle.loads(value.numpy().data)
|
57 |
+
|
58 |
+
target_dict = result
|
59 |
+
for k in key.keys[:-1]:
|
60 |
+
new_target_dict = target_dict.get(k)
|
61 |
+
if new_target_dict is None:
|
62 |
+
new_target_dict = {}
|
63 |
+
target_dict[k] = new_target_dict
|
64 |
+
target_dict = new_target_dict
|
65 |
+
target_dict[key.keys[-1]] = value
|
66 |
+
|
67 |
+
return result
|
68 |
+
|
69 |
+
|
70 |
+
def state_dict_to_safetensors_file(state_dict: Dict, filename: PathOrStr):
|
71 |
+
state_dict = flatten_dict(state_dict)
|
72 |
+
state_dict = {encode_key(k): v for k, v in state_dict.items()}
|
73 |
+
safetensors.torch.save_file(state_dict, filename)
|
74 |
+
|
75 |
+
|
76 |
+
def safetensors_file_to_state_dict(filename: PathOrStr, map_location: Optional[str] = None) -> Dict:
|
77 |
+
if map_location is None:
|
78 |
+
map_location = "cpu"
|
79 |
+
state_dict = safetensors.torch.load_file(filename, device=map_location)
|
80 |
+
state_dict = {decode_key(k): v for k, v in state_dict.items()}
|
81 |
+
return unflatten_dict(state_dict)
|
torch_util.py
ADDED
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gc
|
2 |
+
import os
|
3 |
+
from typing import Optional, TypeVar
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.distributed as dist
|
7 |
+
|
8 |
+
T = TypeVar("T")
|
9 |
+
|
10 |
+
|
11 |
+
def seed_all(seed: int):
|
12 |
+
"""Seed all rng objects."""
|
13 |
+
import random
|
14 |
+
|
15 |
+
import numpy as np
|
16 |
+
|
17 |
+
if seed < 0 or seed > 2**32 - 1:
|
18 |
+
raise ValueError(f"Seed {seed} is invalid. It must be on [0; 2^32 - 1]")
|
19 |
+
random.seed(seed)
|
20 |
+
np.random.seed(seed)
|
21 |
+
torch.manual_seed(seed)
|
22 |
+
# torch.manual_seed may call manual_seed_all but calling it again here
|
23 |
+
# to make sure it gets called at least once
|
24 |
+
torch.cuda.manual_seed_all(seed)
|
25 |
+
|
26 |
+
|
27 |
+
def is_distributed() -> bool:
|
28 |
+
return dist.is_available() and dist.is_initialized()
|
29 |
+
|
30 |
+
|
31 |
+
def get_node_rank() -> int:
|
32 |
+
return int(os.environ.get("NODE_RANK") or (get_global_rank() - get_local_rank()) // get_local_world_size())
|
33 |
+
|
34 |
+
|
35 |
+
def get_world_size() -> int:
|
36 |
+
if is_distributed():
|
37 |
+
return dist.get_world_size()
|
38 |
+
else:
|
39 |
+
return 1
|
40 |
+
|
41 |
+
|
42 |
+
def get_local_world_size() -> int:
|
43 |
+
return int(os.environ.get("LOCAL_WORLD_SIZE") or 1)
|
44 |
+
|
45 |
+
|
46 |
+
def get_global_rank() -> int:
|
47 |
+
if is_distributed():
|
48 |
+
return int(os.environ.get("RANK") or dist.get_rank())
|
49 |
+
else:
|
50 |
+
return 0
|
51 |
+
|
52 |
+
|
53 |
+
def get_local_rank() -> int:
|
54 |
+
return int(os.environ.get("LOCAL_RANK") or 0)
|
55 |
+
|
56 |
+
|
57 |
+
def get_fs_local_rank() -> int:
|
58 |
+
"""Get the local rank per filesystem, meaning that, regardless of the number of nodes,
|
59 |
+
if all ranks share the same filesystem then `get_fs_local_rank()` will be equivalent to `get_global_rank()`,
|
60 |
+
but if nodes do not share the same filesystem then `get_fs_local_rank()` will be equivalent to `get_local_rank()`.
|
61 |
+
"""
|
62 |
+
if os.environ.get("OLMO_SHARED_FS"):
|
63 |
+
return int(os.environ.get("FS_LOCAL_RANK") or get_global_rank())
|
64 |
+
else:
|
65 |
+
return int(os.environ.get("FS_LOCAL_RANK") or get_local_rank())
|
66 |
+
|
67 |
+
|
68 |
+
def move_to_device(o: T, device: torch.device) -> T:
|
69 |
+
if isinstance(o, torch.Tensor):
|
70 |
+
return o.to(device) # type: ignore[return-value]
|
71 |
+
elif isinstance(o, dict):
|
72 |
+
return {k: move_to_device(v, device) for k, v in o.items()} # type: ignore[return-value]
|
73 |
+
elif isinstance(o, list):
|
74 |
+
return [move_to_device(x, device) for x in o] # type: ignore[return-value]
|
75 |
+
elif isinstance(o, tuple):
|
76 |
+
return tuple((move_to_device(x, device) for x in o)) # type: ignore[return-value]
|
77 |
+
else:
|
78 |
+
return o
|
79 |
+
|
80 |
+
|
81 |
+
def ensure_finite_(x: torch.Tensor, check_neg_inf: bool = True, check_pos_inf: bool = False):
|
82 |
+
"""
|
83 |
+
Modify ``x`` in place to replace ``float("-inf")`` with the minimum value of the dtype when ``check_neg_inf``
|
84 |
+
is ``True`` and to replace ``float("inf")`` with the maximum value of the dtype when ``check_pos_inf`` is ``True``.
|
85 |
+
"""
|
86 |
+
if check_neg_inf:
|
87 |
+
x.masked_fill_(x == float("-inf"), torch.finfo(x.dtype).min)
|
88 |
+
if check_pos_inf:
|
89 |
+
x.masked_fill_(x == float("inf"), torch.finfo(x.dtype).max)
|
90 |
+
|
91 |
+
|
92 |
+
def get_default_device() -> torch.device:
|
93 |
+
if torch.cuda.is_available() and torch.cuda.is_initialized():
|
94 |
+
return torch.device("cuda")
|
95 |
+
else:
|
96 |
+
return torch.device("cpu")
|
97 |
+
|
98 |
+
|
99 |
+
def barrier() -> None:
|
100 |
+
if is_distributed():
|
101 |
+
dist.barrier()
|
102 |
+
|
103 |
+
|
104 |
+
def peak_gpu_memory(reset: bool = False) -> Optional[float]:
|
105 |
+
"""
|
106 |
+
Get the peak GPU memory usage in MB across all ranks.
|
107 |
+
Only rank 0 will get the final result.
|
108 |
+
"""
|
109 |
+
if not torch.cuda.is_available():
|
110 |
+
return None
|
111 |
+
|
112 |
+
device = torch.device("cuda")
|
113 |
+
peak_mb = torch.cuda.max_memory_allocated(device) / 1000000
|
114 |
+
if is_distributed():
|
115 |
+
peak_mb_tensor = torch.tensor(peak_mb, device=device)
|
116 |
+
dist.reduce(peak_mb_tensor, 0, dist.ReduceOp.MAX)
|
117 |
+
peak_mb = peak_mb_tensor.item()
|
118 |
+
|
119 |
+
if reset:
|
120 |
+
# Reset peak stats.
|
121 |
+
torch.cuda.reset_max_memory_allocated(device)
|
122 |
+
|
123 |
+
return peak_mb
|
124 |
+
|
125 |
+
|
126 |
+
V = TypeVar("V", bool, int, float)
|
127 |
+
|
128 |
+
|
129 |
+
def synchronize_value(value: V, device: torch.device) -> V:
|
130 |
+
if dist.is_available() and dist.is_initialized():
|
131 |
+
value_tensor = torch.tensor(value, device=device)
|
132 |
+
dist.broadcast(value_tensor, 0)
|
133 |
+
return value_tensor.item() # type: ignore
|
134 |
+
else:
|
135 |
+
return value
|
136 |
+
|
137 |
+
|
138 |
+
def synchronize_flag(flag: bool, device: torch.device) -> bool:
|
139 |
+
return synchronize_value(flag, device)
|
140 |
+
|
141 |
+
|
142 |
+
def gc_cuda():
|
143 |
+
gc.collect()
|
144 |
+
if torch.cuda.is_available():
|
145 |
+
torch.cuda.empty_cache()
|
146 |
+
|
147 |
+
|
148 |
+
def get_cumulative_document_lengths(doc_lens: torch.Tensor) -> torch.Tensor:
|
149 |
+
"""
|
150 |
+
Transform a batched tensor of document lengths into a 1D tensor of cumulative document
|
151 |
+
lengths for the whole batch.
|
152 |
+
"""
|
153 |
+
return torch.cat(
|
154 |
+
[
|
155 |
+
torch.tensor([0], dtype=torch.int32, device=doc_lens.device),
|
156 |
+
torch.cumsum(doc_lens.masked_select(doc_lens != 0), 0, dtype=torch.int32),
|
157 |
+
]
|
158 |
+
)
|
train.py
ADDED
@@ -0,0 +1,1384 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
import cProfile
|
4 |
+
import functools
|
5 |
+
import gc
|
6 |
+
import logging
|
7 |
+
import math
|
8 |
+
import os
|
9 |
+
import random
|
10 |
+
import shutil
|
11 |
+
import time
|
12 |
+
from collections import deque
|
13 |
+
from contextlib import nullcontext
|
14 |
+
from dataclasses import dataclass, field
|
15 |
+
from itertools import islice
|
16 |
+
from pathlib import Path
|
17 |
+
from pstats import SortKey
|
18 |
+
from typing import Any, Callable, Deque, Dict, List, Optional, TextIO, Tuple, Union
|
19 |
+
|
20 |
+
import numpy as np
|
21 |
+
import torch
|
22 |
+
import torch.distributed as dist
|
23 |
+
import torch.nn.functional as F
|
24 |
+
import torch.utils
|
25 |
+
import torch.utils.hooks
|
26 |
+
import wandb
|
27 |
+
from packaging import version
|
28 |
+
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
29 |
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
30 |
+
from torch.utils.data import DataLoader
|
31 |
+
|
32 |
+
from .aliases import PathOrStr
|
33 |
+
from .checkpoint import Checkpointer, FullCheckpointer, build_sharded_checkpointer
|
34 |
+
from .config import (
|
35 |
+
CheckpointType,
|
36 |
+
DDPGradSyncMode,
|
37 |
+
DistributedStrategy,
|
38 |
+
SchedulerUnits,
|
39 |
+
ShardedCheckpointerType,
|
40 |
+
SpeedMonitorConfig,
|
41 |
+
TrainConfig,
|
42 |
+
)
|
43 |
+
from .data import IterableDataset
|
44 |
+
from .eval import Evaluator
|
45 |
+
from .exceptions import OLMoConfigurationError
|
46 |
+
from .model import OLMo
|
47 |
+
from .optim import Optimizer, Scheduler
|
48 |
+
from .torch_util import (
|
49 |
+
barrier,
|
50 |
+
gc_cuda,
|
51 |
+
get_fs_local_rank,
|
52 |
+
get_global_rank,
|
53 |
+
get_world_size,
|
54 |
+
move_to_device,
|
55 |
+
peak_gpu_memory,
|
56 |
+
synchronize_flag,
|
57 |
+
synchronize_value,
|
58 |
+
)
|
59 |
+
from .util import upload
|
60 |
+
|
61 |
+
__all__ = ["SpeedMonitor", "LRMonitor", "Trainer"]
|
62 |
+
|
63 |
+
log = logging.getLogger(__name__)
|
64 |
+
|
65 |
+
|
66 |
+
@dataclass
|
67 |
+
class SpeedMonitor:
|
68 |
+
cfg: SpeedMonitorConfig
|
69 |
+
start_times: Deque[float] = field(default_factory=lambda: deque([]))
|
70 |
+
global_total_tokens: int = 0
|
71 |
+
total_training_Gflops: float = 0
|
72 |
+
device_interval_tokens: Deque[int] = field(default_factory=lambda: deque([]))
|
73 |
+
|
74 |
+
def batch_start(
|
75 |
+
self,
|
76 |
+
global_total_tokens: int,
|
77 |
+
device_batch_num_tokens: int,
|
78 |
+
num_fwd_flops: int,
|
79 |
+
num_bck_flops: int,
|
80 |
+
record: bool = True,
|
81 |
+
) -> None:
|
82 |
+
self.global_total_tokens = global_total_tokens
|
83 |
+
# num_fwd_flops and num_bck_flops from the OLMo model computes flops per token
|
84 |
+
# converting to GFLOPs here prevents numerical issues while logging
|
85 |
+
self.total_training_Gflops = (num_fwd_flops + num_bck_flops) * global_total_tokens / 1e9
|
86 |
+
|
87 |
+
if record:
|
88 |
+
if len(self.start_times) >= self.cfg.window_size:
|
89 |
+
self.start_times.popleft()
|
90 |
+
self.device_interval_tokens.popleft()
|
91 |
+
self.start_times.append(time.monotonic())
|
92 |
+
self.device_interval_tokens.append(device_batch_num_tokens)
|
93 |
+
|
94 |
+
def reset(self) -> None:
|
95 |
+
self.start_times.clear()
|
96 |
+
self.device_interval_tokens.clear()
|
97 |
+
|
98 |
+
def check(self) -> Dict[str, float]:
|
99 |
+
metrics: Dict[str, float] = {"throughput/total_tokens": self.global_total_tokens}
|
100 |
+
|
101 |
+
# plot flops related metrics
|
102 |
+
metrics["throughput/total_training_Gflops"] = self.total_training_Gflops
|
103 |
+
metrics["throughput/total_training_log_Gflops"] = math.log(self.total_training_Gflops)
|
104 |
+
|
105 |
+
if self.start_times:
|
106 |
+
interval_seconds = time.monotonic() - self.start_times[0]
|
107 |
+
interval_batches = len(self.start_times)
|
108 |
+
interval_tokens = sum(self.device_interval_tokens)
|
109 |
+
metrics["throughput/device/tokens_per_second"] = interval_tokens / interval_seconds
|
110 |
+
metrics["throughput/device/batches_per_second"] = interval_batches / interval_seconds
|
111 |
+
return metrics
|
112 |
+
|
113 |
+
|
114 |
+
@dataclass
|
115 |
+
class LRMonitor:
|
116 |
+
optim: torch.optim.Optimizer
|
117 |
+
|
118 |
+
def check(self) -> Dict[str, float]:
|
119 |
+
lrs = [group["lr"] for group in self.optim.param_groups]
|
120 |
+
return {f"optim/learning_rate_group{idx}": lr for idx, lr in enumerate(lrs)}
|
121 |
+
|
122 |
+
|
123 |
+
def cross_entropy_loss(
|
124 |
+
logits,
|
125 |
+
labels,
|
126 |
+
ignore_index: int = -100,
|
127 |
+
reduction: str = "mean",
|
128 |
+
compute_z_loss: bool = False,
|
129 |
+
z_loss_multiplier: float = 1e-4,
|
130 |
+
):
|
131 |
+
loss = F.cross_entropy(logits, labels, ignore_index=ignore_index, reduction=reduction)
|
132 |
+
|
133 |
+
if not compute_z_loss:
|
134 |
+
return loss, None
|
135 |
+
|
136 |
+
z_squared = logits.logsumexp(-1).pow(2)
|
137 |
+
if reduction == "mean":
|
138 |
+
z_squared = (z_squared * (labels != ignore_index)).mean()
|
139 |
+
elif reduction == "sum":
|
140 |
+
z_squared = (z_squared * (labels != ignore_index)).sum()
|
141 |
+
|
142 |
+
z_loss = z_loss_multiplier * z_squared
|
143 |
+
|
144 |
+
return loss, z_loss
|
145 |
+
|
146 |
+
|
147 |
+
fused_loss_fn: Optional[Callable]
|
148 |
+
|
149 |
+
try:
|
150 |
+
import flash_attn
|
151 |
+
from flash_attn.ops.triton.cross_entropy import (
|
152 |
+
cross_entropy_loss as flash_cross_entropy_loss, # type: ignore
|
153 |
+
)
|
154 |
+
|
155 |
+
def fused_loss_fn(
|
156 |
+
logits,
|
157 |
+
labels,
|
158 |
+
ignore_index: int = -100,
|
159 |
+
reduction: str = "mean",
|
160 |
+
compute_z_loss: bool = False,
|
161 |
+
z_loss_multiplier: float = 1e-4,
|
162 |
+
):
|
163 |
+
# The `ignored_index` parameter of `cross_entropy_loss` was changed to `ignore_index` in v2.5.8 with commit https://github.com/Dao-AILab/flash-attention/commit/ec6d22143b5d375e253b2ebfc563b26a43f43684
|
164 |
+
ce_loss_use_ignore_index_param = version.parse(flash_attn.__version__) >= version.parse("2.5.8")
|
165 |
+
|
166 |
+
if ce_loss_use_ignore_index_param:
|
167 |
+
ignore_index_kwarg = {"ignore_index": ignore_index}
|
168 |
+
else:
|
169 |
+
ignore_index_kwarg = {"ignored_index": ignore_index}
|
170 |
+
|
171 |
+
loss, z_loss = flash_cross_entropy_loss(
|
172 |
+
logits,
|
173 |
+
labels,
|
174 |
+
label_smoothing=0.0,
|
175 |
+
logit_scale=1.0,
|
176 |
+
lse_square_scale=z_loss_multiplier,
|
177 |
+
inplace_backward=False,
|
178 |
+
process_group=None,
|
179 |
+
**ignore_index_kwarg,
|
180 |
+
)
|
181 |
+
|
182 |
+
mask = labels != ignore_index
|
183 |
+
|
184 |
+
if reduction == "mean":
|
185 |
+
loss = loss.sum() / mask.sum()
|
186 |
+
elif reduction == "sum":
|
187 |
+
loss = loss.sum()
|
188 |
+
else:
|
189 |
+
loss = loss
|
190 |
+
|
191 |
+
if not compute_z_loss:
|
192 |
+
return loss, None
|
193 |
+
|
194 |
+
if reduction == "mean":
|
195 |
+
z_loss = z_loss.sum() / mask.sum()
|
196 |
+
elif reduction == "sum":
|
197 |
+
z_loss = z_loss.sum()
|
198 |
+
else:
|
199 |
+
z_loss = z_loss
|
200 |
+
|
201 |
+
return loss, z_loss
|
202 |
+
|
203 |
+
except ImportError:
|
204 |
+
fused_loss_fn = None
|
205 |
+
|
206 |
+
|
207 |
+
@dataclass
|
208 |
+
class Trainer:
|
209 |
+
cfg: TrainConfig
|
210 |
+
model: OLMo
|
211 |
+
dist_model: Union[DDP, FSDP]
|
212 |
+
optim: Optimizer
|
213 |
+
scheduler: Scheduler
|
214 |
+
train_loader: DataLoader
|
215 |
+
device: torch.device
|
216 |
+
evaluators: List[Evaluator]
|
217 |
+
epoch: Optional[int] = None
|
218 |
+
global_step: int = 0
|
219 |
+
global_train_examples_seen_this_epoch: int = 0
|
220 |
+
"""Tracks the global number of training examples seen in the current epoch for the purpose of restoring
|
221 |
+
the data loader position on restarts."""
|
222 |
+
global_train_tokens_seen: int = 0
|
223 |
+
"""Tracks the global total number of tokens trained on."""
|
224 |
+
checkpoints: List[Path] = field(default_factory=list)
|
225 |
+
unsharded_checkpoints: List[Path] = field(default_factory=list)
|
226 |
+
ephemeral_checkpoints: List[Path] = field(default_factory=list)
|
227 |
+
min_train_loss: float = float("inf")
|
228 |
+
cur_train_loss: float = float("inf")
|
229 |
+
indices_file: Optional[TextIO] = None
|
230 |
+
_start_time: float = 0.0
|
231 |
+
_gc_init_state: bool = True
|
232 |
+
loss_fn: Callable[..., torch.Tensor] = field(default_factory=lambda: cross_entropy_loss) # type: ignore
|
233 |
+
last_sharded_checkpoint_step: Optional[int] = None
|
234 |
+
last_unsharded_checkpoint_step: Optional[int] = None
|
235 |
+
|
236 |
+
def __post_init__(self):
|
237 |
+
if self.cfg.fused_loss:
|
238 |
+
if fused_loss_fn is not None:
|
239 |
+
self.loss_fn = fused_loss_fn
|
240 |
+
else:
|
241 |
+
raise NameError("`fused_loss_fn` is not defined. Please ensure that `flash_attn` is installed.")
|
242 |
+
|
243 |
+
@property
|
244 |
+
def dataset(self) -> IterableDataset:
|
245 |
+
assert isinstance(self.train_loader.dataset, IterableDataset)
|
246 |
+
return self.train_loader.dataset
|
247 |
+
|
248 |
+
@property
|
249 |
+
def tokens_per_batch(self) -> int:
|
250 |
+
return self.cfg.global_train_batch_size * self.cfg.model.max_sequence_length
|
251 |
+
|
252 |
+
@property
|
253 |
+
def batches_per_epoch(self) -> int:
|
254 |
+
return self.dataset.total_size // self.cfg.global_train_batch_size
|
255 |
+
|
256 |
+
@property
|
257 |
+
def max_epochs(self) -> int:
|
258 |
+
return math.ceil(self.max_steps / self.batches_per_epoch)
|
259 |
+
|
260 |
+
@property
|
261 |
+
def max_steps(self) -> int:
|
262 |
+
if isinstance(self.cfg.max_duration, int):
|
263 |
+
return self.cfg.max_duration
|
264 |
+
elif isinstance(self.cfg.max_duration, str):
|
265 |
+
if self.cfg.max_duration.endswith("T"):
|
266 |
+
# convert to float *first* to handle scientific notation
|
267 |
+
max_tokens = int(float(self.cfg.max_duration[:-1].strip()))
|
268 |
+
tokens_remaining = max(max_tokens - self.global_train_tokens_seen, 0)
|
269 |
+
steps_remaining = math.ceil(tokens_remaining / self.tokens_per_batch)
|
270 |
+
return self.global_step + steps_remaining
|
271 |
+
elif self.cfg.max_duration.endswith("ep"):
|
272 |
+
max_epochs = int(self.cfg.max_duration[:-2].strip())
|
273 |
+
return max_epochs * self.batches_per_epoch
|
274 |
+
else:
|
275 |
+
# convert to float *first* to handle scientific notation
|
276 |
+
return int(float(self.cfg.max_duration))
|
277 |
+
else:
|
278 |
+
raise TypeError(f"expected int or str for 'max_duration', found {type(self.cfg.max_duration)}")
|
279 |
+
|
280 |
+
@property
|
281 |
+
def max_tokens(self) -> int:
|
282 |
+
if isinstance(self.cfg.max_duration, int):
|
283 |
+
return (
|
284 |
+
self.global_train_tokens_seen
|
285 |
+
+ max(self.cfg.max_duration - self.global_step, 0) * self.tokens_per_batch
|
286 |
+
)
|
287 |
+
elif isinstance(self.cfg.max_duration, str):
|
288 |
+
if self.cfg.max_duration.endswith("T"):
|
289 |
+
# convert to float *first* to handle scientific notation
|
290 |
+
return int(float(self.cfg.max_duration[:-1].strip()))
|
291 |
+
elif self.cfg.max_duration.endswith("ep"):
|
292 |
+
max_epochs = int(self.cfg.max_duration[:-2].strip())
|
293 |
+
return max_epochs * self.batches_per_epoch * self.tokens_per_batch
|
294 |
+
else:
|
295 |
+
# convert to float *first* to handle scientific notation
|
296 |
+
return (
|
297 |
+
self.global_train_tokens_seen
|
298 |
+
+ max(int(float(self.cfg.max_duration)) - self.global_step, 0) * self.tokens_per_batch
|
299 |
+
)
|
300 |
+
else:
|
301 |
+
raise TypeError(f"expected int or str for 'max_duration', found {type(self.cfg.max_duration)}")
|
302 |
+
|
303 |
+
@property
|
304 |
+
def scheduler_current(self) -> int:
|
305 |
+
if self.cfg.scheduler.units == SchedulerUnits.steps:
|
306 |
+
return self.global_step
|
307 |
+
elif self.cfg.scheduler.units == SchedulerUnits.tokens:
|
308 |
+
return self.global_train_tokens_seen
|
309 |
+
else:
|
310 |
+
raise NotImplementedError(self.cfg.scheduler.units)
|
311 |
+
|
312 |
+
@property
|
313 |
+
def scheduler_max(self) -> int:
|
314 |
+
if self.cfg.scheduler.units == SchedulerUnits.steps:
|
315 |
+
return self.max_steps
|
316 |
+
elif self.cfg.scheduler.units == SchedulerUnits.tokens:
|
317 |
+
return self.max_tokens
|
318 |
+
else:
|
319 |
+
raise NotImplementedError(self.cfg.scheduler.units)
|
320 |
+
|
321 |
+
def trainer_state_dict(self) -> Dict[str, Any]:
|
322 |
+
return {
|
323 |
+
"epoch": self.epoch or 0,
|
324 |
+
"global_step": self.global_step,
|
325 |
+
"global_train_examples_seen_this_epoch": self.global_train_examples_seen_this_epoch,
|
326 |
+
"global_train_tokens_seen": self.global_train_tokens_seen,
|
327 |
+
"world_size": get_world_size(),
|
328 |
+
"checkpoints": self.checkpoints,
|
329 |
+
"unsharded_checkpoints": self.unsharded_checkpoints,
|
330 |
+
"ephemeral_checkpoints": self.ephemeral_checkpoints,
|
331 |
+
"rng": {
|
332 |
+
"python": random.getstate(),
|
333 |
+
"numpy": np.random.get_state(),
|
334 |
+
"torch": torch.random.get_rng_state(),
|
335 |
+
"cuda": torch.cuda.get_rng_state(),
|
336 |
+
},
|
337 |
+
}
|
338 |
+
|
339 |
+
def load_trainer_state_dict(self, state_dict: Dict[str, Any]) -> None:
|
340 |
+
# Checkpoint paths.
|
341 |
+
self.checkpoints = [
|
342 |
+
path
|
343 |
+
for path in state_dict["checkpoints"]
|
344 |
+
if path.is_dir() and path.resolve().parent == Path(self.cfg.save_folder).resolve()
|
345 |
+
]
|
346 |
+
self.unsharded_checkpoints = [
|
347 |
+
path
|
348 |
+
for path in state_dict["unsharded_checkpoints"]
|
349 |
+
if path.is_dir() and path.resolve().parent == Path(self.cfg.save_folder).resolve()
|
350 |
+
]
|
351 |
+
self.ephemeral_checkpoints = [
|
352 |
+
path
|
353 |
+
for path in state_dict.get("ephemeral_checkpoints", [])
|
354 |
+
if path.is_dir() and path.resolve().parent == Path(self.cfg.save_folder).resolve()
|
355 |
+
]
|
356 |
+
|
357 |
+
# Dataset / dataloader position.
|
358 |
+
checkpoint_epoch = state_dict.get("epoch") or 0
|
359 |
+
self.global_step = state_dict["global_step"]
|
360 |
+
self.global_train_examples_seen_this_epoch = state_dict.get(
|
361 |
+
"global_train_examples_seen_this_epoch",
|
362 |
+
state_dict.get( # for backwards compatibility
|
363 |
+
"global_train_examples_seen",
|
364 |
+
state_dict.get("global_data_step", self.global_step) * self.cfg.global_train_batch_size,
|
365 |
+
),
|
366 |
+
)
|
367 |
+
self.global_train_tokens_seen = state_dict.get(
|
368 |
+
"global_train_tokens_seen",
|
369 |
+
state_dict.get("global_data_step", self.global_step) # for backwards compatibility
|
370 |
+
* self.cfg.global_train_batch_size
|
371 |
+
* self.cfg.model.max_sequence_length,
|
372 |
+
)
|
373 |
+
|
374 |
+
if not self.cfg.restore_dataloader:
|
375 |
+
self.epoch = 0
|
376 |
+
self.global_step = 0
|
377 |
+
self.global_train_tokens_seen = 0
|
378 |
+
self.global_train_examples_seen_this_epoch = 0
|
379 |
+
elif self.epoch is None:
|
380 |
+
self.epoch = checkpoint_epoch
|
381 |
+
elif checkpoint_epoch != self.epoch:
|
382 |
+
log.info(f"Starting new epoch (epoch = {self.epoch})")
|
383 |
+
self.global_train_examples_seen_this_epoch = 0
|
384 |
+
|
385 |
+
assert self.epoch is not None
|
386 |
+
# Reshuffle dataset if needed.
|
387 |
+
if self.dataset.epoch != self.epoch:
|
388 |
+
log.info(f"Reshuffling data loader for epoch {self.epoch}...")
|
389 |
+
self.dataset.reshuffle(self.epoch)
|
390 |
+
|
391 |
+
if self.cfg.fast_forward_batches:
|
392 |
+
log.info(f"Fast-forwarding data loader by {self.cfg.fast_forward_batches:,d} steps")
|
393 |
+
# Technically we don't "see" these batches that we fast-forward through, but we use
|
394 |
+
# this variable to update the position of the dataset so we need to include them here.
|
395 |
+
self.global_train_examples_seen_this_epoch += (
|
396 |
+
self.cfg.fast_forward_batches * self.cfg.global_train_batch_size
|
397 |
+
)
|
398 |
+
# NOTE: on the other hand we don't add anything to 'self.global_train_tokens_seen' here because
|
399 |
+
# that variable is meant to track the actual number of tokens trained on.
|
400 |
+
|
401 |
+
if self.global_train_examples_seen_this_epoch > 0:
|
402 |
+
assert isinstance(self.dataset, IterableDataset)
|
403 |
+
log.info(f"Data loader will start at instance index {self.global_train_examples_seen_this_epoch:,d}")
|
404 |
+
self.dataset.start_index = self.global_train_examples_seen_this_epoch
|
405 |
+
|
406 |
+
# Reset learning rate and weight decay to the values from the config, not the checkpoint.
|
407 |
+
log.info("Resetting learning rate...")
|
408 |
+
new_learning_rate = self.scheduler.get_lr(
|
409 |
+
self.cfg.optimizer.learning_rate, self.scheduler_current, self.scheduler_max
|
410 |
+
)
|
411 |
+
for group in self.optim.param_groups:
|
412 |
+
group["lr"] = new_learning_rate
|
413 |
+
group["initial_lr"] = self.cfg.optimizer.learning_rate
|
414 |
+
if "weight_decay" in group and group["weight_decay"] > 0.0:
|
415 |
+
group["weight_decay"] = self.cfg.optimizer.weight_decay
|
416 |
+
|
417 |
+
# RNG states.
|
418 |
+
if "rng" in state_dict and state_dict.get("world_size", get_world_size()) == get_world_size():
|
419 |
+
log.info("Restoring RNG states...")
|
420 |
+
rng_state = state_dict["rng"]
|
421 |
+
self.restore_rng_state(rng_state)
|
422 |
+
else:
|
423 |
+
log.warning(
|
424 |
+
"Trainer will not restore RNG states since the RNG states in the checkpoint are missing or invalid. "
|
425 |
+
"This typically happens when restoring from an unsharded checkpoint or a checkpoint that was saved "
|
426 |
+
"with a different world size. If that's the case you can safely ignore this warning."
|
427 |
+
)
|
428 |
+
|
429 |
+
def restore_rng_state(self, rng_state: Dict[str, Any]) -> None:
|
430 |
+
random.setstate(rng_state["python"])
|
431 |
+
np.random.set_state(rng_state["numpy"])
|
432 |
+
torch.set_rng_state(rng_state["torch"])
|
433 |
+
torch.cuda.set_rng_state(rng_state["cuda"])
|
434 |
+
|
435 |
+
def _save_checkpoint(
|
436 |
+
self, checkpointer: Checkpointer, checkpoint_type: CheckpointType
|
437 |
+
) -> Tuple[PathOrStr, Optional[PathOrStr]]:
|
438 |
+
if checkpoint_type == CheckpointType.sharded:
|
439 |
+
suffix = ""
|
440 |
+
current_checkpoints = self.checkpoints
|
441 |
+
link_latest = get_fs_local_rank() == 0
|
442 |
+
num_checkpoints_to_keep = self.cfg.save_num_checkpoints_to_keep
|
443 |
+
elif checkpoint_type == CheckpointType.unsharded:
|
444 |
+
suffix = "-unsharded"
|
445 |
+
current_checkpoints = self.unsharded_checkpoints
|
446 |
+
link_latest = get_global_rank() == 0
|
447 |
+
num_checkpoints_to_keep = self.cfg.save_num_unsharded_checkpoints_to_keep
|
448 |
+
elif checkpoint_type == CheckpointType.sharded_ephemeral:
|
449 |
+
suffix = ""
|
450 |
+
current_checkpoints = self.ephemeral_checkpoints
|
451 |
+
link_latest = get_fs_local_rank() == 0
|
452 |
+
num_checkpoints_to_keep = 1
|
453 |
+
else:
|
454 |
+
raise NotImplementedError(checkpoint_type)
|
455 |
+
|
456 |
+
# Zero-gradients to avoid gathering them.
|
457 |
+
self.optim.zero_grad(set_to_none=True)
|
458 |
+
|
459 |
+
# Flush data indices file.
|
460 |
+
# TODO: upload the indices files?
|
461 |
+
if self.indices_file is not None:
|
462 |
+
self.indices_file.flush()
|
463 |
+
|
464 |
+
checkpoint_dir = Path(self.cfg.save_folder) / f"step{self.global_step}{suffix}"
|
465 |
+
remote_checkpoint_dir: Optional[str] = None
|
466 |
+
if self.cfg.remote_save_folder is not None:
|
467 |
+
remote_checkpoint_dir = f"{self.cfg.remote_save_folder.rstrip('/')}/{checkpoint_dir.name}"
|
468 |
+
current_checkpoints.append(checkpoint_dir)
|
469 |
+
|
470 |
+
# Save the checkpoint.
|
471 |
+
try:
|
472 |
+
checkpointer.save_checkpoint(
|
473 |
+
checkpoint_dir,
|
474 |
+
self.dist_model,
|
475 |
+
self.optim,
|
476 |
+
self.trainer_state_dict(),
|
477 |
+
upload_to=remote_checkpoint_dir,
|
478 |
+
)
|
479 |
+
except FileExistsError:
|
480 |
+
raise OLMoConfigurationError(
|
481 |
+
f"Checkpoint for step {self.global_step} already exists, use --save_overwrite to overwrite it"
|
482 |
+
)
|
483 |
+
|
484 |
+
if link_latest:
|
485 |
+
# Link to 'latest'.
|
486 |
+
latest_path = Path(self.cfg.save_folder) / f"latest{suffix}"
|
487 |
+
latest_path.unlink(missing_ok=True)
|
488 |
+
try:
|
489 |
+
latest_path.symlink_to(checkpoint_dir.name, target_is_directory=True)
|
490 |
+
except FileExistsError:
|
491 |
+
# Same as above, caught when another (file-system) local rank 0 has already made the 'latest' symlink.
|
492 |
+
# This can happen when nodes are saving to a common NFS drive but otherwise have distinct
|
493 |
+
# file-systems.
|
494 |
+
if latest_path.resolve().name != checkpoint_dir.name:
|
495 |
+
raise
|
496 |
+
|
497 |
+
# Remove old checkpoints.
|
498 |
+
# For DDP, checkpoint_type being passed to remove_checkpoint is always `unsharded`.
|
499 |
+
if num_checkpoints_to_keep > 0:
|
500 |
+
while len(current_checkpoints) > num_checkpoints_to_keep:
|
501 |
+
self.remove_checkpoint(0, checkpoint_type)
|
502 |
+
|
503 |
+
barrier()
|
504 |
+
|
505 |
+
if remote_checkpoint_dir is not None:
|
506 |
+
return remote_checkpoint_dir, checkpoint_dir
|
507 |
+
else:
|
508 |
+
return checkpoint_dir, None
|
509 |
+
|
510 |
+
def save_sharded_checkpoint(self) -> Tuple[PathOrStr, Optional[PathOrStr]]:
|
511 |
+
checkpointer = build_sharded_checkpointer(self.cfg)
|
512 |
+
result = self._save_checkpoint(checkpointer, CheckpointType.sharded)
|
513 |
+
self.last_sharded_checkpoint_step = self.global_step
|
514 |
+
return result
|
515 |
+
|
516 |
+
def save_ephemeral_checkpoint(self) -> Tuple[PathOrStr, Optional[PathOrStr]]:
|
517 |
+
checkpointer = build_sharded_checkpointer(self.cfg)
|
518 |
+
result = self._save_checkpoint(checkpointer, CheckpointType.sharded_ephemeral)
|
519 |
+
self.last_sharded_checkpoint_step = self.global_step
|
520 |
+
return result
|
521 |
+
|
522 |
+
def _remove_sharded_checkpoint(self, idx: int, checkpoints: List[Path]):
|
523 |
+
oldest_checkpoint = checkpoints.pop(idx)
|
524 |
+
barrier()
|
525 |
+
if get_fs_local_rank() == 0 and oldest_checkpoint.is_dir():
|
526 |
+
shutil.rmtree(oldest_checkpoint, ignore_errors=True)
|
527 |
+
latest_path = Path(self.cfg.save_folder) / "latest"
|
528 |
+
if latest_path.resolve() == oldest_checkpoint.resolve():
|
529 |
+
latest_path.unlink()
|
530 |
+
barrier()
|
531 |
+
|
532 |
+
def remove_sharded_checkpoint(self, idx: int = 0):
|
533 |
+
self._remove_sharded_checkpoint(idx, self.checkpoints)
|
534 |
+
|
535 |
+
def remove_ephemeral_checkpoint(self, idx: int = 0):
|
536 |
+
self._remove_sharded_checkpoint(idx, self.ephemeral_checkpoints)
|
537 |
+
|
538 |
+
def restore_sharded_checkpoint(
|
539 |
+
self,
|
540 |
+
load_path: PathOrStr,
|
541 |
+
local_cache: Optional[PathOrStr] = None,
|
542 |
+
*,
|
543 |
+
load_optimizer_state: bool = True,
|
544 |
+
load_trainer_state: bool = True,
|
545 |
+
sharded_checkpointer: Optional[ShardedCheckpointerType] = None,
|
546 |
+
):
|
547 |
+
# Zero-gradients to avoid gathering them.
|
548 |
+
self.optim.zero_grad(set_to_none=True)
|
549 |
+
checkpointer = build_sharded_checkpointer(self.cfg, name=sharded_checkpointer)
|
550 |
+
trainer_state = checkpointer.restore_checkpoint(
|
551 |
+
load_path,
|
552 |
+
self.dist_model,
|
553 |
+
self.optim,
|
554 |
+
local_cache=local_cache,
|
555 |
+
load_optimizer_state=load_optimizer_state,
|
556 |
+
)
|
557 |
+
if load_trainer_state:
|
558 |
+
self.load_trainer_state_dict(trainer_state)
|
559 |
+
barrier()
|
560 |
+
|
561 |
+
def save_unsharded_checkpoint(self) -> Tuple[PathOrStr, Optional[PathOrStr]]:
|
562 |
+
checkpointer = FullCheckpointer(self.cfg)
|
563 |
+
result = self._save_checkpoint(checkpointer, CheckpointType.unsharded)
|
564 |
+
self.last_unsharded_checkpoint_step = self.global_step
|
565 |
+
return result
|
566 |
+
|
567 |
+
def remove_unsharded_checkpoint(self, idx: int = 0):
|
568 |
+
barrier()
|
569 |
+
oldest_checkpoint = self.unsharded_checkpoints.pop(idx)
|
570 |
+
if get_global_rank() == 0 and oldest_checkpoint.is_dir():
|
571 |
+
shutil.rmtree(oldest_checkpoint, ignore_errors=True)
|
572 |
+
latest_path = Path(self.cfg.save_folder) / "latest-unsharded"
|
573 |
+
if latest_path.resolve() == oldest_checkpoint.resolve():
|
574 |
+
latest_path.unlink()
|
575 |
+
barrier()
|
576 |
+
|
577 |
+
def restore_unsharded_checkpoint(
|
578 |
+
self,
|
579 |
+
load_path: PathOrStr,
|
580 |
+
local_cache: Optional[PathOrStr] = None,
|
581 |
+
*,
|
582 |
+
load_optimizer_state: bool = True,
|
583 |
+
load_trainer_state: bool = True,
|
584 |
+
):
|
585 |
+
# Zero-gradients to avoid gathering them.
|
586 |
+
self.optim.zero_grad(set_to_none=True)
|
587 |
+
checkpointer = FullCheckpointer(self.cfg)
|
588 |
+
trainer_state = checkpointer.restore_checkpoint(
|
589 |
+
load_path,
|
590 |
+
self.dist_model,
|
591 |
+
self.optim,
|
592 |
+
local_cache=local_cache,
|
593 |
+
load_optimizer_state=load_optimizer_state,
|
594 |
+
)
|
595 |
+
if load_trainer_state:
|
596 |
+
self.load_trainer_state_dict(trainer_state)
|
597 |
+
barrier()
|
598 |
+
|
599 |
+
def save_checkpoint(
|
600 |
+
self, checkpoint_type: CheckpointType = CheckpointType.sharded
|
601 |
+
) -> Tuple[PathOrStr, Optional[PathOrStr]]:
|
602 |
+
result: Tuple[PathOrStr, Optional[PathOrStr]]
|
603 |
+
if checkpoint_type == CheckpointType.sharded:
|
604 |
+
result = self.save_sharded_checkpoint()
|
605 |
+
elif checkpoint_type == CheckpointType.unsharded:
|
606 |
+
result = self.save_unsharded_checkpoint()
|
607 |
+
elif checkpoint_type == CheckpointType.sharded_ephemeral:
|
608 |
+
result = self.save_ephemeral_checkpoint()
|
609 |
+
else:
|
610 |
+
raise NotImplementedError(checkpoint_type)
|
611 |
+
|
612 |
+
gc_cuda()
|
613 |
+
return result
|
614 |
+
|
615 |
+
def restore_checkpoint(
|
616 |
+
self,
|
617 |
+
load_path: PathOrStr,
|
618 |
+
*,
|
619 |
+
checkpoint_type: Optional[CheckpointType] = None,
|
620 |
+
local_cache: Optional[PathOrStr] = None,
|
621 |
+
load_optimizer_state: bool = True,
|
622 |
+
load_trainer_state: bool = True,
|
623 |
+
sharded_checkpointer: Optional[ShardedCheckpointerType] = None,
|
624 |
+
):
|
625 |
+
if checkpoint_type == CheckpointType.unsharded or (
|
626 |
+
checkpoint_type is None and str(load_path).rstrip("/").endswith("-unsharded")
|
627 |
+
):
|
628 |
+
self.restore_unsharded_checkpoint(
|
629 |
+
load_path,
|
630 |
+
local_cache=local_cache,
|
631 |
+
load_optimizer_state=load_optimizer_state,
|
632 |
+
load_trainer_state=load_trainer_state,
|
633 |
+
)
|
634 |
+
elif checkpoint_type == CheckpointType.sharded or checkpoint_type is None:
|
635 |
+
self.restore_sharded_checkpoint(
|
636 |
+
load_path,
|
637 |
+
local_cache=local_cache,
|
638 |
+
load_optimizer_state=load_optimizer_state,
|
639 |
+
load_trainer_state=load_trainer_state,
|
640 |
+
sharded_checkpointer=sharded_checkpointer,
|
641 |
+
)
|
642 |
+
elif checkpoint_type is not None:
|
643 |
+
raise NotImplementedError(checkpoint_type)
|
644 |
+
|
645 |
+
gc_cuda()
|
646 |
+
|
647 |
+
def remove_checkpoint(self, idx: int = 0, checkpoint_type: CheckpointType = CheckpointType.sharded):
|
648 |
+
if checkpoint_type == CheckpointType.sharded:
|
649 |
+
self.remove_sharded_checkpoint(idx=idx)
|
650 |
+
elif checkpoint_type == CheckpointType.unsharded:
|
651 |
+
self.remove_unsharded_checkpoint(idx=idx)
|
652 |
+
elif checkpoint_type == CheckpointType.sharded_ephemeral:
|
653 |
+
self.remove_ephemeral_checkpoint(idx=idx)
|
654 |
+
else:
|
655 |
+
raise NotImplementedError(checkpoint_type)
|
656 |
+
|
657 |
+
def _setup_module_output_save_hooks(self, micro_batch_idx: int) -> List[torch.utils.hooks.RemovableHandle]:
|
658 |
+
if (
|
659 |
+
self.cfg.module_outputs_save_steps is None
|
660 |
+
or self.global_step not in self.cfg.module_outputs_save_steps
|
661 |
+
):
|
662 |
+
return []
|
663 |
+
|
664 |
+
if micro_batch_idx != 0 or get_global_rank() != 0:
|
665 |
+
# Hook is currently only used on the first microbatch of rank 0
|
666 |
+
return []
|
667 |
+
|
668 |
+
trace_save_folder = Path(self.cfg.save_folder) / f"traces/step{self.global_step}"
|
669 |
+
if trace_save_folder.exists():
|
670 |
+
if self.cfg.save_overwrite:
|
671 |
+
shutil.rmtree(trace_save_folder)
|
672 |
+
else:
|
673 |
+
raise OLMoConfigurationError(
|
674 |
+
f"Attempting to overwrite traces at step {self.global_step} without --save_overwrite"
|
675 |
+
)
|
676 |
+
trace_save_folder.mkdir(parents=True)
|
677 |
+
|
678 |
+
def trace_outputs_hook(
|
679 |
+
module_name: str, _: torch.nn.Module, args: Tuple[torch.Tensor, ...], output: torch.Tensor
|
680 |
+
) -> None:
|
681 |
+
if len(args) == 0:
|
682 |
+
log.info("No input args for module %s, output %s", module_name, output)
|
683 |
+
|
684 |
+
module_input = args[0] if len(args) > 0 else torch.tensor(())
|
685 |
+
trace_save_folder = Path(self.cfg.save_folder) / f"traces/step{self.global_step}"
|
686 |
+
trace_save_folder.mkdir(parents=True, exist_ok=True)
|
687 |
+
|
688 |
+
module_occurence_num = 0
|
689 |
+
while (
|
690 |
+
module_input_filepath := trace_save_folder / f"{module_name}_{module_occurence_num}_input.pt"
|
691 |
+
).exists():
|
692 |
+
module_occurence_num += 1
|
693 |
+
torch.save(module_input, module_input_filepath)
|
694 |
+
|
695 |
+
module_output_filepath = trace_save_folder / f"{module_name}_{module_occurence_num}_output.pt"
|
696 |
+
torch.save(output, module_output_filepath)
|
697 |
+
|
698 |
+
output_hooks = []
|
699 |
+
for module_name, module in self.model.named_modules(prefix="model"):
|
700 |
+
output_hooks.append(module.register_forward_hook(functools.partial(trace_outputs_hook, module_name)))
|
701 |
+
|
702 |
+
return output_hooks
|
703 |
+
|
704 |
+
def get_labels(self, batch: Dict[str, Any]) -> torch.Tensor:
|
705 |
+
# Labels are just input IDs shifted to the left (first item is ignored).
|
706 |
+
labels, label_mask, attention_mask, instance_mask = (
|
707 |
+
batch["input_ids"].clone(),
|
708 |
+
batch.get("label_mask"),
|
709 |
+
batch.get("attention_mask"),
|
710 |
+
batch.get("instance_mask"),
|
711 |
+
)
|
712 |
+
if label_mask is not None:
|
713 |
+
labels.masked_fill_(~label_mask, -100)
|
714 |
+
if attention_mask is not None:
|
715 |
+
labels.masked_fill_(attention_mask == 0.0, -100)
|
716 |
+
if instance_mask is not None:
|
717 |
+
labels.masked_fill_(~instance_mask.unsqueeze(-1), value=-100)
|
718 |
+
return labels[..., 1:].contiguous()
|
719 |
+
|
720 |
+
def model_forward(
|
721 |
+
self, batch: Dict[str, Any], loss_reduction: str = "mean", compute_z_loss: bool = False
|
722 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor]:
|
723 |
+
# shape: (batch_size, seq_len, vocab_size)
|
724 |
+
logits = self.dist_model(
|
725 |
+
input_ids=batch["input_ids"],
|
726 |
+
attention_mask=batch.get("attention_mask"),
|
727 |
+
attention_bias=batch.get("attention_bias"),
|
728 |
+
doc_lens=batch.get("doc_lens"),
|
729 |
+
max_doc_lens=batch.get("max_doc_lens"),
|
730 |
+
).logits
|
731 |
+
logits_for_loss = logits[..., :-1, :].contiguous()
|
732 |
+
# shape: (batch_size * seq_len, vocab_size)
|
733 |
+
logits_for_loss = logits_for_loss.view(-1, logits_for_loss.size(-1))
|
734 |
+
# shape: (batch_size, seq_len)
|
735 |
+
labels = self.get_labels(batch)
|
736 |
+
# shape: (batch_size * seq_len,)
|
737 |
+
labels = labels.view(-1)
|
738 |
+
ce_loss, z_loss = self.loss_fn(
|
739 |
+
logits_for_loss, labels, ignore_index=-100, reduction=loss_reduction, compute_z_loss=compute_z_loss
|
740 |
+
)
|
741 |
+
if loss_reduction == "none":
|
742 |
+
# Reshape (batch_size * seq_len,) -> (batch_size, seq_len)
|
743 |
+
ce_loss = ce_loss.view(batch["input_ids"].shape[0], -1)
|
744 |
+
if z_loss is not None:
|
745 |
+
z_loss = z_loss.view(batch["input_ids"].shape[0], -1)
|
746 |
+
return ce_loss, z_loss, logits
|
747 |
+
|
748 |
+
def train_micro_batch(
|
749 |
+
self, micro_batch: Dict[str, Any], batch_size_in_tokens: int
|
750 |
+
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
|
751 |
+
ce_loss, z_loss, logits = self.model_forward(
|
752 |
+
micro_batch, compute_z_loss=self.cfg.softmax_auxiliary_loss, loss_reduction="sum"
|
753 |
+
)
|
754 |
+
ce_loss = ce_loss / batch_size_in_tokens
|
755 |
+
|
756 |
+
# In case this helps with memory utilization.
|
757 |
+
del micro_batch
|
758 |
+
|
759 |
+
# Get loss to optimize for.
|
760 |
+
if self.cfg.softmax_auxiliary_loss:
|
761 |
+
assert z_loss is not None
|
762 |
+
z_loss = z_loss / batch_size_in_tokens
|
763 |
+
loss = ce_loss + z_loss
|
764 |
+
else:
|
765 |
+
loss = ce_loss
|
766 |
+
|
767 |
+
del logits
|
768 |
+
|
769 |
+
return loss, ce_loss, z_loss
|
770 |
+
|
771 |
+
def train_batch(self, batch: Dict[str, Any]) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
772 |
+
# Split into micro-batches.
|
773 |
+
micro_batches = self.split_batch(batch)
|
774 |
+
batch_size_in_tokens = batch["input_ids"].numel()
|
775 |
+
|
776 |
+
# In case this helps with memory utilization.
|
777 |
+
del batch
|
778 |
+
|
779 |
+
ce_batch_loss = torch.tensor(0.0, device=self.device)
|
780 |
+
z_batch_loss = None if not self.cfg.softmax_auxiliary_loss else torch.tensor(0.0, device=self.device)
|
781 |
+
num_micro_batches = len(micro_batches)
|
782 |
+
|
783 |
+
for micro_batch_idx, micro_batch in enumerate(micro_batches):
|
784 |
+
# setup sync context for DDP for all micro-batches except the last
|
785 |
+
grad_sync_context = nullcontext
|
786 |
+
if (
|
787 |
+
self.cfg.distributed_strategy == DistributedStrategy.ddp
|
788 |
+
and self.cfg.ddp is not None
|
789 |
+
and self.cfg.ddp.grad_sync_mode == DDPGradSyncMode.batch
|
790 |
+
):
|
791 |
+
if micro_batch_idx != num_micro_batches - 1:
|
792 |
+
grad_sync_context = self.dist_model.no_sync
|
793 |
+
|
794 |
+
# Register output hooks
|
795 |
+
output_hooks: List[torch.utils.hooks.RemovableHandle] = []
|
796 |
+
output_hooks += self._setup_module_output_save_hooks(micro_batch_idx)
|
797 |
+
|
798 |
+
with grad_sync_context():
|
799 |
+
with torch.autocast("cuda", enabled=True, dtype=self.cfg.autocast_precision):
|
800 |
+
# Run forward pass.
|
801 |
+
loss, ce_loss, z_loss = self.train_micro_batch(micro_batch, batch_size_in_tokens)
|
802 |
+
|
803 |
+
# Update overall CE batch loss.
|
804 |
+
ce_batch_loss += ce_loss.detach()
|
805 |
+
|
806 |
+
# Update overall Z batch loss.
|
807 |
+
if z_loss is not None:
|
808 |
+
assert z_batch_loss is not None
|
809 |
+
z_batch_loss += z_loss.detach()
|
810 |
+
|
811 |
+
# Run backward pass.
|
812 |
+
loss.backward()
|
813 |
+
|
814 |
+
# Remove output hooks
|
815 |
+
for hook in output_hooks:
|
816 |
+
hook.remove()
|
817 |
+
|
818 |
+
return ce_batch_loss, z_batch_loss
|
819 |
+
|
820 |
+
def train_step(self, batch: Dict[str, Any], reduce_global_loss: bool = True) -> Dict[str, float]:
|
821 |
+
metrics: Dict[str, float] = {}
|
822 |
+
|
823 |
+
# Write data-indices to file.
|
824 |
+
if self.indices_file is not None and "index" in batch:
|
825 |
+
indices = "\t".join(str(int(i)) for i in batch["index"])
|
826 |
+
self.indices_file.write(f"{self.global_step}\t{indices}\n")
|
827 |
+
|
828 |
+
# Record how many instances are going to be skipped (masked out).
|
829 |
+
if (instance_mask := batch.get("instance_mask")) is not None:
|
830 |
+
metrics["train/masked_instances_local_rank"] = (~instance_mask).sum().item()
|
831 |
+
|
832 |
+
# Zero-gradients.
|
833 |
+
self.optim.zero_grad(set_to_none=True)
|
834 |
+
|
835 |
+
# Move tensors to the right device.
|
836 |
+
batch = move_to_device(batch, self.device)
|
837 |
+
|
838 |
+
# Run forward-backward pass.
|
839 |
+
ce_batch_loss, z_batch_loss = self.train_batch(batch)
|
840 |
+
|
841 |
+
# Collect loss, potentially reducing over all ranks.
|
842 |
+
if reduce_global_loss:
|
843 |
+
dist.reduce(ce_batch_loss, 0)
|
844 |
+
ce_batch_loss.div_(get_world_size())
|
845 |
+
if z_batch_loss is not None:
|
846 |
+
dist.reduce(z_batch_loss, 0)
|
847 |
+
z_batch_loss.div_(get_world_size())
|
848 |
+
|
849 |
+
# Clip gradient norms and collect param/gradient/optim metrics.
|
850 |
+
should_log_optim_metrics_this_step = self.should_log_optim_metrics_this_step()
|
851 |
+
optim_metrics = self.optim.clip_grads_and_collect_metrics(
|
852 |
+
self.global_step,
|
853 |
+
collect_param_metrics=should_log_optim_metrics_this_step,
|
854 |
+
# passing this process group here ensures metrics are reduced correctly when we're using
|
855 |
+
# HYBRID sharding.
|
856 |
+
process_group=self.dist_model.process_group,
|
857 |
+
)
|
858 |
+
|
859 |
+
# Adjust the learning rate.
|
860 |
+
for group in self.optim.param_groups:
|
861 |
+
# TODO (epwalsh): if we want to enable different LRs or gradient clipping settings per group
|
862 |
+
# we should pass `group["initial_lr"]` or `group["initial_max_grad_norm"]` here instead of
|
863 |
+
# the corresponding values from `self.cfg`.
|
864 |
+
group["lr"] = self.scheduler.get_lr(
|
865 |
+
self.cfg.optimizer.learning_rate, self.scheduler_current, self.scheduler_max
|
866 |
+
)
|
867 |
+
group["max_grad_norm"] = self.scheduler.get_max_grad_norm(
|
868 |
+
self.cfg.max_grad_norm, self.scheduler_current, self.scheduler_max
|
869 |
+
)
|
870 |
+
group["max_grad_norm_ratio"] = self.scheduler.get_max_grad_norm(
|
871 |
+
self.cfg.max_grad_norm_ratio, self.scheduler_current, self.scheduler_max
|
872 |
+
)
|
873 |
+
|
874 |
+
# Optimizer step.
|
875 |
+
self.optim.step()
|
876 |
+
|
877 |
+
# Collect metrics and check for NaN loss.
|
878 |
+
# NOTE: this involves a bunch of host-device syncs so we wait until the last moment to do this.
|
879 |
+
if torch.isnan(ce_batch_loss):
|
880 |
+
raise ValueError("nan loss encountered")
|
881 |
+
if z_batch_loss is not None and torch.isnan(z_batch_loss):
|
882 |
+
raise ValueError("nan loss encountered")
|
883 |
+
for key, value in optim_metrics.items():
|
884 |
+
metrics[f"optim/{key}"] = value.item()
|
885 |
+
self.cur_train_loss = ce_batch_loss.item()
|
886 |
+
self.min_train_loss = min(self.min_train_loss, self.cur_train_loss)
|
887 |
+
metrics["train/CrossEntropyLoss"] = self.cur_train_loss
|
888 |
+
metrics["train/Perplexity"] = math.exp(self.cur_train_loss)
|
889 |
+
if z_batch_loss is not None:
|
890 |
+
metrics["train/ZLoss"] = z_batch_loss.item()
|
891 |
+
|
892 |
+
# Maybe collect post-step optimizer-specific metrics.
|
893 |
+
if should_log_optim_metrics_this_step:
|
894 |
+
optim_metrics = self.optim.get_post_step_metrics(
|
895 |
+
self.dist_model, process_group=self.dist_model.process_group
|
896 |
+
)
|
897 |
+
for key, value in optim_metrics.items():
|
898 |
+
metrics[f"optim/{key}"] = value.item()
|
899 |
+
|
900 |
+
return metrics
|
901 |
+
|
902 |
+
def eval_batch(self, batch: Dict[str, Any]) -> Tuple[torch.Tensor, torch.Tensor]:
|
903 |
+
with torch.autocast("cuda", enabled=True, dtype=self.cfg.autocast_precision):
|
904 |
+
ce_loss, _, logits = self.model_forward(batch, loss_reduction="none")
|
905 |
+
return ce_loss.mean(dim=-1), logits
|
906 |
+
|
907 |
+
def eval_step(self, batch: Dict[str, Any], evaluator: Evaluator) -> None:
|
908 |
+
# Move tensors to the right device.
|
909 |
+
batch = move_to_device(batch, self.device)
|
910 |
+
|
911 |
+
# Run forward pass.
|
912 |
+
with torch.no_grad(): # NOTE: 'torch.inference_mode()' doesn't work with 'torch.compile()'.
|
913 |
+
ce_loss, logits = self.eval_batch(batch)
|
914 |
+
|
915 |
+
# Update metrics.
|
916 |
+
evaluator.update_metrics(
|
917 |
+
batch, ce_loss, logits
|
918 |
+
) # batch includes all keys that the downstream evaluation needs
|
919 |
+
|
920 |
+
barrier()
|
921 |
+
|
922 |
+
def split_batch(self, batch: Dict[str, Any]) -> List[Dict[str, Any]]:
|
923 |
+
microbatch_size = self.cfg.device_train_microbatch_size
|
924 |
+
batch_size = batch["input_ids"].shape[0]
|
925 |
+
if batch_size <= microbatch_size:
|
926 |
+
return [batch]
|
927 |
+
else:
|
928 |
+
micro_batches = {}
|
929 |
+
for key, value in batch.items():
|
930 |
+
if isinstance(value, torch.Tensor):
|
931 |
+
micro_batches[key] = value.split(microbatch_size, dim=0)
|
932 |
+
elif isinstance(value, list):
|
933 |
+
micro_batches[key] = [
|
934 |
+
value[microbatch_size * i : microbatch_size * i + microbatch_size]
|
935 |
+
for i in range(math.ceil(batch_size / microbatch_size))
|
936 |
+
]
|
937 |
+
else:
|
938 |
+
raise ValueError(f"unexpected item in batch: '{key}={value}'")
|
939 |
+
return [
|
940 |
+
{key: value[i] for key, value in micro_batches.items()} # type: ignore
|
941 |
+
for i in range(len(micro_batches["input_ids"]))
|
942 |
+
]
|
943 |
+
|
944 |
+
def system_metrics(self) -> Dict[str, float]:
|
945 |
+
metrics = {}
|
946 |
+
if self.global_step < 3 or self.global_step % 10 == 0:
|
947 |
+
peak_gpu_mb = peak_gpu_memory()
|
948 |
+
if peak_gpu_mb is not None:
|
949 |
+
metrics["System/Peak GPU Memory (MB)"] = peak_gpu_mb
|
950 |
+
return metrics
|
951 |
+
|
952 |
+
def log_metrics_to_console(self, prefix: str, metrics: Dict[str, float]):
|
953 |
+
def format_float(value: float) -> str:
|
954 |
+
if value < 0.0001:
|
955 |
+
return str(value) # scientific notation
|
956 |
+
elif value > 1000:
|
957 |
+
return f"{int(value):,d}"
|
958 |
+
elif value > 100:
|
959 |
+
return f"{value:.1f}"
|
960 |
+
elif value > 10:
|
961 |
+
return f"{value:.2f}"
|
962 |
+
elif value > 1:
|
963 |
+
return f"{value:.3f}"
|
964 |
+
else:
|
965 |
+
return f"{value:.4f}"
|
966 |
+
|
967 |
+
log.info(
|
968 |
+
f"{prefix}\n"
|
969 |
+
+ "\n".join(
|
970 |
+
[
|
971 |
+
f" {name}={format_float(value)}"
|
972 |
+
for name, value in metrics.items()
|
973 |
+
if name == "optim/total_grad_norm"
|
974 |
+
or not name.startswith("optim/") # there's too many optimizer metrics
|
975 |
+
]
|
976 |
+
)
|
977 |
+
)
|
978 |
+
|
979 |
+
def should_log_optim_metrics_this_step(self) -> bool:
|
980 |
+
if self.cfg.wandb is None:
|
981 |
+
# We only log optimizer-specific metrics to W&B, since there are usually too many metrics
|
982 |
+
# to log to the console.
|
983 |
+
return False
|
984 |
+
optim_log_interval = self.cfg.optimizer.metrics_log_interval
|
985 |
+
if optim_log_interval is None:
|
986 |
+
optim_log_interval = self.cfg.wandb.log_interval
|
987 |
+
else:
|
988 |
+
optim_log_interval = max(optim_log_interval, self.cfg.wandb.log_interval)
|
989 |
+
return self.global_step % optim_log_interval == 0
|
990 |
+
|
991 |
+
def should_log_this_step(self) -> bool:
|
992 |
+
if self.global_step % self.cfg.console_log_interval == 0:
|
993 |
+
return True
|
994 |
+
elif self.cfg.wandb is not None and self.global_step % self.cfg.wandb.log_interval == 0:
|
995 |
+
return True
|
996 |
+
else:
|
997 |
+
return False
|
998 |
+
|
999 |
+
def eval(self) -> Dict[str, Any]:
|
1000 |
+
# Zero gradients and set model to 'eval' mode.
|
1001 |
+
self.optim.zero_grad(set_to_none=True)
|
1002 |
+
self.dist_model.eval()
|
1003 |
+
|
1004 |
+
eval_metrics = {}
|
1005 |
+
for evaluator in self.evaluators:
|
1006 |
+
log.info(f"Running evaluation for '{evaluator.label}'...")
|
1007 |
+
|
1008 |
+
# Reset metrics.
|
1009 |
+
evaluator.reset_metrics()
|
1010 |
+
|
1011 |
+
# Initialize data loader iterator.
|
1012 |
+
eval_batches = iter(evaluator.eval_loader)
|
1013 |
+
|
1014 |
+
# Adjust how many batches to evaluate on.
|
1015 |
+
num_eval_batches = (
|
1016 |
+
evaluator.subset_num_batches
|
1017 |
+
if evaluator.subset_num_batches is not None
|
1018 |
+
else self.cfg.eval_subset_num_batches
|
1019 |
+
)
|
1020 |
+
if num_eval_batches > 0:
|
1021 |
+
num_eval_batches = min(num_eval_batches, len(evaluator.eval_loader))
|
1022 |
+
eval_batches = islice(eval_batches, num_eval_batches)
|
1023 |
+
|
1024 |
+
# Run model over batches.
|
1025 |
+
for eval_step, eval_batch in enumerate(eval_batches):
|
1026 |
+
self.eval_step(eval_batch, evaluator)
|
1027 |
+
|
1028 |
+
# Log to console.
|
1029 |
+
if eval_step + 1 == num_eval_batches or (eval_step + 1) % self.cfg.console_log_interval == 0:
|
1030 |
+
log.info(f"[eval_step={eval_step + 1}/{num_eval_batches}]")
|
1031 |
+
|
1032 |
+
# Get final metrics.
|
1033 |
+
metrics = evaluator.compute_metrics()
|
1034 |
+
eval_metrics.update(metrics)
|
1035 |
+
self.log_metrics_to_console(f"{evaluator.label}", metrics)
|
1036 |
+
|
1037 |
+
del eval_batches
|
1038 |
+
|
1039 |
+
# Eval compiles a bunch more versions, and the result is terrible. This way we get back to zero.
|
1040 |
+
if self.cfg.compile is not None:
|
1041 |
+
torch.compiler.reset()
|
1042 |
+
|
1043 |
+
return eval_metrics
|
1044 |
+
|
1045 |
+
def check_if_cancelled(self) -> Tuple[bool, int]:
|
1046 |
+
should_cancel = False
|
1047 |
+
cancel_reason: Optional[str] = None
|
1048 |
+
extra_steps = 0
|
1049 |
+
if get_global_rank() == 0:
|
1050 |
+
if self.cfg.time_limit is not None and time.time() - self._start_time >= self.cfg.time_limit:
|
1051 |
+
# First check if we've reached the training time limit.
|
1052 |
+
should_cancel = True
|
1053 |
+
cancel_reason = "time limit reached"
|
1054 |
+
extra_steps = self.cfg.extra_steps_after_cancel
|
1055 |
+
elif (
|
1056 |
+
self.cfg.early_stopping_factor is not None
|
1057 |
+
and self.global_step > self.cfg.scheduler.t_warmup
|
1058 |
+
and self.cur_train_loss > self.cfg.early_stopping_factor * self.min_train_loss
|
1059 |
+
):
|
1060 |
+
# Next check if early stopping loss criteria is met.
|
1061 |
+
should_cancel = True
|
1062 |
+
cancel_reason = "early stopping from loss increase"
|
1063 |
+
elif wandb.run is not None and (api_key := os.environ.get("WANDB_API_KEY")) is not None:
|
1064 |
+
# Finally, check if someone canceled the run from W&B by adding the 'cancel' / 'canceled' tag..
|
1065 |
+
# We won't see it in the run object. So we have to use the import/export API to check.
|
1066 |
+
from requests.exceptions import RequestException
|
1067 |
+
from wandb.errors import CommError
|
1068 |
+
|
1069 |
+
try:
|
1070 |
+
api = wandb.Api(api_key=api_key)
|
1071 |
+
run = api.run(wandb.run.path)
|
1072 |
+
for tag in run.tags or []:
|
1073 |
+
if tag.lower() in {"cancel", "canceled", "cancelled"}:
|
1074 |
+
should_cancel = True
|
1075 |
+
cancel_reason = "Weights & Biases tag"
|
1076 |
+
extra_steps = self.cfg.extra_steps_after_cancel
|
1077 |
+
break
|
1078 |
+
except (RequestException, CommError):
|
1079 |
+
log.info("Failed to check if W&B run is cancelled, continuing run.")
|
1080 |
+
|
1081 |
+
run_canceled = synchronize_flag(should_cancel, self.device)
|
1082 |
+
if run_canceled:
|
1083 |
+
extra_steps = synchronize_value(extra_steps, self.device)
|
1084 |
+
if cancel_reason is None:
|
1085 |
+
if extra_steps > 0:
|
1086 |
+
log.warning(f"Run canceled, stopping in {extra_steps} more steps...")
|
1087 |
+
else:
|
1088 |
+
log.warning("Run canceled")
|
1089 |
+
else:
|
1090 |
+
if extra_steps > 0:
|
1091 |
+
log.warning(f"Run canceled due to {cancel_reason}, stopping in {extra_steps} more steps...")
|
1092 |
+
else:
|
1093 |
+
log.warning(f"Run canceled due to {cancel_reason}")
|
1094 |
+
|
1095 |
+
return run_canceled, extra_steps
|
1096 |
+
|
1097 |
+
def fit(self):
|
1098 |
+
if self.cfg.stop_after is not None:
|
1099 |
+
if self.cfg.stop_at is None:
|
1100 |
+
self.cfg.stop_at = self.global_step + self.cfg.stop_after
|
1101 |
+
else:
|
1102 |
+
self.cfg.stop_at = min(self.cfg.stop_at, self.global_step + self.cfg.stop_after)
|
1103 |
+
if self.cfg.stop_at is None:
|
1104 |
+
self.cfg.stop_at = self.max_steps + 10
|
1105 |
+
|
1106 |
+
self._start_time = time.time()
|
1107 |
+
self._gc_init_state = gc.isenabled() # cache if garbage collection is enabled, reset on close.
|
1108 |
+
|
1109 |
+
# Disable automatic garbage collection, FSDP doesn't work well with it.
|
1110 |
+
if self.cfg.gen1_gc_interval is not None:
|
1111 |
+
gc.disable()
|
1112 |
+
|
1113 |
+
if self.cfg.load_path is not None and self.global_step > 0 and self.cfg.eval_on_load:
|
1114 |
+
eval_metrics = self.eval()
|
1115 |
+
if wandb.run is not None:
|
1116 |
+
wandb.log(eval_metrics, step=self.global_step)
|
1117 |
+
|
1118 |
+
# Set model to 'train' mode.
|
1119 |
+
self.dist_model.train()
|
1120 |
+
|
1121 |
+
# Initialize monitors.
|
1122 |
+
assert self.cfg.device_train_batch_size is not None
|
1123 |
+
speed_monitor = SpeedMonitor(self.cfg.speed_monitor)
|
1124 |
+
lr_monitor = LRMonitor(self.optim)
|
1125 |
+
|
1126 |
+
# Log system metrics at the start of training.
|
1127 |
+
sys_metrics = self.system_metrics()
|
1128 |
+
if sys_metrics:
|
1129 |
+
self.log_metrics_to_console("Pre-train system metrics", sys_metrics)
|
1130 |
+
if wandb.run is not None:
|
1131 |
+
wandb.log(sys_metrics, step=0)
|
1132 |
+
|
1133 |
+
# Python Profiler stuff
|
1134 |
+
if self.cfg.python_profiling:
|
1135 |
+
python_profiler = cProfile.Profile()
|
1136 |
+
else:
|
1137 |
+
python_profiler = None
|
1138 |
+
|
1139 |
+
# PyTorch Profiler stuff
|
1140 |
+
if self.cfg.torch_profiling and get_global_rank() == 0:
|
1141 |
+
from torch.profiler import schedule
|
1142 |
+
|
1143 |
+
profiling_schedule = schedule(wait=1, warmup=5, active=3, repeat=1)
|
1144 |
+
|
1145 |
+
def on_trace_ready(p):
|
1146 |
+
profiler_output_dir = Path(self.cfg.save_folder) / "profiler"
|
1147 |
+
profiler_output_dir.mkdir(exist_ok=True)
|
1148 |
+
|
1149 |
+
output = p.key_averages().table(sort_by="self_cuda_time_total", row_limit=32)
|
1150 |
+
log.info(f"Profile by total GPU time at step {p.step_num}:\n{output}")
|
1151 |
+
output = p.key_averages().table(sort_by="self_cpu_time_total", row_limit=32)
|
1152 |
+
log.info(f"Profile by total CPU time at step {p.step_num}:\n{output}")
|
1153 |
+
|
1154 |
+
p.export_chrome_trace(
|
1155 |
+
str(trace_path := (profiler_output_dir / f"{p.step_num}.chrome_trace.json.gz"))
|
1156 |
+
)
|
1157 |
+
if self.cfg.remote_save_folder is not None:
|
1158 |
+
upload_folder = f"{self.cfg.remote_save_folder.rstrip('/')}/profiler"
|
1159 |
+
log.info(f"Tracing complete, uploading results to '{upload_folder}'...")
|
1160 |
+
upload(trace_path, f"{upload_folder}/{trace_path.name}")
|
1161 |
+
|
1162 |
+
from torch.profiler import ProfilerActivity
|
1163 |
+
|
1164 |
+
torch_profiler = torch.profiler.profile(
|
1165 |
+
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
|
1166 |
+
record_shapes=False,
|
1167 |
+
profile_memory=False,
|
1168 |
+
with_stack=True,
|
1169 |
+
schedule=profiling_schedule,
|
1170 |
+
on_trace_ready=on_trace_ready,
|
1171 |
+
)
|
1172 |
+
del profiling_schedule
|
1173 |
+
else:
|
1174 |
+
import contextlib
|
1175 |
+
|
1176 |
+
torch_profiler = contextlib.nullcontext()
|
1177 |
+
|
1178 |
+
# Train.
|
1179 |
+
first_batch: bool = True
|
1180 |
+
cancel_initiated: bool = False
|
1181 |
+
stop_at: int = self.cfg.stop_at
|
1182 |
+
save_checkpoints: bool = True
|
1183 |
+
|
1184 |
+
with torch_profiler as p:
|
1185 |
+
for epoch in range(self.epoch or 0, self.max_epochs):
|
1186 |
+
for batch in self.train_loader:
|
1187 |
+
# Bookkeeping.
|
1188 |
+
# NOTE: To track the global batch size / number of tokens per batch we make the assumption that all
|
1189 |
+
# batches see the same number of tokens, which should be the case for language model pre-training
|
1190 |
+
# (at least when drop_last=True).
|
1191 |
+
# Alternatively we'd have to use a distributed all reduce over seq_len here, but I don't want that
|
1192 |
+
# overhead. So for now I'm putting these assertions here so if the assumption is violated it will
|
1193 |
+
# fail loudly.
|
1194 |
+
batch_size, seq_len = batch["input_ids"].shape
|
1195 |
+
assert seq_len == self.cfg.model.max_sequence_length
|
1196 |
+
assert batch_size == self.cfg.device_train_batch_size
|
1197 |
+
global_batch_size = batch_size * get_world_size() # assumes batch size equal across ranks
|
1198 |
+
self.global_step += 1
|
1199 |
+
self.global_train_examples_seen_this_epoch += global_batch_size
|
1200 |
+
self.global_train_tokens_seen += global_batch_size * seq_len
|
1201 |
+
speed_monitor.batch_start(
|
1202 |
+
global_total_tokens=self.global_train_tokens_seen,
|
1203 |
+
device_batch_num_tokens=batch_size * seq_len, # num tokens in batch for this device
|
1204 |
+
# We start monitoring speed after the first batch since the first
|
1205 |
+
# batch might be an outlier due to compiling and other initialization overhead.
|
1206 |
+
num_fwd_flops=self.model.num_fwd_flops, # this is per token
|
1207 |
+
num_bck_flops=self.model.num_bck_flops, # this is per token
|
1208 |
+
record=not first_batch,
|
1209 |
+
)
|
1210 |
+
|
1211 |
+
should_log_this_step = self.should_log_this_step()
|
1212 |
+
|
1213 |
+
# Run train step on batch.
|
1214 |
+
metrics = self.train_step(batch, reduce_global_loss=should_log_this_step)
|
1215 |
+
|
1216 |
+
# Maybe collect other metrics.
|
1217 |
+
if should_log_this_step:
|
1218 |
+
# Speed metrics.
|
1219 |
+
metrics.update(speed_monitor.check())
|
1220 |
+
# System metrics.
|
1221 |
+
metrics.update(self.system_metrics())
|
1222 |
+
# Learning rate metrics.
|
1223 |
+
metrics.update(lr_monitor.check())
|
1224 |
+
|
1225 |
+
# Log metrics to console.
|
1226 |
+
if self.global_step % self.cfg.console_log_interval == 0:
|
1227 |
+
if get_global_rank() == 0:
|
1228 |
+
self.log_metrics_to_console(
|
1229 |
+
f"[step={self.global_step}/{self.max_steps},epoch={epoch}]",
|
1230 |
+
metrics,
|
1231 |
+
)
|
1232 |
+
else:
|
1233 |
+
log.info(f"[step={self.global_step}/{self.max_steps},epoch={epoch}]")
|
1234 |
+
|
1235 |
+
# Log metrics to W&B.
|
1236 |
+
if (
|
1237 |
+
wandb.run is not None
|
1238 |
+
and self.cfg.wandb is not None
|
1239 |
+
and self.global_step % self.cfg.wandb.log_interval == 0
|
1240 |
+
):
|
1241 |
+
wandb.log(metrics, step=self.global_step)
|
1242 |
+
|
1243 |
+
# Check if/when run should be canceled.
|
1244 |
+
if not cancel_initiated and self.global_step % self.cfg.canceled_check_interval == 0:
|
1245 |
+
cancel_initiated, extra_steps = self.check_if_cancelled()
|
1246 |
+
if cancel_initiated:
|
1247 |
+
stop_at = min(stop_at, self.global_step + extra_steps)
|
1248 |
+
|
1249 |
+
# Maybe save sharded checkpoint.
|
1250 |
+
if self.cfg.distributed_strategy != DistributedStrategy.ddp:
|
1251 |
+
if save_checkpoints and (
|
1252 |
+
cancel_initiated
|
1253 |
+
or (
|
1254 |
+
self.cfg.save_interval is not None
|
1255 |
+
and self.global_step % self.cfg.save_interval == 0
|
1256 |
+
and self.cfg.save_num_checkpoints_to_keep != 0
|
1257 |
+
)
|
1258 |
+
):
|
1259 |
+
log.info("Saving checkpoint...")
|
1260 |
+
checkpoint_path, _ = self.save_checkpoint(CheckpointType.sharded)
|
1261 |
+
log.info(f"Checkpoint saved to {checkpoint_path}")
|
1262 |
+
|
1263 |
+
# Remove any ephemeral checkpoints.
|
1264 |
+
while self.ephemeral_checkpoints:
|
1265 |
+
self.remove_ephemeral_checkpoint()
|
1266 |
+
|
1267 |
+
# Reset speed monitor so that we don't count the time taken to save checkpoints.
|
1268 |
+
speed_monitor.reset()
|
1269 |
+
|
1270 |
+
# If the run was just canceled this will be the final checkpoint.
|
1271 |
+
if cancel_initiated:
|
1272 |
+
save_checkpoints = False
|
1273 |
+
elif (
|
1274 |
+
self.cfg.save_interval_ephemeral is not None
|
1275 |
+
and self.global_step % self.cfg.save_interval_ephemeral == 0
|
1276 |
+
):
|
1277 |
+
log.info("Saving ephemeral checkpoint...")
|
1278 |
+
checkpoint_path, _ = self.save_checkpoint(CheckpointType.sharded_ephemeral)
|
1279 |
+
log.info(f"Checkpoint saved to {checkpoint_path}")
|
1280 |
+
|
1281 |
+
# Reset speed monitor so that we don't count the time taken to save checkpoints.
|
1282 |
+
speed_monitor.reset()
|
1283 |
+
|
1284 |
+
# Maybe save unsharded checkpoint.
|
1285 |
+
# This code snippet should always execute when running DDP.
|
1286 |
+
if (
|
1287 |
+
save_checkpoints
|
1288 |
+
and self.cfg.save_interval_unsharded is not None
|
1289 |
+
and self.global_step % self.cfg.save_interval_unsharded == 0
|
1290 |
+
and self.cfg.save_num_unsharded_checkpoints_to_keep != 0
|
1291 |
+
):
|
1292 |
+
log.info("Saving unsharded checkpoint...")
|
1293 |
+
checkpoint_path, _ = self.save_checkpoint(CheckpointType.unsharded)
|
1294 |
+
log.info(f"Unsharded checkpoint saved to {checkpoint_path}")
|
1295 |
+
|
1296 |
+
# Reset speed monitor so that we don't count the time taken to save checkpoints.
|
1297 |
+
speed_monitor.reset()
|
1298 |
+
|
1299 |
+
# Maybe run evaluations.
|
1300 |
+
if not cancel_initiated and (
|
1301 |
+
self.global_step % self.cfg.eval_interval == 0 or self.global_step >= stop_at
|
1302 |
+
):
|
1303 |
+
eval_metrics = self.eval()
|
1304 |
+
|
1305 |
+
# Log metrics to W&B.
|
1306 |
+
if wandb.run is not None:
|
1307 |
+
wandb.log(eval_metrics, step=self.global_step)
|
1308 |
+
|
1309 |
+
# Reset speed monitor so that we don't count the time taken to run evaluations.
|
1310 |
+
speed_monitor.reset()
|
1311 |
+
|
1312 |
+
# Reset model to 'train' mode.
|
1313 |
+
self.dist_model.train()
|
1314 |
+
|
1315 |
+
# End of batch.
|
1316 |
+
first_batch = False
|
1317 |
+
if p is not None:
|
1318 |
+
p.step()
|
1319 |
+
|
1320 |
+
if self.global_step >= stop_at:
|
1321 |
+
break
|
1322 |
+
|
1323 |
+
# Run generation 1 garbage collection.
|
1324 |
+
if self.cfg.gen1_gc_interval is not None and self.global_step % self.cfg.gen1_gc_interval == 0:
|
1325 |
+
gc.collect(1)
|
1326 |
+
|
1327 |
+
# Python Profiler stuff
|
1328 |
+
# We do this now, at the bottom of this loop, so we capture the work of getting the next batch.
|
1329 |
+
if python_profiler is not None:
|
1330 |
+
if self.global_step == 5:
|
1331 |
+
python_profiler.enable()
|
1332 |
+
elif self.global_step == 8:
|
1333 |
+
python_profiler.disable()
|
1334 |
+
python_profiler.print_stats(sort=SortKey.CUMULATIVE)
|
1335 |
+
python_profiler = None
|
1336 |
+
else:
|
1337 |
+
log.info("Training epoch complete")
|
1338 |
+
self.epoch = epoch + 1
|
1339 |
+
self.global_train_examples_seen_this_epoch = 0
|
1340 |
+
self.dataset.start_index = 0
|
1341 |
+
if self.epoch < self.max_epochs:
|
1342 |
+
log.info(f"Reshuffling data loader for epoch {self.epoch}...")
|
1343 |
+
self.dataset.reshuffle(self.epoch)
|
1344 |
+
continue
|
1345 |
+
|
1346 |
+
break
|
1347 |
+
|
1348 |
+
# Save final checkpoint.
|
1349 |
+
if save_checkpoints:
|
1350 |
+
if (
|
1351 |
+
self.cfg.save_interval_unsharded is not None
|
1352 |
+
and self.last_unsharded_checkpoint_step != self.global_step
|
1353 |
+
):
|
1354 |
+
log.info("Saving final unsharded model checkpoint...")
|
1355 |
+
checkpoint_path, _ = self.save_checkpoint(CheckpointType.unsharded)
|
1356 |
+
log.info(f"Unsharded checkpoint saved to {checkpoint_path}")
|
1357 |
+
elif (
|
1358 |
+
self.cfg.save_num_checkpoints_to_keep != 0
|
1359 |
+
and self.last_sharded_checkpoint_step != self.global_step
|
1360 |
+
and self.cfg.distributed_strategy == DistributedStrategy.fsdp
|
1361 |
+
):
|
1362 |
+
log.info("Saving final checkpoint...")
|
1363 |
+
checkpoint_path, _ = self.save_checkpoint(CheckpointType.sharded)
|
1364 |
+
log.info(f"Checkpoint saved to {checkpoint_path}")
|
1365 |
+
|
1366 |
+
def close(self, exit_code: int = 0) -> None:
|
1367 |
+
gc_cuda()
|
1368 |
+
|
1369 |
+
if self.indices_file is not None:
|
1370 |
+
self.indices_file.flush()
|
1371 |
+
self.indices_file.close()
|
1372 |
+
if self._gc_init_state:
|
1373 |
+
gc.enable()
|
1374 |
+
else:
|
1375 |
+
gc.disable()
|
1376 |
+
if wandb.run is not None:
|
1377 |
+
wandb.finish(exit_code=exit_code, quiet=True)
|
1378 |
+
|
1379 |
+
def __enter__(self) -> Trainer:
|
1380 |
+
return self
|
1381 |
+
|
1382 |
+
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
|
1383 |
+
del exc_val, exc_tb
|
1384 |
+
self.close(0 if exc_type is None else 1)
|
util.py
ADDED
@@ -0,0 +1,929 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gzip
|
2 |
+
import io
|
3 |
+
import json
|
4 |
+
import logging
|
5 |
+
import os
|
6 |
+
import re
|
7 |
+
import socket
|
8 |
+
import sys
|
9 |
+
import time
|
10 |
+
import warnings
|
11 |
+
from datetime import datetime
|
12 |
+
from enum import Enum
|
13 |
+
from itertools import cycle, islice
|
14 |
+
from pathlib import Path
|
15 |
+
from queue import Queue
|
16 |
+
from threading import Thread
|
17 |
+
from typing import Any, Callable, Dict, MutableMapping, Optional, Tuple, Union
|
18 |
+
|
19 |
+
import boto3
|
20 |
+
import botocore.exceptions as boto_exceptions
|
21 |
+
import datasets
|
22 |
+
import requests
|
23 |
+
import rich
|
24 |
+
from botocore.config import Config
|
25 |
+
from cached_path.schemes import SchemeClient, add_scheme_client
|
26 |
+
from google.api_core.retry import Retry as GCSRetry
|
27 |
+
from google.api_core.retry import if_transient_error as gcs_is_transient_error
|
28 |
+
from rich.console import Console, ConsoleRenderable
|
29 |
+
from rich.highlighter import NullHighlighter
|
30 |
+
from rich.progress import Progress
|
31 |
+
from rich.text import Text
|
32 |
+
from rich.traceback import Traceback
|
33 |
+
|
34 |
+
from olmo_data.data import get_data_path
|
35 |
+
|
36 |
+
from .aliases import PathOrStr
|
37 |
+
from .exceptions import (
|
38 |
+
OLMoCliError,
|
39 |
+
OLMoEnvironmentError,
|
40 |
+
OLMoError,
|
41 |
+
OLMoNetworkError,
|
42 |
+
OLMoThreadError,
|
43 |
+
)
|
44 |
+
from .torch_util import get_global_rank, get_local_rank, get_node_rank, is_distributed
|
45 |
+
|
46 |
+
try:
|
47 |
+
from functools import cache
|
48 |
+
except ImportError:
|
49 |
+
from functools import lru_cache as cache
|
50 |
+
|
51 |
+
|
52 |
+
class StrEnum(str, Enum):
|
53 |
+
"""
|
54 |
+
This is equivalent to Python's :class:`enum.StrEnum` since version 3.11.
|
55 |
+
We include this here for compatibility with older version of Python.
|
56 |
+
"""
|
57 |
+
|
58 |
+
def __str__(self) -> str:
|
59 |
+
return self.value
|
60 |
+
|
61 |
+
def __repr__(self) -> str:
|
62 |
+
return f"'{str(self)}'"
|
63 |
+
|
64 |
+
|
65 |
+
_log_extra_fields: Dict[str, Any] = {}
|
66 |
+
log = logging.getLogger(__name__)
|
67 |
+
|
68 |
+
|
69 |
+
class LogFilterType(StrEnum):
|
70 |
+
rank0_only = "rank0_only"
|
71 |
+
local_rank0_only = "local_rank0_only"
|
72 |
+
all_ranks = "all_ranks"
|
73 |
+
|
74 |
+
|
75 |
+
def log_extra_field(field_name: str, field_value: Any) -> None:
|
76 |
+
global _log_extra_fields
|
77 |
+
if field_value is None:
|
78 |
+
if field_name in _log_extra_fields:
|
79 |
+
del _log_extra_fields[field_name]
|
80 |
+
else:
|
81 |
+
_log_extra_fields[field_name] = field_value
|
82 |
+
|
83 |
+
|
84 |
+
def setup_logging(log_filter_type: LogFilterType = LogFilterType.rank0_only) -> None:
|
85 |
+
"""
|
86 |
+
:param rank0_only: INFO and below messages will only be emitted on the rank0 process.
|
87 |
+
"""
|
88 |
+
log_extra_field("hostname", socket.gethostname())
|
89 |
+
if is_distributed():
|
90 |
+
log_extra_field("node_rank", get_node_rank())
|
91 |
+
log_extra_field("local_rank", get_local_rank())
|
92 |
+
log_extra_field("global_rank", get_global_rank())
|
93 |
+
else:
|
94 |
+
log_extra_field("node_rank", 0)
|
95 |
+
log_extra_field("local_rank", 0)
|
96 |
+
log_extra_field("global_rank", 0)
|
97 |
+
|
98 |
+
old_log_record_factory = logging.getLogRecordFactory()
|
99 |
+
|
100 |
+
def log_record_factory(*args, **kwargs) -> logging.LogRecord:
|
101 |
+
record = old_log_record_factory(*args, **kwargs)
|
102 |
+
for field_name, field_value in _log_extra_fields.items():
|
103 |
+
setattr(record, field_name, field_value)
|
104 |
+
return record
|
105 |
+
|
106 |
+
logging.setLogRecordFactory(log_record_factory)
|
107 |
+
|
108 |
+
handler: logging.Handler
|
109 |
+
if (
|
110 |
+
os.environ.get("OLMo_NONINTERACTIVE", False)
|
111 |
+
or os.environ.get("DEBIAN_FRONTEND", None) == "noninteractive"
|
112 |
+
or not sys.stdout.isatty()
|
113 |
+
):
|
114 |
+
handler = logging.StreamHandler(sys.stdout)
|
115 |
+
formatter = logging.Formatter(
|
116 |
+
"%(asctime)s\t%(hostname)s:%(local_rank)s\t%(name)s:%(lineno)s\t%(levelname)s\t%(message)s"
|
117 |
+
)
|
118 |
+
formatter.default_time_format = "%Y-%m-%d %H:%M:%S"
|
119 |
+
formatter.default_msec_format = "%s.%03d"
|
120 |
+
handler.setFormatter(formatter)
|
121 |
+
else:
|
122 |
+
handler = RichHandler()
|
123 |
+
|
124 |
+
def rank0_filter(record: logging.LogRecord) -> int:
|
125 |
+
if record.levelno > logging.INFO:
|
126 |
+
return 1
|
127 |
+
if getattr(record, "global_rank", 0) == 0:
|
128 |
+
return 1
|
129 |
+
else:
|
130 |
+
return 0
|
131 |
+
|
132 |
+
def local_rank0_filter(record: logging.LogRecord) -> int:
|
133 |
+
if record.levelno > logging.INFO:
|
134 |
+
return 1
|
135 |
+
if getattr(record, "local_rank", 0) == 0:
|
136 |
+
return 1
|
137 |
+
else:
|
138 |
+
return 0
|
139 |
+
|
140 |
+
if log_filter_type == LogFilterType.rank0_only:
|
141 |
+
filter = rank0_filter
|
142 |
+
elif log_filter_type == LogFilterType.local_rank0_only:
|
143 |
+
filter = local_rank0_filter # type: ignore
|
144 |
+
elif log_filter_type == LogFilterType.all_ranks:
|
145 |
+
filter = None
|
146 |
+
else:
|
147 |
+
raise ValueError(log_filter_type)
|
148 |
+
|
149 |
+
if filter is not None:
|
150 |
+
handler.addFilter(filter) # type: ignore
|
151 |
+
logging.basicConfig(handlers=[handler], level=logging.INFO)
|
152 |
+
|
153 |
+
logging.captureWarnings(True)
|
154 |
+
logging.getLogger("urllib3").setLevel(logging.ERROR)
|
155 |
+
|
156 |
+
|
157 |
+
def excepthook(exctype, value, traceback):
|
158 |
+
"""
|
159 |
+
Used to patch `sys.excepthook` in order to log exceptions.
|
160 |
+
"""
|
161 |
+
if issubclass(exctype, KeyboardInterrupt):
|
162 |
+
sys.__excepthook__(exctype, value, traceback)
|
163 |
+
elif issubclass(exctype, OLMoCliError):
|
164 |
+
rich.get_console().print(f"[yellow]{value}[/]", highlight=False)
|
165 |
+
elif issubclass(exctype, OLMoError):
|
166 |
+
rich.get_console().print(Text(f"{exctype.__name__}:", style="red"), value, highlight=False)
|
167 |
+
else:
|
168 |
+
log.critical("Uncaught %s: %s", exctype.__name__, value, exc_info=(exctype, value, traceback))
|
169 |
+
|
170 |
+
|
171 |
+
def install_excepthook():
|
172 |
+
sys.excepthook = excepthook
|
173 |
+
|
174 |
+
|
175 |
+
def filter_warnings():
|
176 |
+
# Filter internal deprecation warnings from torch
|
177 |
+
warnings.filterwarnings(
|
178 |
+
action="ignore",
|
179 |
+
category=UserWarning,
|
180 |
+
message="torch.distributed.*_base is a private function and will be deprecated.*",
|
181 |
+
)
|
182 |
+
warnings.filterwarnings(
|
183 |
+
action="ignore",
|
184 |
+
category=UserWarning,
|
185 |
+
message="TypedStorage is deprecated.*",
|
186 |
+
)
|
187 |
+
warnings.filterwarnings(
|
188 |
+
action="ignore",
|
189 |
+
category=UserWarning,
|
190 |
+
message="Please use DTensor instead.*",
|
191 |
+
)
|
192 |
+
# Torchvision warnings. We don't actually use torchvision.
|
193 |
+
warnings.filterwarnings(
|
194 |
+
action="ignore",
|
195 |
+
message="failed to load.*",
|
196 |
+
module="torchvision.io.image",
|
197 |
+
)
|
198 |
+
|
199 |
+
|
200 |
+
def set_env_variables():
|
201 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
202 |
+
|
203 |
+
|
204 |
+
def prepare_cli_environment(log_filter_type: Optional[LogFilterType] = None):
|
205 |
+
if log_filter_type is None:
|
206 |
+
log_filter_type = LogFilterType(os.environ.get("LOG_FILTER_TYPE", "rank0_only"))
|
207 |
+
rich.reconfigure(width=max(rich.get_console().width, 180), soft_wrap=True)
|
208 |
+
setup_logging(log_filter_type=log_filter_type)
|
209 |
+
install_excepthook()
|
210 |
+
filter_warnings()
|
211 |
+
set_env_variables()
|
212 |
+
|
213 |
+
|
214 |
+
def clean_opt(arg: str) -> str:
|
215 |
+
if "=" not in arg:
|
216 |
+
arg = f"{arg}=True"
|
217 |
+
name, val = arg.split("=", 1)
|
218 |
+
name = name.strip("-").replace("-", "_")
|
219 |
+
return f"{name}={val}"
|
220 |
+
|
221 |
+
|
222 |
+
class RichHandler(logging.Handler):
|
223 |
+
"""
|
224 |
+
A simplified version of rich.logging.RichHandler from
|
225 |
+
https://github.com/Textualize/rich/blob/master/rich/logging.py
|
226 |
+
"""
|
227 |
+
|
228 |
+
def __init__(
|
229 |
+
self,
|
230 |
+
*,
|
231 |
+
level: Union[int, str] = logging.NOTSET,
|
232 |
+
console: Optional[Console] = None,
|
233 |
+
markup: bool = False,
|
234 |
+
) -> None:
|
235 |
+
super().__init__(level=level)
|
236 |
+
self.console = console or rich.get_console()
|
237 |
+
self.highlighter = NullHighlighter()
|
238 |
+
self.markup = markup
|
239 |
+
|
240 |
+
def emit(self, record: logging.LogRecord) -> None:
|
241 |
+
try:
|
242 |
+
if hasattr(record.msg, "__rich__") or hasattr(record.msg, "__rich_console__"):
|
243 |
+
self.console.print(record.msg)
|
244 |
+
else:
|
245 |
+
msg: Any = record.msg
|
246 |
+
if isinstance(record.msg, str):
|
247 |
+
msg = self.render_message(record=record, message=record.getMessage())
|
248 |
+
renderables = [
|
249 |
+
self.get_time_text(record),
|
250 |
+
self.get_level_text(record),
|
251 |
+
self.get_location_text(record),
|
252 |
+
msg,
|
253 |
+
]
|
254 |
+
if record.exc_info is not None:
|
255 |
+
tb = Traceback.from_exception(*record.exc_info) # type: ignore
|
256 |
+
renderables.append(tb)
|
257 |
+
self.console.print(*renderables)
|
258 |
+
except Exception:
|
259 |
+
self.handleError(record)
|
260 |
+
|
261 |
+
def render_message(self, *, record: logging.LogRecord, message: str) -> ConsoleRenderable:
|
262 |
+
use_markup = getattr(record, "markup", self.markup)
|
263 |
+
message_text = Text.from_markup(message) if use_markup else Text(message)
|
264 |
+
|
265 |
+
highlighter = getattr(record, "highlighter", self.highlighter)
|
266 |
+
if highlighter:
|
267 |
+
message_text = highlighter(message_text)
|
268 |
+
|
269 |
+
return message_text
|
270 |
+
|
271 |
+
def get_time_text(self, record: logging.LogRecord) -> Text:
|
272 |
+
log_time = datetime.fromtimestamp(record.created)
|
273 |
+
time_str = log_time.strftime("[%Y-%m-%d %X]")
|
274 |
+
return Text(time_str, style="log.time", end=" ")
|
275 |
+
|
276 |
+
def get_level_text(self, record: logging.LogRecord) -> Text:
|
277 |
+
level_name = record.levelname
|
278 |
+
level_text = Text.styled(level_name.ljust(8), f"logging.level.{level_name.lower()}")
|
279 |
+
level_text.style = "log.level"
|
280 |
+
level_text.end = " "
|
281 |
+
return level_text
|
282 |
+
|
283 |
+
def get_location_text(self, record: logging.LogRecord) -> Text:
|
284 |
+
name_and_line = f"{record.name}:{record.lineno}" if record.name != "root" else "root"
|
285 |
+
text = f"[{name_and_line}, rank={record.local_rank}]" # type: ignore
|
286 |
+
return Text(text, style="log.path")
|
287 |
+
|
288 |
+
|
289 |
+
def wait_for(condition: Callable[[], bool], description: str, timeout: float = 10.0):
|
290 |
+
"""Wait for the condition function to return True."""
|
291 |
+
start_time = time.monotonic()
|
292 |
+
while not condition():
|
293 |
+
time.sleep(0.5)
|
294 |
+
if time.monotonic() - start_time > timeout:
|
295 |
+
raise TimeoutError(f"{description} timed out")
|
296 |
+
|
297 |
+
|
298 |
+
def is_url(path: PathOrStr) -> bool:
|
299 |
+
return re.match(r"[a-z0-9]+://.*", str(path)) is not None
|
300 |
+
|
301 |
+
|
302 |
+
def dir_is_empty(dir: PathOrStr) -> bool:
|
303 |
+
dir = Path(dir)
|
304 |
+
if not dir.is_dir():
|
305 |
+
return True
|
306 |
+
try:
|
307 |
+
next(dir.glob("*"))
|
308 |
+
return False
|
309 |
+
except StopIteration:
|
310 |
+
return True
|
311 |
+
|
312 |
+
|
313 |
+
def get_progress_bar() -> Progress:
|
314 |
+
from cached_path import get_download_progress
|
315 |
+
|
316 |
+
return get_download_progress()
|
317 |
+
|
318 |
+
|
319 |
+
def resource_path(
|
320 |
+
folder: PathOrStr, fname: str, local_cache: Optional[PathOrStr] = None, progress: Optional[Progress] = None
|
321 |
+
) -> Path:
|
322 |
+
if local_cache is not None and (local_path := Path(local_cache) / fname).is_file():
|
323 |
+
log.info(f"Found local cache of {fname} at {local_path}")
|
324 |
+
return local_path
|
325 |
+
else:
|
326 |
+
from cached_path import cached_path
|
327 |
+
|
328 |
+
return cached_path(f"{str(folder).rstrip('/')}/{fname}", progress=progress)
|
329 |
+
|
330 |
+
|
331 |
+
def file_size(path: PathOrStr) -> int:
|
332 |
+
"""
|
333 |
+
Get the size of a local or remote file in bytes.
|
334 |
+
"""
|
335 |
+
if is_url(path):
|
336 |
+
from urllib.parse import urlparse
|
337 |
+
|
338 |
+
parsed = urlparse(str(path))
|
339 |
+
if parsed.scheme == "gs":
|
340 |
+
return _gcs_file_size(parsed.netloc, parsed.path.strip("/"))
|
341 |
+
elif parsed.scheme in ("s3", "r2", "weka"):
|
342 |
+
return _s3_file_size(parsed.scheme, parsed.netloc, parsed.path.strip("/"))
|
343 |
+
elif parsed.scheme in ("http", "https"):
|
344 |
+
return _http_file_size(parsed.scheme, parsed.netloc, parsed.path.strip("/"))
|
345 |
+
elif parsed.scheme == "file":
|
346 |
+
return file_size(str(path).replace("file://", "", 1))
|
347 |
+
else:
|
348 |
+
raise NotImplementedError(f"file size not implemented for '{parsed.scheme}' files")
|
349 |
+
else:
|
350 |
+
return os.stat(path).st_size
|
351 |
+
|
352 |
+
|
353 |
+
def upload(source: PathOrStr, target: str, save_overwrite: bool = False):
|
354 |
+
"""Upload source file to a target location on GCS or S3."""
|
355 |
+
from urllib.parse import urlparse
|
356 |
+
|
357 |
+
source = Path(source)
|
358 |
+
assert source.is_file()
|
359 |
+
parsed = urlparse(target)
|
360 |
+
if parsed.scheme == "gs":
|
361 |
+
_gcs_upload(source, parsed.netloc, parsed.path.strip("/"), save_overwrite=save_overwrite)
|
362 |
+
elif parsed.scheme in ("s3", "r2", "weka"):
|
363 |
+
_s3_upload(source, parsed.scheme, parsed.netloc, parsed.path.strip("/"), save_overwrite=save_overwrite)
|
364 |
+
else:
|
365 |
+
raise NotImplementedError(f"Upload not implemented for '{parsed.scheme}' scheme")
|
366 |
+
|
367 |
+
|
368 |
+
def get_bytes_range(source: PathOrStr, bytes_start: int, num_bytes: int) -> bytes:
|
369 |
+
if is_url(source):
|
370 |
+
from urllib.parse import urlparse
|
371 |
+
|
372 |
+
parsed = urlparse(str(source))
|
373 |
+
if parsed.scheme == "gs":
|
374 |
+
return _gcs_get_bytes_range(parsed.netloc, parsed.path.strip("/"), bytes_start, num_bytes)
|
375 |
+
elif parsed.scheme in ("s3", "r2", "weka"):
|
376 |
+
return _s3_get_bytes_range(
|
377 |
+
parsed.scheme, parsed.netloc, parsed.path.strip("/"), bytes_start, num_bytes
|
378 |
+
)
|
379 |
+
elif parsed.scheme in ("http", "https"):
|
380 |
+
return _http_get_bytes_range(
|
381 |
+
parsed.scheme, parsed.netloc, parsed.path.strip("/"), bytes_start, num_bytes
|
382 |
+
)
|
383 |
+
elif parsed.scheme == "file":
|
384 |
+
return get_bytes_range(str(source).replace("file://", "", 1), bytes_start, num_bytes)
|
385 |
+
else:
|
386 |
+
raise NotImplementedError(f"get bytes range not implemented for '{parsed.scheme}' files")
|
387 |
+
else:
|
388 |
+
with open(source, "rb") as f:
|
389 |
+
f.seek(bytes_start)
|
390 |
+
return f.read(num_bytes)
|
391 |
+
|
392 |
+
|
393 |
+
def find_latest_checkpoint(dir: PathOrStr) -> Optional[PathOrStr]:
|
394 |
+
if is_url(dir):
|
395 |
+
from urllib.parse import urlparse
|
396 |
+
|
397 |
+
parsed = urlparse(str(dir))
|
398 |
+
if parsed.scheme == "gs":
|
399 |
+
return _gcs_find_latest_checkpoint(parsed.netloc, parsed.path.strip("/"))
|
400 |
+
elif parsed.scheme in ("s3", "r2", "weka"):
|
401 |
+
return _s3_find_latest_checkpoint(parsed.scheme, parsed.netloc, parsed.path.strip("/"))
|
402 |
+
elif parsed.scheme == "file":
|
403 |
+
return find_latest_checkpoint(str(dir).replace("file://", "", 1))
|
404 |
+
else:
|
405 |
+
raise NotImplementedError(f"find_latest_checkpoint not implemented for '{parsed.scheme}' files")
|
406 |
+
else:
|
407 |
+
latest_step = 0
|
408 |
+
latest_checkpoint: Optional[Path] = None
|
409 |
+
for path in Path(dir).glob("step*"):
|
410 |
+
if path.is_dir():
|
411 |
+
try:
|
412 |
+
step = int(path.name.replace("step", "").replace("-unsharded", ""))
|
413 |
+
except ValueError:
|
414 |
+
continue
|
415 |
+
# We prioritize sharded checkpoints over unsharded checkpoints.
|
416 |
+
if step > latest_step or (step == latest_step and not path.name.endswith("-unsharded")):
|
417 |
+
latest_step = step
|
418 |
+
latest_checkpoint = path
|
419 |
+
return latest_checkpoint
|
420 |
+
|
421 |
+
|
422 |
+
# Google Storage API is unhinged and requires you to specify the retry policy on every single call you make.
|
423 |
+
def _gcs_is_retriable(exception: Exception) -> bool:
|
424 |
+
if gcs_is_transient_error(exception):
|
425 |
+
return True
|
426 |
+
if isinstance(exception, requests.exceptions.ReadTimeout):
|
427 |
+
return True
|
428 |
+
return False
|
429 |
+
|
430 |
+
|
431 |
+
_gcs_retry = GCSRetry(predicate=_gcs_is_retriable, initial=1.0, maximum=10.0, multiplier=2.0, timeout=600.0)
|
432 |
+
|
433 |
+
|
434 |
+
def _gcs_upload(source: Path, bucket_name: str, key: str, save_overwrite: bool = False):
|
435 |
+
storage_client = _get_gcs_client()
|
436 |
+
bucket = storage_client.bucket(bucket_name)
|
437 |
+
blob = bucket.blob(key)
|
438 |
+
if not save_overwrite and blob.exists():
|
439 |
+
raise FileExistsError(f"gs://{bucket_name}/{key} already exists. Use save_overwrite to overwrite it.")
|
440 |
+
blob.upload_from_filename(source, retry=_gcs_retry)
|
441 |
+
|
442 |
+
|
443 |
+
def _gcs_file_size(bucket_name: str, key: str) -> int:
|
444 |
+
from google.api_core.exceptions import NotFound
|
445 |
+
|
446 |
+
storage_client = _get_gcs_client()
|
447 |
+
bucket = storage_client.bucket(bucket_name)
|
448 |
+
blob = bucket.blob(key)
|
449 |
+
try:
|
450 |
+
blob.reload(retry=_gcs_retry)
|
451 |
+
except NotFound:
|
452 |
+
raise FileNotFoundError(f"gs://{bucket_name}/{key}")
|
453 |
+
assert blob.size is not None
|
454 |
+
return blob.size
|
455 |
+
|
456 |
+
|
457 |
+
def _gcs_get_bytes_range(bucket_name: str, key: str, bytes_start: int, num_bytes: int) -> bytes:
|
458 |
+
from google.api_core.exceptions import NotFound
|
459 |
+
|
460 |
+
storage_client = _get_gcs_client()
|
461 |
+
bucket = storage_client.bucket(bucket_name)
|
462 |
+
blob = bucket.blob(key)
|
463 |
+
try:
|
464 |
+
return blob.download_as_bytes(start=bytes_start, end=bytes_start + num_bytes - 1, retry=_gcs_retry)
|
465 |
+
except NotFound:
|
466 |
+
raise FileNotFoundError(f"gs://{bucket_name}/{key}")
|
467 |
+
|
468 |
+
|
469 |
+
@cache
|
470 |
+
def _get_gcs_client():
|
471 |
+
from google.cloud import storage as gcs
|
472 |
+
|
473 |
+
return gcs.Client()
|
474 |
+
|
475 |
+
|
476 |
+
def _gcs_find_latest_checkpoint(bucket_name: str, prefix: str) -> Optional[str]:
|
477 |
+
if not prefix.endswith("/"):
|
478 |
+
prefix = f"{prefix}/"
|
479 |
+
|
480 |
+
storage_client = _get_gcs_client()
|
481 |
+
bucket = storage_client.bucket(bucket_name)
|
482 |
+
suffix = "/config.yaml"
|
483 |
+
latest_step: Optional[int] = None
|
484 |
+
latest_checkpoint: Optional[str] = None
|
485 |
+
for blob in bucket.list_blobs(prefix=prefix, match_glob=f"**{suffix}"):
|
486 |
+
# Disregard checkpoints that have an empty config file.
|
487 |
+
if blob.size <= 0:
|
488 |
+
continue
|
489 |
+
|
490 |
+
name = blob.name[len(prefix) : -len(suffix)]
|
491 |
+
|
492 |
+
if "/" in name:
|
493 |
+
# We're not considering checkpoints in subdirectories.
|
494 |
+
continue
|
495 |
+
|
496 |
+
if not name.startswith("step"):
|
497 |
+
continue
|
498 |
+
name = name[4:]
|
499 |
+
|
500 |
+
if name.endswith("-unsharded"):
|
501 |
+
name = name[: -len("-unsharded")]
|
502 |
+
|
503 |
+
try:
|
504 |
+
step = int(name)
|
505 |
+
except ValueError:
|
506 |
+
continue
|
507 |
+
|
508 |
+
# we prefer sharded checkpoints to unsharded ones
|
509 |
+
if (
|
510 |
+
latest_step is None
|
511 |
+
or step > latest_step
|
512 |
+
or (step == latest_step and latest_checkpoint is not None and latest_checkpoint.endswith("-unsharded"))
|
513 |
+
):
|
514 |
+
latest_step = step
|
515 |
+
latest_checkpoint = f"gs://{bucket_name}/{blob.name[:-len(suffix)]}"
|
516 |
+
|
517 |
+
return latest_checkpoint
|
518 |
+
|
519 |
+
|
520 |
+
def _get_s3_profile_name(scheme: str) -> Optional[str]:
|
521 |
+
if scheme == "s3":
|
522 |
+
# For backwards compatibility, we assume S3 uses the default profile if S3_PROFILE is not set.
|
523 |
+
return os.environ.get("S3_PROFILE")
|
524 |
+
if scheme == "r2":
|
525 |
+
profile_name = os.environ.get("R2_PROFILE")
|
526 |
+
if profile_name is None:
|
527 |
+
raise OLMoEnvironmentError(
|
528 |
+
"R2 profile name is not set. Did you forget to set the 'R2_PROFILE' env var?"
|
529 |
+
)
|
530 |
+
|
531 |
+
return profile_name
|
532 |
+
if scheme == "weka":
|
533 |
+
profile_name = os.environ.get("WEKA_PROFILE")
|
534 |
+
if profile_name is None:
|
535 |
+
raise OLMoEnvironmentError(
|
536 |
+
"Weka profile name is not set. Did you forget to set the 'WEKA_PROFILE' env var?"
|
537 |
+
)
|
538 |
+
|
539 |
+
return profile_name
|
540 |
+
|
541 |
+
raise NotImplementedError(f"Cannot get profile name for scheme {scheme}")
|
542 |
+
|
543 |
+
|
544 |
+
def _get_s3_endpoint_url(scheme: str) -> Optional[str]:
|
545 |
+
if scheme == "s3":
|
546 |
+
return None
|
547 |
+
if scheme == "r2":
|
548 |
+
r2_endpoint_url = os.environ.get("R2_ENDPOINT_URL")
|
549 |
+
if r2_endpoint_url is None:
|
550 |
+
raise OLMoEnvironmentError(
|
551 |
+
"R2 endpoint url is not set. Did you forget to set the 'R2_ENDPOINT_URL' env var?"
|
552 |
+
)
|
553 |
+
|
554 |
+
return r2_endpoint_url
|
555 |
+
if scheme == "weka":
|
556 |
+
weka_endpoint_url = os.environ.get("WEKA_ENDPOINT_URL")
|
557 |
+
if weka_endpoint_url is None:
|
558 |
+
raise OLMoEnvironmentError(
|
559 |
+
"Weka endpoint url is not set. Did you forget to set the 'WEKA_ENDPOINT_URL' env var?"
|
560 |
+
)
|
561 |
+
|
562 |
+
return weka_endpoint_url
|
563 |
+
|
564 |
+
raise NotImplementedError(f"Cannot get endpoint url for scheme {scheme}")
|
565 |
+
|
566 |
+
|
567 |
+
@cache
|
568 |
+
def _get_s3_client(scheme: str):
|
569 |
+
session = boto3.Session(profile_name=_get_s3_profile_name(scheme))
|
570 |
+
return session.client(
|
571 |
+
"s3",
|
572 |
+
endpoint_url=_get_s3_endpoint_url(scheme),
|
573 |
+
config=Config(retries={"max_attempts": 10, "mode": "standard"}),
|
574 |
+
use_ssl=not int(os.environ.get("OLMO_NO_SSL", "0")),
|
575 |
+
)
|
576 |
+
|
577 |
+
|
578 |
+
def _wait_before_retry(attempt: int):
|
579 |
+
time.sleep(min(0.5 * 2**attempt, 3.0))
|
580 |
+
|
581 |
+
|
582 |
+
def _s3_upload(
|
583 |
+
source: Path, scheme: str, bucket_name: str, key: str, save_overwrite: bool = False, max_attempts: int = 3
|
584 |
+
):
|
585 |
+
err: Optional[Exception] = None
|
586 |
+
if not save_overwrite:
|
587 |
+
for attempt in range(1, max_attempts + 1):
|
588 |
+
try:
|
589 |
+
_get_s3_client(scheme).head_object(Bucket=bucket_name, Key=key)
|
590 |
+
raise FileExistsError(
|
591 |
+
f"s3://{bucket_name}/{key} already exists. Use save_overwrite to overwrite it."
|
592 |
+
)
|
593 |
+
except boto_exceptions.ClientError as e:
|
594 |
+
if e.response["ResponseMetadata"]["HTTPStatusCode"] == 404:
|
595 |
+
err = None
|
596 |
+
break
|
597 |
+
err = e
|
598 |
+
|
599 |
+
if attempt < max_attempts:
|
600 |
+
log.warning("%s failed attempt %d with retriable error: %s", _s3_upload.__name__, attempt, err)
|
601 |
+
_wait_before_retry(attempt)
|
602 |
+
|
603 |
+
if err is not None:
|
604 |
+
raise OLMoNetworkError(f"Failed to check object existence during {scheme} upload") from err
|
605 |
+
|
606 |
+
try:
|
607 |
+
_get_s3_client(scheme).upload_file(source, bucket_name, key)
|
608 |
+
except boto_exceptions.ClientError as e:
|
609 |
+
raise OLMoNetworkError(f"Failed to upload to {scheme}") from e
|
610 |
+
|
611 |
+
|
612 |
+
def _s3_file_size(scheme: str, bucket_name: str, key: str, max_attempts: int = 3) -> int:
|
613 |
+
err: Optional[Exception] = None
|
614 |
+
for attempt in range(1, max_attempts + 1):
|
615 |
+
try:
|
616 |
+
return _get_s3_client(scheme).head_object(Bucket=bucket_name, Key=key)["ContentLength"]
|
617 |
+
except boto_exceptions.ClientError as e:
|
618 |
+
if e.response["ResponseMetadata"]["HTTPStatusCode"] == 404:
|
619 |
+
raise FileNotFoundError(f"s3://{bucket_name}/{key}") from e
|
620 |
+
err = e
|
621 |
+
|
622 |
+
if attempt < max_attempts:
|
623 |
+
log.warning("%s failed attempt %d with retriable error: %s", _s3_file_size.__name__, attempt, err)
|
624 |
+
_wait_before_retry(attempt)
|
625 |
+
|
626 |
+
raise OLMoNetworkError(f"Failed to get {scheme} file size") from err
|
627 |
+
|
628 |
+
|
629 |
+
def _s3_get_bytes_range(
|
630 |
+
scheme: str, bucket_name: str, key: str, bytes_start: int, num_bytes: int, max_attempts: int = 3
|
631 |
+
) -> bytes:
|
632 |
+
err: Optional[Exception] = None
|
633 |
+
for attempt in range(1, max_attempts + 1):
|
634 |
+
try:
|
635 |
+
return (
|
636 |
+
_get_s3_client(scheme)
|
637 |
+
.get_object(
|
638 |
+
Bucket=bucket_name, Key=key, Range=f"bytes={bytes_start}-{bytes_start + num_bytes - 1}"
|
639 |
+
)["Body"]
|
640 |
+
.read()
|
641 |
+
)
|
642 |
+
except boto_exceptions.ClientError as e:
|
643 |
+
if e.response["ResponseMetadata"]["HTTPStatusCode"] == 404:
|
644 |
+
raise FileNotFoundError(f"{scheme}://{bucket_name}/{key}") from e
|
645 |
+
err = e
|
646 |
+
except (boto_exceptions.HTTPClientError, boto_exceptions.ConnectionError) as e:
|
647 |
+
# ResponseStreamingError (subclass of HTTPClientError) can happen as
|
648 |
+
# a result of a failed read from the stream (http.client.IncompleteRead).
|
649 |
+
# Retrying can help in this case.
|
650 |
+
err = e
|
651 |
+
|
652 |
+
if attempt < max_attempts:
|
653 |
+
log.warning(
|
654 |
+
"%s failed attempt %d with retriable error: %s", _s3_get_bytes_range.__name__, attempt, err
|
655 |
+
)
|
656 |
+
_wait_before_retry(attempt)
|
657 |
+
|
658 |
+
# When torch's DataLoader intercepts exceptions, it may try to re-raise them
|
659 |
+
# by recalling their constructor with a single message arg. Torch has some
|
660 |
+
# logic to deal with the absence of a single-parameter constructor, but it
|
661 |
+
# doesn't gracefully handle other possible failures in calling such a constructor
|
662 |
+
# This can cause an irrelevant exception (e.g. KeyError: 'error'), resulting
|
663 |
+
# in us losing the true exception info. To avoid this, we change the exception
|
664 |
+
# to a type that has a single-parameter constructor.
|
665 |
+
raise OLMoNetworkError(f"Failed to get bytes range from {scheme}") from err
|
666 |
+
|
667 |
+
|
668 |
+
def _s3_find_latest_checkpoint(scheme: str, bucket_name: str, prefix: str) -> Optional[str]:
|
669 |
+
if not prefix.endswith("/"):
|
670 |
+
prefix = f"{prefix}/"
|
671 |
+
response = _get_s3_client(scheme).list_objects(Bucket=bucket_name, Prefix=prefix, Delimiter="/")
|
672 |
+
assert not response["IsTruncated"] # need to handle this if it happens
|
673 |
+
latest_step = 0
|
674 |
+
latest_checkpoint: Optional[str] = None
|
675 |
+
for item in response.get("CommonPrefixes", []):
|
676 |
+
prefix = item["Prefix"].strip("/")
|
677 |
+
checkpoint_name = os.path.split(prefix)[-1]
|
678 |
+
if not checkpoint_name.startswith("step"):
|
679 |
+
continue
|
680 |
+
try:
|
681 |
+
step = int(checkpoint_name.replace("step", "").replace("-unsharded", ""))
|
682 |
+
except ValueError:
|
683 |
+
continue
|
684 |
+
# Make sure the checkpoint dir contains a config, otherwise the checkpoint is incomplete
|
685 |
+
# (upload might have have failed part way through).
|
686 |
+
try:
|
687 |
+
_s3_file_size(scheme, bucket_name, f"{prefix}/config.yaml")
|
688 |
+
except FileNotFoundError:
|
689 |
+
continue
|
690 |
+
# We prioritize sharded checkpoints over unsharded ones.
|
691 |
+
if step > latest_step or (step == latest_step and not checkpoint_name.endswith("-unsharded")):
|
692 |
+
latest_step = step
|
693 |
+
latest_checkpoint = f"{scheme}://{bucket_name}/{prefix}"
|
694 |
+
return latest_checkpoint
|
695 |
+
|
696 |
+
|
697 |
+
def _http_file_size(scheme: str, host_name: str, path: str) -> int:
|
698 |
+
import requests
|
699 |
+
|
700 |
+
response = requests.head(f"{scheme}://{host_name}/{path}", allow_redirects=True)
|
701 |
+
return int(response.headers.get("content-length"))
|
702 |
+
|
703 |
+
|
704 |
+
def _http_get_bytes_range(scheme: str, host_name: str, path: str, bytes_start: int, num_bytes: int) -> bytes:
|
705 |
+
import requests
|
706 |
+
|
707 |
+
response = requests.get(
|
708 |
+
f"{scheme}://{host_name}/{path}", headers={"Range": f"bytes={bytes_start}-{bytes_start+num_bytes-1}"}
|
709 |
+
)
|
710 |
+
result = response.content
|
711 |
+
assert (
|
712 |
+
len(result) == num_bytes
|
713 |
+
), f"expected {num_bytes} bytes, got {len(result)}" # Some web servers silently ignore range requests and send everything
|
714 |
+
return result
|
715 |
+
|
716 |
+
|
717 |
+
def save_hf_dataset_to_disk(
|
718 |
+
dataset: datasets.DatasetDict | datasets.Dataset,
|
719 |
+
hf_path: str,
|
720 |
+
name: Optional[str],
|
721 |
+
split: str,
|
722 |
+
datasets_dir: PathOrStr,
|
723 |
+
):
|
724 |
+
"""
|
725 |
+
Saves a HF dataset to disk under the `datasets_dir`. It can be used to add a HF dataset
|
726 |
+
to `olmo_data` as follows:
|
727 |
+
|
728 |
+
```
|
729 |
+
import datasets
|
730 |
+
|
731 |
+
from olmo.util import save_hf_dataset_to_disk
|
732 |
+
|
733 |
+
path, name, split = ...
|
734 |
+
|
735 |
+
dataset = datasets.load_dataset(path, name=name, split=split)
|
736 |
+
save_hf_dataset_to_disk(dataset, path, name, split, "olmo_data/hf_datasets")
|
737 |
+
```
|
738 |
+
"""
|
739 |
+
dataset_path = Path(datasets_dir) / hf_path / (name or "none") / split
|
740 |
+
return dataset.save_to_disk(str(dataset_path))
|
741 |
+
|
742 |
+
|
743 |
+
def load_hf_dataset(path: str, name: Optional[str], split: str):
|
744 |
+
"""
|
745 |
+
Loads a HuggingFace dataset. The dataset is assumed to be saved using
|
746 |
+
`save_hf_dataset_to_disk` and located in `olmo_data/hf_datasets`.
|
747 |
+
"""
|
748 |
+
dataset_rel_path = os.path.join("hf_datasets", path, name or "none", split)
|
749 |
+
with get_data_path(dataset_rel_path) as dataset_path:
|
750 |
+
if not dataset_path.is_dir():
|
751 |
+
raise NotADirectoryError(
|
752 |
+
f"HF dataset {path} name {name} split {split} not found in directory {dataset_rel_path}"
|
753 |
+
)
|
754 |
+
return datasets.load_from_disk(str(dataset_path))
|
755 |
+
|
756 |
+
|
757 |
+
def load_oe_eval_requests(path: str, name: Optional[str] = None, split: Optional[str] = None):
|
758 |
+
"""
|
759 |
+
Loads an oe-eval request file from `olmo_data/oe_eval_tasks`.
|
760 |
+
TODO: Add support from loading from S3 instead?
|
761 |
+
"""
|
762 |
+
dataset_rel_path = os.path.join("oe_eval_tasks", path)
|
763 |
+
if name is not None:
|
764 |
+
dataset_rel_path = os.path.join(dataset_rel_path, name)
|
765 |
+
with get_data_path(dataset_rel_path) as dataset_path:
|
766 |
+
if not dataset_path.is_dir():
|
767 |
+
raise NotADirectoryError(f"OE Eval dataset not found in directory {dataset_rel_path}")
|
768 |
+
data_file = dataset_path / "requests.jsonl.gz"
|
769 |
+
if not data_file.is_file():
|
770 |
+
data_file = dataset_path / "requests.jsonl"
|
771 |
+
if not data_file.is_file():
|
772 |
+
raise FileNotFoundError(
|
773 |
+
f"OE Eval dataset file requests-{split}.jsonl(.gz) missing in directory {dataset_rel_path}"
|
774 |
+
)
|
775 |
+
requests = []
|
776 |
+
if data_file.suffix == ".gz":
|
777 |
+
with gzip.open(data_file, "r") as file:
|
778 |
+
for line in file:
|
779 |
+
requests.append(json.loads(line.decode("utf-8").strip()))
|
780 |
+
else:
|
781 |
+
with open(data_file, "r") as file:
|
782 |
+
for line2 in file:
|
783 |
+
requests.append(json.loads(line2.strip()))
|
784 |
+
config = None
|
785 |
+
config_file = dataset_path / "config.json"
|
786 |
+
if config_file.is_file():
|
787 |
+
with open(config_file, "r") as file:
|
788 |
+
config = json.load(file)
|
789 |
+
return config, requests
|
790 |
+
|
791 |
+
|
792 |
+
def default_thread_count() -> int:
|
793 |
+
return int(os.environ.get("OLMO_NUM_THREADS") or min(32, (os.cpu_count() or 1) + 4))
|
794 |
+
|
795 |
+
|
796 |
+
def pass_through_fn(fn, *args, **kwargs):
|
797 |
+
return fn(*args, **kwargs)
|
798 |
+
|
799 |
+
|
800 |
+
def threaded_generator(g, maxsize: int = 16, thread_name: Optional[str] = None):
|
801 |
+
q: Queue = Queue(maxsize=maxsize)
|
802 |
+
|
803 |
+
sentinel = object()
|
804 |
+
|
805 |
+
def fill_queue():
|
806 |
+
try:
|
807 |
+
for value in g:
|
808 |
+
q.put(value)
|
809 |
+
except Exception as e:
|
810 |
+
q.put(e)
|
811 |
+
finally:
|
812 |
+
q.put(sentinel)
|
813 |
+
|
814 |
+
thread_name = thread_name or repr(g)
|
815 |
+
thread = Thread(name=thread_name, target=fill_queue, daemon=True)
|
816 |
+
thread.start()
|
817 |
+
|
818 |
+
for x in iter(q.get, sentinel):
|
819 |
+
if isinstance(x, Exception):
|
820 |
+
raise OLMoThreadError(f"generator thread {thread_name} failed") from x
|
821 |
+
else:
|
822 |
+
yield x
|
823 |
+
|
824 |
+
|
825 |
+
def roundrobin(*iterables):
|
826 |
+
"""
|
827 |
+
Call the given iterables in a round-robin fashion. For example:
|
828 |
+
``roundrobin('ABC', 'D', 'EF') --> A D E B F C``
|
829 |
+
"""
|
830 |
+
# Adapted from https://docs.python.org/3/library/itertools.html#itertools-recipes
|
831 |
+
num_active = len(iterables)
|
832 |
+
nexts = cycle(iter(it).__next__ for it in iterables)
|
833 |
+
while num_active:
|
834 |
+
try:
|
835 |
+
for next in nexts:
|
836 |
+
yield next()
|
837 |
+
except StopIteration:
|
838 |
+
# Remove the iterator we just exhausted from the cycle.
|
839 |
+
num_active -= 1
|
840 |
+
nexts = cycle(islice(nexts, num_active))
|
841 |
+
|
842 |
+
|
843 |
+
def add_cached_path_clients():
|
844 |
+
add_scheme_client(WekaClient)
|
845 |
+
|
846 |
+
|
847 |
+
class WekaClient(SchemeClient):
|
848 |
+
recoverable_errors = SchemeClient.recoverable_errors + (
|
849 |
+
boto_exceptions.HTTPClientError,
|
850 |
+
boto_exceptions.ConnectionError,
|
851 |
+
)
|
852 |
+
|
853 |
+
scheme = "weka"
|
854 |
+
|
855 |
+
def __init__(self, resource: str) -> None:
|
856 |
+
SchemeClient.__init__(self, resource)
|
857 |
+
self.bucket_name, self.path = WekaClient._split_cloud_path(resource, "weka")
|
858 |
+
self.s3 = _get_s3_client("weka")
|
859 |
+
self.object_info = None
|
860 |
+
|
861 |
+
@staticmethod
|
862 |
+
def _split_cloud_path(url: str, provider: str) -> Tuple[str, str]:
|
863 |
+
"""Split a full s3 path into the bucket name and path."""
|
864 |
+
from urllib.parse import urlparse
|
865 |
+
|
866 |
+
parsed = urlparse(url)
|
867 |
+
if not parsed.netloc or not parsed.path:
|
868 |
+
raise ValueError("bad {} path {}".format(provider, url))
|
869 |
+
bucket_name = parsed.netloc
|
870 |
+
provider_path = parsed.path
|
871 |
+
# Remove '/' at beginning of path.
|
872 |
+
if provider_path.startswith("/"):
|
873 |
+
provider_path = provider_path[1:]
|
874 |
+
return bucket_name, provider_path
|
875 |
+
|
876 |
+
def _ensure_object_info(self):
|
877 |
+
if self.object_info is None:
|
878 |
+
try:
|
879 |
+
self.object_info = self.s3.head_object(Bucket=self.bucket_name, Key=self.path)
|
880 |
+
except boto_exceptions.ClientError as e:
|
881 |
+
if e.response["ResponseMetadata"]["HTTPStatusCode"] == 404:
|
882 |
+
raise FileNotFoundError(f"weka://{self.bucket_name}/{self.path}") from e
|
883 |
+
raise e
|
884 |
+
|
885 |
+
def get_etag(self) -> Optional[str]:
|
886 |
+
self._ensure_object_info()
|
887 |
+
assert self.object_info is not None
|
888 |
+
return self.object_info.get("ETag")
|
889 |
+
|
890 |
+
def get_size(self) -> Optional[int]:
|
891 |
+
self._ensure_object_info()
|
892 |
+
assert self.object_info is not None
|
893 |
+
return self.object_info.get("ContentLength")
|
894 |
+
|
895 |
+
def get_resource(self, temp_file: io.BufferedWriter) -> None:
|
896 |
+
self.s3.download_fileobj(Fileobj=temp_file, Bucket=self.bucket_name, Key=self.path)
|
897 |
+
|
898 |
+
def get_bytes_range(self, index: int, length: int) -> bytes:
|
899 |
+
response = self.s3.get_object(
|
900 |
+
Bucket=self.bucket_name, Key=self.path, Range=f"bytes={index}-{index+length-1}"
|
901 |
+
)
|
902 |
+
return response["Body"].read()
|
903 |
+
|
904 |
+
|
905 |
+
def flatten_dict(dictionary, parent_key="", separator=".", include_lists=False):
|
906 |
+
"""
|
907 |
+
Flatten a nested dictionary into a single-level dictionary.
|
908 |
+
|
909 |
+
Args:
|
910 |
+
dictionary (dict): The nested dictionary to be flattened.
|
911 |
+
parent_key (str, optional): The parent key to be prepended to the keys of the flattened dictionary. Defaults to "".
|
912 |
+
separator (str, optional): The separator to be used between the parent key and the keys of the flattened dictionary. Defaults to ".".
|
913 |
+
include_lists (bool, optional): Whether to convert lists to dictionaries with integer keys. Defaults to False.
|
914 |
+
|
915 |
+
Returns:
|
916 |
+
dict: The flattened dictionary.
|
917 |
+
|
918 |
+
"""
|
919 |
+
d: Dict[str, Any] = {}
|
920 |
+
for key, value in dictionary.items():
|
921 |
+
new_key = parent_key + separator + key if parent_key else key
|
922 |
+
# convert lists to dict with key <int>
|
923 |
+
if isinstance(value, list) and include_lists:
|
924 |
+
value = {f"{i}": v for i, v in enumerate(value)}
|
925 |
+
if isinstance(value, MutableMapping):
|
926 |
+
d.update(**flatten_dict(value, new_key, separator=separator, include_lists=include_lists))
|
927 |
+
else:
|
928 |
+
d[new_key] = value
|
929 |
+
return d
|
version.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
_MAJOR = "0"
|
2 |
+
_MINOR = "6"
|
3 |
+
# On main and in a nightly release the patch should be one ahead of the last
|
4 |
+
# released build.
|
5 |
+
_PATCH = "0"
|
6 |
+
# This is mainly for nightly builds which have the suffix ".dev$DATE". See
|
7 |
+
# https://semver.org/#is-v123-a-semantic-version for the semantics.
|
8 |
+
_SUFFIX = ""
|
9 |
+
|
10 |
+
VERSION_SHORT = "{0}.{1}".format(_MAJOR, _MINOR)
|
11 |
+
VERSION = "{0}.{1}.{2}{3}".format(_MAJOR, _MINOR, _PATCH, _SUFFIX)
|