Spaces:
Running
on
Zero
Running
on
Zero
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from typing import Optional, Tuple | |
| import math | |
| from comfy.ldm.modules.attention import optimized_attention_for_device | |
| def process_qwen2vl_images( | |
| images: torch.Tensor, | |
| min_pixels: int = 3136, | |
| max_pixels: int = 12845056, | |
| patch_size: int = 14, | |
| temporal_patch_size: int = 2, | |
| merge_size: int = 2, | |
| image_mean: list = None, | |
| image_std: list = None, | |
| ): | |
| if image_mean is None: | |
| image_mean = [0.48145466, 0.4578275, 0.40821073] | |
| if image_std is None: | |
| image_std = [0.26862954, 0.26130258, 0.27577711] | |
| batch_size, height, width, channels = images.shape | |
| device = images.device | |
| # dtype = images.dtype | |
| images = images.permute(0, 3, 1, 2) | |
| grid_thw_list = [] | |
| img = images[0] | |
| factor = patch_size * merge_size | |
| h_bar = round(height / factor) * factor | |
| w_bar = round(width / factor) * factor | |
| if h_bar * w_bar > max_pixels: | |
| beta = math.sqrt((height * width) / max_pixels) | |
| h_bar = max(factor, math.floor(height / beta / factor) * factor) | |
| w_bar = max(factor, math.floor(width / beta / factor) * factor) | |
| elif h_bar * w_bar < min_pixels: | |
| beta = math.sqrt(min_pixels / (height * width)) | |
| h_bar = math.ceil(height * beta / factor) * factor | |
| w_bar = math.ceil(width * beta / factor) * factor | |
| img_resized = F.interpolate( | |
| img.unsqueeze(0), | |
| size=(h_bar, w_bar), | |
| mode='bilinear', | |
| align_corners=False | |
| ).squeeze(0) | |
| normalized = img_resized.clone() | |
| for c in range(3): | |
| normalized[c] = (img_resized[c] - image_mean[c]) / image_std[c] | |
| grid_h = h_bar // patch_size | |
| grid_w = w_bar // patch_size | |
| grid_thw = torch.tensor([1, grid_h, grid_w], device=device, dtype=torch.long) | |
| pixel_values = normalized | |
| grid_thw_list.append(grid_thw) | |
| image_grid_thw = torch.stack(grid_thw_list) | |
| grid_t = 1 | |
| channel = pixel_values.shape[0] | |
| pixel_values = pixel_values.unsqueeze(0).repeat(2, 1, 1, 1) | |
| patches = pixel_values.reshape( | |
| grid_t, | |
| temporal_patch_size, | |
| channel, | |
| grid_h // merge_size, | |
| merge_size, | |
| patch_size, | |
| grid_w // merge_size, | |
| merge_size, | |
| patch_size, | |
| ) | |
| patches = patches.permute(0, 3, 6, 4, 7, 2, 1, 5, 8) | |
| flatten_patches = patches.reshape( | |
| grid_t * grid_h * grid_w, | |
| channel * temporal_patch_size * patch_size * patch_size | |
| ) | |
| return flatten_patches, image_grid_thw | |
| class VisionPatchEmbed(nn.Module): | |
| def __init__( | |
| self, | |
| patch_size: int = 14, | |
| temporal_patch_size: int = 2, | |
| in_channels: int = 3, | |
| embed_dim: int = 3584, | |
| device=None, | |
| dtype=None, | |
| ops=None, | |
| ): | |
| super().__init__() | |
| self.patch_size = patch_size | |
| self.temporal_patch_size = temporal_patch_size | |
| self.in_channels = in_channels | |
| self.embed_dim = embed_dim | |
| kernel_size = [temporal_patch_size, patch_size, patch_size] | |
| self.proj = ops.Conv3d( | |
| in_channels, | |
| embed_dim, | |
| kernel_size=kernel_size, | |
| stride=kernel_size, | |
| bias=False, | |
| device=device, | |
| dtype=dtype | |
| ) | |
| def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: | |
| hidden_states = hidden_states.view( | |
| -1, self.in_channels, self.temporal_patch_size, self.patch_size, self.patch_size | |
| ) | |
| hidden_states = self.proj(hidden_states) | |
| return hidden_states.view(-1, self.embed_dim) | |
| def rotate_half(x): | |
| x1 = x[..., : x.shape[-1] // 2] | |
| x2 = x[..., x.shape[-1] // 2 :] | |
| return torch.cat((-x2, x1), dim=-1) | |
| def apply_rotary_pos_emb_vision(q, k, cos, sin): | |
| cos, sin = cos.unsqueeze(-2).float(), sin.unsqueeze(-2).float() | |
| q_embed = (q * cos) + (rotate_half(q) * sin) | |
| k_embed = (k * cos) + (rotate_half(k) * sin) | |
| return q_embed, k_embed | |
| class VisionRotaryEmbedding(nn.Module): | |
| def __init__(self, dim: int, theta: float = 10000.0): | |
| super().__init__() | |
| self.dim = dim | |
| self.theta = theta | |
| def forward(self, seqlen: int, device) -> torch.Tensor: | |
| inv_freq = 1.0 / (self.theta ** (torch.arange(0, self.dim, 2, dtype=torch.float, device=device) / self.dim)) | |
| seq = torch.arange(seqlen, device=inv_freq.device, dtype=inv_freq.dtype) | |
| freqs = torch.outer(seq, inv_freq) | |
| return freqs | |
| class PatchMerger(nn.Module): | |
| def __init__(self, dim: int, context_dim: int, spatial_merge_size: int = 2, device=None, dtype=None, ops=None): | |
| super().__init__() | |
| self.hidden_size = context_dim * (spatial_merge_size ** 2) | |
| self.ln_q = ops.RMSNorm(context_dim, eps=1e-6, device=device, dtype=dtype) | |
| self.mlp = nn.Sequential( | |
| ops.Linear(self.hidden_size, self.hidden_size, device=device, dtype=dtype), | |
| nn.GELU(), | |
| ops.Linear(self.hidden_size, dim, device=device, dtype=dtype), | |
| ) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| x = self.ln_q(x).reshape(-1, self.hidden_size) | |
| x = self.mlp(x) | |
| return x | |
| class VisionAttention(nn.Module): | |
| def __init__(self, hidden_size: int, num_heads: int, device=None, dtype=None, ops=None): | |
| super().__init__() | |
| self.hidden_size = hidden_size | |
| self.num_heads = num_heads | |
| self.head_dim = hidden_size // num_heads | |
| self.scaling = self.head_dim ** -0.5 | |
| self.qkv = ops.Linear(hidden_size, hidden_size * 3, bias=True, device=device, dtype=dtype) | |
| self.proj = ops.Linear(hidden_size, hidden_size, bias=True, device=device, dtype=dtype) | |
| def forward( | |
| self, | |
| hidden_states: torch.Tensor, | |
| position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, | |
| cu_seqlens=None, | |
| optimized_attention=None, | |
| ) -> torch.Tensor: | |
| if hidden_states.dim() == 2: | |
| seq_length, _ = hidden_states.shape | |
| batch_size = 1 | |
| hidden_states = hidden_states.unsqueeze(0) | |
| else: | |
| batch_size, seq_length, _ = hidden_states.shape | |
| qkv = self.qkv(hidden_states) | |
| qkv = qkv.reshape(batch_size, seq_length, 3, self.num_heads, self.head_dim) | |
| query_states, key_states, value_states = qkv.reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0) | |
| if position_embeddings is not None: | |
| cos, sin = position_embeddings | |
| query_states, key_states = apply_rotary_pos_emb_vision(query_states, key_states, cos, sin) | |
| query_states = query_states.transpose(0, 1).unsqueeze(0) | |
| key_states = key_states.transpose(0, 1).unsqueeze(0) | |
| value_states = value_states.transpose(0, 1).unsqueeze(0) | |
| lengths = cu_seqlens[1:] - cu_seqlens[:-1] | |
| splits = [ | |
| torch.split(tensor, lengths.tolist(), dim=2) for tensor in (query_states, key_states, value_states) | |
| ] | |
| attn_outputs = [ | |
| optimized_attention(q, k, v, self.num_heads, skip_reshape=True) | |
| for q, k, v in zip(*splits) | |
| ] | |
| attn_output = torch.cat(attn_outputs, dim=1) | |
| attn_output = attn_output.reshape(seq_length, -1) | |
| attn_output = self.proj(attn_output) | |
| return attn_output | |
| class VisionMLP(nn.Module): | |
| def __init__(self, hidden_size: int, intermediate_size: int, device=None, dtype=None, ops=None): | |
| super().__init__() | |
| self.gate_proj = ops.Linear(hidden_size, intermediate_size, bias=True, device=device, dtype=dtype) | |
| self.up_proj = ops.Linear(hidden_size, intermediate_size, bias=True, device=device, dtype=dtype) | |
| self.down_proj = ops.Linear(intermediate_size, hidden_size, bias=True, device=device, dtype=dtype) | |
| self.act_fn = nn.SiLU() | |
| def forward(self, hidden_state): | |
| return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state)) | |
| class VisionBlock(nn.Module): | |
| def __init__(self, hidden_size: int, intermediate_size: int, num_heads: int, device=None, dtype=None, ops=None): | |
| super().__init__() | |
| self.norm1 = ops.RMSNorm(hidden_size, eps=1e-6, device=device, dtype=dtype) | |
| self.norm2 = ops.RMSNorm(hidden_size, eps=1e-6, device=device, dtype=dtype) | |
| self.attn = VisionAttention(hidden_size, num_heads, device=device, dtype=dtype, ops=ops) | |
| self.mlp = VisionMLP(hidden_size, intermediate_size, device=device, dtype=dtype, ops=ops) | |
| def forward( | |
| self, | |
| hidden_states: torch.Tensor, | |
| position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, | |
| cu_seqlens=None, | |
| optimized_attention=None, | |
| ) -> torch.Tensor: | |
| residual = hidden_states | |
| hidden_states = self.norm1(hidden_states) | |
| hidden_states = self.attn(hidden_states, position_embeddings, cu_seqlens, optimized_attention) | |
| hidden_states = residual + hidden_states | |
| residual = hidden_states | |
| hidden_states = self.norm2(hidden_states) | |
| hidden_states = self.mlp(hidden_states) | |
| hidden_states = residual + hidden_states | |
| return hidden_states | |
| class Qwen2VLVisionTransformer(nn.Module): | |
| def __init__( | |
| self, | |
| hidden_size: int = 3584, | |
| output_hidden_size: int = 3584, | |
| intermediate_size: int = 3420, | |
| num_heads: int = 16, | |
| num_layers: int = 32, | |
| patch_size: int = 14, | |
| temporal_patch_size: int = 2, | |
| spatial_merge_size: int = 2, | |
| window_size: int = 112, | |
| device=None, | |
| dtype=None, | |
| ops=None | |
| ): | |
| super().__init__() | |
| self.hidden_size = hidden_size | |
| self.patch_size = patch_size | |
| self.spatial_merge_size = spatial_merge_size | |
| self.window_size = window_size | |
| self.fullatt_block_indexes = [7, 15, 23, 31] | |
| self.patch_embed = VisionPatchEmbed( | |
| patch_size=patch_size, | |
| temporal_patch_size=temporal_patch_size, | |
| in_channels=3, | |
| embed_dim=hidden_size, | |
| device=device, | |
| dtype=dtype, | |
| ops=ops, | |
| ) | |
| head_dim = hidden_size // num_heads | |
| self.rotary_pos_emb = VisionRotaryEmbedding(head_dim // 2) | |
| self.blocks = nn.ModuleList([ | |
| VisionBlock(hidden_size, intermediate_size, num_heads, device, dtype, ops) | |
| for _ in range(num_layers) | |
| ]) | |
| self.merger = PatchMerger( | |
| dim=output_hidden_size, | |
| context_dim=hidden_size, | |
| spatial_merge_size=spatial_merge_size, | |
| device=device, | |
| dtype=dtype, | |
| ops=ops, | |
| ) | |
| def get_window_index(self, grid_thw): | |
| window_index = [] | |
| cu_window_seqlens = [0] | |
| window_index_id = 0 | |
| vit_merger_window_size = self.window_size // self.spatial_merge_size // self.patch_size | |
| for grid_t, grid_h, grid_w in grid_thw: | |
| llm_grid_h = grid_h // self.spatial_merge_size | |
| llm_grid_w = grid_w // self.spatial_merge_size | |
| index = torch.arange(grid_t * llm_grid_h * llm_grid_w).reshape(grid_t, llm_grid_h, llm_grid_w) | |
| pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size | |
| pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size | |
| num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size | |
| num_windows_w = (llm_grid_w + pad_w) // vit_merger_window_size | |
| index_padded = F.pad(index, (0, pad_w, 0, pad_h), "constant", -100) | |
| index_padded = index_padded.reshape( | |
| grid_t, | |
| num_windows_h, | |
| vit_merger_window_size, | |
| num_windows_w, | |
| vit_merger_window_size, | |
| ) | |
| index_padded = index_padded.permute(0, 1, 3, 2, 4).reshape( | |
| grid_t, | |
| num_windows_h * num_windows_w, | |
| vit_merger_window_size, | |
| vit_merger_window_size, | |
| ) | |
| seqlens = (index_padded != -100).sum([2, 3]).reshape(-1) | |
| index_padded = index_padded.reshape(-1) | |
| index_new = index_padded[index_padded != -100] | |
| window_index.append(index_new + window_index_id) | |
| cu_seqlens_tmp = seqlens.cumsum(0) * self.spatial_merge_size * self.spatial_merge_size + cu_window_seqlens[-1] | |
| cu_window_seqlens.extend(cu_seqlens_tmp.tolist()) | |
| window_index_id += (grid_t * llm_grid_h * llm_grid_w).item() | |
| window_index = torch.cat(window_index, dim=0) | |
| return window_index, cu_window_seqlens | |
| def get_position_embeddings(self, grid_thw, device): | |
| pos_ids = [] | |
| for t, h, w in grid_thw: | |
| hpos_ids = torch.arange(h, device=device).unsqueeze(1).expand(-1, w) | |
| hpos_ids = hpos_ids.reshape( | |
| h // self.spatial_merge_size, | |
| self.spatial_merge_size, | |
| w // self.spatial_merge_size, | |
| self.spatial_merge_size, | |
| ) | |
| hpos_ids = hpos_ids.permute(0, 2, 1, 3).flatten() | |
| wpos_ids = torch.arange(w, device=device).unsqueeze(0).expand(h, -1) | |
| wpos_ids = wpos_ids.reshape( | |
| h // self.spatial_merge_size, | |
| self.spatial_merge_size, | |
| w // self.spatial_merge_size, | |
| self.spatial_merge_size, | |
| ) | |
| wpos_ids = wpos_ids.permute(0, 2, 1, 3).flatten() | |
| pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) | |
| pos_ids = torch.cat(pos_ids, dim=0) | |
| max_grid_size = grid_thw[:, 1:].max() | |
| rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size, device) | |
| return rotary_pos_emb_full[pos_ids].flatten(1) | |
| def forward( | |
| self, | |
| pixel_values: torch.Tensor, | |
| image_grid_thw: Optional[torch.Tensor] = None, | |
| ) -> torch.Tensor: | |
| optimized_attention = optimized_attention_for_device(pixel_values.device, mask=False, small_input=True) | |
| hidden_states = self.patch_embed(pixel_values) | |
| window_index, cu_window_seqlens = self.get_window_index(image_grid_thw) | |
| cu_window_seqlens = torch.tensor(cu_window_seqlens, device=hidden_states.device) | |
| cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens) | |
| position_embeddings = self.get_position_embeddings(image_grid_thw, hidden_states.device) | |
| seq_len, _ = hidden_states.size() | |
| spatial_merge_unit = self.spatial_merge_size * self.spatial_merge_size | |
| hidden_states = hidden_states.reshape(seq_len // spatial_merge_unit, spatial_merge_unit, -1) | |
| hidden_states = hidden_states[window_index, :, :] | |
| hidden_states = hidden_states.reshape(seq_len, -1) | |
| position_embeddings = position_embeddings.reshape(seq_len // spatial_merge_unit, spatial_merge_unit, -1) | |
| position_embeddings = position_embeddings[window_index, :, :] | |
| position_embeddings = position_embeddings.reshape(seq_len, -1) | |
| position_embeddings = torch.cat((position_embeddings, position_embeddings), dim=-1) | |
| position_embeddings = (position_embeddings.cos(), position_embeddings.sin()) | |
| cu_seqlens = torch.repeat_interleave(image_grid_thw[:, 1] * image_grid_thw[:, 2], image_grid_thw[:, 0]).cumsum( | |
| dim=0, | |
| dtype=torch.int32, | |
| ) | |
| cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) | |
| for i, block in enumerate(self.blocks): | |
| if i in self.fullatt_block_indexes: | |
| cu_seqlens_now = cu_seqlens | |
| else: | |
| cu_seqlens_now = cu_window_seqlens | |
| hidden_states = block(hidden_states, position_embeddings, cu_seqlens_now, optimized_attention=optimized_attention) | |
| hidden_states = self.merger(hidden_states) | |
| return hidden_states | |