|
""" |
|
Based on: https://github.com/lucidrains/flamingo-pytorch |
|
""" |
|
|
|
import torch.nn as nn |
|
from .helpers import GatedCrossAttentionBlock |
|
from .utils import getattr_recursive, setattr_recursive |
|
|
|
from typing import List, Optional, Tuple, Union |
|
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union |
|
|
|
from transformers.utils import ModelOutput |
|
|
|
import torch |
|
class MixinLayer(nn.Module): |
|
""" |
|
MixinLayer is a wrapper around the GatedCrossAttentionBlock and DecoderLayer. |
|
""" |
|
|
|
def __init__( |
|
self, gated_cross_attn_layer, decoder_layer, gradient_checkpointing=False |
|
): |
|
super().__init__() |
|
self.gated_cross_attn_layer = gated_cross_attn_layer |
|
self.decoder_layer = decoder_layer |
|
self.vis_x = None |
|
if self.gated_cross_attn_layer is not None: |
|
self.gated_cross_attn_layer._use_gradient_checkpointing = ( |
|
gradient_checkpointing |
|
) |
|
self.decoder_layer._use_gradient_checkpointing = gradient_checkpointing |
|
|
|
def is_conditioned(self) -> bool: |
|
"""Check whether the layer is conditioned.""" |
|
return self.vis_x is not None |
|
|
|
|
|
def condition_vis_x(self, vis_x): |
|
self.vis_x = vis_x |
|
|
|
def condition_media(self, media, text_position_ids): |
|
if self.gated_cross_attn_layer is not None: |
|
self.gated_cross_attn_layer.media = media |
|
self.gated_cross_attn_layer.cross_attn.text_position_ids = text_position_ids |
|
|
|
def condition_use_cached_media(self, use_cached_media): |
|
self.use_cached_media = use_cached_media |
|
|
|
def forward( |
|
self, |
|
hidden_states: torch.Tensor, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
position_ids: Optional[torch.LongTensor] = None, |
|
past_key_value: Optional[Tuple[torch.Tensor]] = None, |
|
output_attentions: Optional[bool] = False, |
|
use_cache: Optional[bool] = False, |
|
**kwargs, |
|
): |
|
|
|
if self.gated_cross_attn_layer is not None and self.vis_x is not None: |
|
if self.vis_x is None: |
|
raise ValueError("vis_x must be conditioned before forward pass") |
|
|
|
hidden_states = self.gated_cross_attn_layer( |
|
hidden_states, |
|
self.vis_x, |
|
use_cached_media=self.use_cached_media, |
|
) |
|
|
|
|
|
hidden_states = self.decoder_layer( |
|
hidden_states=hidden_states, |
|
attention_mask=attention_mask, |
|
position_ids=position_ids, |
|
past_key_value=past_key_value, |
|
output_attentions=output_attentions, |
|
use_cache=use_cache, |
|
**kwargs |
|
) |
|
return hidden_states |
|
|
|
|
|
class LMMixin(nn.Module): |
|
""" |
|
Mixin to add cross-attention layers to a language model. |
|
""" |
|
|
|
def set_decoder_layers_attr_name(self, decoder_layers_attr_name): |
|
self.decoder_layers_attr_name = decoder_layers_attr_name |
|
|
|
def _get_decoder_layers(self): |
|
return getattr_recursive(self, self.decoder_layers_attr_name) |
|
|
|
def _set_decoder_layers(self, value): |
|
setattr_recursive(self, self.decoder_layers_attr_name, value) |
|
|
|
def init_mixin( |
|
self, |
|
config, |
|
gradient_checkpointing, |
|
): |
|
""" |
|
Initialize Mixin by adding a new gated cross attn to the decoder. Store the media token id for computing the media locations. |
|
""" |
|
self.old_decoder_blocks = self._get_decoder_layers() |
|
mixin_every_n_layers = config.mixin_every_n_layers |
|
self.gated_cross_attn_layers = nn.ModuleList( |
|
[ |
|
GatedCrossAttentionBlock(config) |
|
if (layer_idx + 1) % mixin_every_n_layers == 0 |
|
else None |
|
for layer_idx, _ in enumerate(self._get_decoder_layers()) |
|
] |
|
) |
|
|
|
self.init_mixin_layers(gradient_checkpointing) |
|
self.old_decoder_blocks = None |
|
self.gated_cross_attn_layers = None |
|
self.initialized_mixin = True |
|
self._use_cached_vision_x = False |
|
|
|
def init_mixin_layers(self, gradient_checkpointing): |
|
""" |
|
Re initializes the FlamingoLayers. |
|
Propagates any changes made to self.gated_corss_attn_layers or self.old_decoder_blocks |
|
""" |
|
self._set_decoder_layers( |
|
nn.ModuleList( |
|
[ |
|
MixinLayer( |
|
gated_cross_attn_layer, decoder_layer, gradient_checkpointing |
|
) |
|
for gated_cross_attn_layer, decoder_layer in zip( |
|
self.gated_cross_attn_layers, self.old_decoder_blocks |
|
) |
|
] |
|
) |
|
) |
|
|
|
def forward(self, position_ids=None,**kwargs |
|
): |
|
if not self.initialized_mixin: |
|
raise ValueError( |
|
"Flamingo layers are not initialized. Please call `init_flamingo` first." |
|
) |
|
|
|
kwargs["position_ids"] = position_ids |
|
return super().forward(**kwargs) |
|
|
|
|
|
def _update_model_kwargs_for_generation( |
|
self, |
|
outputs: ModelOutput, |
|
model_kwargs: Dict[str, Any], |
|
is_encoder_decoder: bool = False, |
|
standardize_cache_format: bool = False, |
|
) -> Dict[str, Any]: |
|
|
|
model_kwargs["past_key_values"] = self._extract_past_from_model_output( |
|
outputs, standardize_cache_format=standardize_cache_format |
|
) |
|
if getattr(outputs, "state", None) is not None: |
|
model_kwargs["state"] = outputs.state |
|
|
|
|
|
if "token_type_ids" in model_kwargs: |
|
token_type_ids = model_kwargs["token_type_ids"] |
|
model_kwargs["token_type_ids"] = torch.cat([token_type_ids, token_type_ids[:, -1].unsqueeze(-1)], dim=-1) |
|
|
|
if not is_encoder_decoder: |
|
|
|
if "attention_mask" in model_kwargs: |
|
attention_mask = model_kwargs["attention_mask"] |
|
model_kwargs["attention_mask"] = torch.cat( |
|
[attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1 |
|
) |
|
else: |
|
|
|
if "decoder_attention_mask" in model_kwargs: |
|
decoder_attention_mask = model_kwargs["decoder_attention_mask"] |
|
model_kwargs["decoder_attention_mask"] = torch.cat( |
|
[decoder_attention_mask, decoder_attention_mask.new_ones((decoder_attention_mask.shape[0], 1))], |
|
dim=-1, |
|
) |
|
|
|
|
|
if model_kwargs['past_key_values'] and 'position_ids' in model_kwargs: |
|
new_pos_ids = model_kwargs['position_ids'][:, -1:] + 1 |
|
model_kwargs['position_ids'] = new_pos_ids |
|
|
|
return model_kwargs |
|
|
|
|
|
def is_conditioned(self) -> bool: |
|
"""Check whether all decoder layers are already conditioned.""" |
|
return all(l.is_conditioned() for l in self._get_decoder_layers()) |
|
|
|
def clear_conditioned_layers(self): |
|
for layer in self._get_decoder_layers(): |
|
layer.condition_vis_x(None) |
|
layer.condition_use_cached_media(False) |
|
layer.condition_media(None, None) |
|
|