Spaces:
Runtime error
Runtime error
| import torch | |
| from einops import rearrange | |
| from torch import nn | |
| import torch.nn.utils as utils | |
| from torch.distributed.fsdp.wrap import ( | |
| enable_wrap, | |
| wrap, | |
| ) | |
| from transformers.modeling_outputs import CausalLMOutputWithPast | |
| from torch.distributed.fsdp import ( | |
| FullyShardedDataParallel as FSDP, | |
| ) | |
| try: | |
| from .helpers import TransformerEncoder | |
| from .utils import apply_with_stopping_condition | |
| except: | |
| from helpers import TransformerEncoder | |
| from utils import apply_with_stopping_condition | |
| class Flamingo(nn.Module): | |
| def __init__( | |
| self, | |
| clap, | |
| unfreeze_clap, | |
| lang_encoder: nn.Module, | |
| eoc_token_id: int, | |
| media_token_id: int, | |
| sep_token_id: int, | |
| clap_embed_dim: int, | |
| audio_transformer_kwargs: dict, | |
| cross_attn_every_n_layers: int = 1, | |
| gradient_checkpointing: bool = False, | |
| ): | |
| super().__init__() | |
| self.eoc_token_id = eoc_token_id | |
| self.media_token_id = media_token_id | |
| self.sep_token_id = sep_token_id | |
| # initialize embedding dimensions | |
| self.clap_embed_dim = clap_embed_dim | |
| # Initilaize CLAP | |
| self.clap = clap | |
| self.unfreeze_clap = unfreeze_clap | |
| self.clap.requires_grad_(unfreeze_clap) | |
| if hasattr(lang_encoder.config, "d_model"): | |
| self.lang_dim = lang_encoder.config.d_model # mpt uses d_model | |
| else: | |
| self.lang_dim = lang_encoder.config.hidden_size | |
| # define just one type of audio transformer | |
| n_head = audio_transformer_kwargs["n_head"] | |
| n_layers = audio_transformer_kwargs["n_layers"] | |
| d_inner = audio_transformer_kwargs["d_inner"] | |
| max_num_media = audio_transformer_kwargs["max_num_media"] | |
| max_window_per_audio = audio_transformer_kwargs["max_window_per_audio"] | |
| common_encoder_embed_dim = clap_embed_dim | |
| # define the transformers | |
| assert common_encoder_embed_dim % n_head == 0 | |
| self.audio_transformer_clap = TransformerEncoder( | |
| d_word_vec=common_encoder_embed_dim, | |
| n_layers=n_layers, | |
| n_head=n_head, | |
| d_k=common_encoder_embed_dim // n_head, | |
| d_v=common_encoder_embed_dim // n_head, | |
| d_model=common_encoder_embed_dim, | |
| d_inner=d_inner, | |
| dropout=0.0, | |
| n_position=max_num_media, | |
| scale_emb=True | |
| ) | |
| self.lang_encoder = lang_encoder | |
| self.lang_encoder.init_flamingo( | |
| media_token_id=media_token_id, | |
| lang_hidden_size=self.lang_dim, | |
| audio_hidden_size=common_encoder_embed_dim, | |
| max_window_per_audio=max_window_per_audio, | |
| cross_attn_every_n_layers=cross_attn_every_n_layers, | |
| gradient_checkpointing=gradient_checkpointing, | |
| ) | |
| self._use_gradient_checkpointing = gradient_checkpointing | |
| # enable gradient checkpoint for the audio transformers | |
| self.audio_transformer_clap._use_gradient_checkpointing = gradient_checkpointing | |
| # enable gradient checkpoint for encoders | |
| self.clap._use_gradient_checkpointing = gradient_checkpointing | |
| def forward( | |
| self, | |
| audio_x: torch.Tensor, | |
| audio_x_mask: torch.Tensor, | |
| lang_x: torch.Tensor, | |
| attention_mask: torch.Tensor = None, | |
| labels: torch.Tensor = None, | |
| clear_conditioned_layers: bool = True, | |
| past_key_values=None, | |
| use_cache: bool = False, | |
| ): | |
| assert ( | |
| self.lang_encoder.initialized_flamingo | |
| ), "Flamingo layers are not initialized. Please call `init_flamingo` first." | |
| assert ( | |
| self.lang_encoder._use_cached_audio_x or audio_x is not None | |
| ), "Must provide either audio_x or have precached media using cache_media()." | |
| if self.lang_encoder._use_cached_audio_x: | |
| assert ( | |
| audio_x is None | |
| ), "Expect audio_x to be None when media has been cached using cache_media(). Try uncache_media() first." | |
| assert self.lang_encoder.is_conditioned() | |
| else: | |
| self._encode_audio_x(audio_x=audio_x, audio_x_mask=audio_x_mask) | |
| self._condition_media_locations(input_ids=lang_x) | |
| output = self.lang_encoder( | |
| input_ids=lang_x, | |
| attention_mask=attention_mask, | |
| labels=labels, | |
| past_key_values=past_key_values, | |
| use_cache=use_cache, | |
| ) | |
| if clear_conditioned_layers: | |
| self.lang_encoder.clear_conditioned_layers() | |
| return output | |
| def generate( | |
| self, | |
| audio_x: torch.Tensor, | |
| audio_x_mask: torch.Tensor, | |
| lang_x: torch.Tensor, | |
| attention_mask: torch.Tensor = None, | |
| **kwargs, | |
| ): | |
| num_beams = kwargs.pop("num_beams", 1) | |
| if num_beams > 1: | |
| audio_x = audio_x.repeat_interleave(num_beams, dim=0) | |
| self.lang_encoder._use_cached_audio_x = True | |
| self._encode_audio_x(audio_x=audio_x, audio_x_mask=audio_x_mask) | |
| eos_token_id = kwargs.pop("eos_token_id", self.eoc_token_id) | |
| output = self.lang_encoder.generate( | |
| input_ids=lang_x, | |
| attention_mask=attention_mask, | |
| eos_token_id=eos_token_id, | |
| num_beams=num_beams, | |
| **kwargs, | |
| ) | |
| self.lang_encoder.clear_conditioned_layers() | |
| self.lang_encoder._use_cached_audio_x = False | |
| return output | |
| def _encode_audio_x(self, audio_x: torch.Tensor, audio_x_mask: torch.Tensor): | |
| """ | |
| rearrange code based on https://github.com/dhansmair/flamingo-mini | |
| """ | |
| assert audio_x.ndim == 3, "audio_x should be of shape (B, num_window, window_length)" | |
| #------------------------------------------------------------------------# | |
| # get embeddings from CLAP | |
| audio_embeds = self.clap(audio_x) | |
| B, L, H, D = audio_embeds.shape # L is number of windows, D is feature dim | |
| assert D == self.clap_embed_dim | |
| audio_x_out = rearrange(audio_embeds, 'b l h d -> b (l h) d') | |
| # handle the masks | |
| expanded_speech_mask = audio_x_mask.repeat_interleave(H, dim=1) # B, (LxH) | |
| if B > 1 and expanded_speech_mask.shape[0] == 1: | |
| expanded_speech_mask = expanded_speech_mask.repeat(B, 1) | |
| assert expanded_speech_mask.shape[0] == B and expanded_speech_mask.shape[1] == L*H, "{} != ({},{})".format(expanded_speech_mask.shape, B, L*H) | |
| #------------------------------------------------------------------------# | |
| audio_x_out = self.audio_transformer_clap(audio_x_out, causal_mask = expanded_speech_mask) # B, LxH, D | |
| # Unsqueeze to handle Falmingo code | |
| audio_x_out = audio_x_out.unsqueeze(2) # B, L, n=1, D | |
| audio_x_mask = expanded_speech_mask.unsqueeze(2) # B, L, n=1 | |
| #------------------------------------------------------------------------# | |
| for layer in self.lang_encoder._get_decoder_layers(): | |
| layer.condition_audio_x(audio_x_out, audio_x_mask) | |
| def wrap_fsdp(self, wrapper_kwargs, device_id): | |
| # unfreeze the decoder layers | |
| for block in self.lang_encoder.old_decoder_blocks: | |
| block.requires_grad_(True) | |
| # wrap in FSDP | |
| with enable_wrap(wrapper_cls=FSDP, **wrapper_kwargs): | |
| self.audio_transformer_clap = wrap(wrap(self.audio_transformer_clap)) | |
| self.lang_encoder.old_decoder_blocks = nn.ModuleList( | |
| wrap(wrap(block)) for block in self.lang_encoder.old_decoder_blocks | |
| ) | |
| self.lang_encoder.gated_cross_attn_layers_sound = nn.ModuleList( | |
| wrap(wrap(layer)) if layer is not None else None | |
| for layer in self.lang_encoder.gated_cross_attn_layers_sound | |
| ) | |
| self.lang_encoder.init_flamingo_layers(self._use_gradient_checkpointing) | |
| self.lang_encoder.set_input_embeddings( | |
| wrap(wrap(self.lang_encoder.get_input_embeddings())) | |
| ) | |
| if hasattr(self.lang_encoder, 'set_output_embeddings'): | |
| self.lang_encoder.set_output_embeddings( | |
| wrap(wrap(self.lang_encoder.get_output_embeddings())) | |
| ) | |
| else: | |
| print('skip wrapping output embeddings') | |
| # manually move non-FSDP managed parameters to device_id | |
| # these are all in lang_encoder | |
| apply_with_stopping_condition( | |
| module=self.lang_encoder, | |
| apply_fn=lambda m: m.to(device_id), | |
| apply_condition=lambda m: len(list(m.children())) == 0, | |
| stopping_condition=lambda m: isinstance(m, FSDP), | |
| ) | |
| # clap shouldn't be wrapped; should be on each gpu | |
| if self.unfreeze_clap: | |
| apply_with_stopping_condition( | |
| module=self.clap, | |
| apply_fn=lambda m: m.to(device_id), | |
| apply_condition=lambda m: len(list(m.children())) == 0, | |
| stopping_condition=lambda m: isinstance(m, FSDP), | |
| ) | |
| # exclude the original decoder layers from the optimizer | |
| for block in self.lang_encoder.old_decoder_blocks: | |
| for p in block.parameters(): | |
| p.exclude_from_optimizer = True | |
| # set up clip_grad_norm_ function | |
| def clip_grad_norm_(max_norm): | |
| utils.clip_grad_norm_(self.clap.nvclap.parameters(), max_norm=max_norm) | |
| self.audio_transformer_clap.clip_grad_norm_(max_norm) | |
| for layer in self.lang_encoder.gated_cross_attn_layers_sound: | |
| if layer is not None: | |
| layer.clip_grad_norm_(max_norm) | |
| self.lang_encoder.get_input_embeddings().clip_grad_norm_(max_norm) | |
| self.clip_grad_norm_ = clip_grad_norm_ | |
| def _condition_media_locations(self, input_ids: torch.Tensor): | |
| media_locations = (input_ids == self.media_token_id) | |
| for layer in self.lang_encoder._get_decoder_layers(): | |
| layer.condition_media_locations(media_locations) | |
| def cache_media(self, input_ids: torch.Tensor, audio_x: torch.Tensor, audio_x_mask: torch.Tensor): | |
| self._encode_audio_x(audio_x=audio_x, audio_x_mask=audio_x_mask) | |
| self._condition_media_locations(input_ids=input_ids) | |
| self.lang_encoder._use_cached_audio_x = True | |
| def uncache_media(self): | |
| self.lang_encoder.clear_conditioned_layers() | |
| self.lang_encoder._use_cached_audio_x = False | |