VoRA-7B-Base / vision_embedding.py
Hon-Wong's picture
Upload folder using huggingface_hub
b92bd4e verified
import torch
import torch.nn as nn
from .configuration_vora import VoRAConfig
def _get_1d_sincos_pos_embed_from_grid(
embed_dim: int, pos: torch.Tensor, device: torch.device
) -> torch.Tensor:
omega = torch.arange(embed_dim // 2).float().to(device)
omega /= embed_dim / 2.0
omega = 1.0 / 10000**omega # (D / 2,)
pos = pos.reshape(-1) # (M,)
out = pos[:, None] * omega[None, :] # (M, D / 2), outer product
emb_sin, emb_cos = torch.sin(out).to(device), torch.cos(out).to(device) # (M, D / 2)
emb = torch.cat([emb_sin, emb_cos], dim=1) # (M, D)
return emb
def get_sincos_pos_embed(h: int, w: int, embed_dim: int, device: torch.device) -> torch.Tensor:
assert embed_dim % 2 == 0, embed_dim
grid_h = torch.arange(h).float().to(device)
grid_w = torch.arange(w).float().to(device)
grid = torch.meshgrid(grid_w, grid_h, indexing="xy")
grid = torch.stack(grid, dim=0).to(device)
grid = grid.reshape([2, 1, h, w])
emb_h = _get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0], device)
emb_w = _get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1], device)
pos_embed = torch.cat([emb_h, emb_w], dim=1) # (H * W, D)
return pos_embed
class RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.weight = nn.Parameter(torch.ones(dim))
self.eps = eps
def forward(self, x: torch.Tensor) -> torch.Tensor:
output = self._norm(x.float()).type_as(x)
return output * self.weight
def extra_repr(self) -> str:
return f"{tuple(self.weight.shape)}, eps={self.eps}"
def _norm(self, x: torch.Tensor) -> torch.Tensor:
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
class VisionEmbedding(nn.Module):
def __init__(self,
config: VoRAConfig = None,
hidden_size: int = 4096,
):
super().__init__()
self.patch_size = config.patch_size
self.proj = nn.Conv2d(
3,
hidden_size,
kernel_size=(self.patch_size, self.patch_size),
stride=(self.patch_size, self.patch_size),
bias=True,
)
self.norm = RMSNorm(hidden_size, eps=1e-05)
self.embed_dim = hidden_size
def forward(self, pixel_values: torch.Tensor):
_, _, H, W = pixel_values.shape
tokens = self.norm(self.proj(pixel_values).flatten(2).transpose(1, 2))
pos_embed = get_sincos_pos_embed(
H // self.patch_size, W // self.patch_size, embed_dim=self.embed_dim, device=tokens.device
)
tokens = tokens + pos_embed.to(tokens.device)
return tokens
class AIMv2PatchEmbed(nn.Module):
def __init__(self, config: VoRAConfig):
super().__init__()
self.proj = nn.Conv2d(
3,
config.vision_embedding_intermediate_size,
kernel_size=(config.patch_size, config.patch_size),
stride=(config.patch_size, config.patch_size),
)
self.norm = RMSNorm(config.vision_embedding_intermediate_size, eps=config.rms_norm_eps)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.proj(x).flatten(2).transpose(1, 2)
x = self.norm(x)
return x
class AIMv2ViTPreprocessor(nn.Module):
def __init__(self,
config: VoRAConfig = None,
hidden_size: int = 4096,
):
super().__init__()
num_patches = (config.image_size // config.patch_size) ** 2
self.config = config
self.patchifier = AIMv2PatchEmbed(config)
self.pos_embed = nn.Parameter(torch.zeros((1, num_patches, config.vision_embedding_intermediate_size)))
self.out_proj = nn.Linear(config.vision_embedding_intermediate_size, hidden_size, bias=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
B, C, H, W = x.shape
h_token = H // self.config.patch_size
w_token = W // self.config.patch_size
tokens = self.patchifier(x)
_, N, _ = tokens.shape
pos_embed = self.pos_embed.to(tokens.device)
if N <= pos_embed.size(1):
# 如果 N 小于或等于 num_patches,直接相加
tokens = tokens + pos_embed[:, :N]
else:
# 如果 N 大于 num_patches,使用双线性插值
# 将 pos_embed 调整为 (1, num_patches, hidden_size) 的形状
pos_embed = pos_embed.view(1, int(pos_embed.size(1)**0.5), int(pos_embed.size(1)**0.5), -1).permute(0, 3, 1, 2)
# 使用双线性插值调整大小
pos_embed = F.interpolate(pos_embed, size=(h_token, w_token), mode='bilinear', align_corners=False).permute(0, 2, 3, 1)
# 重塑为 (1, N, hidden_size) 形状
pos_embed = pos_embed.view(1, N, pos_embed.size(-1))
tokens = tokens + pos_embed
return self.out_proj(tokens)
def build_vision_embedding(config: VoRAConfig, hidden_size):
if config.vision_embedding_type == "AIMv2":
return AIMv2ViTPreprocessor(config, hidden_size)
return VisionEmbedding(config, hidden_size)