| """ | |
| patch to add noisy embeddings per https://arxiv.org/abs/2310.05914 | |
| """ | |
| import torch | |
| import transformers.models.llama.modeling_llama | |
| from transformers.utils import logging | |
| logger = logging.get_logger(__name__) | |
| def replace_llama_embeddings_with_uniform_distribution(noise_alpha=5): | |
| # pylint: disable=duplicate-code | |
| def noised_embed(orig_embed, noise_alpha, model): | |
| def new_func(input_ids): | |
| # during training, we add noise to the embedding | |
| # during generation, we don't add noise to the embedding | |
| if model.training: | |
| embed_init = orig_embed(input_ids) | |
| dims = torch.tensor(embed_init.size(1) * embed_init.size(2)) | |
| mag_norm = noise_alpha / torch.sqrt(dims) | |
| return embed_init + torch.zeros_like(embed_init).uniform_( | |
| -mag_norm, mag_norm | |
| ) | |
| return orig_embed(input_ids) | |
| return new_func | |
| def post_init(orig_post_init): | |
| def new_func(self): | |
| orig_post_init(self) | |
| self.embed_tokens.forward = noised_embed( | |
| self.embed_tokens.forward, noise_alpha, self | |
| ) | |
| return new_func | |
| transformers.models.llama.modeling_llama.LlamaModel.post_init = post_init( | |
| transformers.models.llama.modeling_llama.LlamaModel.post_init | |
| ) | |