|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import math |
|
from math import pi |
|
from typing import Optional, Any, Union, Tuple |
|
import torch |
|
from torch import nn |
|
|
|
from einops import rearrange, repeat |
|
from functools import lru_cache |
|
|
|
|
|
|
|
|
|
|
|
def find_correction_factor(num_rotations, dim, base=10000, max_position_embeddings=2048): |
|
return (dim * math.log(max_position_embeddings/(num_rotations * 2 * math.pi)))/(2 * math.log(base)) |
|
|
|
def find_correction_range(low_rot, high_rot, dim, base=10000, max_position_embeddings=2048): |
|
low = math.floor(find_correction_factor(low_rot, dim, base, max_position_embeddings)) |
|
high = math.ceil(find_correction_factor(high_rot, dim, base, max_position_embeddings)) |
|
return max(low, 0), min(high, dim-1) |
|
|
|
def linear_ramp_mask(min, max, dim): |
|
if min == max: |
|
max += 0.001 |
|
|
|
linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min) |
|
ramp_func = torch.clamp(linear_func, 0, 1) |
|
return ramp_func |
|
|
|
def find_newbase_ntk(dim, base=10000, scale=1): |
|
|
|
return base * scale ** (dim / (dim-2)) |
|
|
|
def get_mscale(scale=torch.Tensor): |
|
|
|
|
|
|
|
return torch.where(scale <= 1., torch.tensor(1.0), 0.1 * torch.log(scale) + 1.0) |
|
|
|
def get_proportion(L_test, L_train): |
|
L_test = L_test * 2 |
|
return torch.where(torch.tensor(L_test/L_train) <= 1., torch.tensor(1.0), torch.sqrt(torch.log(torch.tensor(L_test))/torch.log(torch.tensor(L_train)))) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def rotate_half(x): |
|
x = rearrange(x, '... (d r) -> ... d r', r = 2) |
|
x1, x2 = x.unbind(dim = -1) |
|
x = torch.stack((-x2, x1), dim = -1) |
|
return rearrange(x, '... d r -> ... (d r)') |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class VisionRotaryEmbedding(nn.Module): |
|
def __init__( |
|
self, |
|
head_dim: int, |
|
custom_freqs: str = 'normal', |
|
theta: int = 10000, |
|
online_rope: bool = False, |
|
max_cached_len: int = 1024, |
|
max_pe_len_h: Optional[int] = None, |
|
max_pe_len_w: Optional[int] = None, |
|
decouple: bool = False, |
|
ori_max_pe_len: Optional[int] = None, |
|
): |
|
super().__init__() |
|
|
|
dim = head_dim // 2 |
|
assert dim % 2 == 0 |
|
self.dim = dim |
|
self.custom_freqs = custom_freqs.lower() |
|
self.theta = theta |
|
self.decouple = decouple |
|
self.ori_max_pe_len = ori_max_pe_len |
|
|
|
self.custom_freqs = custom_freqs.lower() |
|
if not online_rope: |
|
if self.custom_freqs in ['normal', 'scale1', 'scale2']: |
|
freqs_h = 1. / (theta ** (torch.arange(0, dim, 2).float() / dim)) |
|
freqs_w = 1. / (theta ** (torch.arange(0, dim, 2).float() / dim)) |
|
else: |
|
if decouple: |
|
freqs_h = self.get_1d_rope_freqs(theta, dim, max_pe_len_h, ori_max_pe_len) |
|
freqs_w = self.get_1d_rope_freqs(theta, dim, max_pe_len_w, ori_max_pe_len) |
|
else: |
|
max_pe_len = max(max_pe_len_h, max_pe_len_w) |
|
freqs_h = self.get_1d_rope_freqs(theta, dim, max_pe_len, ori_max_pe_len) |
|
freqs_w = self.get_1d_rope_freqs(theta, dim, max_pe_len, ori_max_pe_len) |
|
|
|
self.register_buffer('freqs_h', freqs_h, persistent=False) |
|
self.register_buffer('freqs_w', freqs_w, persistent=False) |
|
|
|
if max_pe_len_h != None and max_pe_len_w != None and ori_max_pe_len != None: |
|
attn_factor = 1.0 |
|
scale = torch.clamp_min(torch.tensor(max(max_pe_len_h, max_pe_len_w)) / ori_max_pe_len, 1.0) |
|
self.mscale = get_mscale(scale).to(scale) * attn_factor |
|
self.proportion1 = get_proportion(max(max_pe_len_h, max_pe_len_w), ori_max_pe_len) |
|
self.proportion2 = get_proportion(max_pe_len_h * max_pe_len_w, ori_max_pe_len ** 2) |
|
|
|
|
|
freqs_h_cached = torch.einsum('..., f -> ... f', torch.arange(max_cached_len), self.freqs_h) |
|
freqs_h_cached = repeat(freqs_h_cached, '... n -> ... (n r)', r = 2) |
|
self.register_buffer('freqs_h_cached', freqs_h_cached, persistent=False) |
|
freqs_w_cached = torch.einsum('..., f -> ... f', torch.arange(max_cached_len), self.freqs_w) |
|
freqs_w_cached = repeat(freqs_w_cached, '... n -> ... (n r)', r = 2) |
|
self.register_buffer('freqs_w_cached', freqs_w_cached, persistent=False) |
|
|
|
|
|
def get_1d_rope_freqs(self, theta, dim, max_pe_len, ori_max_pe_len): |
|
|
|
assert isinstance(ori_max_pe_len, int) |
|
|
|
if not isinstance(max_pe_len, torch.Tensor): |
|
max_pe_len = torch.tensor(max_pe_len) |
|
scale = torch.clamp_min(max_pe_len / ori_max_pe_len, 1.0) |
|
|
|
if self.custom_freqs == 'linear': |
|
freqs = 1. / torch.einsum('..., f -> ... f', scale, theta ** (torch.arange(0, dim, 2).float() / dim)) |
|
elif self.custom_freqs == 'ntk-aware' or self.custom_freqs == 'ntk-aware-pro1' or self.custom_freqs == 'ntk-aware-pro2': |
|
freqs = 1. / torch.pow( |
|
find_newbase_ntk(dim, theta, scale).view(-1, 1), |
|
(torch.arange(0, dim, 2).to(scale).float() / dim) |
|
).squeeze() |
|
elif self.custom_freqs == 'ntk-by-parts': |
|
|
|
|
|
beta_0 = 1.25 |
|
beta_1 = 0.75 |
|
gamma_0 = 16 |
|
gamma_1 = 2 |
|
ntk_factor = 1 |
|
extrapolation_factor = 1 |
|
|
|
|
|
freqs_base = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim)) |
|
freqs_linear = 1.0 / torch.einsum('..., f -> ... f', scale, (theta ** (torch.arange(0, dim, 2).to(scale).float() / dim))) |
|
freqs_ntk = 1. / torch.pow( |
|
find_newbase_ntk(dim, theta, scale).view(-1, 1), |
|
(torch.arange(0, dim, 2).to(scale).float() / dim) |
|
).squeeze() |
|
|
|
|
|
low, high = find_correction_range(beta_0, beta_1, dim, theta, ori_max_pe_len) |
|
freqs_mask = (1 - linear_ramp_mask(low, high, dim // 2).to(scale)) * ntk_factor |
|
freqs = freqs_linear * (1 - freqs_mask) + freqs_ntk * freqs_mask |
|
|
|
|
|
low, high = find_correction_range(gamma_0, gamma_1, dim, theta, ori_max_pe_len) |
|
freqs_mask = (1 - linear_ramp_mask(low, high, dim // 2).to(scale)) * extrapolation_factor |
|
freqs = freqs * (1 - freqs_mask) + freqs_base * freqs_mask |
|
|
|
elif self.custom_freqs == 'yarn': |
|
|
|
|
|
beta_fast = 32 |
|
beta_slow = 1 |
|
extrapolation_factor = 1 |
|
|
|
freqs_extrapolation = 1.0 / (theta ** (torch.arange(0, dim, 2).to(scale).float() / dim)) |
|
freqs_interpolation = 1.0 / torch.einsum('..., f -> ... f', scale, (theta ** (torch.arange(0, dim, 2).to(scale).float() / dim))) |
|
|
|
low, high = find_correction_range(beta_fast, beta_slow, dim, theta, ori_max_pe_len) |
|
freqs_mask = (1 - linear_ramp_mask(low, high, dim // 2).to(scale).float()) * extrapolation_factor |
|
freqs = freqs_interpolation * (1 - freqs_mask) + freqs_extrapolation * freqs_mask |
|
else: |
|
raise ValueError(f'Unknown modality {self.custom_freqs}. Only support normal, linear, ntk-aware, ntk-by-parts, yarn!') |
|
return freqs |
|
|
|
|
|
def online_get_2d_rope_from_grid(self, grid, size): |
|
''' |
|
grid: (B, 2, N) |
|
N = H * W |
|
the first dimension represents width, and the second reprensents height |
|
e.g., [0. 1. 2. 3. 0. 1. 2. 3. 0. 1. 2. 3.] |
|
[0. 0. 0. 0. 1. 1. 1. 1. 2. 2. 2. 2.] |
|
size: (B, 1, 2), h goes first and w goes last |
|
''' |
|
size = size.squeeze() |
|
if self.decouple: |
|
size_h = size[:, 0] |
|
size_w = size[:, 1] |
|
freqs_h = self.get_1d_rope_freqs(self.theta, self.dim, size_h, self.ori_max_pe_len) |
|
freqs_w = self.get_1d_rope_freqs(self.theta, self.dim, size_w, self.ori_max_pe_len) |
|
else: |
|
size_max = torch.max(size[:, 0], size[:, 1]) |
|
freqs_h = self.get_1d_rope_freqs(self.theta, self.dim, size_max, self.ori_max_pe_len) |
|
freqs_w = self.get_1d_rope_freqs(self.theta, self.dim, size_max, self.ori_max_pe_len) |
|
freqs_w = grid[:, 0][..., None] * freqs_w[:, None, :] |
|
freqs_w = repeat(freqs_w, '... n -> ... (n r)', r = 2) |
|
|
|
freqs_h = grid[:, 1][..., None] * freqs_h[:, None, :] |
|
freqs_h = repeat(freqs_h, '... n -> ... (n r)', r = 2) |
|
|
|
freqs = torch.cat([freqs_h, freqs_w], dim=-1) |
|
|
|
if self.custom_freqs == 'yarn': |
|
freqs_cos = freqs.cos() * self.mscale[:, None, None] |
|
freqs_sin = freqs.sin() * self.mscale[:, None, None] |
|
elif self.custom_freqs == 'ntk-aware-pro1': |
|
freqs_cos = freqs.cos() * self.proportion1[:, None, None] |
|
freqs_sin = freqs.sin() * self.proportion1[:, None, None] |
|
elif self.custom_freqs == 'ntk-aware-pro2': |
|
freqs_cos = freqs.cos() * self.proportion2[:, None, None] |
|
freqs_sin = freqs.sin() * self.proportion2[:, None, None] |
|
else: |
|
freqs_cos = freqs.cos() |
|
freqs_sin = freqs.sin() |
|
|
|
return freqs_cos, freqs_sin |
|
|
|
@lru_cache() |
|
def get_2d_rope_from_grid(self, grid): |
|
''' |
|
grid: (B, 2, N) |
|
N = H * W |
|
the first dimension represents width, and the second reprensents height |
|
e.g., [0. 1. 2. 3. 0. 1. 2. 3. 0. 1. 2. 3.] |
|
[0. 0. 0. 0. 1. 1. 1. 1. 2. 2. 2. 2.] |
|
''' |
|
freqs_h = torch.einsum('..., f -> ... f', grid[:, 0], self.freqs_h) |
|
freqs_h = repeat(freqs_h, '... n -> ... (n r)', r = 2) |
|
freqs_w = torch.einsum('..., f -> ... f', grid[:, 1], self.freqs_w) |
|
freqs_w = repeat(freqs_w, '... n -> ... (n r)', r = 2) |
|
|
|
freqs = torch.cat([freqs_h, freqs_w], dim=-1) |
|
|
|
if self.custom_freqs == 'yarn': |
|
freqs_cos = freqs.cos() * self.mscale |
|
freqs_sin = freqs.sin() * self.mscale |
|
elif self.custom_freqs in ['ntk-aware-pro1', 'scale1']: |
|
freqs_cos = freqs.cos() * self.proportion1 |
|
freqs_sin = freqs.sin() * self.proportion1 |
|
elif self.custom_freqs in ['ntk-aware-pro2', 'scale2']: |
|
freqs_cos = freqs.cos() * self.proportion2 |
|
freqs_sin = freqs.sin() * self.proportion2 |
|
else: |
|
freqs_cos = freqs.cos() |
|
freqs_sin = freqs.sin() |
|
|
|
return freqs_cos, freqs_sin |
|
|
|
@lru_cache() |
|
def get_cached_2d_rope_from_grid(self, grid: torch.Tensor): |
|
''' |
|
grid: (B, 2, N) |
|
N = H * W |
|
the first dimension represents width, and the second reprensents height |
|
e.g., [0. 1. 2. 3. 0. 1. 2. 3. 0. 1. 2. 3.] |
|
[0. 0. 0. 0. 1. 1. 1. 1. 2. 2. 2. 2.] |
|
''' |
|
if len(grid.shape) == 3: |
|
freqs_h, freqs_w = self.freqs_h_cached[grid[:, 0]], self.freqs_w_cached[grid[:, 1]] |
|
elif len(grid.shape) == 2: |
|
freqs_h, freqs_w = self.freqs_h_cached[grid[0]], self.freqs_w_cached[grid[1]] |
|
freqs = torch.cat([freqs_h, freqs_w], dim=-1) |
|
|
|
if self.custom_freqs == 'yarn': |
|
freqs_cos = freqs.cos() * self.mscale |
|
freqs_sin = freqs.sin() * self.mscale |
|
elif self.custom_freqs in ['ntk-aware-pro1', 'scale1']: |
|
freqs_cos = freqs.cos() * self.proportion1 |
|
freqs_sin = freqs.sin() * self.proportion1 |
|
elif self.custom_freqs in ['ntk-aware-pro2', 'scale2']: |
|
freqs_cos = freqs.cos() * self.proportion2 |
|
freqs_sin = freqs.sin() * self.proportion2 |
|
else: |
|
freqs_cos = freqs.cos() |
|
freqs_sin = freqs.sin() |
|
|
|
return freqs_cos, freqs_sin |
|
|
|
|
|
def forward(self, x, grid): |
|
''' |
|
x: (B, n_head, N, D) |
|
grid: (B, 2, N) |
|
''' |
|
|
|
|
|
|
|
freqs_cos, freqs_sin = self.get_cached_2d_rope_from_grid(grid) |
|
freqs_cos, freqs_sin = freqs_cos.unsqueeze(1), freqs_sin.unsqueeze(1) |
|
return x * freqs_cos + rotate_half(x) * freqs_sin |
|
|
|
|