Spaces:
Build error
Build error
| import os | |
| import torch | |
| import torch.nn.functional as F | |
| from torch import nn | |
| from transformers import AutoConfig, AutoModel, PretrainedConfig, PreTrainedModel | |
| from transformers.utils import ContextManagers | |
| from m4.training.setup_vision_model import vision_model_name_to_model | |
| from m4.training.utils import ( | |
| deepspeed_zero_init_disabled_context_manager, | |
| is_deepspeed_zero_init_enabled, | |
| load_state_dict_into_model, | |
| ) | |
| # from pathlib import Path | |
| class VLOOMPreTrainedModelBase(PreTrainedModel): | |
| # The problem we are trying to solve is 2 nested zero.Init thanks to fetching from_pretrained(vision_model_name) | |
| # and then one more zero.Init to override from_pretrained(vision_model_name) once again as it was done in the original - this breaks deepspeed zero3 w/ zero.Init | |
| # So one solution is this: | |
| # a. replace from_pretrained(vision_model_name) with from_config(vision_model_name) while hacking to disable zero.Init context | |
| # b. instead of straight replacement of model.vision_model = from_pretrained(vision_model_name) when it gets updated, we first do from_pretrained(vision_model_name) and then update the existing model with weights using the already zero.Init'ed pre-sharded weights | |
| # | |
| # there are a few variations to get_vision_model_from_config - all need to bypass zero.Init under zero3 | |
| # 1. one variant is to hack into accelerate's deepspeed_plugin and turn off zero.Init while loading the vision model | |
| # 2. the other variant is to override _from_config method with our version that doesn't do zero.Init | |
| def override_vision_model(cls, model, vision_model_name, vision_model_params, torch_dtype): | |
| # 1. fetch the pretrained vision model w/o zero.Init | |
| with ContextManagers(deepspeed_zero_init_disabled_context_manager()): | |
| vision_model = AutoModel.from_pretrained(vision_model_name, **vision_model_params, torch_dtype=torch_dtype) | |
| # this extracts the desired submodule if the part we want is nested (e.g. as in clip) | |
| real_vision_model = vision_model_name_to_model(vision_model_name, vision_model) | |
| # 2. now override the weights already sharded by zero.Init with the weights from the real_vision_model | |
| # by gradually gathering sharded weights and replacing with new weights | |
| if is_deepspeed_zero_init_enabled(): | |
| state_dict = real_vision_model.state_dict() | |
| load_state_dict_into_model(model.vision_model, state_dict, start_prefix="") | |
| else: | |
| model.vision_model = real_vision_model | |
| def from_config(cls, config, **kwargs): | |
| # torch_dtype is crucial for using the minimal amount of memory at load time | |
| torch_dtype = kwargs.get("torch_dtype", None) | |
| vision_model_name = config.vision_model_name | |
| vision_model_params = eval(config.vision_model_params) | |
| # 1. create an uninitialized vision_model to insert into the main model. | |
| # It has to be created outside lm's `from_pretrained` and w/o zero.Init so that zero3+zero.Init works | |
| with ContextManagers(deepspeed_zero_init_disabled_context_manager()): | |
| vision_model_config = AutoConfig.from_pretrained(vision_model_name, **vision_model_params) | |
| vision_model_from_config = AutoModel.from_config(vision_model_config, torch_dtype=torch_dtype) | |
| # this extracts the desired submodule if the part we want is nested (e.g. as in clip) | |
| kwargs["vision_model"] = vision_model_name_to_model(vision_model_name, vision_model_from_config) | |
| # 2. create the main class's model, passing the uninitialized vision_model to it | |
| model = cls(config, **kwargs) | |
| return model | |
| def from_pretrained_models(cls, *args, **kwargs): | |
| """ | |
| Use this method when creating a new vloom model that hasn't been yet trained and it'll be | |
| composed of 2 pre-trained models - hence `pretrained_models`. | |
| """ | |
| return cls.from_pretrained(*args, **kwargs, new_model=True) | |
| def from_pretrained(cls, *model_args, is_resume=False, new_model=False, **kwargs): | |
| """ | |
| Use this method when loading an already pretrained vloom model - either from a checkpoint or from hub. | |
| For creating an untrained model use `pretrained_models` instead. | |
| """ | |
| is_untrained_vloom_model = False | |
| is_pretrained_vloom_model_resumed = False | |
| is_pretrained_vloom_model_from_hub_or_path = False | |
| # we have 3 use cases: | |
| # 1. is_untrained_vloom_model - a totally new vloom model | |
| # 2. is_pretrained_vloom_model_resumed - a pretrained vloom model being resumed from a | |
| # checkpoint (instantiate a random empty model in this case) | |
| # 3. is_pretrained_vloom_model_from_hub_or_path - a pretrained vloom model loaded from hub or local path | |
| if new_model: | |
| is_untrained_vloom_model = True | |
| elif is_resume: | |
| is_pretrained_vloom_model_resumed = True | |
| else: | |
| is_pretrained_vloom_model_from_hub_or_path = True | |
| # torch_dtype is crucial for using the minimal amount of memory at load time | |
| torch_dtype = kwargs.get("torch_dtype", None) | |
| # config is: | |
| # 1. either not passed and then we use the model's default config (used by tests) | |
| # 2. passed and in which case it's one of: | |
| # 2a. `PretrainedConfig` (a new m4 model) | |
| # 2b. path to a json config (an already pretrained m4 model, usually resumed training) | |
| config = kwargs.get("config", None) | |
| if config is None: | |
| config = cls.config_class.from_pretrained(*model_args, **kwargs, return_unused_kwargs=False) | |
| elif not isinstance(config, PretrainedConfig): | |
| # adapted from https://github.com/huggingface/transformers/blob/d0acc9537829e7d067edbb791473bbceb2ecf056/src/transformers/modeling_utils.py#L1920 | |
| assert isinstance(config, os.PathLike) | |
| config_path = str(config) | |
| config = cls.config_class.from_pretrained( | |
| config_path, | |
| return_unused_kwargs=False, | |
| **kwargs, | |
| ) | |
| vision_model_name = config.vision_model_name | |
| vision_model_params = eval(config.vision_model_params) | |
| # 1. create an uninitialized vision_model to insert into the main model. | |
| # It has to be created outside lm's `from_pretrained` and w/o zero.Init so that zero3+zero.Init works | |
| with ContextManagers(deepspeed_zero_init_disabled_context_manager()): | |
| vision_model_config = AutoConfig.from_pretrained(vision_model_name, **vision_model_params) | |
| vision_model_from_config = AutoModel.from_config(vision_model_config, torch_dtype=torch_dtype) | |
| # this extracts the desired submodule if the part we want is nested (e.g. as in clip) | |
| kwargs["vision_model"] = vision_model_name_to_model(vision_model_name, vision_model_from_config) | |
| # 2. create the vloom model | |
| if is_untrained_vloom_model or is_pretrained_vloom_model_from_hub_or_path: | |
| model = super().from_pretrained(*model_args, **kwargs) | |
| elif is_pretrained_vloom_model_resumed: | |
| # in the case of resume under deepspeed we create an empty model, and get deepspeed | |
| # to load the weights from the checkpoint | |
| # but not all models have these keys so handle the case they don't have them | |
| _ = kwargs.pop("config", None) | |
| model = super().from_pretrained(None, config=config, state_dict={}, **kwargs) | |
| # 3. if is_untrained_vloom_model, now override the uninitialized vision_model with one with pretrained weights | |
| if is_untrained_vloom_model: | |
| cls.override_vision_model_wrapper(model, config, vision_model_name, vision_model_params, torch_dtype) | |
| return model | |
| class DecoupledEmbedding(nn.Embedding): | |
| # Derived from https://pytorch.org/docs/stable/_modules/torch/nn/modules/sparse.html#Embedding | |
| """ | |
| Implements a decoupling of parameters to allow freezing (or not) a subset of the embeddings. | |
| In practise, the regular `weight` can be trained or frozen (i.e. `partially_freeze=True`), and if `num_additional_embeddings` > 0, then it will create `num_additional_embeddings` additional parameters that are always trained. | |
| If `num_additional_embeddings=0`, then the module defaults back to the regular behavior of `nn.Embedding`. | |
| """ | |
| def __init__( | |
| self, | |
| num_embeddings, | |
| num_additional_embeddings, | |
| embedding_dim, | |
| partially_freeze=False, | |
| device=None, | |
| dtype=None, | |
| padding_idx=None, | |
| **kwargs, | |
| ) -> None: | |
| """ | |
| num_additional_embeddings: int. Number of additional embeddings. Only useful when you `partially_freeze=True`. | |
| partially_freeze: bool. If True, the regular `weight` will be frozen. `additional_weight` is never frozen. | |
| Note: there are a lot of other parameters to initialize a standard `nn.Embedding` such as `padding_idx`, `max_norm` or `norm_type`. We are not supporting these. | |
| """ | |
| if padding_idx is not None and padding_idx > num_embeddings: | |
| raise ValueError(f"padding_idx must be within num_embeddings. Got {padding_idx} and {num_embeddings}") | |
| super().__init__( | |
| num_embeddings=num_embeddings, | |
| embedding_dim=embedding_dim, | |
| device=device, | |
| dtype=dtype, | |
| padding_idx=padding_idx, | |
| **kwargs, | |
| ) | |
| self.num_embeddings = num_embeddings | |
| self.padding_idx = padding_idx | |
| self.num_additional_embeddings = num_additional_embeddings | |
| self.partially_freeze = partially_freeze | |
| if partially_freeze: | |
| self.weight.requires_grad_(False) | |
| if self.num_additional_embeddings > 0: | |
| self.additional_embedding = nn.Embedding( | |
| num_embeddings=self.num_additional_embeddings, | |
| embedding_dim=embedding_dim, | |
| device=device, | |
| dtype=dtype, | |
| ) | |
| def forward(self, input_ids): | |
| """ | |
| we have 2 embeddings, with different indices - one pretrained self.weight and another | |
| self.additional_embedding.weight that is being trained. | |
| in order to make a lookup of the input ids, we: | |
| 1. find out the indices of the entries belonging to the 2nd embedding | |
| 2. extract those values while subtracting the size of the first embedding (num_embeddings), | |
| since the 2nd embedding starts from 0 and not num_embeddings | |
| 3. perform the 2nd embedding lookup | |
| 4. now we handle the 1st embedding, we overwrite indices belonging to the 2nd embedding with a padding index | |
| 5. perform the 1st embedding lookup | |
| 6. now we overwrite the values in the 1st embedding lookup with the values of the 2nd embedding lookup | |
| note: for the 1st embedding lookup we could have looked up only the low indices and not do | |
| the padding, but then we have to create a new tensor and populate it with 2 tensors that are | |
| spread out across various indices - i.e. not a simple concat - I haven't benchmarked the | |
| complex case if it's any faster, given that seqlens are usually relatively short it's | |
| probably not faster or if faster not by much - but might be a good idea to measure. | |
| """ | |
| if self.num_additional_embeddings == 0: | |
| return F.embedding(input_ids, self.weight) | |
| # Clone so that we don't modify the original input_ids later on | |
| input_ids = input_ids.clone() | |
| additional_vocab_indices = torch.where(input_ids >= self.num_embeddings) | |
| input_ids_additional_vocab = input_ids[additional_vocab_indices] | |
| additional_embeddings = self.additional_embedding(input_ids_additional_vocab - self.num_embeddings) | |
| # for successful lookup replace input_ids with 0, the results of these will be discarded anyway | |
| input_ids[additional_vocab_indices] = 0 | |
| full_vector = F.embedding(input_ids, self.weight) | |
| # overwrite the records with high indices | |
| full_vector[additional_vocab_indices] = additional_embeddings | |
| return full_vector | |
| def extra_repr(self) -> str: | |
| return "num_embeddings={}, num_additional_embeddings={}, embedding_dim={}, partially_freeze={}".format( | |
| self.num_embeddings, | |
| self.num_additional_embeddings, | |
| self.embedding_dim, | |
| self.partially_freeze, | |
| ) | |
| def from_pretrained(cls, embeddings, freeze=True, **kwargs): | |
| raise NotImplementedError | |
| class DecoupledLinear(nn.Linear): | |
| # Derived from https://pytorch.org/docs/stable/_modules/torch/nn/modules/linear.html#Linear | |
| """ | |
| Implements a decoupling of parameters to allow freezing (or not) a subset of the parameters. | |
| In practise, the regular `weight` can be trained or frozen (i.e. `partially_freeze=True`), and if `out_additional_features` > 0, then it will create `out_additional_features * in_features` additional parameters that are always trained. | |
| If `out_additional_features=0`, then the module defaults back to the regular behavior of `nn.Linear`. | |
| """ | |
| def __init__( | |
| self, | |
| in_features: int, | |
| out_features: int, | |
| out_additional_features: int = 0, | |
| bias: bool = True, | |
| partially_freeze: bool = True, | |
| device=None, | |
| dtype=None, | |
| ) -> None: | |
| """ | |
| out_additional_features: int. Number of additional trainable dimensions. Only makes sense when `partially_freeze=True`. | |
| partially_freeze: bool. If True, the regular `weight` will be frozen and extra parameters (if any) will be trainable. If False, default to the regular behavior of nn.Linear. | |
| """ | |
| super().__init__(in_features, out_features, bias, device, dtype) | |
| self.out_additional_features = out_additional_features | |
| self.partially_freeze = partially_freeze | |
| self.in_features = in_features | |
| self.out_features = out_features | |
| if partially_freeze: | |
| self.weight.requires_grad_(False) | |
| if bias: | |
| self.bias.requires_grad_(False) | |
| if out_additional_features > 0: | |
| self.additional_fc = nn.Linear( | |
| in_features=in_features, | |
| out_features=out_additional_features, | |
| bias=bias, | |
| device=device, | |
| dtype=dtype, | |
| ) | |
| def forward(self, input: torch.Tensor) -> torch.Tensor: | |
| output = F.linear(input, self.weight, self.bias) | |
| if self.out_additional_features > 0: | |
| additional_features = F.linear(input, self.additional_fc.weight, self.additional_fc.bias) | |
| output = torch.cat((output, additional_features), -1) | |
| return output | |
| def extra_repr(self) -> str: | |
| """Overwriting `nn.Linear.extra_repr` to include new parameters.""" | |
| return "in_features={}, out_features={}, out_additional_features={}, bias={}, partially_freeze={}".format( | |
| self.in_features, | |
| self.out_features, | |
| self.out_additional_features, | |
| self.bias is not None, | |
| self.partially_freeze, | |
| ) | |
| if __name__ == "__main__": | |
| emb = DecoupledEmbedding(num_embeddings=10, num_additional_embeddings=3, embedding_dim=5, partially_freeze=True) | |
| for n, p in emb.named_parameters(): | |
| print(n, p.requires_grad) | |
| idx = torch.tensor([[11, 1, 3]]) | |
| y = emb(idx) | |
| loss = y.sum() | |
| loss.backward() | |
| print(emb.weight, emb.weight.grad) | |
| print(emb.additional_embedding, emb.additional_embedding.grad) | |
| lin = DecoupledLinear(in_features=3, out_features=4, out_additional_features=2, bias=True, partially_freeze=True) | |
| for n, p in lin.named_parameters(): | |
| print(n, p.requires_grad) | |
| x = torch.randn(12, 3) | |
| y = lin(x) | |
| loss = y.sum() | |
| loss.backward() | |
| print("Weight w and grad:", lin.weight, lin.weight.grad) | |
| print("bias w and grad:", lin.bias, lin.bias.grad) | |
| print("additional_fc.weight w and grad:", lin.additional_fc.weight, lin.additional_fc.weight.grad) | |
| print("additional_bias w and grad:", lin.additional_fc.bias, lin.additional_fc.bias.grad) | |