TiM / tim /models /utils /rope.py
Julien Blanchon
Clean Space repo (code only, checkpoints in model repo)
d0e893e
# --------------------------------------------------------
# FiT: A Flexible Vision Transformer for Image Generation
#
# Based on the following repository
# https://github.com/lucidrains/rotary-embedding-torch
# https://github.com/jquesnelle/yarn/blob/HEAD/scaled_rope
# https://colab.research.google.com/drive/1VI2nhlyKvd5cw4-zHvAIk00cAVj2lCCC#scrollTo=b80b3f37
# --------------------------------------------------------
import math
from math import pi
from typing import Optional, Any, Union, Tuple
import torch
from torch import nn
from einops import rearrange, repeat
from functools import lru_cache
#################################################################################
# NTK Operations #
#################################################################################
def find_correction_factor(num_rotations, dim, base=10000, max_position_embeddings=2048):
return (dim * math.log(max_position_embeddings/(num_rotations * 2 * math.pi)))/(2 * math.log(base)) #Inverse dim formula to find number of rotations
def find_correction_range(low_rot, high_rot, dim, base=10000, max_position_embeddings=2048):
low = math.floor(find_correction_factor(low_rot, dim, base, max_position_embeddings))
high = math.ceil(find_correction_factor(high_rot, dim, base, max_position_embeddings))
return max(low, 0), min(high, dim-1) #Clamp values just in case
def linear_ramp_mask(min, max, dim):
if min == max:
max += 0.001 #Prevent singularity
linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min)
ramp_func = torch.clamp(linear_func, 0, 1)
return ramp_func
def find_newbase_ntk(dim, base=10000, scale=1):
# Base change formula
return base * scale ** (dim / (dim-2))
def get_mscale(scale=torch.Tensor):
# if scale <= 1:
# return 1.0
# return 0.1 * math.log(scale) + 1.0
return torch.where(scale <= 1., torch.tensor(1.0), 0.1 * torch.log(scale) + 1.0)
def get_proportion(L_test, L_train):
L_test = L_test * 2
return torch.where(torch.tensor(L_test/L_train) <= 1., torch.tensor(1.0), torch.sqrt(torch.log(torch.tensor(L_test))/torch.log(torch.tensor(L_train))))
# return torch.sqrt(torch.log(torch.tensor(L_test))/torch.log(torch.tensor(L_train)))
#################################################################################
# Rotate Q or K #
#################################################################################
def rotate_half(x):
x = rearrange(x, '... (d r) -> ... d r', r = 2)
x1, x2 = x.unbind(dim = -1)
x = torch.stack((-x2, x1), dim = -1)
return rearrange(x, '... d r -> ... (d r)')
#################################################################################
# Core Vision RoPE #
#################################################################################
class VisionRotaryEmbedding(nn.Module):
def __init__(
self,
head_dim: int, # embed dimension for each head
custom_freqs: str = 'normal',
theta: int = 10000,
online_rope: bool = False,
max_cached_len: int = 1024,
max_pe_len_h: Optional[int] = None,
max_pe_len_w: Optional[int] = None,
decouple: bool = False,
ori_max_pe_len: Optional[int] = None,
):
super().__init__()
dim = head_dim // 2
assert dim % 2 == 0 # accually, this is important
self.dim = dim
self.custom_freqs = custom_freqs.lower()
self.theta = theta
self.decouple = decouple
self.ori_max_pe_len = ori_max_pe_len
self.custom_freqs = custom_freqs.lower()
if not online_rope:
if self.custom_freqs in ['normal', 'scale1', 'scale2']:
freqs_h = 1. / (theta ** (torch.arange(0, dim, 2).float() / dim))
freqs_w = 1. / (theta ** (torch.arange(0, dim, 2).float() / dim))
else:
if decouple:
freqs_h = self.get_1d_rope_freqs(theta, dim, max_pe_len_h, ori_max_pe_len)
freqs_w = self.get_1d_rope_freqs(theta, dim, max_pe_len_w, ori_max_pe_len)
else:
max_pe_len = max(max_pe_len_h, max_pe_len_w)
freqs_h = self.get_1d_rope_freqs(theta, dim, max_pe_len, ori_max_pe_len)
freqs_w = self.get_1d_rope_freqs(theta, dim, max_pe_len, ori_max_pe_len)
self.register_buffer('freqs_h', freqs_h, persistent=False)
self.register_buffer('freqs_w', freqs_w, persistent=False)
if max_pe_len_h != None and max_pe_len_w != None and ori_max_pe_len != None:
attn_factor = 1.0
scale = torch.clamp_min(torch.tensor(max(max_pe_len_h, max_pe_len_w)) / ori_max_pe_len, 1.0) # dynamic scale
self.mscale = get_mscale(scale).to(scale) * attn_factor # Get n-d magnitude scaling corrected for interpolation
self.proportion1 = get_proportion(max(max_pe_len_h, max_pe_len_w), ori_max_pe_len)
self.proportion2 = get_proportion(max_pe_len_h * max_pe_len_w, ori_max_pe_len ** 2)
freqs_h_cached = torch.einsum('..., f -> ... f', torch.arange(max_cached_len), self.freqs_h)
freqs_h_cached = repeat(freqs_h_cached, '... n -> ... (n r)', r = 2)
self.register_buffer('freqs_h_cached', freqs_h_cached, persistent=False)
freqs_w_cached = torch.einsum('..., f -> ... f', torch.arange(max_cached_len), self.freqs_w)
freqs_w_cached = repeat(freqs_w_cached, '... n -> ... (n r)', r = 2)
self.register_buffer('freqs_w_cached', freqs_w_cached, persistent=False)
def get_1d_rope_freqs(self, theta, dim, max_pe_len, ori_max_pe_len):
# scaling operations for extrapolation
assert isinstance(ori_max_pe_len, int)
# scale = max_pe_len / ori_max_pe_len
if not isinstance(max_pe_len, torch.Tensor):
max_pe_len = torch.tensor(max_pe_len)
scale = torch.clamp_min(max_pe_len / ori_max_pe_len, 1.0) # dynamic scale
if self.custom_freqs == 'linear': # equal to position interpolation
freqs = 1. / torch.einsum('..., f -> ... f', scale, theta ** (torch.arange(0, dim, 2).float() / dim))
elif self.custom_freqs == 'ntk-aware' or self.custom_freqs == 'ntk-aware-pro1' or self.custom_freqs == 'ntk-aware-pro2':
freqs = 1. / torch.pow(
find_newbase_ntk(dim, theta, scale).view(-1, 1),
(torch.arange(0, dim, 2).to(scale).float() / dim)
).squeeze()
elif self.custom_freqs == 'ntk-by-parts':
#Interpolation constants found experimentally for LLaMA (might not be totally optimal though)
#Do not change unless there is a good reason for doing so!
beta_0 = 1.25
beta_1 = 0.75
gamma_0 = 16
gamma_1 = 2
ntk_factor = 1
extrapolation_factor = 1
#Three RoPE extrapolation/interpolation methods
freqs_base = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
freqs_linear = 1.0 / torch.einsum('..., f -> ... f', scale, (theta ** (torch.arange(0, dim, 2).to(scale).float() / dim)))
freqs_ntk = 1. / torch.pow(
find_newbase_ntk(dim, theta, scale).view(-1, 1),
(torch.arange(0, dim, 2).to(scale).float() / dim)
).squeeze()
#Combine NTK and Linear
low, high = find_correction_range(beta_0, beta_1, dim, theta, ori_max_pe_len)
freqs_mask = (1 - linear_ramp_mask(low, high, dim // 2).to(scale)) * ntk_factor
freqs = freqs_linear * (1 - freqs_mask) + freqs_ntk * freqs_mask
#Combine Extrapolation and NTK and Linear
low, high = find_correction_range(gamma_0, gamma_1, dim, theta, ori_max_pe_len)
freqs_mask = (1 - linear_ramp_mask(low, high, dim // 2).to(scale)) * extrapolation_factor
freqs = freqs * (1 - freqs_mask) + freqs_base * freqs_mask
elif self.custom_freqs == 'yarn':
#Interpolation constants found experimentally for LLaMA (might not be totally optimal though)
#Do not change unless there is a good reason for doing so!
beta_fast = 32
beta_slow = 1
extrapolation_factor = 1
freqs_extrapolation = 1.0 / (theta ** (torch.arange(0, dim, 2).to(scale).float() / dim))
freqs_interpolation = 1.0 / torch.einsum('..., f -> ... f', scale, (theta ** (torch.arange(0, dim, 2).to(scale).float() / dim)))
low, high = find_correction_range(beta_fast, beta_slow, dim, theta, ori_max_pe_len)
freqs_mask = (1 - linear_ramp_mask(low, high, dim // 2).to(scale).float()) * extrapolation_factor # Get n-d rotational scaling corrected for extrapolation
freqs = freqs_interpolation * (1 - freqs_mask) + freqs_extrapolation * freqs_mask
else:
raise ValueError(f'Unknown modality {self.custom_freqs}. Only support normal, linear, ntk-aware, ntk-by-parts, yarn!')
return freqs
def online_get_2d_rope_from_grid(self, grid, size):
'''
grid: (B, 2, N)
N = H * W
the first dimension represents width, and the second reprensents height
e.g., [0. 1. 2. 3. 0. 1. 2. 3. 0. 1. 2. 3.]
[0. 0. 0. 0. 1. 1. 1. 1. 2. 2. 2. 2.]
size: (B, 1, 2), h goes first and w goes last
'''
size = size.squeeze() # (B, 1, 2) -> (B, 2)
if self.decouple:
size_h = size[:, 0]
size_w = size[:, 1]
freqs_h = self.get_1d_rope_freqs(self.theta, self.dim, size_h, self.ori_max_pe_len)
freqs_w = self.get_1d_rope_freqs(self.theta, self.dim, size_w, self.ori_max_pe_len)
else:
size_max = torch.max(size[:, 0], size[:, 1])
freqs_h = self.get_1d_rope_freqs(self.theta, self.dim, size_max, self.ori_max_pe_len)
freqs_w = self.get_1d_rope_freqs(self.theta, self.dim, size_max, self.ori_max_pe_len)
freqs_w = grid[:, 0][..., None] * freqs_w[:, None, :]
freqs_w = repeat(freqs_w, '... n -> ... (n r)', r = 2)
freqs_h = grid[:, 1][..., None] * freqs_h[:, None, :]
freqs_h = repeat(freqs_h, '... n -> ... (n r)', r = 2)
freqs = torch.cat([freqs_h, freqs_w], dim=-1) # (B, N, D)
if self.custom_freqs == 'yarn':
freqs_cos = freqs.cos() * self.mscale[:, None, None]
freqs_sin = freqs.sin() * self.mscale[:, None, None]
elif self.custom_freqs == 'ntk-aware-pro1':
freqs_cos = freqs.cos() * self.proportion1[:, None, None]
freqs_sin = freqs.sin() * self.proportion1[:, None, None]
elif self.custom_freqs == 'ntk-aware-pro2':
freqs_cos = freqs.cos() * self.proportion2[:, None, None]
freqs_sin = freqs.sin() * self.proportion2[:, None, None]
else:
freqs_cos = freqs.cos()
freqs_sin = freqs.sin()
return freqs_cos, freqs_sin
@lru_cache()
def get_2d_rope_from_grid(self, grid):
'''
grid: (B, 2, N)
N = H * W
the first dimension represents width, and the second reprensents height
e.g., [0. 1. 2. 3. 0. 1. 2. 3. 0. 1. 2. 3.]
[0. 0. 0. 0. 1. 1. 1. 1. 2. 2. 2. 2.]
'''
freqs_h = torch.einsum('..., f -> ... f', grid[:, 0], self.freqs_h)
freqs_h = repeat(freqs_h, '... n -> ... (n r)', r = 2)
freqs_w = torch.einsum('..., f -> ... f', grid[:, 1], self.freqs_w)
freqs_w = repeat(freqs_w, '... n -> ... (n r)', r = 2)
freqs = torch.cat([freqs_h, freqs_w], dim=-1) # (B, N, D)
if self.custom_freqs == 'yarn':
freqs_cos = freqs.cos() * self.mscale
freqs_sin = freqs.sin() * self.mscale
elif self.custom_freqs in ['ntk-aware-pro1', 'scale1']:
freqs_cos = freqs.cos() * self.proportion1
freqs_sin = freqs.sin() * self.proportion1
elif self.custom_freqs in ['ntk-aware-pro2', 'scale2']:
freqs_cos = freqs.cos() * self.proportion2
freqs_sin = freqs.sin() * self.proportion2
else:
freqs_cos = freqs.cos()
freqs_sin = freqs.sin()
return freqs_cos, freqs_sin
@lru_cache()
def get_cached_2d_rope_from_grid(self, grid: torch.Tensor):
'''
grid: (B, 2, N)
N = H * W
the first dimension represents width, and the second reprensents height
e.g., [0. 1. 2. 3. 0. 1. 2. 3. 0. 1. 2. 3.]
[0. 0. 0. 0. 1. 1. 1. 1. 2. 2. 2. 2.]
'''
if len(grid.shape) == 3: # (B, 2, N)
freqs_h, freqs_w = self.freqs_h_cached[grid[:, 0]], self.freqs_w_cached[grid[:, 1]]
elif len(grid.shape) == 2: # (2, N)
freqs_h, freqs_w = self.freqs_h_cached[grid[0]], self.freqs_w_cached[grid[1]]
freqs = torch.cat([freqs_h, freqs_w], dim=-1) # (B, N, D)
if self.custom_freqs == 'yarn':
freqs_cos = freqs.cos() * self.mscale
freqs_sin = freqs.sin() * self.mscale
elif self.custom_freqs in ['ntk-aware-pro1', 'scale1']:
freqs_cos = freqs.cos() * self.proportion1
freqs_sin = freqs.sin() * self.proportion1
elif self.custom_freqs in ['ntk-aware-pro2', 'scale2']:
freqs_cos = freqs.cos() * self.proportion2
freqs_sin = freqs.sin() * self.proportion2
else:
freqs_cos = freqs.cos()
freqs_sin = freqs.sin()
return freqs_cos, freqs_sin
def forward(self, x, grid):
'''
x: (B, n_head, N, D)
grid: (B, 2, N)
'''
# freqs_cos, freqs_sin = self.get_2d_rope_from_grid(grid)
# freqs_cos, freqs_sin = freqs_cos.unsqueeze(1), freqs_sin.unsqueeze(1)
# using cache to accelerate, this is the same with the above codes:
freqs_cos, freqs_sin = self.get_cached_2d_rope_from_grid(grid)
freqs_cos, freqs_sin = freqs_cos.unsqueeze(1), freqs_sin.unsqueeze(1)
return x * freqs_cos + rotate_half(x) * freqs_sin