CoMemo-9B / mixin_lm.py
CLLBJ16's picture
Upload folder using huggingface_hub
a12c31a verified
"""
Based on: https://github.com/lucidrains/flamingo-pytorch
"""
import torch.nn as nn
from .helpers import GatedCrossAttentionBlock
from .utils import getattr_recursive, setattr_recursive
from typing import List, Optional, Tuple, Union
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
from transformers.utils import ModelOutput
import torch
class MixinLayer(nn.Module):
"""
MixinLayer is a wrapper around the GatedCrossAttentionBlock and DecoderLayer.
"""
def __init__(
self, gated_cross_attn_layer, decoder_layer, gradient_checkpointing=False
):
super().__init__()
self.gated_cross_attn_layer = gated_cross_attn_layer
self.decoder_layer = decoder_layer
self.vis_x = None
if self.gated_cross_attn_layer is not None:
self.gated_cross_attn_layer._use_gradient_checkpointing = (
gradient_checkpointing
)
self.decoder_layer._use_gradient_checkpointing = gradient_checkpointing
def is_conditioned(self) -> bool:
"""Check whether the layer is conditioned."""
return self.vis_x is not None
# Used this great idea from this implementation of Flamingo (https://github.com/dhansmair/flamingo-mini/)
def condition_vis_x(self, vis_x):
self.vis_x = vis_x
def condition_media(self, media, text_position_ids):
if self.gated_cross_attn_layer is not None:
self.gated_cross_attn_layer.media = media
self.gated_cross_attn_layer.cross_attn.text_position_ids = text_position_ids
def condition_use_cached_media(self, use_cached_media):
self.use_cached_media = use_cached_media
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
**kwargs,
):
# Cross attention
if self.gated_cross_attn_layer is not None and self.vis_x is not None:
if self.vis_x is None:
raise ValueError("vis_x must be conditioned before forward pass")
hidden_states = self.gated_cross_attn_layer(
hidden_states,
self.vis_x,
use_cached_media=self.use_cached_media,
)
# Normal decoder layer
hidden_states = self.decoder_layer(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
**kwargs
)
return hidden_states
class LMMixin(nn.Module):
"""
Mixin to add cross-attention layers to a language model.
"""
def set_decoder_layers_attr_name(self, decoder_layers_attr_name):
self.decoder_layers_attr_name = decoder_layers_attr_name
def _get_decoder_layers(self):
return getattr_recursive(self, self.decoder_layers_attr_name)
def _set_decoder_layers(self, value):
setattr_recursive(self, self.decoder_layers_attr_name, value)
def init_mixin(
self,
config,
gradient_checkpointing,
):
"""
Initialize Mixin by adding a new gated cross attn to the decoder. Store the media token id for computing the media locations.
"""
self.old_decoder_blocks = self._get_decoder_layers()
mixin_every_n_layers = config.mixin_every_n_layers
self.gated_cross_attn_layers = nn.ModuleList(
[
GatedCrossAttentionBlock(config)
if (layer_idx + 1) % mixin_every_n_layers == 0
else None
for layer_idx, _ in enumerate(self._get_decoder_layers())
]
)
self.init_mixin_layers(gradient_checkpointing)
self.old_decoder_blocks = None
self.gated_cross_attn_layers = None
self.initialized_mixin = True
self._use_cached_vision_x = False
def init_mixin_layers(self, gradient_checkpointing):
"""
Re initializes the FlamingoLayers.
Propagates any changes made to self.gated_corss_attn_layers or self.old_decoder_blocks
"""
self._set_decoder_layers(
nn.ModuleList(
[
MixinLayer(
gated_cross_attn_layer, decoder_layer, gradient_checkpointing
)
for gated_cross_attn_layer, decoder_layer in zip(
self.gated_cross_attn_layers, self.old_decoder_blocks
)
]
)
)
def forward(self, position_ids=None,**kwargs
):
if not self.initialized_mixin:
raise ValueError(
"Flamingo layers are not initialized. Please call `init_flamingo` first."
)
kwargs["position_ids"] = position_ids
return super().forward(**kwargs) # Call the other parent's forward method
def _update_model_kwargs_for_generation(
self,
outputs: ModelOutput,
model_kwargs: Dict[str, Any],
is_encoder_decoder: bool = False,
standardize_cache_format: bool = False,
) -> Dict[str, Any]:
# update past_key_values
model_kwargs["past_key_values"] = self._extract_past_from_model_output(
outputs, standardize_cache_format=standardize_cache_format
)
if getattr(outputs, "state", None) is not None:
model_kwargs["state"] = outputs.state
# update token_type_ids with last value
if "token_type_ids" in model_kwargs:
token_type_ids = model_kwargs["token_type_ids"]
model_kwargs["token_type_ids"] = torch.cat([token_type_ids, token_type_ids[:, -1].unsqueeze(-1)], dim=-1)
if not is_encoder_decoder:
# update attention mask
if "attention_mask" in model_kwargs:
attention_mask = model_kwargs["attention_mask"]
model_kwargs["attention_mask"] = torch.cat(
[attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
)
else:
# update decoder attention mask
if "decoder_attention_mask" in model_kwargs:
decoder_attention_mask = model_kwargs["decoder_attention_mask"]
model_kwargs["decoder_attention_mask"] = torch.cat(
[decoder_attention_mask, decoder_attention_mask.new_ones((decoder_attention_mask.shape[0], 1))],
dim=-1,
)
# To support RoPE-DHR's position_ids calculation method
if model_kwargs['past_key_values'] and 'position_ids' in model_kwargs:
new_pos_ids = model_kwargs['position_ids'][:, -1:] + 1
model_kwargs['position_ids'] = new_pos_ids
return model_kwargs
def is_conditioned(self) -> bool:
"""Check whether all decoder layers are already conditioned."""
return all(l.is_conditioned() for l in self._get_decoder_layers())
def clear_conditioned_layers(self):
for layer in self._get_decoder_layers():
layer.condition_vis_x(None)
layer.condition_use_cached_media(False)
layer.condition_media(None, None)