midashenglm-7b / modeling_midashenglm.py
zhoukz's picture
Upload folder using huggingface_hub
0939826
raw
history blame
20.3 kB
import collections
import collections.abc
from dataclasses import dataclass
from typing import Any, Callable, Iterable, Optional, Sequence, Tuple, Union, cast
import torch
import torch.nn as nn
import torchaudio.transforms as audio_transforms
from torch import Tensor
from transformers import GenerationMixin, PreTrainedModel
from transformers.cache_utils import Cache
from transformers.modeling_outputs import BaseModelOutputWithPast, ModelOutput
from transformers.models.qwen2_5_omni.configuration_qwen2_5_omni import (
Qwen2_5OmniTextConfig,
)
from transformers.models.qwen2_5_omni.modeling_qwen2_5_omni import (
Qwen2_5OmniThinkerTextModel,
)
from .configuration_midashenglm import DashengConfig, MiDashengLMConfig
_Tuple2 = Union[int, Tuple[int, int], Sequence[int]]
def _resolve_tuple2(x: _Tuple2) -> Tuple[int, int]:
if isinstance(x, collections.abc.Sequence):
assert len(x) == 2, (
f"Expected a sequence of length 2, got {x} with length {len(x)}"
)
return cast(Tuple[int, int], tuple(x))
return (x, x)
class AudioPatchEmbed(nn.Module):
def __init__(
self,
input_size: _Tuple2 = 64,
patch_size: _Tuple2 = 16,
patch_stride: _Tuple2 = 16,
in_chans: int = 1,
embed_dim: int = 768,
norm_layer: Optional[Callable] = None,
flatten: bool = False,
):
super().__init__()
self.input_size = _resolve_tuple2(input_size)
self.patch_size = _resolve_tuple2(patch_size)
self.patch_stride = _resolve_tuple2(patch_stride)
self.grid_size = (
self.input_size[0] // self.patch_stride[0],
self.input_size[1] // self.patch_stride[1],
)
self.num_patches = self.grid_size[0] * self.grid_size[1]
self.flatten = flatten
self.proj = nn.Conv2d(
in_chans,
embed_dim,
kernel_size=self.patch_size,
stride=self.patch_stride,
)
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
def forward(self, x):
x = self.proj(x)
if self.flatten:
x = torch.permute(
torch.flatten(x, 2, 3), (0, 2, 1)
) # rearrange(x, "b c f t -> b (f t) c")
x = self.norm(x)
return x
class LayerScale(nn.Module):
def __init__(self, dim, init_values=1e-5, inplace=False):
super().__init__()
self.inplace = inplace
self.gamma = nn.Parameter(init_values * torch.ones(dim))
def forward(self, x):
return x.mul_(self.gamma) if self.inplace else x * self.gamma
class DashengMlp(nn.Module):
def __init__(
self,
in_features: int,
hidden_features: Optional[int] = None,
out_features: Optional[int] = None,
drop: float = 0.0,
):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = nn.GELU()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
class DashengAttention(nn.Module):
def __init__(
self,
dim: int,
num_heads: int = 8,
qkv_bias: bool = False,
attn_drop: float = 0.0,
proj_drop: float = 0.0,
causal: bool = False,
):
super().__init__()
assert dim % num_heads == 0, "dim should be divisible by num_heads"
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = head_dim**-0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
self.causal = causal
def forward(self, x, mask: Optional[torch.Tensor] = None):
B, N, C = x.shape
qkv = (
self.qkv(x)
.reshape(B, N, 3, self.num_heads, C // self.num_heads)
.permute(2, 0, 3, 1, 4)
)
q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
attn = (q @ k.transpose(-2, -1)) * self.scale
# if mask is not None:
# # Mask is a tensor of shape [B, T, T]
# # Different from self.causal == True, the mask might be something like:
# # [False, False, True]
# # [False, False, True]
# # [True, True, True]
# # We use -inf to pad here, since if we would pad by any number, the entries at rows only containing
# # [True, True, True] would lead to weights such as: [0.33,0.33,0.33], which is not correct
if self.causal:
mask_value = -torch.finfo(attn.dtype).max
i, j = attn.shape[-2:]
mask = torch.ones(i, j, device=q.device, dtype=torch.bool).triu(j - i + 1)
attn = attn.masked_fill(mask, mask_value)
if mask is not None:
# mask value as the lowest possible value in fp32
mask_value = torch.finfo(attn.dtype).min
# Mask is of shape [1, SRC_LEN]
attn_mask = mask[:, None, None, :].expand(B, 1, N, N)
# Mask should be of shape
# [B,1,Target_len, Source_len]
attn = attn.masked_fill(attn_mask, mask_value)
attn = attn.softmax(dim=-1)
attn = torch.nan_to_num(attn)
# Only for the case that a mask with all True entries on a row is passed.
# attn = torch.nan_to_num(attn)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class DashengBlock(nn.Module):
def __init__(
self,
dim: int,
num_heads: int,
mlp_ratio: float = 4.0,
qkv_bias: bool = False,
drop: float = 0.0,
attn_drop: float = 0.0,
init_values: Optional[float] = None,
):
super().__init__()
self.norm1 = nn.LayerNorm(dim, eps=1e-6)
self.attn = DashengAttention(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
attn_drop=attn_drop,
proj_drop=drop,
)
self.ls1 = (
LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
)
self.norm2 = nn.LayerNorm(dim, eps=1e-6)
self.mlp = DashengMlp(
in_features=dim,
hidden_features=int(dim * mlp_ratio),
drop=drop,
)
self.ls2 = (
LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
)
# Kwargs usually has a mask parameter that is passed to Attention
def forward(self, x, **kwargs):
x = x + self.ls1(self.attn(self.norm1(x), **kwargs))
x = x + self.ls2(self.mlp(self.norm2(x)))
return x
class DashengAudioTransformer(PreTrainedModel):
config_class = DashengConfig
def __init__(self, config: DashengConfig):
super().__init__(config)
self.target_length = config.target_length
self.embed_dim = config.embed_dim
self.hop_length = config.hop_length
self.front_end = nn.Sequential(
audio_transforms.MelSpectrogram(
f_min=config.f_min,
f_max=config.f_max,
center=config.center,
win_length=config.win_length,
hop_length=config.hop_length,
sample_rate=config.sample_rate,
n_fft=config.n_fft,
n_mels=config.n_mels,
),
audio_transforms.AmplitudeToDB(top_db=120),
)
self.init_bn = nn.BatchNorm2d(config.n_mels, momentum=0.01)
self.patch_embed = AudioPatchEmbed(
input_size=(config.n_mels, config.target_length),
embed_dim=config.embed_dim,
in_chans=config.input_channels,
patch_size=config.patch_size,
flatten=False,
patch_stride=config.patch_stride,
)
self.time_pos_embed = nn.Parameter(
torch.randn(1, config.embed_dim, 1, self.patch_embed.grid_size[1]) * 0.02
)
self.freq_pos_embed = nn.Parameter(
torch.randn(1, config.embed_dim, self.patch_embed.grid_size[0], 1) * 0.02
)
self.pos_drop = nn.Dropout(p=config.drop_rate)
self.blocks = nn.ModuleList(
DashengBlock(
dim=config.embed_dim,
num_heads=config.num_heads,
mlp_ratio=config.mlp_ratio,
qkv_bias=config.qkv_bias,
init_values=config.init_values,
drop=config.drop_rate,
attn_drop=config.attn_drop_rate,
)
for i in range(config.depth)
)
self.norm = nn.LayerNorm(config.embed_dim, eps=1e-6)
self.post_init()
def forward_features(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
t = x.shape[-1]
x = x + self.time_pos_embed[:, :, :, :t]
x = (
x + self.freq_pos_embed[:, :, :, :]
) # Just to support __getitem__ in posembed
x = torch.permute(
torch.flatten(x, 2, 3), (0, 2, 1)
) # rearrange(x, "b c f t -> b (f t) c")
x = self.pos_drop(x)
for block in self.blocks:
x = block(x, **kwargs)
x = self.norm(x)
return x
def _to_mask(self, lengths: torch.Tensor, max_length: int) -> torch.Tensor:
batch_size = len(lengths)
idx = torch.arange(max_length, device=lengths.device)
idx = idx.repeat(batch_size).view(batch_size, max_length)
mask = (idx < lengths.unsqueeze(-1)).bool()
return mask
def forward(
self,
x: torch.Tensor,
x_length: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
x = self.front_end(x)
target_length_in_patches = self.target_length // 4
x = x.unsqueeze(1)
x = torch.permute(x, (0, 2, 1, 3))
x = self.init_bn(x)
x = torch.permute(x, (0, 2, 1, 3))
x = self.patch_embed(x)
t = x.shape[-1]
input_splits = x.split(target_length_in_patches, dim=-1)
if x_length is not None:
assert len(x_length) == len(x), (
"batchsizes of input x and x_length need to be same"
)
assert x_length.ndim == 1, "Lengths are of size (B,)"
scaled_lengths = (x_length / (self.hop_length * 4)).long()
mask = self._to_mask(max_length=t, lengths=scaled_lengths)
split_masks = mask.logical_not().split(target_length_in_patches, dim=-1)
else:
mask = None
split_masks = [None] * len(input_splits)
outputs = []
for split_x, split_mask in zip(input_splits, split_masks):
forward_kwargs = {}
forward_kwargs["mask"] = split_mask
split_x = self.forward_features(split_x, **forward_kwargs)
outputs.append(split_x)
x = torch.cat(outputs, dim=1)
return x, mask
class AudioProjectorSubsample(nn.Module):
def __init__(self, in_dim: int, out_dim: int, downsample_rate=5):
super().__init__()
self.k = downsample_rate
self.net = nn.Sequential(
nn.Linear(in_dim * self.k, out_dim),
nn.GELU(),
nn.Linear(out_dim, out_dim),
)
def forward(self, x, mask=None):
batch_size, seq_len, dim = x.shape
num_frames_to_discard = seq_len % self.k
if num_frames_to_discard > 0:
x = x[:, :-num_frames_to_discard, :]
if mask is not None:
mask = mask[:, :-num_frames_to_discard]
if mask is None:
mask = torch.ones(x.shape[:-1], dtype=torch.long, device=x.device)
x = x.reshape(
batch_size, -1, self.k * dim
) # rearrange(x, "b (s k) d -> b s (k d)", k=self.k)
x = self.net(x)
mask = mask.reshape(
batch_size, -1, self.k
) # rearrange(mask, "b (s k) -> b s k", k=self.k)
mask = mask.any(dim=-1).long()
return x, mask
@dataclass
class Qwen25OmniTextModelOutput(ModelOutput):
logits: Optional[torch.FloatTensor] = None
past_key_values: Optional[Cache] = None
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
class Qwen25OmniThinkerTextOnlyDecoder(PreTrainedModel, GenerationMixin):
config_class = Qwen2_5OmniTextConfig
_supports_flash_attn_2 = Qwen2_5OmniThinkerTextModel._supports_flash_attn_2
_supports_sdpa = Qwen2_5OmniThinkerTextModel._supports_sdpa
_supports_flex_attn = Qwen2_5OmniThinkerTextModel._supports_flex_attn
_supports_cache_class = Qwen2_5OmniThinkerTextModel._supports_cache_class
_supports_static_cache = Qwen2_5OmniThinkerTextModel._supports_static_cache
_supports_quantized_cache = Qwen2_5OmniThinkerTextModel._supports_quantized_cache
def __init__(self, config: Qwen2_5OmniTextConfig):
super().__init__(config)
self.model = Qwen2_5OmniThinkerTextModel._from_config(config)
self.lm_head = nn.Linear(
config.hidden_size,
config.vocab_size,
bias=False,
)
self.post_init()
def forward(
self,
attention_mask: Optional[Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
return_dict: Optional[bool] = None,
**kwargs: Any,
) -> Union[Tuple, Qwen25OmniTextModelOutput]:
if attention_mask is not None and position_ids is None:
position_ids = (
attention_mask.long()
.cumsum(dim=-1)
.masked_fill_(attention_mask == 0, 1)
- 1
)
outputs: BaseModelOutputWithPast = self.model(
attention_mask=attention_mask,
position_ids=position_ids,
return_dict=True,
**kwargs,
)
hidden_states = outputs.last_hidden_state
logits = self.lm_head(hidden_states)
if not return_dict:
return tuple(
v
for v in [
logits,
outputs.last_hidden_state,
outputs.past_key_values,
outputs.hidden_states,
outputs.attentions,
]
if v is not None
)
return Qwen25OmniTextModelOutput(
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
class MiDashengLMModel(PreTrainedModel):
config_class = MiDashengLMConfig
_supports_flash_attn_2 = Qwen2_5OmniThinkerTextModel._supports_flash_attn_2
_supports_sdpa = Qwen2_5OmniThinkerTextModel._supports_sdpa
_supports_flex_attn = Qwen2_5OmniThinkerTextModel._supports_flex_attn
_supports_cache_class = Qwen2_5OmniThinkerTextModel._supports_cache_class
_supports_static_cache = Qwen2_5OmniThinkerTextModel._supports_static_cache
_supports_quantized_cache = Qwen2_5OmniThinkerTextModel._supports_quantized_cache
def __init__(self, config: MiDashengLMConfig):
super().__init__(config)
self.audio_encoder = DashengAudioTransformer._from_config(
config.audio_encoder_config
)
self.audio_projector = AudioProjectorSubsample(
self.audio_encoder.embed_dim,
config.text_config.hidden_size,
config.subsample_factor,
)
self.decoder = Qwen25OmniThinkerTextOnlyDecoder._from_config(
config.text_config,
attn_implementation=config._attn_implementation,
)
self.post_init()
def _forward_audio_encoder(
self,
audios: torch.Tensor,
audio_length: Optional[Iterable[int]],
) -> torch.Tensor:
encoder_out, encoder_atts = self.audio_encoder(audios, audio_length)
# audio projector
encoder_out, encoder_atts = self.audio_projector(encoder_out, encoder_atts)
return encoder_out
def _prepare_inputs_embeds(
self,
input_ids: Optional[torch.Tensor],
input_values: Optional[torch.Tensor],
inputs_embeds: Optional[torch.Tensor],
audio_length: Optional[Iterable[int]] = None,
audio_token_id: Optional[int] = None,
) -> torch.Tensor:
if input_ids is not None:
if inputs_embeds is not None:
raise ValueError(
"Both `inputs_embeds` and `input_ids` are passed. Please pass only one of them."
)
inputs_embeds = cast(
torch.Tensor, self.decoder.model.embed_tokens(input_ids)
)
if input_values is not None:
if audio_token_id is None:
raise ValueError(
"If `input_values` is provided, `audio_token_id` must also be provided."
)
audio_embeddings = self._forward_audio_encoder(
input_values,
audio_length=audio_length,
).to(inputs_embeds.dtype)
audio_mask = (input_ids == audio_token_id).flatten()
diff = torch.diff(
audio_mask.long(),
prepend=torch.zeros(
(1,),
dtype=torch.long,
device=audio_mask.device,
),
)
audio_span_starts = (diff == 1).nonzero()
audio_span_ends = (diff == -1).nonzero()
embeds_view = inputs_embeds.view(-1, inputs_embeds.shape[-1])
for span_start, span_end, audio in zip(
audio_span_starts,
audio_span_ends,
audio_embeddings,
strict=True,
):
embeds_view[span_start:span_end] = audio[: span_end - span_start]
else:
if inputs_embeds is None:
raise ValueError(
"Either `input_ids` or `inputs_embeds` must be passed."
)
if input_values is not None:
raise ValueError(
"Cannot pass `input_values` when `inputs_embeds` is provided."
)
return inputs_embeds
def forward(
self,
input_ids: Optional[Tensor] = None,
input_values: Optional[Tensor] = None,
inputs_embeds: Optional[Tensor] = None,
audio_length: Optional[Iterable[int]] = None,
audio_token_id: Optional[int] = None,
**kwargs: Any,
):
inputs_embeds = self._prepare_inputs_embeds(
input_ids=input_ids,
input_values=input_values,
inputs_embeds=inputs_embeds,
audio_length=audio_length,
audio_token_id=audio_token_id,
)
return self.decoder(
input_ids=None,
inputs_embeds=inputs_embeds,
**kwargs,
)
def generate(
self,
input_ids: Optional[Tensor] = None,
input_values: Optional[Tensor] = None,
inputs_embeds: Optional[Tensor] = None,
audio_length: Optional[Iterable[int]] = None,
audio_token_id: Optional[int] = None,
**kwargs,
):
inputs_embeds = self._prepare_inputs_embeds(
input_ids=input_ids,
input_values=input_values,
inputs_embeds=inputs_embeds,
audio_length=audio_length,
audio_token_id=audio_token_id,
)
return self.decoder.generate(
inputs_embeds=inputs_embeds,
generation_config=kwargs.pop("generation_config", self.generation_config),
**kwargs,
)