Spaces:
Running
Running
feat: support pod (#139)
Browse files- src/dalle_mini/data.py +37 -2
- src/dalle_mini/model/modeling.py +25 -15
- src/dalle_mini/model/utils.py +0 -6
- tools/inference/inference_pipeline.ipynb +15 -9
- tools/train/config/medium/config.json +0 -1
- tools/train/config/mega/config.json +8 -10
- tools/train/config/micro/config.json +6 -8
- tools/train/config/mini/config.json +0 -1
- tools/train/scalable_shampoo/README.md +7 -0
- tools/train/{distributed_shampoo.py → scalable_shampoo/distributed_shampoo.py} +67 -170
- tools/train/scalable_shampoo/quantization_utils.py +124 -0
- tools/train/scalable_shampoo/sm3.py +176 -0
- tools/train/scalable_shampoo/symmetric_matrices/symmetric_matrices.py +211 -0
- tools/train/train.py +197 -128
src/dalle_mini/data.py
CHANGED
|
@@ -27,6 +27,7 @@ class Dataset:
|
|
| 27 |
do_eval: bool = True
|
| 28 |
seed_dataset: int = None
|
| 29 |
shard_by_host: bool = False
|
|
|
|
| 30 |
train_dataset: Dataset = field(init=False)
|
| 31 |
eval_dataset: Dataset = field(init=False)
|
| 32 |
rng_dataset: jnp.ndarray = field(init=False)
|
|
@@ -34,6 +35,11 @@ class Dataset:
|
|
| 34 |
|
| 35 |
def __post_init__(self):
|
| 36 |
self.multi_hosts = jax.process_count() > 1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
# define data_files
|
| 38 |
if self.train_file is not None or self.validation_file is not None:
|
| 39 |
# accept braceexpand notation
|
|
@@ -101,6 +107,25 @@ class Dataset:
|
|
| 101 |
self.seed_dataset = np.random.get_state()[1][0]
|
| 102 |
self.rng_dataset = jax.random.PRNGKey(self.seed_dataset)
|
| 103 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 104 |
# normalize text
|
| 105 |
if normalize_text:
|
| 106 |
text_normalizer = TextNormalizer()
|
|
@@ -144,6 +169,10 @@ class Dataset:
|
|
| 144 |
getattr(self, ds).map(
|
| 145 |
partial_preprocess_function,
|
| 146 |
batched=True,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 147 |
)
|
| 148 |
if self.streaming
|
| 149 |
else getattr(self, ds).map(
|
|
@@ -193,8 +222,8 @@ class Dataset:
|
|
| 193 |
while (self.multi_hosts and split == "train") or first_loop:
|
| 194 |
# in multi-host, we run forever (no epoch) as hosts need to stop
|
| 195 |
# at the same time and training data may not be split equally
|
| 196 |
-
# For validation data we put the entire
|
| 197 |
-
#
|
| 198 |
if epoch is not None:
|
| 199 |
assert split == "train"
|
| 200 |
# reshuffle training data at each epoch
|
|
@@ -252,6 +281,12 @@ def shift_tokens_right(input_ids: np.array, decoder_start_token_id: int):
|
|
| 252 |
return shifted_input_ids
|
| 253 |
|
| 254 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 255 |
def normalize_function(example, text_column, text_normalizer):
|
| 256 |
example[text_column] = text_normalizer(example[text_column])
|
| 257 |
return example
|
|
|
|
| 27 |
do_eval: bool = True
|
| 28 |
seed_dataset: int = None
|
| 29 |
shard_by_host: bool = False
|
| 30 |
+
blank_caption_prob: float = 0.0
|
| 31 |
train_dataset: Dataset = field(init=False)
|
| 32 |
eval_dataset: Dataset = field(init=False)
|
| 33 |
rng_dataset: jnp.ndarray = field(init=False)
|
|
|
|
| 35 |
|
| 36 |
def __post_init__(self):
|
| 37 |
self.multi_hosts = jax.process_count() > 1
|
| 38 |
+
# feed blank captions only in streaming mode for now
|
| 39 |
+
if self.blank_caption_prob:
|
| 40 |
+
assert (
|
| 41 |
+
self.streaming is True
|
| 42 |
+
), "blank_caption_prob can only be used in streaming mode"
|
| 43 |
# define data_files
|
| 44 |
if self.train_file is not None or self.validation_file is not None:
|
| 45 |
# accept braceexpand notation
|
|
|
|
| 107 |
self.seed_dataset = np.random.get_state()[1][0]
|
| 108 |
self.rng_dataset = jax.random.PRNGKey(self.seed_dataset)
|
| 109 |
|
| 110 |
+
# blank captions
|
| 111 |
+
if self.blank_caption_prob:
|
| 112 |
+
partial_blank_caption_function = partial(
|
| 113 |
+
blank_caption_function,
|
| 114 |
+
text_column=self.text_column,
|
| 115 |
+
blank_caption_prob=self.blank_caption_prob,
|
| 116 |
+
)
|
| 117 |
+
if hasattr(self, "train_dataset"):
|
| 118 |
+
self.train_dataset = (
|
| 119 |
+
self.train_dataset.map(partial_blank_caption_function)
|
| 120 |
+
if self.streaming
|
| 121 |
+
else self.train_dataset.map(
|
| 122 |
+
partial_blank_caption_function,
|
| 123 |
+
num_proc=self.preprocessing_num_workers,
|
| 124 |
+
load_from_cache_file=False,
|
| 125 |
+
desc="Blanking some captions",
|
| 126 |
+
)
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
# normalize text
|
| 130 |
if normalize_text:
|
| 131 |
text_normalizer = TextNormalizer()
|
|
|
|
| 169 |
getattr(self, ds).map(
|
| 170 |
partial_preprocess_function,
|
| 171 |
batched=True,
|
| 172 |
+
remove_columns=[
|
| 173 |
+
self.text_column,
|
| 174 |
+
self.encoding_column,
|
| 175 |
+
],
|
| 176 |
)
|
| 177 |
if self.streaming
|
| 178 |
else getattr(self, ds).map(
|
|
|
|
| 222 |
while (self.multi_hosts and split == "train") or first_loop:
|
| 223 |
# in multi-host, we run forever (no epoch) as hosts need to stop
|
| 224 |
# at the same time and training data may not be split equally
|
| 225 |
+
# For validation data we put the entire batch on each host and then
|
| 226 |
+
# keep only the one specific to each host (could be improved but not necessary)
|
| 227 |
if epoch is not None:
|
| 228 |
assert split == "train"
|
| 229 |
# reshuffle training data at each epoch
|
|
|
|
| 281 |
return shifted_input_ids
|
| 282 |
|
| 283 |
|
| 284 |
+
def blank_caption_function(example, text_column, blank_caption_prob):
|
| 285 |
+
if blank_caption_prob and np.random.rand() < blank_caption_prob:
|
| 286 |
+
example[text_column] = ""
|
| 287 |
+
return example
|
| 288 |
+
|
| 289 |
+
|
| 290 |
def normalize_function(example, text_column, text_normalizer):
|
| 291 |
example[text_column] = text_normalizer(example[text_column])
|
| 292 |
return example
|
src/dalle_mini/model/modeling.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
| 1 |
# coding=utf-8
|
| 2 |
-
# Copyright 2021 The Fairseq Authors and The Google Flax Team Authors And The HuggingFace Inc. team and
|
| 3 |
#
|
| 4 |
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
# you may not use this file except in compliance with the License.
|
|
@@ -328,6 +328,7 @@ class FlaxBartPreTrainedModel(FlaxBartPreTrainedModel):
|
|
| 328 |
dtype: jnp.dtype = jnp.float32,
|
| 329 |
abstract_init: bool = False,
|
| 330 |
load_on_cpu: bool = False,
|
|
|
|
| 331 |
**kwargs,
|
| 332 |
):
|
| 333 |
module = self.module_class(config=config, dtype=dtype, **kwargs)
|
|
@@ -347,25 +348,34 @@ class FlaxBartPreTrainedModel(FlaxBartPreTrainedModel):
|
|
| 347 |
self.key = PRNGKey(seed)
|
| 348 |
self.dtype = dtype
|
| 349 |
|
| 350 |
-
|
| 351 |
-
|
| 352 |
-
|
| 353 |
-
|
| 354 |
-
|
| 355 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 356 |
|
| 357 |
-
|
| 358 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 359 |
if abstract_init:
|
| 360 |
# only set shape and dtype, load parameters separately
|
| 361 |
init_fn = partial(init_fn, input_shape=input_shape)
|
| 362 |
-
|
| 363 |
else:
|
| 364 |
-
|
| 365 |
-
|
| 366 |
-
# save required_params as set
|
| 367 |
-
self._required_params = set(flatten_dict(unfreeze(random_params)).keys())
|
| 368 |
-
self.params = random_params
|
| 369 |
|
| 370 |
@property
|
| 371 |
def num_params(self):
|
|
|
|
| 1 |
# coding=utf-8
|
| 2 |
+
# Copyright 2021-2022 The Fairseq Authors and The Google Flax Team Authors And The HuggingFace Inc. team and & DALL·E Mini team. All rights reserved.
|
| 3 |
#
|
| 4 |
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
# you may not use this file except in compliance with the License.
|
|
|
|
| 328 |
dtype: jnp.dtype = jnp.float32,
|
| 329 |
abstract_init: bool = False,
|
| 330 |
load_on_cpu: bool = False,
|
| 331 |
+
init_weights: bool = True,
|
| 332 |
**kwargs,
|
| 333 |
):
|
| 334 |
module = self.module_class(config=config, dtype=dtype, **kwargs)
|
|
|
|
| 348 |
self.key = PRNGKey(seed)
|
| 349 |
self.dtype = dtype
|
| 350 |
|
| 351 |
+
if init_weights:
|
| 352 |
+
# get shape of params only
|
| 353 |
+
random_params = self.init_weights(
|
| 354 |
+
self.key,
|
| 355 |
+
input_shape,
|
| 356 |
+
abstract_init=abstract_init,
|
| 357 |
+
load_on_cpu=load_on_cpu,
|
| 358 |
+
)
|
| 359 |
+
|
| 360 |
+
# save required_params as set
|
| 361 |
+
self._required_params = set(flatten_dict(unfreeze(random_params)).keys())
|
| 362 |
+
self.params = random_params
|
| 363 |
|
| 364 |
+
def init_weights(
|
| 365 |
+
self, rng=None, input_shape=(1, 1), abstract_init=False, load_on_cpu=False
|
| 366 |
+
):
|
| 367 |
+
if rng is None:
|
| 368 |
+
rng = self.key
|
| 369 |
+
init_fn = super().init_weights
|
| 370 |
+
if load_on_cpu:
|
| 371 |
+
init_fn = jax.jit(init_fn, static_argnums=(1,), backend="cpu")
|
| 372 |
if abstract_init:
|
| 373 |
# only set shape and dtype, load parameters separately
|
| 374 |
init_fn = partial(init_fn, input_shape=input_shape)
|
| 375 |
+
params = jax.eval_shape(init_fn, rng)
|
| 376 |
else:
|
| 377 |
+
params = init_fn(rng, input_shape)
|
| 378 |
+
return params
|
|
|
|
|
|
|
|
|
|
| 379 |
|
| 380 |
@property
|
| 381 |
def num_params(self):
|
src/dalle_mini/model/utils.py
CHANGED
|
@@ -23,12 +23,6 @@ class PretrainedFromWandbMixin:
|
|
| 23 |
else:
|
| 24 |
artifact = wandb.Api().artifact(pretrained_model_name_or_path)
|
| 25 |
pretrained_model_name_or_path = artifact.download(tmp_dir)
|
| 26 |
-
if artifact.metadata.get("bucket_path"):
|
| 27 |
-
pretrained_model_name_or_path = artifact.metadata["bucket_path"]
|
| 28 |
-
|
| 29 |
-
if pretrained_model_name_or_path.startswith("gs://"):
|
| 30 |
-
copy_blobs(pretrained_model_name_or_path, tmp_dir)
|
| 31 |
-
pretrained_model_name_or_path = tmp_dir
|
| 32 |
|
| 33 |
return super(PretrainedFromWandbMixin, cls).from_pretrained(
|
| 34 |
pretrained_model_name_or_path, *model_args, **kwargs
|
|
|
|
| 23 |
else:
|
| 24 |
artifact = wandb.Api().artifact(pretrained_model_name_or_path)
|
| 25 |
pretrained_model_name_or_path = artifact.download(tmp_dir)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
|
| 27 |
return super(PretrainedFromWandbMixin, cls).from_pretrained(
|
| 28 |
pretrained_model_name_or_path, *model_args, **kwargs
|
tools/inference/inference_pipeline.ipynb
CHANGED
|
@@ -83,7 +83,7 @@
|
|
| 83 |
"VQGAN_COMMIT_ID = \"e93a26e7707683d349bf5d5c41c5b0ef69b677a9\"\n",
|
| 84 |
"\n",
|
| 85 |
"# CLIP model\n",
|
| 86 |
-
"CLIP_REPO = \"openai/clip-vit-
|
| 87 |
"CLIP_COMMIT_ID = None"
|
| 88 |
]
|
| 89 |
},
|
|
@@ -129,7 +129,6 @@
|
|
| 129 |
"from dalle_mini.model import DalleBart, DalleBartTokenizer\n",
|
| 130 |
"from vqgan_jax.modeling_flax_vqgan import VQModel\n",
|
| 131 |
"from transformers import CLIPProcessor, FlaxCLIPModel\n",
|
| 132 |
-
"import wandb\n",
|
| 133 |
"\n",
|
| 134 |
"# Load dalle-mini\n",
|
| 135 |
"model = DalleBart.from_pretrained(\n",
|
|
@@ -168,9 +167,9 @@
|
|
| 168 |
"if dtype == jnp.bfloat16:\n",
|
| 169 |
" model.params = model.to_bf16(model.params)\n",
|
| 170 |
"\n",
|
| 171 |
-
"
|
| 172 |
-
"
|
| 173 |
-
"
|
| 174 |
]
|
| 175 |
},
|
| 176 |
{
|
|
@@ -292,7 +291,7 @@
|
|
| 292 |
},
|
| 293 |
"outputs": [],
|
| 294 |
"source": [
|
| 295 |
-
"prompt = \"
|
| 296 |
]
|
| 297 |
},
|
| 298 |
{
|
|
@@ -414,12 +413,12 @@
|
|
| 414 |
" key, subkey = jax.random.split(key)\n",
|
| 415 |
" # generate images\n",
|
| 416 |
" encoded_images = p_generate(\n",
|
| 417 |
-
" tokenized_prompt, shard_prng_key(subkey),
|
| 418 |
" )\n",
|
| 419 |
" # remove BOS\n",
|
| 420 |
" encoded_images = encoded_images.sequences[..., 1:]\n",
|
| 421 |
" # decode images\n",
|
| 422 |
-
" decoded_images = p_decode(encoded_images,
|
| 423 |
" decoded_images = decoded_images.clip(0.0, 1.0).reshape((-1, 256, 256, 3))\n",
|
| 424 |
" for img in decoded_images:\n",
|
| 425 |
" images.append(Image.fromarray(np.asarray(img * 255, dtype=np.uint8)))"
|
|
@@ -453,7 +452,7 @@
|
|
| 453 |
" max_length=77,\n",
|
| 454 |
" truncation=True,\n",
|
| 455 |
").data\n",
|
| 456 |
-
"logits = p_clip(shard(clip_inputs),
|
| 457 |
"logits = logits.squeeze().flatten()"
|
| 458 |
]
|
| 459 |
},
|
|
@@ -479,6 +478,13 @@
|
|
| 479 |
" display(images[idx])\n",
|
| 480 |
" print(f\"Score: {logits[idx]:.2f}\\n\")"
|
| 481 |
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 482 |
}
|
| 483 |
],
|
| 484 |
"metadata": {
|
|
|
|
| 83 |
"VQGAN_COMMIT_ID = \"e93a26e7707683d349bf5d5c41c5b0ef69b677a9\"\n",
|
| 84 |
"\n",
|
| 85 |
"# CLIP model\n",
|
| 86 |
+
"CLIP_REPO = \"openai/clip-vit-large-patch14\"\n",
|
| 87 |
"CLIP_COMMIT_ID = None"
|
| 88 |
]
|
| 89 |
},
|
|
|
|
| 129 |
"from dalle_mini.model import DalleBart, DalleBartTokenizer\n",
|
| 130 |
"from vqgan_jax.modeling_flax_vqgan import VQModel\n",
|
| 131 |
"from transformers import CLIPProcessor, FlaxCLIPModel\n",
|
|
|
|
| 132 |
"\n",
|
| 133 |
"# Load dalle-mini\n",
|
| 134 |
"model = DalleBart.from_pretrained(\n",
|
|
|
|
| 167 |
"if dtype == jnp.bfloat16:\n",
|
| 168 |
" model.params = model.to_bf16(model.params)\n",
|
| 169 |
"\n",
|
| 170 |
+
"model._params = replicate(model.params)\n",
|
| 171 |
+
"vqgan._params = replicate(vqgan.params)\n",
|
| 172 |
+
"clip._params = replicate(clip.params)"
|
| 173 |
]
|
| 174 |
},
|
| 175 |
{
|
|
|
|
| 291 |
},
|
| 292 |
"outputs": [],
|
| 293 |
"source": [
|
| 294 |
+
"prompt = \"view of the beach during sunset\""
|
| 295 |
]
|
| 296 |
},
|
| 297 |
{
|
|
|
|
| 413 |
" key, subkey = jax.random.split(key)\n",
|
| 414 |
" # generate images\n",
|
| 415 |
" encoded_images = p_generate(\n",
|
| 416 |
+
" tokenized_prompt, shard_prng_key(subkey), model.params, gen_top_k, gen_top_p\n",
|
| 417 |
" )\n",
|
| 418 |
" # remove BOS\n",
|
| 419 |
" encoded_images = encoded_images.sequences[..., 1:]\n",
|
| 420 |
" # decode images\n",
|
| 421 |
+
" decoded_images = p_decode(encoded_images, vqgan.params)\n",
|
| 422 |
" decoded_images = decoded_images.clip(0.0, 1.0).reshape((-1, 256, 256, 3))\n",
|
| 423 |
" for img in decoded_images:\n",
|
| 424 |
" images.append(Image.fromarray(np.asarray(img * 255, dtype=np.uint8)))"
|
|
|
|
| 452 |
" max_length=77,\n",
|
| 453 |
" truncation=True,\n",
|
| 454 |
").data\n",
|
| 455 |
+
"logits = p_clip(shard(clip_inputs), clip.params)\n",
|
| 456 |
"logits = logits.squeeze().flatten()"
|
| 457 |
]
|
| 458 |
},
|
|
|
|
| 478 |
" display(images[idx])\n",
|
| 479 |
" print(f\"Score: {logits[idx]:.2f}\\n\")"
|
| 480 |
]
|
| 481 |
+
},
|
| 482 |
+
{
|
| 483 |
+
"cell_type": "code",
|
| 484 |
+
"execution_count": null,
|
| 485 |
+
"metadata": {},
|
| 486 |
+
"outputs": [],
|
| 487 |
+
"source": []
|
| 488 |
}
|
| 489 |
],
|
| 490 |
"metadata": {
|
tools/train/config/medium/config.json
CHANGED
|
@@ -28,6 +28,5 @@
|
|
| 28 |
"pad_token_id": 16385,
|
| 29 |
"scale_embedding": false,
|
| 30 |
"tie_word_embeddings": false,
|
| 31 |
-
"transformers_version": "4.13.0.dev0",
|
| 32 |
"use_cache": true
|
| 33 |
}
|
|
|
|
| 28 |
"pad_token_id": 16385,
|
| 29 |
"scale_embedding": false,
|
| 30 |
"tie_word_embeddings": false,
|
|
|
|
| 31 |
"use_cache": true
|
| 32 |
}
|
tools/train/config/mega/config.json
CHANGED
|
@@ -5,21 +5,20 @@
|
|
| 5 |
"bos_token_id": 16385,
|
| 6 |
"classifier_dropout": 0.0,
|
| 7 |
"d_model": 2048,
|
| 8 |
-
"decoder_attention_heads":
|
| 9 |
-
"decoder_ffn_dim":
|
| 10 |
"decoder_layerdrop": 0.0,
|
| 11 |
-
"decoder_layers":
|
| 12 |
"decoder_start_token_id": 16384,
|
| 13 |
-
"dropout": 0.
|
| 14 |
-
"encoder_attention_heads":
|
| 15 |
-
"encoder_ffn_dim":
|
| 16 |
"encoder_layerdrop": 0.0,
|
| 17 |
-
"encoder_layers":
|
| 18 |
"encoder_vocab_size": 50264,
|
| 19 |
"eos_token_id": 16385,
|
| 20 |
-
"gradient_checkpointing": false,
|
| 21 |
"image_length": 256,
|
| 22 |
-
"image_vocab_size":
|
| 23 |
"init_std": 0.01,
|
| 24 |
"is_encoder_decoder": true,
|
| 25 |
"max_text_length": 64,
|
|
@@ -28,6 +27,5 @@
|
|
| 28 |
"pad_token_id": 16385,
|
| 29 |
"scale_embedding": false,
|
| 30 |
"tie_word_embeddings": false,
|
| 31 |
-
"transformers_version": "4.13.0.dev0",
|
| 32 |
"use_cache": true
|
| 33 |
}
|
|
|
|
| 5 |
"bos_token_id": 16385,
|
| 6 |
"classifier_dropout": 0.0,
|
| 7 |
"d_model": 2048,
|
| 8 |
+
"decoder_attention_heads": 32,
|
| 9 |
+
"decoder_ffn_dim": 8192,
|
| 10 |
"decoder_layerdrop": 0.0,
|
| 11 |
+
"decoder_layers": 24,
|
| 12 |
"decoder_start_token_id": 16384,
|
| 13 |
+
"dropout": 0.0,
|
| 14 |
+
"encoder_attention_heads": 32,
|
| 15 |
+
"encoder_ffn_dim": 8192,
|
| 16 |
"encoder_layerdrop": 0.0,
|
| 17 |
+
"encoder_layers": 24,
|
| 18 |
"encoder_vocab_size": 50264,
|
| 19 |
"eos_token_id": 16385,
|
|
|
|
| 20 |
"image_length": 256,
|
| 21 |
+
"image_vocab_size": 16391,
|
| 22 |
"init_std": 0.01,
|
| 23 |
"is_encoder_decoder": true,
|
| 24 |
"max_text_length": 64,
|
|
|
|
| 27 |
"pad_token_id": 16385,
|
| 28 |
"scale_embedding": false,
|
| 29 |
"tie_word_embeddings": false,
|
|
|
|
| 30 |
"use_cache": true
|
| 31 |
}
|
tools/train/config/micro/config.json
CHANGED
|
@@ -4,22 +4,21 @@
|
|
| 4 |
"attention_dropout": 0.0,
|
| 5 |
"bos_token_id": 16385,
|
| 6 |
"classifier_dropout": 0.0,
|
| 7 |
-
"d_model":
|
| 8 |
-
"decoder_attention_heads":
|
| 9 |
-
"decoder_ffn_dim":
|
| 10 |
"decoder_layerdrop": 0.0,
|
| 11 |
"decoder_layers": 2,
|
| 12 |
"decoder_start_token_id": 16384,
|
| 13 |
"dropout": 0.0,
|
| 14 |
-
"encoder_attention_heads":
|
| 15 |
-
"encoder_ffn_dim":
|
| 16 |
"encoder_layerdrop": 0.0,
|
| 17 |
"encoder_layers": 2,
|
| 18 |
"encoder_vocab_size": 50264,
|
| 19 |
"eos_token_id": 16385,
|
| 20 |
-
"gradient_checkpointing": false,
|
| 21 |
"image_length": 256,
|
| 22 |
-
"image_vocab_size":
|
| 23 |
"init_std": 0.02,
|
| 24 |
"is_encoder_decoder": true,
|
| 25 |
"max_text_length": 64,
|
|
@@ -28,6 +27,5 @@
|
|
| 28 |
"pad_token_id": 16385,
|
| 29 |
"scale_embedding": false,
|
| 30 |
"tie_word_embeddings": false,
|
| 31 |
-
"transformers_version": "4.13.0.dev0",
|
| 32 |
"use_cache": true
|
| 33 |
}
|
|
|
|
| 4 |
"attention_dropout": 0.0,
|
| 5 |
"bos_token_id": 16385,
|
| 6 |
"classifier_dropout": 0.0,
|
| 7 |
+
"d_model": 256,
|
| 8 |
+
"decoder_attention_heads": 2,
|
| 9 |
+
"decoder_ffn_dim": 256,
|
| 10 |
"decoder_layerdrop": 0.0,
|
| 11 |
"decoder_layers": 2,
|
| 12 |
"decoder_start_token_id": 16384,
|
| 13 |
"dropout": 0.0,
|
| 14 |
+
"encoder_attention_heads": 2,
|
| 15 |
+
"encoder_ffn_dim": 256,
|
| 16 |
"encoder_layerdrop": 0.0,
|
| 17 |
"encoder_layers": 2,
|
| 18 |
"encoder_vocab_size": 50264,
|
| 19 |
"eos_token_id": 16385,
|
|
|
|
| 20 |
"image_length": 256,
|
| 21 |
+
"image_vocab_size": 16391,
|
| 22 |
"init_std": 0.02,
|
| 23 |
"is_encoder_decoder": true,
|
| 24 |
"max_text_length": 64,
|
|
|
|
| 27 |
"pad_token_id": 16385,
|
| 28 |
"scale_embedding": false,
|
| 29 |
"tie_word_embeddings": false,
|
|
|
|
| 30 |
"use_cache": true
|
| 31 |
}
|
tools/train/config/mini/config.json
CHANGED
|
@@ -28,6 +28,5 @@
|
|
| 28 |
"pad_token_id": 16385,
|
| 29 |
"scale_embedding": false,
|
| 30 |
"tie_word_embeddings": false,
|
| 31 |
-
"transformers_version": "4.13.0.dev0",
|
| 32 |
"use_cache": true
|
| 33 |
}
|
|
|
|
| 28 |
"pad_token_id": 16385,
|
| 29 |
"scale_embedding": false,
|
| 30 |
"tie_word_embeddings": false,
|
|
|
|
| 31 |
"use_cache": true
|
| 32 |
}
|
tools/train/scalable_shampoo/README.md
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Notes
|
| 2 |
+
|
| 3 |
+
Files copied from [google-research/scalable_shampoo/optax](https://github.com/google-research/google-research/tree/master/scalable_shampoo/optax).
|
| 4 |
+
|
| 5 |
+
Imports have been modified to be relative.
|
| 6 |
+
|
| 7 |
+
This will be replaced with `optax-shampoo` package eventually.
|
tools/train/{distributed_shampoo.py → scalable_shampoo/distributed_shampoo.py}
RENAMED
|
@@ -1,5 +1,3 @@
|
|
| 1 |
-
# file from: https://github.com/google-research/google-research/blob/master/scalable_shampoo/optax/distributed_shampoo.py
|
| 2 |
-
|
| 3 |
# coding=utf-8
|
| 4 |
# Copyright 2022 The Google Research Authors.
|
| 5 |
#
|
|
@@ -44,107 +42,12 @@ import optax
|
|
| 44 |
from flax import struct
|
| 45 |
from jax import lax
|
| 46 |
|
|
|
|
| 47 |
|
| 48 |
-
#
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
quantized: chex.Array
|
| 54 |
-
diagonal: chex.Array # Diagonal (if extract_diagonal is set)
|
| 55 |
-
bucket_size: chex.Array
|
| 56 |
-
quantized_dtype: jnp.dtype = struct.field(
|
| 57 |
-
pytree_node=False
|
| 58 |
-
) # Dtype for the quantized value.
|
| 59 |
-
extract_diagonal: bool = struct.field(pytree_node=False) # In case its centered.
|
| 60 |
-
shape: Any = struct.field(pytree_node=False) # Shape of the tensor.
|
| 61 |
-
|
| 62 |
-
@classmethod
|
| 63 |
-
def from_float_value(cls, fvalue, quantized_dtype, extract_diagonal=False):
|
| 64 |
-
if isinstance(fvalue, list) and not fvalue:
|
| 65 |
-
return QuantizedValue([], [], [], quantized_dtype, extract_diagonal, [])
|
| 66 |
-
quantized, diagonal_fvalue, bucket_size = QuantizedValue.quantize(
|
| 67 |
-
fvalue, quantized_dtype, extract_diagonal
|
| 68 |
-
)
|
| 69 |
-
return QuantizedValue(
|
| 70 |
-
quantized,
|
| 71 |
-
diagonal_fvalue,
|
| 72 |
-
bucket_size,
|
| 73 |
-
quantized_dtype,
|
| 74 |
-
extract_diagonal,
|
| 75 |
-
list(quantized.shape),
|
| 76 |
-
)
|
| 77 |
-
|
| 78 |
-
# Quantization is from Lingvo JAX optimizers.
|
| 79 |
-
# We extend it for int16 quantization of PSD matrices.
|
| 80 |
-
@classmethod
|
| 81 |
-
def quantize(cls, fvalue, quantized_dtype, extract_diagonal=False):
|
| 82 |
-
"""Returns quantized value and the bucket."""
|
| 83 |
-
if quantized_dtype == jnp.float32:
|
| 84 |
-
return fvalue, [], []
|
| 85 |
-
elif quantized_dtype == jnp.bfloat16:
|
| 86 |
-
return fvalue.astype(jnp.bfloat16), [], []
|
| 87 |
-
|
| 88 |
-
float_dtype = fvalue.dtype
|
| 89 |
-
if quantized_dtype == jnp.int8:
|
| 90 |
-
# value -128 is not used.
|
| 91 |
-
num_buckets = jnp.array(127.0, dtype=float_dtype)
|
| 92 |
-
elif quantized_dtype == jnp.int16:
|
| 93 |
-
# value -32768 is not used.
|
| 94 |
-
num_buckets = jnp.array(32767.0, dtype=float_dtype)
|
| 95 |
-
else:
|
| 96 |
-
raise ValueError(f"Quantized dtype {quantized_dtype} not supported.")
|
| 97 |
-
# max value is mapped to num_buckets
|
| 98 |
-
|
| 99 |
-
if extract_diagonal and fvalue.ndim != 2:
|
| 100 |
-
raise ValueError(
|
| 101 |
-
f"Input array {fvalue} must be 2D to work with extract_diagonal."
|
| 102 |
-
)
|
| 103 |
-
|
| 104 |
-
diagonal_fvalue = []
|
| 105 |
-
if extract_diagonal:
|
| 106 |
-
diagonal_fvalue = jnp.diag(fvalue)
|
| 107 |
-
# Remove the diagonal entries.
|
| 108 |
-
fvalue = fvalue - jnp.diag(diagonal_fvalue)
|
| 109 |
-
|
| 110 |
-
# TODO(rohananil): Extend this by making use of information about the blocks
|
| 111 |
-
# SM3 style which will be useful for diagonal statistics
|
| 112 |
-
# We first decide the scale.
|
| 113 |
-
if fvalue.ndim < 1:
|
| 114 |
-
raise ValueError(
|
| 115 |
-
f"Input array {fvalue} must have a strictly positive number of "
|
| 116 |
-
"dimensions."
|
| 117 |
-
)
|
| 118 |
-
|
| 119 |
-
max_abs = jnp.max(jnp.abs(fvalue), axis=0)
|
| 120 |
-
bucket_size = max_abs / num_buckets
|
| 121 |
-
bs_expanded = bucket_size[jnp.newaxis, Ellipsis]
|
| 122 |
-
# To avoid divide by 0.0
|
| 123 |
-
bs_nonzero = jnp.where(
|
| 124 |
-
bs_expanded > 0.0, bs_expanded, jnp.ones_like(bs_expanded)
|
| 125 |
-
)
|
| 126 |
-
ratio = fvalue / bs_nonzero
|
| 127 |
-
# We use rounding to remove bias.
|
| 128 |
-
quantized = jnp.round(ratio)
|
| 129 |
-
return quantized.astype(quantized_dtype), diagonal_fvalue, bucket_size
|
| 130 |
-
|
| 131 |
-
def to_float(self):
|
| 132 |
-
"""Returns the float value."""
|
| 133 |
-
if isinstance(self.quantized, list) and not self.quantized:
|
| 134 |
-
return self.quantized
|
| 135 |
-
|
| 136 |
-
if self.quantized_dtype == jnp.float32:
|
| 137 |
-
return self.quantized
|
| 138 |
-
|
| 139 |
-
if self.quantized_dtype == jnp.bfloat16:
|
| 140 |
-
return self.quantized.astype(jnp.float32)
|
| 141 |
-
|
| 142 |
-
float_dtype = self.bucket_size.dtype
|
| 143 |
-
bucket_size = self.bucket_size[jnp.newaxis, Ellipsis]
|
| 144 |
-
val = self.quantized.astype(float_dtype) * bucket_size
|
| 145 |
-
if self.extract_diagonal:
|
| 146 |
-
val += jnp.diag(self.diagonal)
|
| 147 |
-
return val
|
| 148 |
|
| 149 |
|
| 150 |
@struct.dataclass
|
|
@@ -193,24 +96,21 @@ class LocalShardedParameterStats:
|
|
| 193 |
|
| 194 |
|
| 195 |
def init_training_metrics(num_statistics):
|
| 196 |
-
if
|
| 197 |
-
|
| 198 |
-
else
|
| 199 |
-
|
| 200 |
|
| 201 |
|
| 202 |
def init_training_metrics_shapes(num_statistics):
|
| 203 |
-
if
|
| 204 |
-
|
| 205 |
-
else
|
| 206 |
-
|
| 207 |
|
| 208 |
|
| 209 |
-
def init_training_metrics_pspec(
|
| 210 |
-
|
| 211 |
-
return TrainingMetrics(pjit.PartitionSpec())
|
| 212 |
-
else:
|
| 213 |
-
return TrainingMetrics(None)
|
| 214 |
|
| 215 |
|
| 216 |
class ShardedShampooStats(NamedTuple):
|
|
@@ -296,6 +196,30 @@ def power_iteration(
|
|
| 296 |
return v_out, s_out
|
| 297 |
|
| 298 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 299 |
def matrix_inverse_pth_root(
|
| 300 |
matrix,
|
| 301 |
p,
|
|
@@ -332,57 +256,19 @@ def matrix_inverse_pth_root(
|
|
| 332 |
|
| 333 |
assert matrix.shape[0] == matrix.shape[1]
|
| 334 |
|
| 335 |
-
# We use
|
| 336 |
-
# Switch to f64 if you have hardware that supports it.
|
|
|
|
| 337 |
matrix_size = matrix.shape[0]
|
| 338 |
-
|
| 339 |
-
|
|
|
|
|
|
|
| 340 |
_, max_ev = power_iteration(
|
| 341 |
matrix=matrix, num_iters=100, error_tolerance=1e-6, precision=precision
|
| 342 |
)
|
| 343 |
ridge_epsilon = ridge_epsilon * jnp.maximum(max_ev, 1e-6)
|
| 344 |
|
| 345 |
-
def _unrolled_mat_pow_1(mat_m):
|
| 346 |
-
"""Computes mat_m^1."""
|
| 347 |
-
return mat_m
|
| 348 |
-
|
| 349 |
-
def _unrolled_mat_pow_2(mat_m):
|
| 350 |
-
"""Computes mat_m^2."""
|
| 351 |
-
return jnp.matmul(mat_m, mat_m, precision=precision)
|
| 352 |
-
|
| 353 |
-
def _unrolled_mat_pow_4(mat_m):
|
| 354 |
-
"""Computes mat_m^4."""
|
| 355 |
-
mat_pow_2 = _unrolled_mat_pow_2(mat_m)
|
| 356 |
-
return jnp.matmul(mat_pow_2, mat_pow_2, precision=precision)
|
| 357 |
-
|
| 358 |
-
def _unrolled_mat_pow_8(mat_m):
|
| 359 |
-
"""Computes mat_m^4."""
|
| 360 |
-
mat_pow_4 = _unrolled_mat_pow_4(mat_m)
|
| 361 |
-
return jnp.matmul(mat_pow_4, mat_pow_4, precision=precision)
|
| 362 |
-
|
| 363 |
-
def mat_power(mat_m, p):
|
| 364 |
-
"""Computes mat_m^p, for p == 1, 2, 4 or 8.
|
| 365 |
-
|
| 366 |
-
Args:
|
| 367 |
-
mat_m: a square matrix
|
| 368 |
-
p: a positive integer
|
| 369 |
-
|
| 370 |
-
Returns:
|
| 371 |
-
mat_m^p
|
| 372 |
-
"""
|
| 373 |
-
# We unrolled the loop for performance reasons.
|
| 374 |
-
exponent = jnp.round(jnp.log2(p))
|
| 375 |
-
return lax.switch(
|
| 376 |
-
jnp.asarray(exponent, jnp.int32),
|
| 377 |
-
[
|
| 378 |
-
_unrolled_mat_pow_1,
|
| 379 |
-
_unrolled_mat_pow_2,
|
| 380 |
-
_unrolled_mat_pow_4,
|
| 381 |
-
_unrolled_mat_pow_8,
|
| 382 |
-
],
|
| 383 |
-
(mat_m),
|
| 384 |
-
)
|
| 385 |
-
|
| 386 |
def _iter_condition(state):
|
| 387 |
(i, unused_mat_m, unused_mat_h, unused_old_mat_h, error, run_step) = state
|
| 388 |
error_above_threshold = jnp.logical_and(error > error_tolerance, run_step)
|
|
@@ -412,10 +298,10 @@ def matrix_inverse_pth_root(
|
|
| 412 |
_, mat_m, mat_h, old_mat_h, error, convergence = lax.while_loop(
|
| 413 |
_iter_condition, _iter_body, init_state
|
| 414 |
)
|
| 415 |
-
error = jnp.max(jnp.abs(mat_m - identity))
|
| 416 |
is_converged = jnp.asarray(convergence, old_mat_h.dtype)
|
| 417 |
resultant_mat_h = is_converged * mat_h + (1 - is_converged) * old_mat_h
|
| 418 |
-
resultant_mat_h = jnp.asarray(resultant_mat_h,
|
| 419 |
return resultant_mat_h, error
|
| 420 |
|
| 421 |
|
|
@@ -433,6 +319,9 @@ def merge_small_dims(shape_to_merge, max_dim):
|
|
| 433 |
Returns:
|
| 434 |
Merged shape.
|
| 435 |
"""
|
|
|
|
|
|
|
|
|
|
| 436 |
resulting_shape = []
|
| 437 |
product = 1
|
| 438 |
for d in shape_to_merge:
|
|
@@ -975,16 +864,22 @@ def distributed_shampoo(
|
|
| 975 |
)
|
| 976 |
|
| 977 |
local_stats = jax.tree_unflatten(treedef, local_stats_flat)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 978 |
# Pad the statistics and preconditioner matrices to be a multiple of
|
| 979 |
# num devices.
|
| 980 |
# TODO(rohananil): Relax to only the size of the mesh axis where the dim
|
| 981 |
# is split on.
|
| 982 |
-
to_pad = -len(padded_statistics) % num_devices_for_pjit
|
| 983 |
padded_statistics.extend(
|
| 984 |
-
[jnp.eye(max_size, dtype=
|
| 985 |
)
|
| 986 |
padded_preconditioners.extend(
|
| 987 |
-
[jnp.eye(max_size, dtype=
|
| 988 |
)
|
| 989 |
exponents.extend([1 for _ in range(to_pad)])
|
| 990 |
global_stats = GlobalShardedParameterStats(
|
|
@@ -1016,7 +911,7 @@ def distributed_shampoo(
|
|
| 1016 |
if pspec and len(pspec) > 1:
|
| 1017 |
return pjit.PartitionSpec(*pspec[1:])
|
| 1018 |
else:
|
| 1019 |
-
return
|
| 1020 |
|
| 1021 |
def sharded_init_partition_spec_fn(
|
| 1022 |
params, params_partition_spec, partition_spec_for_statistics
|
|
@@ -1102,7 +997,7 @@ def distributed_shampoo(
|
|
| 1102 |
False,
|
| 1103 |
list(param.shape),
|
| 1104 |
),
|
| 1105 |
-
init_training_metrics_pspec(
|
| 1106 |
index_start,
|
| 1107 |
sizes,
|
| 1108 |
)
|
|
@@ -1209,6 +1104,9 @@ def distributed_shampoo(
|
|
| 1209 |
max_statistics_size = _max_statistics_size_from_params(params_flat)
|
| 1210 |
to_pad = -num_statistics % num_devices_for_pjit
|
| 1211 |
num_statistics += to_pad
|
|
|
|
|
|
|
|
|
|
| 1212 |
statistics_shape = [num_statistics, max_statistics_size, max_statistics_size]
|
| 1213 |
global_stats = GlobalShardedParameterStats(
|
| 1214 |
[statistics_shape, jnp.float32],
|
|
@@ -2069,7 +1967,7 @@ def distributed_shampoo(
|
|
| 2069 |
|
| 2070 |
scaled_grad = grad
|
| 2071 |
if graft_type == GraftingType.ADAGRAD_NORMALIZED:
|
| 2072 |
-
scaled_grad = grad / jnp.linalg.norm(grad)
|
| 2073 |
|
| 2074 |
new_diagonal_statistics = state.diagonal_statistics.to_float() + jnp.square(
|
| 2075 |
scaled_grad
|
|
@@ -2085,7 +1983,7 @@ def distributed_shampoo(
|
|
| 2085 |
|
| 2086 |
scaled_grad = grad
|
| 2087 |
if graft_type == GraftingType.RMSPROP_NORMALIZED:
|
| 2088 |
-
scaled_grad = grad / jnp.linalg.norm(grad)
|
| 2089 |
|
| 2090 |
w1 = beta2
|
| 2091 |
w2 = beta2 if beta2 == 1.0 else (1.0 - beta2)
|
|
@@ -2212,7 +2110,6 @@ def distributed_shampoo(
|
|
| 2212 |
new_stats_flat = _compute_preconditioners(
|
| 2213 |
new_stats_flat, params_flat, state.count
|
| 2214 |
)
|
| 2215 |
-
|
| 2216 |
outputs = jax.tree_multimap(
|
| 2217 |
lambda g, s, p: _transform_grad(g, s, p, state.count),
|
| 2218 |
grads_flat,
|
|
|
|
|
|
|
|
|
|
| 1 |
# coding=utf-8
|
| 2 |
# Copyright 2022 The Google Research Authors.
|
| 3 |
#
|
|
|
|
| 42 |
from flax import struct
|
| 43 |
from jax import lax
|
| 44 |
|
| 45 |
+
from .quantization_utils import QuantizedValue
|
| 46 |
|
| 47 |
+
# Dtype for inverse-pth root routine
|
| 48 |
+
# Switch to f64 if you have hardware that supports it. Enable the jax flag
|
| 49 |
+
# jax_enable_x64 for this to work, otherwise it will default to float32.
|
| 50 |
+
_MAT_INV_PTH_ROOT_DTYPE = jnp.float64
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 51 |
|
| 52 |
|
| 53 |
@struct.dataclass
|
|
|
|
| 96 |
|
| 97 |
|
| 98 |
def init_training_metrics(num_statistics):
|
| 99 |
+
# Since the downstream apis expect a jnp.array - we create a dummy one if
|
| 100 |
+
# num_statistics=0.
|
| 101 |
+
n = 1 if not num_statistics else num_statistics
|
| 102 |
+
return TrainingMetrics(jnp.zeros([n], jnp.float32))
|
| 103 |
|
| 104 |
|
| 105 |
def init_training_metrics_shapes(num_statistics):
|
| 106 |
+
# Since the downstream apis expect a jnp.array - we create a dummy one if
|
| 107 |
+
# num_statistics=0.
|
| 108 |
+
n = 1 if not num_statistics else num_statistics
|
| 109 |
+
return TrainingMetrics([[n], jnp.float32])
|
| 110 |
|
| 111 |
|
| 112 |
+
def init_training_metrics_pspec():
|
| 113 |
+
return TrainingMetrics(pjit.PartitionSpec())
|
|
|
|
|
|
|
|
|
|
| 114 |
|
| 115 |
|
| 116 |
class ShardedShampooStats(NamedTuple):
|
|
|
|
| 196 |
return v_out, s_out
|
| 197 |
|
| 198 |
|
| 199 |
+
def mat_power(mat_m, p, precision=lax.Precision.HIGHEST):
|
| 200 |
+
"""A simple matrix power method. M^p where p can be TracedValue."""
|
| 201 |
+
power = jnp.eye(mat_m.shape[0], dtype=_MAT_INV_PTH_ROOT_DTYPE)
|
| 202 |
+
|
| 203 |
+
def _iter_condition(state):
|
| 204 |
+
i, _, _ = state
|
| 205 |
+
return i > 0
|
| 206 |
+
|
| 207 |
+
def _iter_body(state):
|
| 208 |
+
i, power, mat = state
|
| 209 |
+
|
| 210 |
+
power = jax.lax.cond(
|
| 211 |
+
i % 2 == 1,
|
| 212 |
+
lambda: jnp.matmul(mat, power, precision=precision),
|
| 213 |
+
lambda: power,
|
| 214 |
+
)
|
| 215 |
+
i //= 2
|
| 216 |
+
mat = jnp.matmul(mat, mat, precision=precision)
|
| 217 |
+
return i, power, mat
|
| 218 |
+
|
| 219 |
+
_, result, _ = lax.while_loop(_iter_condition, _iter_body, (p, power, mat_m))
|
| 220 |
+
return result
|
| 221 |
+
|
| 222 |
+
|
| 223 |
def matrix_inverse_pth_root(
|
| 224 |
matrix,
|
| 225 |
p,
|
|
|
|
| 256 |
|
| 257 |
assert matrix.shape[0] == matrix.shape[1]
|
| 258 |
|
| 259 |
+
# We use _MAT_INV_PTH_ROOT_DTYPE for the matrix inverse pth root.
|
| 260 |
+
# Switch to f64 if you have hardware that supports it. Enable the jax flag
|
| 261 |
+
# jax_enable_x64 for this to work.
|
| 262 |
matrix_size = matrix.shape[0]
|
| 263 |
+
orig_dtype = matrix.dtype
|
| 264 |
+
matrix = matrix.astype(_MAT_INV_PTH_ROOT_DTYPE)
|
| 265 |
+
alpha = jnp.asarray(-1.0 / p, _MAT_INV_PTH_ROOT_DTYPE)
|
| 266 |
+
identity = jnp.eye(matrix_size, dtype=_MAT_INV_PTH_ROOT_DTYPE)
|
| 267 |
_, max_ev = power_iteration(
|
| 268 |
matrix=matrix, num_iters=100, error_tolerance=1e-6, precision=precision
|
| 269 |
)
|
| 270 |
ridge_epsilon = ridge_epsilon * jnp.maximum(max_ev, 1e-6)
|
| 271 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 272 |
def _iter_condition(state):
|
| 273 |
(i, unused_mat_m, unused_mat_h, unused_old_mat_h, error, run_step) = state
|
| 274 |
error_above_threshold = jnp.logical_and(error > error_tolerance, run_step)
|
|
|
|
| 298 |
_, mat_m, mat_h, old_mat_h, error, convergence = lax.while_loop(
|
| 299 |
_iter_condition, _iter_body, init_state
|
| 300 |
)
|
| 301 |
+
error = jnp.max(jnp.abs(mat_m - identity)).astype(jnp.float32)
|
| 302 |
is_converged = jnp.asarray(convergence, old_mat_h.dtype)
|
| 303 |
resultant_mat_h = is_converged * mat_h + (1 - is_converged) * old_mat_h
|
| 304 |
+
resultant_mat_h = jnp.asarray(resultant_mat_h, orig_dtype)
|
| 305 |
return resultant_mat_h, error
|
| 306 |
|
| 307 |
|
|
|
|
| 319 |
Returns:
|
| 320 |
Merged shape.
|
| 321 |
"""
|
| 322 |
+
if shape_to_merge and np.all(np.array(shape_to_merge) == 1):
|
| 323 |
+
return [1]
|
| 324 |
+
|
| 325 |
resulting_shape = []
|
| 326 |
product = 1
|
| 327 |
for d in shape_to_merge:
|
|
|
|
| 864 |
)
|
| 865 |
|
| 866 |
local_stats = jax.tree_unflatten(treedef, local_stats_flat)
|
| 867 |
+
to_pad = -len(padded_statistics) % num_devices_for_pjit
|
| 868 |
+
if max_size == 0:
|
| 869 |
+
to_pad = num_devices_for_pjit
|
| 870 |
+
max_size = block_size
|
| 871 |
+
stat_dtype = jnp.float32
|
| 872 |
+
else:
|
| 873 |
+
stat_dtype = padded_statistics[0].dtype
|
| 874 |
# Pad the statistics and preconditioner matrices to be a multiple of
|
| 875 |
# num devices.
|
| 876 |
# TODO(rohananil): Relax to only the size of the mesh axis where the dim
|
| 877 |
# is split on.
|
|
|
|
| 878 |
padded_statistics.extend(
|
| 879 |
+
[jnp.eye(max_size, dtype=stat_dtype) for _ in range(to_pad)]
|
| 880 |
)
|
| 881 |
padded_preconditioners.extend(
|
| 882 |
+
[jnp.eye(max_size, dtype=stat_dtype) for _ in range(to_pad)]
|
| 883 |
)
|
| 884 |
exponents.extend([1 for _ in range(to_pad)])
|
| 885 |
global_stats = GlobalShardedParameterStats(
|
|
|
|
| 911 |
if pspec and len(pspec) > 1:
|
| 912 |
return pjit.PartitionSpec(*pspec[1:])
|
| 913 |
else:
|
| 914 |
+
return []
|
| 915 |
|
| 916 |
def sharded_init_partition_spec_fn(
|
| 917 |
params, params_partition_spec, partition_spec_for_statistics
|
|
|
|
| 997 |
False,
|
| 998 |
list(param.shape),
|
| 999 |
),
|
| 1000 |
+
init_training_metrics_pspec(),
|
| 1001 |
index_start,
|
| 1002 |
sizes,
|
| 1003 |
)
|
|
|
|
| 1104 |
max_statistics_size = _max_statistics_size_from_params(params_flat)
|
| 1105 |
to_pad = -num_statistics % num_devices_for_pjit
|
| 1106 |
num_statistics += to_pad
|
| 1107 |
+
if num_statistics == 0:
|
| 1108 |
+
num_statistics = num_devices_for_pjit
|
| 1109 |
+
max_statistics_size = block_size
|
| 1110 |
statistics_shape = [num_statistics, max_statistics_size, max_statistics_size]
|
| 1111 |
global_stats = GlobalShardedParameterStats(
|
| 1112 |
[statistics_shape, jnp.float32],
|
|
|
|
| 1967 |
|
| 1968 |
scaled_grad = grad
|
| 1969 |
if graft_type == GraftingType.ADAGRAD_NORMALIZED:
|
| 1970 |
+
scaled_grad = grad / (jnp.linalg.norm(grad) + 1e-16)
|
| 1971 |
|
| 1972 |
new_diagonal_statistics = state.diagonal_statistics.to_float() + jnp.square(
|
| 1973 |
scaled_grad
|
|
|
|
| 1983 |
|
| 1984 |
scaled_grad = grad
|
| 1985 |
if graft_type == GraftingType.RMSPROP_NORMALIZED:
|
| 1986 |
+
scaled_grad = grad / (jnp.linalg.norm(grad) + 1e-16)
|
| 1987 |
|
| 1988 |
w1 = beta2
|
| 1989 |
w2 = beta2 if beta2 == 1.0 else (1.0 - beta2)
|
|
|
|
| 2110 |
new_stats_flat = _compute_preconditioners(
|
| 2111 |
new_stats_flat, params_flat, state.count
|
| 2112 |
)
|
|
|
|
| 2113 |
outputs = jax.tree_multimap(
|
| 2114 |
lambda g, s, p: _transform_grad(g, s, p, state.count),
|
| 2115 |
grads_flat,
|
tools/train/scalable_shampoo/quantization_utils.py
ADDED
|
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2022 The Google Research Authors.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
"""Helper routines for quantization."""
|
| 17 |
+
|
| 18 |
+
from typing import Any
|
| 19 |
+
|
| 20 |
+
import chex
|
| 21 |
+
import jax.numpy as jnp
|
| 22 |
+
from flax import struct
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
# pylint:disable=no-value-for-parameter
|
| 26 |
+
@struct.dataclass
|
| 27 |
+
class QuantizedValue:
|
| 28 |
+
"""State associated with quantized value."""
|
| 29 |
+
|
| 30 |
+
quantized: chex.Array
|
| 31 |
+
diagonal: chex.Array # Diagonal (if extract_diagonal is set)
|
| 32 |
+
bucket_size: chex.Array
|
| 33 |
+
quantized_dtype: jnp.dtype = struct.field(
|
| 34 |
+
pytree_node=False
|
| 35 |
+
) # Dtype for the quantized value.
|
| 36 |
+
extract_diagonal: bool = struct.field(pytree_node=False) # In case its centered.
|
| 37 |
+
shape: Any = struct.field(pytree_node=False) # Shape of the tensor.
|
| 38 |
+
|
| 39 |
+
@classmethod
|
| 40 |
+
def from_float_value(cls, fvalue, quantized_dtype, extract_diagonal=False):
|
| 41 |
+
if isinstance(fvalue, list) and not fvalue:
|
| 42 |
+
return QuantizedValue([], [], [], quantized_dtype, extract_diagonal, [])
|
| 43 |
+
quantized, diagonal_fvalue, bucket_size = QuantizedValue.quantize(
|
| 44 |
+
fvalue, quantized_dtype, extract_diagonal
|
| 45 |
+
)
|
| 46 |
+
return QuantizedValue(
|
| 47 |
+
quantized,
|
| 48 |
+
diagonal_fvalue,
|
| 49 |
+
bucket_size,
|
| 50 |
+
quantized_dtype,
|
| 51 |
+
extract_diagonal,
|
| 52 |
+
list(quantized.shape),
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
# Quantization is from Lingvo JAX optimizers.
|
| 56 |
+
# We extend it for int16 quantization of PSD matrices.
|
| 57 |
+
@classmethod
|
| 58 |
+
def quantize(cls, fvalue, quantized_dtype, extract_diagonal=False):
|
| 59 |
+
"""Returns quantized value and the bucket."""
|
| 60 |
+
if quantized_dtype == jnp.float32:
|
| 61 |
+
return fvalue, [], []
|
| 62 |
+
elif quantized_dtype == jnp.bfloat16:
|
| 63 |
+
return fvalue.astype(jnp.bfloat16), [], []
|
| 64 |
+
|
| 65 |
+
float_dtype = fvalue.dtype
|
| 66 |
+
if quantized_dtype == jnp.int8:
|
| 67 |
+
# value -128 is not used.
|
| 68 |
+
num_buckets = jnp.array(127.0, dtype=float_dtype)
|
| 69 |
+
elif quantized_dtype == jnp.int16:
|
| 70 |
+
# value -32768 is not used.
|
| 71 |
+
num_buckets = jnp.array(32767.0, dtype=float_dtype)
|
| 72 |
+
else:
|
| 73 |
+
raise ValueError(f"Quantized dtype {quantized_dtype} not supported.")
|
| 74 |
+
# max value is mapped to num_buckets
|
| 75 |
+
|
| 76 |
+
if extract_diagonal and fvalue.ndim != 2:
|
| 77 |
+
raise ValueError(
|
| 78 |
+
f"Input array {fvalue} must be 2D to work with extract_diagonal."
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
diagonal_fvalue = []
|
| 82 |
+
if extract_diagonal:
|
| 83 |
+
diagonal_fvalue = jnp.diag(fvalue)
|
| 84 |
+
# Remove the diagonal entries.
|
| 85 |
+
fvalue = fvalue - jnp.diag(diagonal_fvalue)
|
| 86 |
+
|
| 87 |
+
# TODO(rohananil): Extend this by making use of information about the blocks
|
| 88 |
+
# SM3 style which will be useful for diagonal statistics
|
| 89 |
+
# We first decide the scale.
|
| 90 |
+
if fvalue.ndim < 1:
|
| 91 |
+
raise ValueError(
|
| 92 |
+
f"Input array {fvalue} must have a strictly positive number of "
|
| 93 |
+
"dimensions."
|
| 94 |
+
)
|
| 95 |
+
|
| 96 |
+
max_abs = jnp.max(jnp.abs(fvalue), axis=0)
|
| 97 |
+
bucket_size = max_abs / num_buckets
|
| 98 |
+
bs_expanded = bucket_size[jnp.newaxis, Ellipsis]
|
| 99 |
+
# To avoid divide by 0.0
|
| 100 |
+
bs_nonzero = jnp.where(
|
| 101 |
+
bs_expanded > 0.0, bs_expanded, jnp.ones_like(bs_expanded)
|
| 102 |
+
)
|
| 103 |
+
ratio = fvalue / bs_nonzero
|
| 104 |
+
# We use rounding to remove bias.
|
| 105 |
+
quantized = jnp.round(ratio)
|
| 106 |
+
return quantized.astype(quantized_dtype), diagonal_fvalue, bucket_size
|
| 107 |
+
|
| 108 |
+
def to_float(self):
|
| 109 |
+
"""Returns the float value."""
|
| 110 |
+
if isinstance(self.quantized, list) and not self.quantized:
|
| 111 |
+
return self.quantized
|
| 112 |
+
|
| 113 |
+
if self.quantized_dtype == jnp.float32:
|
| 114 |
+
return self.quantized
|
| 115 |
+
|
| 116 |
+
if self.quantized_dtype == jnp.bfloat16:
|
| 117 |
+
return self.quantized.astype(jnp.float32)
|
| 118 |
+
|
| 119 |
+
float_dtype = self.bucket_size.dtype
|
| 120 |
+
bucket_size = self.bucket_size[jnp.newaxis, Ellipsis]
|
| 121 |
+
val = self.quantized.astype(float_dtype) * bucket_size
|
| 122 |
+
if self.extract_diagonal:
|
| 123 |
+
val += jnp.diag(self.diagonal)
|
| 124 |
+
return val
|
tools/train/scalable_shampoo/sm3.py
ADDED
|
@@ -0,0 +1,176 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2022 The Google Research Authors.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
# An implementation of SM3 from:
|
| 17 |
+
#
|
| 18 |
+
# Memory-Efficient Adaptive Optimization, https://arxiv.org/pdf/1901.11150.pdf
|
| 19 |
+
# Rohan Anil, Vineet Gupta, Tomer Koren, Yoram Singer
|
| 20 |
+
#
|
| 21 |
+
# Author: Rohan Anil (rohananil at google dot com)
|
| 22 |
+
#
|
| 23 |
+
|
| 24 |
+
"""SM3 Implementation."""
|
| 25 |
+
|
| 26 |
+
import functools
|
| 27 |
+
from typing import Any, NamedTuple
|
| 28 |
+
|
| 29 |
+
import chex
|
| 30 |
+
import jax
|
| 31 |
+
import jax.numpy as jnp
|
| 32 |
+
import optax
|
| 33 |
+
|
| 34 |
+
from .quantization_utils import QuantizedValue
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class SM3State(NamedTuple):
|
| 38 |
+
count: chex.Array
|
| 39 |
+
stats: Any
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
# Per parameter optimizer state used in data-parallel training.
|
| 43 |
+
class ParameterStats(NamedTuple):
|
| 44 |
+
"""State associated to each parameter of the model being trained."""
|
| 45 |
+
|
| 46 |
+
diagonal_statistics: chex.Array # Accumulator for diagonal preconditioner
|
| 47 |
+
diagonal_momentum: QuantizedValue # Momentum for the diagonal preconditioner
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def sm3(
|
| 51 |
+
learning_rate, beta1=0.9, beta2=0.999, diagonal_epsilon=1e-10, normalize_grads=False
|
| 52 |
+
):
|
| 53 |
+
"""SM3 optimizer.
|
| 54 |
+
|
| 55 |
+
Memory-Efficient Adaptive Optimization, Rohan Anil, Vineet Gupta, Tomer Koren,
|
| 56 |
+
Yoram Singer
|
| 57 |
+
|
| 58 |
+
https://arxiv.org/abs/1901.11150
|
| 59 |
+
|
| 60 |
+
Args:
|
| 61 |
+
learning_rate: the step size used to update the parameters.
|
| 62 |
+
beta1: momentum parameter.
|
| 63 |
+
beta2: second moment averaging parameter.
|
| 64 |
+
diagonal_epsilon: epsilon for sm3
|
| 65 |
+
normalize_grads: Whether to normalize grads. Author finds it useful when
|
| 66 |
+
grads are high variance.
|
| 67 |
+
|
| 68 |
+
Returns:
|
| 69 |
+
a GradientTransformation.
|
| 70 |
+
"""
|
| 71 |
+
|
| 72 |
+
def _quantize_momentum(momentum_statistics):
|
| 73 |
+
return QuantizedValue.from_float_value(momentum_statistics, jnp.int8)
|
| 74 |
+
|
| 75 |
+
def init_fn(params):
|
| 76 |
+
"""Initialise the optimiser's state."""
|
| 77 |
+
|
| 78 |
+
def _init(param):
|
| 79 |
+
accumulators = [jnp.zeros([s]) for s in param.shape]
|
| 80 |
+
momentum = _quantize_momentum(jnp.zeros_like(param))
|
| 81 |
+
return ParameterStats(accumulators, momentum)
|
| 82 |
+
|
| 83 |
+
return SM3State(
|
| 84 |
+
count=jnp.zeros([], jnp.int32), stats=jax.tree_map(_init, params)
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
def _get_expanded_shape(shape, i):
|
| 88 |
+
rank = len(shape)
|
| 89 |
+
# Replaces a `shape` of [M, N, K] with 1 in all dimensions except for i.
|
| 90 |
+
# For eg: i = 1 returns [1, N, 1].
|
| 91 |
+
return [1] * i + [shape[i]] + [1] * (rank - i - 1)
|
| 92 |
+
|
| 93 |
+
def _moving_averages(grad, accumulators):
|
| 94 |
+
w = (1.0 - beta2) if beta2 != 1.0 else 1.0
|
| 95 |
+
if grad.ndim < 2:
|
| 96 |
+
return beta2 * accumulators[0] + w * grad**2
|
| 97 |
+
else:
|
| 98 |
+
min_accumulator = functools.reduce(jnp.minimum, accumulators)
|
| 99 |
+
return beta2 * min_accumulator + w * grad**2
|
| 100 |
+
|
| 101 |
+
def _moving_averages_momentum(grad, momentum):
|
| 102 |
+
w = (1.0 - beta1) if beta1 != 1.0 else 1.0
|
| 103 |
+
return beta1 * momentum.to_float() + w * grad
|
| 104 |
+
|
| 105 |
+
def _sketch_diagonal_statistics(grad, updated_diagonal_statistics):
|
| 106 |
+
all_diagonal_statistics = []
|
| 107 |
+
for i in range(grad.ndim):
|
| 108 |
+
axes = list(range(i)) + list(range(i + 1, grad.ndim))
|
| 109 |
+
dim_diagonal_statistics = jnp.max(updated_diagonal_statistics, axis=axes)
|
| 110 |
+
all_diagonal_statistics.append(dim_diagonal_statistics)
|
| 111 |
+
if grad.ndim == 1:
|
| 112 |
+
all_diagonal_statistics[0] = updated_diagonal_statistics
|
| 113 |
+
return all_diagonal_statistics
|
| 114 |
+
|
| 115 |
+
def update_fn(updates, state, params=None):
|
| 116 |
+
del params
|
| 117 |
+
stats = state.stats
|
| 118 |
+
if normalize_grads:
|
| 119 |
+
updates = jax.tree_map(lambda g: g / (jnp.linalg.norm(g) + 1e-16), updates)
|
| 120 |
+
# Reshape all vectors into N-d tensors to compute min over them.
|
| 121 |
+
# [n], [m] -> [n, 1], [1, m]
|
| 122 |
+
expanded_diagonal_statistics = jax.tree_multimap(
|
| 123 |
+
lambda grad, state: [ # pylint:disable=g-long-lambda
|
| 124 |
+
jnp.reshape(
|
| 125 |
+
state.diagonal_statistics[i], _get_expanded_shape(grad.shape, i)
|
| 126 |
+
)
|
| 127 |
+
for i in range(grad.ndim)
|
| 128 |
+
],
|
| 129 |
+
updates,
|
| 130 |
+
stats,
|
| 131 |
+
)
|
| 132 |
+
|
| 133 |
+
# Compute new diagonal statistics
|
| 134 |
+
new_diagonal_statistics = jax.tree_multimap(
|
| 135 |
+
_moving_averages, updates, expanded_diagonal_statistics
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
# Compute preconditioners (1/sqrt(s)) where s is the statistics.
|
| 139 |
+
new_preconditioners = jax.tree_map(
|
| 140 |
+
lambda t: 1.0 / jnp.sqrt(t + diagonal_epsilon), new_diagonal_statistics
|
| 141 |
+
)
|
| 142 |
+
preconditioned_grads = jax.tree_multimap(
|
| 143 |
+
lambda g, p: g * p, updates, new_preconditioners
|
| 144 |
+
)
|
| 145 |
+
|
| 146 |
+
# Compute updated momentum (also handle quantization)
|
| 147 |
+
updated_momentum = jax.tree_multimap(
|
| 148 |
+
lambda preconditioned_grad, state: _moving_averages_momentum( # pylint:disable=g-long-lambda
|
| 149 |
+
preconditioned_grad, state.diagonal_momentum
|
| 150 |
+
),
|
| 151 |
+
preconditioned_grads,
|
| 152 |
+
stats,
|
| 153 |
+
)
|
| 154 |
+
|
| 155 |
+
# Update diagonal statistics.
|
| 156 |
+
updated_diagonal_statistics = jax.tree_multimap(
|
| 157 |
+
_sketch_diagonal_statistics, updates, new_diagonal_statistics
|
| 158 |
+
)
|
| 159 |
+
|
| 160 |
+
# Update momentum.
|
| 161 |
+
new_sm3_stats = jax.tree_multimap(
|
| 162 |
+
lambda momentum, diagonal_stats: ParameterStats( # pylint:disable=g-long-lambda
|
| 163 |
+
diagonal_stats, _quantize_momentum(momentum)
|
| 164 |
+
),
|
| 165 |
+
updated_momentum,
|
| 166 |
+
updated_diagonal_statistics,
|
| 167 |
+
)
|
| 168 |
+
|
| 169 |
+
lr = learning_rate
|
| 170 |
+
if callable(learning_rate):
|
| 171 |
+
lr = learning_rate(state.count)
|
| 172 |
+
|
| 173 |
+
new_updates = jax.tree_map(lambda pg: -lr * pg, updated_momentum)
|
| 174 |
+
return new_updates, SM3State(count=state.count + 1, stats=new_sm3_stats)
|
| 175 |
+
|
| 176 |
+
return optax.GradientTransformation(init_fn, update_fn)
|
tools/train/scalable_shampoo/symmetric_matrices/symmetric_matrices.py
ADDED
|
@@ -0,0 +1,211 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2022 The Google Research Authors.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
"""JAX Ops for symmetric matrices used by the Shampoo optimizer."""
|
| 17 |
+
|
| 18 |
+
import functools
|
| 19 |
+
from typing import List, Union
|
| 20 |
+
|
| 21 |
+
import jax
|
| 22 |
+
import jax.numpy as jnp
|
| 23 |
+
from flax import struct
|
| 24 |
+
from jax import lax
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
@struct.dataclass
|
| 28 |
+
class SlicedSymmetricMatrix:
|
| 29 |
+
"""A symmetric matrix represented by lower-triangular block row slices.
|
| 30 |
+
|
| 31 |
+
For example, the symmetric matrix M = [[a, b^T], [b, c]] would be represented
|
| 32 |
+
by the block rows a and [b, c].
|
| 33 |
+
|
| 34 |
+
The matrix may be batched, in which case each entry of block_rows may have
|
| 35 |
+
dimension greater than 2. The last two dimensions represent the rows and cols.
|
| 36 |
+
"""
|
| 37 |
+
|
| 38 |
+
block_rows: List[jnp.ndarray]
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def product_with_transpose(
|
| 42 |
+
mat1,
|
| 43 |
+
mat2,
|
| 44 |
+
precision=lax.Precision.DEFAULT,
|
| 45 |
+
):
|
| 46 |
+
"""Returns mat1 * mat2^T for two matrices (possibly batched).
|
| 47 |
+
|
| 48 |
+
The rows and columns are the last two dimensions for each matrix.
|
| 49 |
+
|
| 50 |
+
Args:
|
| 51 |
+
mat1: First matrix.
|
| 52 |
+
mat2: Second matrix.
|
| 53 |
+
precision: JAX precision to use for the multiplication.
|
| 54 |
+
"""
|
| 55 |
+
return jnp.einsum("...ij,...kj->...ik", mat1, mat2, precision=precision)
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
@functools.partial(jax.jit, static_argnames=("block_size", "precision"))
|
| 59 |
+
def sliced_transposed_product(
|
| 60 |
+
mat,
|
| 61 |
+
block_size,
|
| 62 |
+
precision=lax.Precision.DEFAULT,
|
| 63 |
+
):
|
| 64 |
+
"""Returns the blocked slices representing a symmetric matrix mat*mat^T.
|
| 65 |
+
|
| 66 |
+
Args:
|
| 67 |
+
mat: The matrix for which we will compute mat*mat^T. It does not need to be
|
| 68 |
+
square, and may be batched.
|
| 69 |
+
block_size: The size of row blocks to compute.
|
| 70 |
+
precision: The precision to use in each computation.
|
| 71 |
+
|
| 72 |
+
Raises:
|
| 73 |
+
ValueError: Raised when the specified block size does not evenly divide
|
| 74 |
+
the number of rows of the input mat.
|
| 75 |
+
"""
|
| 76 |
+
num_rows = mat.shape[-2]
|
| 77 |
+
if num_rows % block_size != 0:
|
| 78 |
+
raise ValueError(
|
| 79 |
+
"The row dimension must be divisible by block_size. "
|
| 80 |
+
f"Instead got row dimension={num_rows} and block_size={block_size}."
|
| 81 |
+
)
|
| 82 |
+
block_rows = [
|
| 83 |
+
product_with_transpose(
|
| 84 |
+
mat[Ellipsis, i * block_size : (i + 1) * block_size, :],
|
| 85 |
+
mat[Ellipsis, 0 : (i + 1) * block_size, :],
|
| 86 |
+
precision,
|
| 87 |
+
)
|
| 88 |
+
for i in range(num_rows // block_size)
|
| 89 |
+
]
|
| 90 |
+
return SlicedSymmetricMatrix(block_rows=block_rows)
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
@functools.partial(jax.jit, static_argnames=("block_size", "precision"))
|
| 94 |
+
def sliced_transposed_product_concat(
|
| 95 |
+
mat,
|
| 96 |
+
block_size,
|
| 97 |
+
precision=lax.Precision.DEFAULT,
|
| 98 |
+
):
|
| 99 |
+
"""Returns the concatenated slices representing mat*mat^T.
|
| 100 |
+
|
| 101 |
+
Args:
|
| 102 |
+
mat: The matrix for which we will compute mat*mat^T. It does not need to be
|
| 103 |
+
square, and may be batched.
|
| 104 |
+
block_size: The size of row blocks to compute.
|
| 105 |
+
precision: The precision to use in each computation.
|
| 106 |
+
|
| 107 |
+
Raises:
|
| 108 |
+
ValueError: Raised when the specified block size does not evenly divide
|
| 109 |
+
the number of rows of the input mat.
|
| 110 |
+
"""
|
| 111 |
+
sliced_symmetric_matrix = sliced_transposed_product(
|
| 112 |
+
mat=mat, block_size=block_size, precision=precision
|
| 113 |
+
)
|
| 114 |
+
return jnp.concatenate(sliced_symmetric_matrix.block_rows, axis=-1)
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
@jax.jit
|
| 118 |
+
def materialize_matrix(symmetric_matrix):
|
| 119 |
+
"""Returns a materialized symmetric matrix.
|
| 120 |
+
|
| 121 |
+
Args:
|
| 122 |
+
symmetric_matrix: the matrix represented by lower-triangular block slices.
|
| 123 |
+
"""
|
| 124 |
+
block_rows = symmetric_matrix.block_rows
|
| 125 |
+
block_size = block_rows[0].shape[-2]
|
| 126 |
+
num_blocks = len(block_rows)
|
| 127 |
+
|
| 128 |
+
# Slice the lower-triangular and diagonal blocks into blocks.
|
| 129 |
+
blocks = [
|
| 130 |
+
[
|
| 131 |
+
block_row[Ellipsis, i * block_size : (i + 1) * block_size]
|
| 132 |
+
for i in range(k + 1)
|
| 133 |
+
]
|
| 134 |
+
for k, block_row in enumerate(block_rows)
|
| 135 |
+
]
|
| 136 |
+
|
| 137 |
+
# Generate the (off-diagonal) upper-triangular blocks.
|
| 138 |
+
off_diags = [[] for _ in range(num_blocks - 1)]
|
| 139 |
+
for k, block_row in enumerate(block_rows[1:]):
|
| 140 |
+
for i in range(k + 1):
|
| 141 |
+
off_diags[i].append(
|
| 142 |
+
jnp.swapaxes(
|
| 143 |
+
a=block_row[Ellipsis, i * block_size : (i + 1) * block_size],
|
| 144 |
+
axis1=-1,
|
| 145 |
+
axis2=-2,
|
| 146 |
+
)
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
return jnp.block(
|
| 150 |
+
[row + row_t for row, row_t in zip(blocks[:-1], off_diags)] + [blocks[-1]]
|
| 151 |
+
)
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
@functools.partial(jax.jit, static_argnames=("num_blocks"))
|
| 155 |
+
def materialize_matrix_from_concat(
|
| 156 |
+
block_rows_concat,
|
| 157 |
+
num_blocks,
|
| 158 |
+
):
|
| 159 |
+
"""Returns a materialized symmetric matrix from concatenated slices.
|
| 160 |
+
|
| 161 |
+
Args:
|
| 162 |
+
block_rows_concat: The matrix represented as the concatenated
|
| 163 |
+
lower-triangular blocks.
|
| 164 |
+
num_blocks: The number of block-rows used to represent the symmetric matrix.
|
| 165 |
+
"""
|
| 166 |
+
block_size = block_rows_concat.shape[-2]
|
| 167 |
+
|
| 168 |
+
block_rows = [
|
| 169 |
+
block_rows_concat[
|
| 170 |
+
Ellipsis,
|
| 171 |
+
(k * (k + 1))
|
| 172 |
+
// 2
|
| 173 |
+
* block_size : (((k + 1) * (k + 2)) // 2 + 1)
|
| 174 |
+
* block_size,
|
| 175 |
+
]
|
| 176 |
+
for k in range(num_blocks)
|
| 177 |
+
]
|
| 178 |
+
|
| 179 |
+
return materialize_matrix(SlicedSymmetricMatrix(block_rows=block_rows))
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
@functools.partial(jax.jit, static_argnames=("alpha", "beta"))
|
| 183 |
+
def update_sliced_rows(
|
| 184 |
+
symmetric_matrix,
|
| 185 |
+
mat,
|
| 186 |
+
alpha,
|
| 187 |
+
beta,
|
| 188 |
+
):
|
| 189 |
+
"""Implements the blocked equivalent of SYRK.
|
| 190 |
+
|
| 191 |
+
Specifically, the symmetric matrix (represented using lower-triangular block
|
| 192 |
+
rows) is updated using the sliced product of mat.
|
| 193 |
+
|
| 194 |
+
Args:
|
| 195 |
+
symmetric_matrix: The symmetric matrix to update.
|
| 196 |
+
mat: The matrix to use for the update = mat * mat^T. The number of rows
|
| 197 |
+
should match that of symmetric_matrix.
|
| 198 |
+
alpha: The weight for the update.
|
| 199 |
+
beta: The weight for the original symmetric matrix.
|
| 200 |
+
|
| 201 |
+
Returns:
|
| 202 |
+
The updated rows of alpha * mat * mat^T + beta * symmetric_matrix.
|
| 203 |
+
"""
|
| 204 |
+
block_size = symmetric_matrix.block_rows[0].shape[-2]
|
| 205 |
+
sym_prod = sliced_transposed_product(mat=mat, block_size=block_size)
|
| 206 |
+
return SlicedSymmetricMatrix(
|
| 207 |
+
block_rows=[
|
| 208 |
+
update * alpha + row * beta
|
| 209 |
+
for update, row in zip(sym_prod.block_rows, symmetric_matrix.block_rows)
|
| 210 |
+
]
|
| 211 |
+
)
|
tools/train/train.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
| 1 |
#!/usr/bin/env python
|
| 2 |
# coding=utf-8
|
| 3 |
-
# Copyright 2021-2022 The HuggingFace & DALL·E Mini
|
| 4 |
#
|
| 5 |
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
# you may not use this file except in compliance with the License.
|
|
@@ -37,7 +37,6 @@ import optax
|
|
| 37 |
import transformers
|
| 38 |
import wandb
|
| 39 |
from datasets import Dataset
|
| 40 |
-
from distributed_shampoo import GraftingType, distributed_shampoo
|
| 41 |
from flax.core.frozen_dict import FrozenDict, freeze
|
| 42 |
from flax.serialization import from_bytes, to_bytes
|
| 43 |
from flax.training import train_state
|
|
@@ -46,6 +45,7 @@ from google.cloud import storage
|
|
| 46 |
from jax.experimental import PartitionSpec, maps
|
| 47 |
from jax.experimental.compilation_cache import compilation_cache as cc
|
| 48 |
from jax.experimental.pjit import pjit, with_sharding_constraint
|
|
|
|
| 49 |
from tqdm import tqdm
|
| 50 |
from transformers import HfArgumentParser
|
| 51 |
|
|
@@ -57,7 +57,7 @@ from dalle_mini.model import (
|
|
| 57 |
set_partitions,
|
| 58 |
)
|
| 59 |
|
| 60 |
-
cc.initialize_cache("./jax_cache", max_cache_size_bytes=
|
| 61 |
|
| 62 |
logger = logging.getLogger(__name__)
|
| 63 |
|
|
@@ -203,6 +203,12 @@ class DataTrainingArguments:
|
|
| 203 |
"help": "Whether to shard data files by host in multi-host environments."
|
| 204 |
},
|
| 205 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 206 |
max_train_samples: Optional[int] = field(
|
| 207 |
default=None,
|
| 208 |
metadata={
|
|
@@ -314,10 +320,6 @@ class TrainingArguments:
|
|
| 314 |
default=1024,
|
| 315 |
metadata={"help": "Chunked size for large layers with Distributed Shampoo."},
|
| 316 |
)
|
| 317 |
-
start_preconditioning_step: int = field(
|
| 318 |
-
default=100,
|
| 319 |
-
metadata={"help": "Number of steps before starting to update preconditioner."},
|
| 320 |
-
)
|
| 321 |
preconditioning_compute_steps: int = field(
|
| 322 |
default=10, metadata={"help": "Number of steps to update preconditioner."}
|
| 323 |
)
|
|
@@ -325,6 +327,12 @@ class TrainingArguments:
|
|
| 325 |
default=4096,
|
| 326 |
metadata={"help": "Max size for preconditioning with Distributed Shampoo."},
|
| 327 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 328 |
optim_quantized: bool = field(
|
| 329 |
default=False,
|
| 330 |
metadata={
|
|
@@ -413,11 +421,28 @@ class TrainingArguments:
|
|
| 413 |
dp_devices: int = field(init=False)
|
| 414 |
|
| 415 |
def __post_init__(self):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 416 |
assert self.optim in [
|
| 417 |
"distributed_shampoo",
|
| 418 |
"adam",
|
| 419 |
"adafactor",
|
| 420 |
], f"Selected optimizer not supported: {self.optim}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 421 |
if self.per_device_eval_batch_size is None:
|
| 422 |
self.per_device_eval_batch_size = self.per_device_train_batch_size
|
| 423 |
if (
|
|
@@ -430,6 +455,9 @@ class TrainingArguments:
|
|
| 430 |
f"Output directory ({self.output_dir}) already exists and is not empty."
|
| 431 |
"Use --overwrite_output_dir to overcome."
|
| 432 |
)
|
|
|
|
|
|
|
|
|
|
| 433 |
assert (
|
| 434 |
jax.device_count() % self.mp_devices == 0
|
| 435 |
), f"Number of available devices ({jax.device_count()} must be divisible by number of devices used for model parallelism ({self.mp_devices})."
|
|
@@ -514,10 +542,6 @@ def main():
|
|
| 514 |
|
| 515 |
logger.info(f"Local TPUs: {jax.local_device_count()}")
|
| 516 |
logger.info(f"Global TPUs: {jax.device_count()}")
|
| 517 |
-
if training_args.assert_TPU_available:
|
| 518 |
-
assert (
|
| 519 |
-
jax.local_device_count() == 8
|
| 520 |
-
), "TPUs in use, please check running processes"
|
| 521 |
|
| 522 |
# Set up wandb run
|
| 523 |
if jax.process_index() == 0:
|
|
@@ -544,8 +568,7 @@ def main():
|
|
| 544 |
config=config,
|
| 545 |
seed=training_args.seed_model,
|
| 546 |
dtype=getattr(jnp, model_args.dtype),
|
| 547 |
-
abstract_init=True,
|
| 548 |
-
load_on_cpu=True,
|
| 549 |
# initializing params with gradient checkpointing creates issues
|
| 550 |
# we correctly set it later per training_args
|
| 551 |
gradient_checkpointing=False,
|
|
@@ -555,29 +578,23 @@ def main():
|
|
| 555 |
config,
|
| 556 |
seed=training_args.seed_model,
|
| 557 |
dtype=getattr(jnp, model_args.dtype),
|
| 558 |
-
|
| 559 |
)
|
| 560 |
|
| 561 |
-
#
|
| 562 |
-
|
| 563 |
-
# This is still considered correctly during training as function is pjitted
|
| 564 |
-
model.config.gradient_checkpointing = training_args.gradient_checkpointing
|
| 565 |
-
|
| 566 |
if training_args.gradient_checkpointing:
|
| 567 |
-
|
| 568 |
-
|
| 569 |
-
|
| 570 |
-
|
| 571 |
-
eval_config,
|
| 572 |
seed=training_args.seed_model,
|
| 573 |
dtype=getattr(jnp, model_args.dtype),
|
| 574 |
-
|
| 575 |
-
load_on_cpu=True,
|
| 576 |
)
|
| 577 |
-
|
| 578 |
-
eval_fn = eval_model.__call__
|
| 579 |
else:
|
| 580 |
-
|
| 581 |
|
| 582 |
# get model metadata
|
| 583 |
model_metadata = model_args.get_metadata()
|
|
@@ -620,7 +637,7 @@ def main():
|
|
| 620 |
eval_batch_size_per_step = eval_batch_size_per_node * jax.process_count()
|
| 621 |
len_train_dataset, len_eval_dataset = dataset.length
|
| 622 |
steps_per_epoch = (
|
| 623 |
-
len_train_dataset //
|
| 624 |
if len_train_dataset is not None
|
| 625 |
else None
|
| 626 |
)
|
|
@@ -633,7 +650,7 @@ def main():
|
|
| 633 |
logger.info(f" Num examples = {len_train_dataset}")
|
| 634 |
logger.info(f" Num Epochs = {num_epochs}")
|
| 635 |
logger.info(
|
| 636 |
-
f" Batch size per device = {training_args.per_device_train_batch_size}"
|
| 637 |
)
|
| 638 |
logger.info(f" Number of devices = {jax.device_count()}")
|
| 639 |
logger.info(
|
|
@@ -701,22 +718,32 @@ def main():
|
|
| 701 |
# create adam optimizer
|
| 702 |
if training_args.optim == "distributed_shampoo":
|
| 703 |
# parameters from https://github.com/tensorflow/lingvo/blob/03ee9d7cd50764b0424c7c863733c91fc0b053ec/lingvo/jax/optimizers.py#L729
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 704 |
optimizer = distributed_shampoo(
|
| 705 |
learning_rate_fn,
|
| 706 |
block_size=training_args.block_size,
|
| 707 |
beta1=training_args.beta1,
|
| 708 |
beta2=training_args.beta2,
|
| 709 |
diagonal_epsilon=1e-10,
|
| 710 |
-
matrix_epsilon=1e-
|
| 711 |
-
start_preconditioning_step=
|
|
|
|
|
|
|
| 712 |
preconditioning_compute_steps=training_args.preconditioning_compute_steps,
|
| 713 |
statistics_compute_steps=1,
|
| 714 |
best_effort_shape_interpretation=True,
|
| 715 |
-
graft_type=
|
| 716 |
nesterov=False,
|
| 717 |
exponent_override=0,
|
| 718 |
-
statistics_partition_spec=PartitionSpec(None, "
|
| 719 |
-
preconditioner_partition_spec=PartitionSpec("
|
| 720 |
num_devices_for_pjit=training_args.dp_devices,
|
| 721 |
shard_optimizer_states=True,
|
| 722 |
inverse_failure_threshold=0.1,
|
|
@@ -779,7 +806,7 @@ def main():
|
|
| 779 |
opt_state_spec = opt_fn.pspec_fn(
|
| 780 |
params=model.params,
|
| 781 |
params_partition_spec=param_spec,
|
| 782 |
-
partition_spec_for_statistics=PartitionSpec(None, "
|
| 783 |
)
|
| 784 |
else:
|
| 785 |
raise NotImplementedError
|
|
@@ -790,7 +817,8 @@ def main():
|
|
| 790 |
# create a mesh
|
| 791 |
mesh_shape = (training_args.dp_devices, training_args.mp_devices)
|
| 792 |
devices = np.asarray(jax.devices()).reshape(*mesh_shape)
|
| 793 |
-
mesh = maps.Mesh(devices, ("
|
|
|
|
| 794 |
|
| 795 |
# define state spec
|
| 796 |
state_spec = TrainState(
|
|
@@ -801,28 +829,39 @@ def main():
|
|
| 801 |
epoch=None,
|
| 802 |
train_time=None,
|
| 803 |
train_samples=None,
|
| 804 |
-
apply_fn=
|
| 805 |
tx=optimizer,
|
| 806 |
)
|
| 807 |
|
| 808 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 809 |
with maps.mesh(mesh.devices, mesh.axis_names):
|
|
|
|
| 810 |
if not model_args.restore_state:
|
| 811 |
|
| 812 |
def init_state(params):
|
| 813 |
return TrainState.create(
|
| 814 |
-
apply_fn=
|
| 815 |
tx=optimizer,
|
| 816 |
-
params=params,
|
| 817 |
dropout_rng=dropout_rng,
|
| 818 |
)
|
| 819 |
|
| 820 |
state = pjit(
|
| 821 |
init_state,
|
| 822 |
-
in_axis_resources=(param_spec,)
|
|
|
|
|
|
|
| 823 |
out_axis_resources=state_spec,
|
| 824 |
donate_argnums=(0,),
|
| 825 |
-
)(model.params)
|
| 826 |
|
| 827 |
else:
|
| 828 |
# load opt_state
|
|
@@ -836,7 +875,7 @@ def main():
|
|
| 836 |
|
| 837 |
def restore_state(params, opt_state):
|
| 838 |
return TrainState(
|
| 839 |
-
apply_fn=
|
| 840 |
tx=optimizer,
|
| 841 |
params=params,
|
| 842 |
opt_state=opt_state,
|
|
@@ -846,7 +885,10 @@ def main():
|
|
| 846 |
|
| 847 |
state = pjit(
|
| 848 |
restore_state,
|
| 849 |
-
in_axis_resources=(
|
|
|
|
|
|
|
|
|
|
| 850 |
out_axis_resources=state_spec,
|
| 851 |
donate_argnums=(0, 1),
|
| 852 |
)(model.params, opt_state)
|
|
@@ -854,37 +896,32 @@ def main():
|
|
| 854 |
# remove opt_state from CPU
|
| 855 |
del opt_state
|
| 856 |
|
| 857 |
-
# free memory
|
| 858 |
del model._params, opt_state_spec, opt_state_shape
|
| 859 |
|
| 860 |
# define batch specs
|
| 861 |
-
|
| 862 |
-
|
| 863 |
-
grad_batch_spec = freeze({k: PartitionSpec(None, "batch") for k in keys})
|
| 864 |
|
| 865 |
-
#
|
| 866 |
def loss_fn(logits, labels):
|
| 867 |
loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1]))
|
| 868 |
loss = loss.mean()
|
| 869 |
return loss
|
| 870 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 871 |
# Define gradient update step fn
|
| 872 |
def train_step(state, batch, delta_time):
|
| 873 |
-
# we reshape to (gradient_accumulation_steps, dp_devices, ...)
|
| 874 |
-
# allows feeding partial batch size per node for full model parallel
|
| 875 |
-
batch = jax.tree_map(
|
| 876 |
-
lambda x: x.reshape(
|
| 877 |
-
(
|
| 878 |
-
training_args.gradient_accumulation_steps,
|
| 879 |
-
training_args.dp_devices,
|
| 880 |
-
training_args.per_device_train_batch_size,
|
| 881 |
-
)
|
| 882 |
-
+ x.shape[2:]
|
| 883 |
-
),
|
| 884 |
-
batch,
|
| 885 |
-
)
|
| 886 |
-
# ensure data is sharded correctly per dp device
|
| 887 |
-
batch = with_sharding_constraint(batch, grad_batch_spec)
|
| 888 |
|
| 889 |
# get a minibatch (one gradient accumulation slice)
|
| 890 |
def get_minibatch(batch, grad_idx):
|
|
@@ -904,62 +941,71 @@ def main():
|
|
| 904 |
grad_fn = jax.value_and_grad(compute_loss)
|
| 905 |
|
| 906 |
def loss_and_grad(grad_idx, dropout_rng):
|
| 907 |
-
# minibatch at grad_idx
|
| 908 |
-
minibatch =
|
| 909 |
-
|
| 910 |
-
dropout_rng, _ = jax.random.split(dropout_rng)
|
| 911 |
-
# ensure inputs are sharded per device
|
| 912 |
-
minibatch = jax.tree_map(
|
| 913 |
-
lambda x: with_sharding_constraint(x, PartitionSpec("batch")),
|
| 914 |
-
minibatch,
|
| 915 |
-
)
|
| 916 |
-
# only 1 single rng per grad step, let us handle larger batch size
|
| 917 |
-
loss_grads = jax.vmap(grad_fn, in_axes=(None, 0, None), out_axes=(0, 0))(
|
| 918 |
-
state.params, minibatch, dropout_rng
|
| 919 |
)
|
| 920 |
-
# ensure
|
| 921 |
-
|
| 922 |
-
|
| 923 |
-
|
| 924 |
-
|
| 925 |
-
|
| 926 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 927 |
# return loss and grads
|
| 928 |
-
return
|
| 929 |
|
| 930 |
if training_args.gradient_accumulation_steps == 1:
|
| 931 |
-
|
| 932 |
else:
|
| 933 |
# create initial state for cumul_minibatch_step loop
|
| 934 |
init_minibatch_step = (
|
| 935 |
-
|
| 936 |
-
|
| 937 |
-
jax.tree_map(jnp.zeros_like, state.params),
|
| 938 |
),
|
| 939 |
state.dropout_rng,
|
| 940 |
)
|
| 941 |
|
| 942 |
# accumulate gradients
|
| 943 |
def cumul_minibatch_step(grad_idx, cumul_loss_grad_dropout):
|
| 944 |
-
|
| 945 |
-
|
| 946 |
-
|
| 947 |
-
|
|
|
|
|
|
|
|
|
|
| 948 |
|
| 949 |
# loop over gradients
|
| 950 |
-
|
| 951 |
0,
|
| 952 |
training_args.gradient_accumulation_steps,
|
| 953 |
cumul_minibatch_step,
|
| 954 |
init_minibatch_step,
|
| 955 |
)
|
|
|
|
| 956 |
# sum -> mean
|
| 957 |
-
|
| 958 |
-
lambda x: x / training_args.gradient_accumulation_steps,
|
| 959 |
)
|
| 960 |
|
| 961 |
# update state
|
| 962 |
-
|
| 963 |
state = state.apply_gradients(
|
| 964 |
grads=grads,
|
| 965 |
dropout_rng=dropout_rng,
|
|
@@ -976,37 +1022,32 @@ def main():
|
|
| 976 |
|
| 977 |
# Define eval fn
|
| 978 |
def eval_step(state, batch):
|
| 979 |
-
# we reshape to (dp_devices, ...)
|
| 980 |
-
batch = jax.tree_map(
|
| 981 |
-
lambda x: x.reshape(
|
| 982 |
-
(
|
| 983 |
-
training_args.dp_devices,
|
| 984 |
-
training_args.per_device_eval_batch_size,
|
| 985 |
-
)
|
| 986 |
-
+ x.shape[1:]
|
| 987 |
-
),
|
| 988 |
-
batch,
|
| 989 |
-
)
|
| 990 |
-
# ensure data is sharded correctly per dp device
|
| 991 |
-
batch = with_sharding_constraint(batch, batch_spec)
|
| 992 |
-
|
| 993 |
def compute_eval_loss(batch):
|
| 994 |
batch, labels = batch.pop("labels")
|
| 995 |
logits = eval_fn(**batch, params=state.params, train=False)[0]
|
| 996 |
return loss_fn(logits, labels)
|
| 997 |
|
| 998 |
-
|
| 999 |
-
|
| 1000 |
-
|
| 1001 |
-
|
| 1002 |
-
|
| 1003 |
-
|
|
|
|
|
|
|
|
|
|
| 1004 |
return loss
|
| 1005 |
|
| 1006 |
# Create parallel version of the train and eval step
|
| 1007 |
p_train_step = pjit(
|
| 1008 |
train_step,
|
| 1009 |
-
in_axis_resources=(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1010 |
out_axis_resources=(state_spec, None),
|
| 1011 |
donate_argnums=(0,),
|
| 1012 |
)
|
|
@@ -1022,7 +1063,10 @@ def main():
|
|
| 1022 |
step = int(state.step)
|
| 1023 |
metrics_logger = MetricsLogger(step)
|
| 1024 |
epochs = tqdm(
|
| 1025 |
-
range(state.epoch, num_epochs),
|
|
|
|
|
|
|
|
|
|
| 1026 |
)
|
| 1027 |
|
| 1028 |
def run_evaluation():
|
|
@@ -1041,6 +1085,7 @@ def main():
|
|
| 1041 |
position=2,
|
| 1042 |
leave=False,
|
| 1043 |
total=eval_steps,
|
|
|
|
| 1044 |
):
|
| 1045 |
# need to keep only eval_batch_size_per_node items relevant to the node
|
| 1046 |
batch = jax.tree_map(
|
|
@@ -1050,6 +1095,17 @@ def main():
|
|
| 1050 |
batch,
|
| 1051 |
)
|
| 1052 |
batch = jax.tree_map(lambda x: x[jax.process_index()], batch)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1053 |
# freeze batch to pass safely to jax transforms
|
| 1054 |
batch = freeze(batch)
|
| 1055 |
# accumulate losses async
|
|
@@ -1166,6 +1222,7 @@ def main():
|
|
| 1166 |
)
|
| 1167 |
wandb.run.log_artifact(artifact_state)
|
| 1168 |
|
|
|
|
| 1169 |
with maps.mesh(mesh.devices, mesh.axis_names):
|
| 1170 |
for epoch in epochs:
|
| 1171 |
state.replace(epoch=epoch)
|
|
@@ -1186,21 +1243,33 @@ def main():
|
|
| 1186 |
position=1,
|
| 1187 |
leave=False,
|
| 1188 |
total=steps_per_epoch,
|
|
|
|
| 1189 |
):
|
| 1190 |
# calculate delta time (we have a lag of one step but it's ok)
|
| 1191 |
new_time = time.perf_counter()
|
| 1192 |
delta_time = new_time - last_time
|
| 1193 |
last_time = new_time
|
| 1194 |
|
| 1195 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1196 |
batch = jax.tree_map(
|
| 1197 |
-
lambda x: x.reshape(
|
| 1198 |
-
(
|
| 1199 |
-
training_args.gradient_accumulation_steps,
|
| 1200 |
-
batch_size_per_node_per_grad_step,
|
| 1201 |
-
)
|
| 1202 |
-
+ x.shape[1:]
|
| 1203 |
-
),
|
| 1204 |
batch,
|
| 1205 |
)
|
| 1206 |
# freeze batch to pass safely to jax transforms
|
|
|
|
| 1 |
#!/usr/bin/env python
|
| 2 |
# coding=utf-8
|
| 3 |
+
# Copyright 2021-2022 The HuggingFace & DALL·E Mini team. All rights reserved.
|
| 4 |
#
|
| 5 |
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
# you may not use this file except in compliance with the License.
|
|
|
|
| 37 |
import transformers
|
| 38 |
import wandb
|
| 39 |
from datasets import Dataset
|
|
|
|
| 40 |
from flax.core.frozen_dict import FrozenDict, freeze
|
| 41 |
from flax.serialization import from_bytes, to_bytes
|
| 42 |
from flax.training import train_state
|
|
|
|
| 45 |
from jax.experimental import PartitionSpec, maps
|
| 46 |
from jax.experimental.compilation_cache import compilation_cache as cc
|
| 47 |
from jax.experimental.pjit import pjit, with_sharding_constraint
|
| 48 |
+
from scalable_shampoo.distributed_shampoo import GraftingType, distributed_shampoo
|
| 49 |
from tqdm import tqdm
|
| 50 |
from transformers import HfArgumentParser
|
| 51 |
|
|
|
|
| 57 |
set_partitions,
|
| 58 |
)
|
| 59 |
|
| 60 |
+
cc.initialize_cache("./jax_cache", max_cache_size_bytes=10 * 2**30)
|
| 61 |
|
| 62 |
logger = logging.getLogger(__name__)
|
| 63 |
|
|
|
|
| 203 |
"help": "Whether to shard data files by host in multi-host environments."
|
| 204 |
},
|
| 205 |
)
|
| 206 |
+
blank_caption_prob: Optional[float] = field(
|
| 207 |
+
default=0.0,
|
| 208 |
+
metadata={
|
| 209 |
+
"help": "Probability of removing some captions for classifier-free guidance."
|
| 210 |
+
},
|
| 211 |
+
)
|
| 212 |
max_train_samples: Optional[int] = field(
|
| 213 |
default=None,
|
| 214 |
metadata={
|
|
|
|
| 320 |
default=1024,
|
| 321 |
metadata={"help": "Chunked size for large layers with Distributed Shampoo."},
|
| 322 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 323 |
preconditioning_compute_steps: int = field(
|
| 324 |
default=10, metadata={"help": "Number of steps to update preconditioner."}
|
| 325 |
)
|
|
|
|
| 327 |
default=4096,
|
| 328 |
metadata={"help": "Max size for preconditioning with Distributed Shampoo."},
|
| 329 |
)
|
| 330 |
+
graft_type: str = field(
|
| 331 |
+
default="rmsprop_normalized",
|
| 332 |
+
metadata={
|
| 333 |
+
"help": "The type of grafting to use. Can be 'rmsprop_normalized' (default), 'rmsprop', 'adagrad', 'adagrad_normalized', 'sgd' or 'sqrt_n'"
|
| 334 |
+
},
|
| 335 |
+
)
|
| 336 |
optim_quantized: bool = field(
|
| 337 |
default=False,
|
| 338 |
metadata={
|
|
|
|
| 421 |
dp_devices: int = field(init=False)
|
| 422 |
|
| 423 |
def __post_init__(self):
|
| 424 |
+
if self.assert_TPU_available:
|
| 425 |
+
assert (
|
| 426 |
+
jax.local_device_count() == 8
|
| 427 |
+
), "TPUs in use, please check running processes"
|
| 428 |
assert self.optim in [
|
| 429 |
"distributed_shampoo",
|
| 430 |
"adam",
|
| 431 |
"adafactor",
|
| 432 |
], f"Selected optimizer not supported: {self.optim}"
|
| 433 |
+
assert self.graft_type in [
|
| 434 |
+
"rmsprop_normalized",
|
| 435 |
+
"rmsprop",
|
| 436 |
+
"adagrad",
|
| 437 |
+
"adagrad_normalized",
|
| 438 |
+
"sgd",
|
| 439 |
+
"sqrt_n",
|
| 440 |
+
], f"Selected graft type not supported: {self.graft_type}"
|
| 441 |
+
assert self.lr_decay in [
|
| 442 |
+
None,
|
| 443 |
+
"linear",
|
| 444 |
+
"exponential",
|
| 445 |
+
], f"Selected learning rate decay not supported: {self.lr_decay}"
|
| 446 |
if self.per_device_eval_batch_size is None:
|
| 447 |
self.per_device_eval_batch_size = self.per_device_train_batch_size
|
| 448 |
if (
|
|
|
|
| 455 |
f"Output directory ({self.output_dir}) already exists and is not empty."
|
| 456 |
"Use --overwrite_output_dir to overcome."
|
| 457 |
)
|
| 458 |
+
assert (
|
| 459 |
+
self.mp_devices > 0
|
| 460 |
+
), f"Number of devices for model parallelism must be > 0"
|
| 461 |
assert (
|
| 462 |
jax.device_count() % self.mp_devices == 0
|
| 463 |
), f"Number of available devices ({jax.device_count()} must be divisible by number of devices used for model parallelism ({self.mp_devices})."
|
|
|
|
| 542 |
|
| 543 |
logger.info(f"Local TPUs: {jax.local_device_count()}")
|
| 544 |
logger.info(f"Global TPUs: {jax.device_count()}")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 545 |
|
| 546 |
# Set up wandb run
|
| 547 |
if jax.process_index() == 0:
|
|
|
|
| 568 |
config=config,
|
| 569 |
seed=training_args.seed_model,
|
| 570 |
dtype=getattr(jnp, model_args.dtype),
|
| 571 |
+
abstract_init=True, # we overwrite them with loaded checkpoint
|
|
|
|
| 572 |
# initializing params with gradient checkpointing creates issues
|
| 573 |
# we correctly set it later per training_args
|
| 574 |
gradient_checkpointing=False,
|
|
|
|
| 578 |
config,
|
| 579 |
seed=training_args.seed_model,
|
| 580 |
dtype=getattr(jnp, model_args.dtype),
|
| 581 |
+
abstract_init=True,
|
| 582 |
)
|
| 583 |
|
| 584 |
+
# define model eval and train functions
|
| 585 |
+
eval_fn = model.__call__
|
|
|
|
|
|
|
|
|
|
| 586 |
if training_args.gradient_checkpointing:
|
| 587 |
+
remat_config = copy.deepcopy(model.config)
|
| 588 |
+
remat_config.gradient_checkpointing = True
|
| 589 |
+
remat_model = DalleBart(
|
| 590 |
+
remat_config,
|
|
|
|
| 591 |
seed=training_args.seed_model,
|
| 592 |
dtype=getattr(jnp, model_args.dtype),
|
| 593 |
+
init_weights=False,
|
|
|
|
| 594 |
)
|
| 595 |
+
train_fn = remat_model.__call__
|
|
|
|
| 596 |
else:
|
| 597 |
+
train_fn = model.__call__
|
| 598 |
|
| 599 |
# get model metadata
|
| 600 |
model_metadata = model_args.get_metadata()
|
|
|
|
| 637 |
eval_batch_size_per_step = eval_batch_size_per_node * jax.process_count()
|
| 638 |
len_train_dataset, len_eval_dataset = dataset.length
|
| 639 |
steps_per_epoch = (
|
| 640 |
+
len_train_dataset // batch_size_per_node
|
| 641 |
if len_train_dataset is not None
|
| 642 |
else None
|
| 643 |
)
|
|
|
|
| 650 |
logger.info(f" Num examples = {len_train_dataset}")
|
| 651 |
logger.info(f" Num Epochs = {num_epochs}")
|
| 652 |
logger.info(
|
| 653 |
+
f" Batch size per dp device = {training_args.per_device_train_batch_size}"
|
| 654 |
)
|
| 655 |
logger.info(f" Number of devices = {jax.device_count()}")
|
| 656 |
logger.info(
|
|
|
|
| 718 |
# create adam optimizer
|
| 719 |
if training_args.optim == "distributed_shampoo":
|
| 720 |
# parameters from https://github.com/tensorflow/lingvo/blob/03ee9d7cd50764b0424c7c863733c91fc0b053ec/lingvo/jax/optimizers.py#L729
|
| 721 |
+
graft_type = {
|
| 722 |
+
"sgd": GraftingType.SGD,
|
| 723 |
+
"adagrad": GraftingType.ADAGRAD,
|
| 724 |
+
"rmsprop": GraftingType.RMSPROP,
|
| 725 |
+
"rmsprop_normalized": GraftingType.RMSPROP_NORMALIZED,
|
| 726 |
+
"sqrt_n": GraftingType.SQRT_N,
|
| 727 |
+
"adagrad_normalized": GraftingType.ADAGRAD_NORMALIZED,
|
| 728 |
+
}[training_args.graft_type]
|
| 729 |
optimizer = distributed_shampoo(
|
| 730 |
learning_rate_fn,
|
| 731 |
block_size=training_args.block_size,
|
| 732 |
beta1=training_args.beta1,
|
| 733 |
beta2=training_args.beta2,
|
| 734 |
diagonal_epsilon=1e-10,
|
| 735 |
+
matrix_epsilon=1e-6,
|
| 736 |
+
start_preconditioning_step=max(
|
| 737 |
+
training_args.preconditioning_compute_steps + 1, 101
|
| 738 |
+
),
|
| 739 |
preconditioning_compute_steps=training_args.preconditioning_compute_steps,
|
| 740 |
statistics_compute_steps=1,
|
| 741 |
best_effort_shape_interpretation=True,
|
| 742 |
+
graft_type=graft_type,
|
| 743 |
nesterov=False,
|
| 744 |
exponent_override=0,
|
| 745 |
+
statistics_partition_spec=PartitionSpec(None, "dp", None),
|
| 746 |
+
preconditioner_partition_spec=PartitionSpec("dp", None, None),
|
| 747 |
num_devices_for_pjit=training_args.dp_devices,
|
| 748 |
shard_optimizer_states=True,
|
| 749 |
inverse_failure_threshold=0.1,
|
|
|
|
| 806 |
opt_state_spec = opt_fn.pspec_fn(
|
| 807 |
params=model.params,
|
| 808 |
params_partition_spec=param_spec,
|
| 809 |
+
partition_spec_for_statistics=PartitionSpec(None, "dp", None),
|
| 810 |
)
|
| 811 |
else:
|
| 812 |
raise NotImplementedError
|
|
|
|
| 817 |
# create a mesh
|
| 818 |
mesh_shape = (training_args.dp_devices, training_args.mp_devices)
|
| 819 |
devices = np.asarray(jax.devices()).reshape(*mesh_shape)
|
| 820 |
+
mesh = maps.Mesh(devices, ("dp", "mp"))
|
| 821 |
+
logger.info(f" Mesh shape: {mesh_shape}")
|
| 822 |
|
| 823 |
# define state spec
|
| 824 |
state_spec = TrainState(
|
|
|
|
| 829 |
epoch=None,
|
| 830 |
train_time=None,
|
| 831 |
train_samples=None,
|
| 832 |
+
apply_fn=train_fn,
|
| 833 |
tx=optimizer,
|
| 834 |
)
|
| 835 |
|
| 836 |
+
# init params if not available yet
|
| 837 |
+
def maybe_init_params(params):
|
| 838 |
+
if model_args.model_name_or_path:
|
| 839 |
+
# model params are correctly loaded
|
| 840 |
+
return params
|
| 841 |
+
else:
|
| 842 |
+
# params have not been initialized yet
|
| 843 |
+
return model.init_weights()
|
| 844 |
+
|
| 845 |
with maps.mesh(mesh.devices, mesh.axis_names):
|
| 846 |
+
logger.info(" Creating state")
|
| 847 |
if not model_args.restore_state:
|
| 848 |
|
| 849 |
def init_state(params):
|
| 850 |
return TrainState.create(
|
| 851 |
+
apply_fn=train_fn,
|
| 852 |
tx=optimizer,
|
| 853 |
+
params=maybe_init_params(params),
|
| 854 |
dropout_rng=dropout_rng,
|
| 855 |
)
|
| 856 |
|
| 857 |
state = pjit(
|
| 858 |
init_state,
|
| 859 |
+
in_axis_resources=(param_spec,)
|
| 860 |
+
if model_args.model_name_or_path
|
| 861 |
+
else None,
|
| 862 |
out_axis_resources=state_spec,
|
| 863 |
donate_argnums=(0,),
|
| 864 |
+
)(model.params if model_args.model_name_or_path else None)
|
| 865 |
|
| 866 |
else:
|
| 867 |
# load opt_state
|
|
|
|
| 875 |
|
| 876 |
def restore_state(params, opt_state):
|
| 877 |
return TrainState(
|
| 878 |
+
apply_fn=train_fn,
|
| 879 |
tx=optimizer,
|
| 880 |
params=params,
|
| 881 |
opt_state=opt_state,
|
|
|
|
| 885 |
|
| 886 |
state = pjit(
|
| 887 |
restore_state,
|
| 888 |
+
in_axis_resources=(
|
| 889 |
+
param_spec,
|
| 890 |
+
opt_state_spec,
|
| 891 |
+
),
|
| 892 |
out_axis_resources=state_spec,
|
| 893 |
donate_argnums=(0, 1),
|
| 894 |
)(model.params, opt_state)
|
|
|
|
| 896 |
# remove opt_state from CPU
|
| 897 |
del opt_state
|
| 898 |
|
| 899 |
+
# free CPU memory
|
| 900 |
del model._params, opt_state_spec, opt_state_shape
|
| 901 |
|
| 902 |
# define batch specs
|
| 903 |
+
batch_spec = PartitionSpec("dp")
|
| 904 |
+
grad_batch_spec = PartitionSpec(None, "dp")
|
|
|
|
| 905 |
|
| 906 |
+
# define loss
|
| 907 |
def loss_fn(logits, labels):
|
| 908 |
loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1]))
|
| 909 |
loss = loss.mean()
|
| 910 |
return loss
|
| 911 |
|
| 912 |
+
# "vmap trick" avoids a crash when mp_devices > 1 (not sure why it happens)
|
| 913 |
+
# lead to better perf: see https://wandb.ai/dalle-mini/dalle-mini/reports/JAX-pmap-vs-pjit--VmlldzoxNDg1ODA2
|
| 914 |
+
use_vmap_trick = True
|
| 915 |
+
|
| 916 |
+
# make grad_param_spec for vmap
|
| 917 |
+
if use_vmap_trick:
|
| 918 |
+
grad_param_spec = jax.tree_map(
|
| 919 |
+
lambda x: PartitionSpec(*("dp",) + (x if x is not None else (None,))),
|
| 920 |
+
param_spec,
|
| 921 |
+
)
|
| 922 |
+
|
| 923 |
# Define gradient update step fn
|
| 924 |
def train_step(state, batch, delta_time):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 925 |
|
| 926 |
# get a minibatch (one gradient accumulation slice)
|
| 927 |
def get_minibatch(batch, grad_idx):
|
|
|
|
| 941 |
grad_fn = jax.value_and_grad(compute_loss)
|
| 942 |
|
| 943 |
def loss_and_grad(grad_idx, dropout_rng):
|
| 944 |
+
# minibatch at grad_idx for gradient accumulation (None otherwise)
|
| 945 |
+
minibatch = (
|
| 946 |
+
get_minibatch(batch, grad_idx) if grad_idx is not None else batch
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 947 |
)
|
| 948 |
+
# ensure it is sharded properly
|
| 949 |
+
minibatch = with_sharding_constraint(minibatch, batch_spec)
|
| 950 |
+
# only 1 single rng per grad step, let us handle larger batch size (not sure why)
|
| 951 |
+
dropout_rng, _ = jax.random.split(dropout_rng)
|
| 952 |
+
|
| 953 |
+
if use_vmap_trick:
|
| 954 |
+
# "vmap trick", calculate loss and grads independently per dp_device
|
| 955 |
+
loss, grads = jax.vmap(
|
| 956 |
+
grad_fn, in_axes=(None, 0, None), out_axes=(0, 0)
|
| 957 |
+
)(state.params, minibatch, dropout_rng)
|
| 958 |
+
# ensure they are sharded correctly
|
| 959 |
+
loss = with_sharding_constraint(loss, batch_spec)
|
| 960 |
+
grads = with_sharding_constraint(grads, grad_param_spec)
|
| 961 |
+
# average across all devices
|
| 962 |
+
# Note: we could average per device only after gradient accumulation, right before params update
|
| 963 |
+
loss, grads = jax.tree_map(lambda x: jnp.mean(x, axis=0), (loss, grads))
|
| 964 |
+
else:
|
| 965 |
+
# "vmap trick" does not work in multi-hosts and requires too much hbm
|
| 966 |
+
loss, grads = grad_fn(state.params, minibatch, dropout_rng)
|
| 967 |
+
# ensure grads are sharded
|
| 968 |
+
grads = with_sharding_constraint(grads, param_spec)
|
| 969 |
# return loss and grads
|
| 970 |
+
return loss, grads, dropout_rng
|
| 971 |
|
| 972 |
if training_args.gradient_accumulation_steps == 1:
|
| 973 |
+
loss, grads, dropout_rng = loss_and_grad(None, state.dropout_rng)
|
| 974 |
else:
|
| 975 |
# create initial state for cumul_minibatch_step loop
|
| 976 |
init_minibatch_step = (
|
| 977 |
+
0.0,
|
| 978 |
+
with_sharding_constraint(
|
| 979 |
+
jax.tree_map(jnp.zeros_like, state.params), param_spec
|
| 980 |
),
|
| 981 |
state.dropout_rng,
|
| 982 |
)
|
| 983 |
|
| 984 |
# accumulate gradients
|
| 985 |
def cumul_minibatch_step(grad_idx, cumul_loss_grad_dropout):
|
| 986 |
+
cumul_loss, cumul_grads, dropout_rng = cumul_loss_grad_dropout
|
| 987 |
+
loss, grads, dropout_rng = loss_and_grad(grad_idx, dropout_rng)
|
| 988 |
+
cumul_loss, cumul_grads = jax.tree_map(
|
| 989 |
+
jnp.add, (cumul_loss, cumul_grads), (loss, grads)
|
| 990 |
+
)
|
| 991 |
+
cumul_grads = with_sharding_constraint(cumul_grads, param_spec)
|
| 992 |
+
return cumul_loss, cumul_grads, dropout_rng
|
| 993 |
|
| 994 |
# loop over gradients
|
| 995 |
+
loss, grads, dropout_rng = jax.lax.fori_loop(
|
| 996 |
0,
|
| 997 |
training_args.gradient_accumulation_steps,
|
| 998 |
cumul_minibatch_step,
|
| 999 |
init_minibatch_step,
|
| 1000 |
)
|
| 1001 |
+
grads = with_sharding_constraint(grads, param_spec)
|
| 1002 |
# sum -> mean
|
| 1003 |
+
loss, grads = jax.tree_map(
|
| 1004 |
+
lambda x: x / training_args.gradient_accumulation_steps, (loss, grads)
|
| 1005 |
)
|
| 1006 |
|
| 1007 |
# update state
|
| 1008 |
+
grads = with_sharding_constraint(grads, param_spec)
|
| 1009 |
state = state.apply_gradients(
|
| 1010 |
grads=grads,
|
| 1011 |
dropout_rng=dropout_rng,
|
|
|
|
| 1022 |
|
| 1023 |
# Define eval fn
|
| 1024 |
def eval_step(state, batch):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1025 |
def compute_eval_loss(batch):
|
| 1026 |
batch, labels = batch.pop("labels")
|
| 1027 |
logits = eval_fn(**batch, params=state.params, train=False)[0]
|
| 1028 |
return loss_fn(logits, labels)
|
| 1029 |
|
| 1030 |
+
if use_vmap_trick:
|
| 1031 |
+
loss = jax.vmap(compute_eval_loss)(batch)
|
| 1032 |
+
# ensure they are sharded correctly
|
| 1033 |
+
loss = with_sharding_constraint(loss, batch_spec)
|
| 1034 |
+
# average across all devices
|
| 1035 |
+
loss = jnp.mean(loss)
|
| 1036 |
+
else:
|
| 1037 |
+
loss = compute_eval_loss(batch)
|
| 1038 |
+
|
| 1039 |
return loss
|
| 1040 |
|
| 1041 |
# Create parallel version of the train and eval step
|
| 1042 |
p_train_step = pjit(
|
| 1043 |
train_step,
|
| 1044 |
+
in_axis_resources=(
|
| 1045 |
+
state_spec,
|
| 1046 |
+
grad_batch_spec
|
| 1047 |
+
if training_args.gradient_accumulation_steps > 1
|
| 1048 |
+
else batch_spec,
|
| 1049 |
+
None,
|
| 1050 |
+
),
|
| 1051 |
out_axis_resources=(state_spec, None),
|
| 1052 |
donate_argnums=(0,),
|
| 1053 |
)
|
|
|
|
| 1063 |
step = int(state.step)
|
| 1064 |
metrics_logger = MetricsLogger(step)
|
| 1065 |
epochs = tqdm(
|
| 1066 |
+
range(state.epoch, num_epochs),
|
| 1067 |
+
desc=f"Epoch ... (1/{num_epochs})",
|
| 1068 |
+
position=0,
|
| 1069 |
+
disable=jax.process_index() > 0,
|
| 1070 |
)
|
| 1071 |
|
| 1072 |
def run_evaluation():
|
|
|
|
| 1085 |
position=2,
|
| 1086 |
leave=False,
|
| 1087 |
total=eval_steps,
|
| 1088 |
+
disable=jax.process_index() > 0,
|
| 1089 |
):
|
| 1090 |
# need to keep only eval_batch_size_per_node items relevant to the node
|
| 1091 |
batch = jax.tree_map(
|
|
|
|
| 1095 |
batch,
|
| 1096 |
)
|
| 1097 |
batch = jax.tree_map(lambda x: x[jax.process_index()], batch)
|
| 1098 |
+
|
| 1099 |
+
# add dp dimension when using "vmap trick"
|
| 1100 |
+
if use_vmap_trick:
|
| 1101 |
+
bs_shape = (
|
| 1102 |
+
jax.local_device_count() // training_args.mp_devices,
|
| 1103 |
+
training_args.per_device_eval_batch_size,
|
| 1104 |
+
)
|
| 1105 |
+
batch = jax.tree_map(
|
| 1106 |
+
lambda x: x.reshape(bs_shape + x.shape[1:]), batch
|
| 1107 |
+
)
|
| 1108 |
+
|
| 1109 |
# freeze batch to pass safely to jax transforms
|
| 1110 |
batch = freeze(batch)
|
| 1111 |
# accumulate losses async
|
|
|
|
| 1222 |
)
|
| 1223 |
wandb.run.log_artifact(artifact_state)
|
| 1224 |
|
| 1225 |
+
logger.info(" Ready to start training")
|
| 1226 |
with maps.mesh(mesh.devices, mesh.axis_names):
|
| 1227 |
for epoch in epochs:
|
| 1228 |
state.replace(epoch=epoch)
|
|
|
|
| 1243 |
position=1,
|
| 1244 |
leave=False,
|
| 1245 |
total=steps_per_epoch,
|
| 1246 |
+
disable=jax.process_index() > 0,
|
| 1247 |
):
|
| 1248 |
# calculate delta time (we have a lag of one step but it's ok)
|
| 1249 |
new_time = time.perf_counter()
|
| 1250 |
delta_time = new_time - last_time
|
| 1251 |
last_time = new_time
|
| 1252 |
|
| 1253 |
+
# set correct shape to batch
|
| 1254 |
+
# - add grad_step dim if gradient_accumulation_steps > 1
|
| 1255 |
+
# - split per dp device if not multi-host for vmap trick (does not work in multi-host)
|
| 1256 |
+
bs_shape = (
|
| 1257 |
+
(batch_size_per_node_per_grad_step,)
|
| 1258 |
+
if not use_vmap_trick
|
| 1259 |
+
else (
|
| 1260 |
+
jax.local_device_count()
|
| 1261 |
+
// training_args.mp_devices, # local dp devices
|
| 1262 |
+
training_args.per_device_train_batch_size,
|
| 1263 |
+
)
|
| 1264 |
+
)
|
| 1265 |
+
if training_args.gradient_accumulation_steps > 1:
|
| 1266 |
+
# reshape data into (gradient_accumulation_steps, batch_per_node, ...)
|
| 1267 |
+
# to avoid any data redistribution when sharding
|
| 1268 |
+
bs_shape = (training_args.gradient_accumulation_steps,) + bs_shape
|
| 1269 |
+
|
| 1270 |
+
# reshape batch
|
| 1271 |
batch = jax.tree_map(
|
| 1272 |
+
lambda x: x.reshape(bs_shape + x.shape[1:]),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1273 |
batch,
|
| 1274 |
)
|
| 1275 |
# freeze batch to pass safely to jax transforms
|