Spaces:
Build error
Build error
| import dataclasses | |
| import gc | |
| import json | |
| import logging | |
| from contextlib import contextmanager | |
| from enum import Enum | |
| import accelerate | |
| import psutil | |
| import pynvml | |
| import torch | |
| import torch.nn as nn | |
| import torchvision.transforms as transforms | |
| from accelerate.state import AcceleratorState | |
| from PIL import Image | |
| from transformers import ( # AddedToken is needed for the eval of the tokenizer params # noqa: F401 | |
| AddedToken, | |
| AutoTokenizer, | |
| ) | |
| IMAGE_TOKEN = "<image>" | |
| FAKE_TOKEN_AROUND_IMAGE_V2 = "<fake_token_around_image>" | |
| FAKE_TOKEN_AROUND_IMAGE_V1 = "\n\n" | |
| # Originally taken from the values used in OpenCLIP | |
| IMAGE_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073) | |
| IMAGE_DATASET_STD = (0.26862954, 0.26130258, 0.27577711) | |
| logger = logging.getLogger(__name__) | |
| class LoggingTypes(Enum): | |
| """Types of logging to use for the gradient and parameter statistics""" | |
| JSONL = "jsonl" | |
| WANDB = "wandb" | |
| PRINT = "print" | |
| class JSONEncoderForDataclasses(json.JSONEncoder): | |
| """ | |
| Use to serialize dataclass object, like so: | |
| json.dump(data, fp, indent=2, cls=JSONEncoderForDataclasses) | |
| """ | |
| def default(self, obj): | |
| if dataclasses.is_dataclass(obj): | |
| return dataclasses.asdict(obj) | |
| return super().default(obj) | |
| def freeze_model(model, module_exceptions=[]): | |
| mapping = { | |
| "LayerNorm": nn.LayerNorm, | |
| "Linear": nn.Linear, | |
| "Embedding": nn.Embedding, | |
| } | |
| module_exceptions_mapped = [mapping[m] for m in module_exceptions] | |
| for module in model.modules(): | |
| if module_exceptions and any([isinstance(module, t) for t in module_exceptions_mapped]): | |
| module.requires_grad_(True) # Explicitly setting it to true to avoid any mistakes | |
| else: | |
| module.requires_grad_(False) | |
| return model | |
| def _convert_to_rgb(image): | |
| # `image.convert("RGB")` would only work for .jpg images, as it creates | |
| # a wrong background for transparent images. The call to `alpha_composite` | |
| # handles this case | |
| if image.mode == "RGB": | |
| return image | |
| image_rgba = image.convert("RGBA") | |
| background = Image.new("RGBA", image_rgba.size, (255, 255, 255)) | |
| alpha_composite = Image.alpha_composite(background, image_rgba) | |
| alpha_composite = alpha_composite.convert("RGB") | |
| return alpha_composite | |
| # TODO(aps): Take parameters from config | |
| def build_image_transform(image_size=224, eval=False): | |
| return transforms.Compose( | |
| [ | |
| _convert_to_rgb, | |
| ( | |
| transforms.Resize((image_size, image_size), interpolation=transforms.InterpolationMode.BICUBIC) | |
| if eval | |
| else transforms.RandomResizedCrop( | |
| (image_size, image_size), scale=(0.9, 1.0), interpolation=transforms.InterpolationMode.BICUBIC | |
| ) | |
| ), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=IMAGE_DATASET_MEAN, std=IMAGE_DATASET_STD), | |
| ] | |
| ) | |
| def get_tokenizer( | |
| tokenizer_name: str, | |
| tokenizer_add_tokens, | |
| tokenizer_add_special_tokens, | |
| tokenizer_params, | |
| additional_vocab_size, | |
| model_vocab_size=None, | |
| ): | |
| """ | |
| We artificially separate `tokenizer_add_tokens` and `tokenizer_add_special_tokens` is a dictionary whose keys only takes into account special tokens (eos, pad, cls, etc.). | |
| On the contrary, `tokenizer_add_tokens` is a list of string of `AddedToken`. | |
| In practise, we use `tokenizer.add_special_tokens` to add all of these new special tokens or update the existing ones. | |
| NB: we constraint to tokenizer to be a fast tokenizer because with the slow tokenizer, we can't set the arguments of the added tokens (cf `.add_tokens`) and by default, the separators are stripped. | |
| """ | |
| tokenizer_params = eval(tokenizer_params) | |
| assert isinstance(tokenizer_params, dict) | |
| tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, **tokenizer_params) | |
| if model_vocab_size is not None: | |
| if model_vocab_size > len(tokenizer): | |
| logger.warning( | |
| f"The model vocabulary size ({model_vocab_size}) is larger than the tokenizer vocabulary size " | |
| f"({len(tokenizer)}). Updating the tokenizer to match." | |
| ) | |
| if "additional_special_tokens" in tokenizer_params: | |
| raise ValueError( | |
| "You can't use `additional_special_tokens` in `tokenizer_params` with a model vocab " | |
| "size > tokenizer vocab size. We need to adjust tokenizer before adding special " | |
| "tokens. Please use `tokenizer_add_tokens` instead." | |
| ) | |
| # We need to pad the tokenizer vocab with fake tokens | |
| tokenizer.add_tokens(["<fake_token_{}>".format(i) for i in range(model_vocab_size - len(tokenizer))]) | |
| assert str(eval(tokenizer_add_tokens)[-1]) == IMAGE_TOKEN | |
| assert str(eval(tokenizer_add_tokens)[-2]) == FAKE_TOKEN_AROUND_IMAGE_V2 | |
| # This check ensures that the image token and the fake token around it will be in the `DecoupledEmbedding.additional_weight`. | |
| existing_special_tokens = ( | |
| [*tokenizer.special_tokens_map_extended["additional_special_tokens"]] | |
| if "additional_special_tokens" in tokenizer.special_tokens_map_extended | |
| else [] | |
| ) | |
| add_special_tokens_dict = {"additional_special_tokens": existing_special_tokens + eval(tokenizer_add_tokens)} | |
| if tokenizer_add_special_tokens is not None: | |
| add_special_tokens_dict.update(eval(tokenizer_add_special_tokens)) | |
| tokenizer.add_special_tokens(add_special_tokens_dict) | |
| assert IMAGE_TOKEN in tokenizer.convert_ids_to_tokens( | |
| [idx for idx in range(len(tokenizer) - additional_vocab_size, len(tokenizer))] | |
| ) | |
| assert FAKE_TOKEN_AROUND_IMAGE_V2 in tokenizer.convert_ids_to_tokens( | |
| [idx for idx in range(len(tokenizer) - additional_vocab_size, len(tokenizer))] | |
| ) | |
| # This verifies that `<image>` was correctly added to the tokenizer vocabulary | |
| # XXX: opt-1.3b fails here | |
| # assert tokenizer.is_fast == tokenizer_params.get("use_fast", True) | |
| return tokenizer | |
| def pynmvl_handle(accelerator): | |
| if not torch.cuda.is_available(): | |
| return None | |
| pynvml.nvmlInit() | |
| return pynvml.nvmlDeviceGetHandleByIndex(accelerator.local_process_index) | |
| def pynvml_get_total_energy_in_joules(handle): | |
| if not torch.cuda.is_available(): | |
| return 0 | |
| return pynvml.nvmlDeviceGetTotalEnergyConsumption(handle) / 1000 | |
| def compute_tflops_per_batch_per_gpu( | |
| num_layers, | |
| batch_size, | |
| q_seq_len, | |
| k_seq_len, | |
| hidden_size, | |
| kv_in_dim, | |
| ff_exp_factor=None, | |
| grad_acc_size=1, | |
| swiglu=False, | |
| vocab_size=None, | |
| count_backward=False, | |
| use_grad_checkpointing=False, | |
| ): | |
| multiply_add_factor = torch.tensor(2) | |
| query_transformation = multiply_add_factor * batch_size * q_seq_len * hidden_size**2 | |
| # k_seq_len == v_seq_len | |
| key_value_transformation = multiply_add_factor * batch_size * k_seq_len * (2 * hidden_size * kv_in_dim) | |
| attention_matrix_computation = multiply_add_factor * batch_size * q_seq_len * k_seq_len * hidden_size | |
| attention_softmax = multiply_add_factor * q_seq_len * k_seq_len | |
| att_over_values_computation = multiply_add_factor * batch_size * q_seq_len * k_seq_len * hidden_size | |
| post_attention_linear_proj = multiply_add_factor * batch_size * q_seq_len * hidden_size**2 | |
| # There are usually 2 expansion_linear_layers because first one expands, and second one retracts back to hidden_size | |
| # When using a classic decoder, some blocks don't have those feed-forward layers | |
| # Swiglu duplicates the first linear layer, so we have to account for 3 of them instead of 2 | |
| if ff_exp_factor and swiglu: | |
| expansion_linear_layers = 3 * ( | |
| multiply_add_factor * batch_size * q_seq_len * (hidden_size * ff_exp_factor) * hidden_size | |
| ) | |
| elif ff_exp_factor: | |
| expansion_linear_layers = 2 * ( | |
| multiply_add_factor * batch_size * q_seq_len * (hidden_size * ff_exp_factor) * hidden_size | |
| ) | |
| else: | |
| expansion_linear_layers = torch.tensor(0) | |
| transformer_block_flops = ( | |
| query_transformation | |
| + key_value_transformation | |
| + attention_matrix_computation | |
| + attention_softmax | |
| + att_over_values_computation | |
| + post_attention_linear_proj | |
| + expansion_linear_layers | |
| ) | |
| # This computation should only be added if the model has a language head | |
| if vocab_size: | |
| language_head_computation = multiply_add_factor * batch_size * q_seq_len * hidden_size * vocab_size | |
| else: | |
| language_head_computation = torch.tensor(0) | |
| forward_fact = 1 | |
| backward_factor = 2 if count_backward else 0 | |
| grad_checkpointing_factor = 1 if use_grad_checkpointing else 0 | |
| model_flops = (forward_fact + backward_factor + grad_checkpointing_factor) * ( | |
| num_layers * transformer_block_flops + language_head_computation | |
| ) | |
| model_tflops = model_flops / (10**12) | |
| return model_tflops | |
| def compute_perceiver_tflops_per_batch_per_gpu( | |
| num_layers, | |
| batch_size, | |
| q_seq_len, | |
| vision_embed_seq_len, | |
| q_k_v_input_dim, | |
| attention_hidden_size, | |
| ff_exp_factor=None, | |
| count_backward=False, | |
| use_grad_checkpointing=False, | |
| ): | |
| multiply_add_factor = torch.tensor(2) | |
| query_transformation = multiply_add_factor * batch_size * q_seq_len * q_k_v_input_dim * attention_hidden_size | |
| # k_seq_len == v_seq_len | |
| key_value_transformation = ( | |
| multiply_add_factor * batch_size * vision_embed_seq_len * (2 * attention_hidden_size * q_k_v_input_dim) | |
| ) | |
| k_seq_len = vision_embed_seq_len + q_seq_len | |
| attention_matrix_computation = multiply_add_factor * batch_size * q_seq_len * k_seq_len * attention_hidden_size | |
| attention_softmax = multiply_add_factor * q_seq_len * k_seq_len | |
| att_over_values_computation = multiply_add_factor * batch_size * q_seq_len * k_seq_len * attention_hidden_size | |
| post_attention_linear_proj = multiply_add_factor * batch_size * q_seq_len * attention_hidden_size * q_k_v_input_dim | |
| # There are usually 2 expansion_linear_layers because first one expands, and second one retracts back to hidden_size | |
| # When using a classic decoder, some blocks don't have those feed-forward layers | |
| if ff_exp_factor: | |
| expansion_linear_layers = 2 * ( | |
| multiply_add_factor * batch_size * q_seq_len * (q_k_v_input_dim * ff_exp_factor) * q_k_v_input_dim | |
| ) | |
| else: | |
| expansion_linear_layers = torch.tensor(0) | |
| transformer_block_flops = ( | |
| query_transformation | |
| + key_value_transformation | |
| + attention_matrix_computation | |
| + attention_softmax | |
| + att_over_values_computation | |
| + post_attention_linear_proj | |
| + expansion_linear_layers | |
| ) | |
| forward_fact = 1 | |
| backward_factor = 2 if count_backward else 0 | |
| grad_checkpointing_factor = 1 if use_grad_checkpointing else 0 | |
| model_flops = (forward_fact + backward_factor + grad_checkpointing_factor) * (num_layers * transformer_block_flops) | |
| model_tflops = model_flops / (10**12) | |
| return model_tflops | |
| def mem_usage_formatted(logging_type=LoggingTypes.PRINT): | |
| # adapted from deepspeed's see_memory_usage | |
| torch.cuda.empty_cache() | |
| # python doesn't do real-time garbage collection so do it explicitly to get the correct usage reports | |
| gc.collect() | |
| vm_stats = psutil.virtual_memory() | |
| mem = { | |
| "gpu mem alloc": f"{torch.cuda.memory_allocated()/2**30:0.2f}GB", | |
| "max alloc": f"{torch.cuda.max_memory_allocated()/2**30:0.2f}GB", | |
| "reserv": f"{torch.cuda.memory_reserved()/2**30:0.2f}GB", | |
| "max reserv": f"{torch.cuda.max_memory_reserved()/2**30:0.2f}GB", | |
| "cpu vm used": f"{(vm_stats.total-vm_stats.available)/2**30:0.2f}GB {vm_stats.percent}%", | |
| } | |
| if logging_type == LoggingTypes.PRINT: | |
| mem = " | ".join([f"{k}: {v}" for k, v in mem.items()]) + " | " | |
| # get the peak memory to report correct data, so reset the max_memory_allocated counter for the next call | |
| torch.cuda.reset_peak_memory_stats() | |
| return mem | |
| def is_deepspeed_used(): | |
| deepspeed_plugin = get_deepspeed_plugin() | |
| return deepspeed_plugin is not None | |
| def get_deepspeed_stage(): | |
| deepspeed_plugin = get_deepspeed_plugin() | |
| if deepspeed_plugin is None: | |
| return 0 | |
| ds_config = deepspeed_plugin.deepspeed_config | |
| stage = ds_config.get("zero_optimization", {}).get("stage", 0) | |
| # from accelerate>=0.17.1 can do instead: | |
| # stage = deepspeed_plugin.zero_stage | |
| return stage | |
| def is_deepspeed_zero3_used(): | |
| return get_deepspeed_stage() == 3 | |
| def accelerate_torch_dtype(): | |
| """ | |
| derive and return `torch_dtype` to be used in `from_pretrained` from either Deepspeed config or if | |
| Deepspeed isn't used than accelerator state | |
| """ | |
| if not is_accelerate_initialized(): | |
| return None | |
| accelerator_state = AcceleratorState() | |
| if is_deepspeed_used(): | |
| deepspeed_plugin = accelerator_state.deepspeed_plugin | |
| ds_config = deepspeed_plugin.deepspeed_config | |
| if ds_config.get("fp16", {}).get("enabled", False): | |
| torch_dtype = torch.float16 | |
| elif ds_config.get("bf16", {}).get("enabled", False): | |
| torch_dtype = torch.bfloat16 | |
| else: | |
| torch_dtype = None | |
| else: # no Deepspeed | |
| if accelerator_state.mixed_precision == "fp16": | |
| torch_dtype = torch.float16 | |
| elif accelerator_state.mixed_precision == "bf16": | |
| torch_dtype = torch.bfloat16 | |
| else: | |
| torch_dtype = None | |
| return torch_dtype | |
| def is_accelerate_initialized(): | |
| return accelerate.state.is_initialized() | |
| def get_deepspeed_plugin(): | |
| if is_accelerate_initialized(): | |
| return AcceleratorState().deepspeed_plugin | |
| else: | |
| return None | |
| def get_deepspeed_engine(accelerator): | |
| return accelerator.deepspeed_engine_wrapped.engine | |
| def is_deepspeed_zero_init_enabled(): | |
| deepspeed_plugin = get_deepspeed_plugin() | |
| if deepspeed_plugin is not None: | |
| return deepspeed_plugin.is_zero3_init_enabled() | |
| else: | |
| return False | |
| def hf_trainer_disable_zero3_init_context_manager(): | |
| # monkey patch hack to emulate a context that has zero_init disabled as it's used in | |
| # modeling_utils.py in transformers for from_config and from_pretrained. | |
| import transformers.modeling_utils # noqa | |
| orig = transformers.modeling_utils.is_deepspeed_zero3_enabled | |
| transformers.modeling_utils.is_deepspeed_zero3_enabled = lambda: False | |
| yield | |
| transformers.modeling_utils.is_deepspeed_zero3_enabled = orig | |
| def deepspeed_zero_init_disabled_context_manager(): | |
| """ | |
| returns either a context list that includes one that will disable zero.Init or an empty context list | |
| """ | |
| deepspeed_plugin = get_deepspeed_plugin() | |
| if deepspeed_plugin is not None: | |
| return [deepspeed_plugin.zero3_init_context_manager(enable=False)] | |
| else: | |
| return [hf_trainer_disable_zero3_init_context_manager()] | |
| def deepspeed_gathered_parameters_context_manager(params, modify=True): | |
| """ | |
| Under zero.Init returns a context manager that will gather the sharded param, otherwise returns an empty list | |
| If `modify` is `True`, gather the shards and once the context exits update the shards with the | |
| modified data - one wants that when modifying the gathered param. If one wants to just gather | |
| the shards in order to read the param and no modifications are done to it, use `modify=False` as | |
| it's more efficient. | |
| `params` - can be a single parameter, a list, or a tuple of parameters to collect. | |
| Example: | |
| from transformers.utils import ContextManagers | |
| from m4.training.utils import deepspeed_gathered_parameters_context_manager | |
| with ContextManagers(deepspeed_gathered_parameters_context_manager(module.weight, modify=True)): | |
| module.weight.data.normal_(mean=0.0, std=std) | |
| if module.padding_idx is not None: | |
| module.weight.data[module.padding_idx].zero_() | |
| """ | |
| if is_deepspeed_zero_init_enabled(): | |
| import deepspeed | |
| # 0 is for updating `params` shards after modifying it, `None` is for read-only (only gather) | |
| modifier_rank = 0 if modify else None | |
| return [deepspeed.zero.GatheredParameters(params, modifier_rank=modifier_rank)] | |
| else: | |
| return [] | |
| # adapted from https://github.com/huggingface/transformers/blob/a081f292ca8479eaf66d7396186021268f128829/src/transformers/modeling_utils.py#L438-L496 | |
| # as it appears to be a private function | |
| def load_state_dict_into_model(model_to_load, state_dict, start_prefix): | |
| # Convert old format to new format if needed from a PyTorch state_dict | |
| old_keys = [] | |
| new_keys = [] | |
| for key in state_dict.keys(): | |
| new_key = None | |
| if "gamma" in key: | |
| new_key = key.replace("gamma", "weight") | |
| if "beta" in key: | |
| new_key = key.replace("beta", "bias") | |
| if new_key: | |
| old_keys.append(key) | |
| new_keys.append(new_key) | |
| for old_key, new_key in zip(old_keys, new_keys): | |
| state_dict[new_key] = state_dict.pop(old_key) | |
| # copy state_dict so _load_from_state_dict can modify it | |
| metadata = getattr(state_dict, "_metadata", None) | |
| state_dict = state_dict.copy() | |
| if metadata is not None: | |
| state_dict._metadata = metadata | |
| error_msgs = [] | |
| # PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants | |
| # so we need to apply the function recursively. | |
| def load(module: torch.nn.Module, state_dict, prefix=""): | |
| local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) | |
| args = (state_dict, prefix, local_metadata, True, [], [], error_msgs) | |
| # Parameters of module and children will start with prefix. We can exit early if there are none in this | |
| # state_dict | |
| if len([key for key in state_dict if key.startswith(prefix)]) > 0: | |
| if is_deepspeed_zero_init_enabled(): | |
| import deepspeed | |
| # In sharded models, each shard has only part of the full state_dict, so only gather | |
| # parameters that are in the current state_dict. | |
| named_parameters = dict(module.named_parameters(prefix=prefix[:-1], recurse=False)) | |
| params_to_gather = [named_parameters[k] for k in state_dict.keys() if k in named_parameters] | |
| if len(params_to_gather) > 0: | |
| # because zero3 puts placeholders in model params, this context | |
| # manager gathers (unpartitions) the params of the current layer, then loads from | |
| # the state dict and then re-partitions them again | |
| with deepspeed.zero.GatheredParameters(params_to_gather, modifier_rank=0): | |
| if torch.distributed.get_rank() == 0: | |
| module._load_from_state_dict(*args) | |
| else: | |
| module._load_from_state_dict(*args) | |
| for name, child in module._modules.items(): | |
| if child is not None: | |
| load(child, state_dict, prefix + name + ".") | |
| load(model_to_load, state_dict, prefix=start_prefix) | |
| # Delete `state_dict` so it could be collected by GC earlier. Note that `state_dict` is a copy of the argument, so | |
| # it's safe to delete it. | |
| del state_dict | |
| return error_msgs | |
| def get_stats(var, ctx): | |
| if var is None: | |
| return {} | |
| var = var.float() | |
| abs_var = var.abs() | |
| return { | |
| f"{ctx}_var_min": var.min().item(), | |
| f"{ctx}_var_max": var.max().item(), | |
| f"{ctx}_var_mean": var.mean().item(), | |
| f"{ctx}_var_std": var.std().item(), | |
| f"{ctx}_abs_var_min": abs_var.min().item(), | |
| f"{ctx}_abs_var_max": abs_var.max().item(), | |
| f"{ctx}_abs_var_mean": abs_var.mean().item(), | |
| f"{ctx}_abs_var_std": abs_var.std().item(), | |
| f"{ctx}_var_norm_2": (var.norm(p=2) / var.numel()).item(), | |
| f"{ctx}_var_norm_1": (var.norm(p=1) / var.numel()).item(), | |
| f"{ctx}_nonzero": (var != 0).sum().item(), | |
| } | |
| def get_stats_format(ctx): | |
| return { | |
| f"{ctx}_var_min": "e", | |
| f"{ctx}_var_max": "e", | |
| f"{ctx}_var_mean": "e", | |
| f"{ctx}_var_std": "e", | |
| f"{ctx}_abs_var_min": "e", | |
| f"{ctx}_abs_var_max": "e", | |
| f"{ctx}_abs_var_mean": "e", | |
| f"{ctx}_abs_var_std": "e", | |
| f"{ctx}_var_norm_2": "e", | |
| f"{ctx}_var_norm_1": "e", | |
| f"{ctx}_nonzero": "", | |
| } | |