File size: 14,673 Bytes
d0e893e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
# --------------------------------------------------------
# FiT: A Flexible Vision Transformer for Image Generation
#
# Based on the following repository
# https://github.com/lucidrains/rotary-embedding-torch
# https://github.com/jquesnelle/yarn/blob/HEAD/scaled_rope
# https://colab.research.google.com/drive/1VI2nhlyKvd5cw4-zHvAIk00cAVj2lCCC#scrollTo=b80b3f37
# --------------------------------------------------------

import math
from math import pi
from typing import Optional, Any, Union, Tuple
import torch
from torch import nn

from einops import rearrange, repeat
from functools import lru_cache

#################################################################################
#                                 NTK Operations                                #
#################################################################################

def find_correction_factor(num_rotations, dim, base=10000, max_position_embeddings=2048):
    return (dim * math.log(max_position_embeddings/(num_rotations * 2 * math.pi)))/(2 * math.log(base)) #Inverse dim formula to find number of rotations

def find_correction_range(low_rot, high_rot, dim, base=10000, max_position_embeddings=2048):
    low = math.floor(find_correction_factor(low_rot, dim, base, max_position_embeddings))
    high = math.ceil(find_correction_factor(high_rot, dim, base, max_position_embeddings))
    return max(low, 0), min(high, dim-1) #Clamp values just in case

def linear_ramp_mask(min, max, dim):
    if min == max:
        max += 0.001 #Prevent singularity

    linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min)
    ramp_func = torch.clamp(linear_func, 0, 1)
    return ramp_func

def find_newbase_ntk(dim, base=10000, scale=1):
    # Base change formula
    return base * scale ** (dim / (dim-2))

def get_mscale(scale=torch.Tensor):
    # if scale <= 1:
    #     return 1.0
    # return 0.1 * math.log(scale) + 1.0
    return torch.where(scale <= 1., torch.tensor(1.0), 0.1 * torch.log(scale) + 1.0)

def get_proportion(L_test, L_train):
    L_test = L_test * 2
    return torch.where(torch.tensor(L_test/L_train) <= 1., torch.tensor(1.0), torch.sqrt(torch.log(torch.tensor(L_test))/torch.log(torch.tensor(L_train))))
    # return torch.sqrt(torch.log(torch.tensor(L_test))/torch.log(torch.tensor(L_train)))



#################################################################################
#                                 Rotate Q or K                                 #
#################################################################################

def rotate_half(x):
    x = rearrange(x, '... (d r) -> ... d r', r = 2)
    x1, x2 = x.unbind(dim = -1)
    x = torch.stack((-x2, x1), dim = -1)
    return rearrange(x, '... d r -> ... (d r)')



#################################################################################
#                               Core Vision RoPE                                #
#################################################################################

class VisionRotaryEmbedding(nn.Module):
    def __init__(
        self,
        head_dim: int,  # embed dimension for each head
        custom_freqs: str = 'normal',
        theta: int = 10000,
        online_rope: bool = False,
        max_cached_len: int = 1024,
        max_pe_len_h: Optional[int] = None,
        max_pe_len_w: Optional[int] = None,
        decouple: bool = False,
        ori_max_pe_len: Optional[int] = None,
    ):
        super().__init__()
        
        dim = head_dim // 2
        assert dim % 2 == 0 # accually, this is important
        self.dim = dim
        self.custom_freqs = custom_freqs.lower()
        self.theta = theta
        self.decouple = decouple
        self.ori_max_pe_len = ori_max_pe_len
        
        self.custom_freqs = custom_freqs.lower()
        if not online_rope:
            if self.custom_freqs in ['normal', 'scale1', 'scale2']:
                freqs_h = 1. / (theta ** (torch.arange(0, dim, 2).float() / dim))
                freqs_w = 1. / (theta ** (torch.arange(0, dim, 2).float() / dim))
            else:
                if decouple:
                    freqs_h = self.get_1d_rope_freqs(theta, dim, max_pe_len_h, ori_max_pe_len)
                    freqs_w = self.get_1d_rope_freqs(theta, dim, max_pe_len_w, ori_max_pe_len)
                else:
                    max_pe_len = max(max_pe_len_h, max_pe_len_w)
                    freqs_h = self.get_1d_rope_freqs(theta, dim, max_pe_len, ori_max_pe_len)
                    freqs_w = self.get_1d_rope_freqs(theta, dim, max_pe_len, ori_max_pe_len)
            
            self.register_buffer('freqs_h', freqs_h, persistent=False)        
            self.register_buffer('freqs_w', freqs_w, persistent=False)        
            
            if max_pe_len_h != None and max_pe_len_w != None and ori_max_pe_len != None:
                attn_factor = 1.0
                scale = torch.clamp_min(torch.tensor(max(max_pe_len_h, max_pe_len_w)) / ori_max_pe_len, 1.0)   # dynamic scale
                self.mscale = get_mscale(scale).to(scale) * attn_factor # Get n-d magnitude scaling corrected for interpolation
                self.proportion1 = get_proportion(max(max_pe_len_h, max_pe_len_w), ori_max_pe_len)
                self.proportion2 = get_proportion(max_pe_len_h * max_pe_len_w, ori_max_pe_len ** 2)
                
                
            freqs_h_cached = torch.einsum('..., f -> ... f', torch.arange(max_cached_len), self.freqs_h)
            freqs_h_cached = repeat(freqs_h_cached, '... n -> ... (n r)', r = 2)
            self.register_buffer('freqs_h_cached', freqs_h_cached, persistent=False) 
            freqs_w_cached = torch.einsum('..., f -> ... f', torch.arange(max_cached_len), self.freqs_w)
            freqs_w_cached = repeat(freqs_w_cached, '... n -> ... (n r)', r = 2)
            self.register_buffer('freqs_w_cached', freqs_w_cached, persistent=False) 
        

    def get_1d_rope_freqs(self, theta, dim, max_pe_len, ori_max_pe_len):
        # scaling operations for extrapolation
        assert isinstance(ori_max_pe_len, int)
        # scale = max_pe_len / ori_max_pe_len
        if not isinstance(max_pe_len, torch.Tensor):
            max_pe_len = torch.tensor(max_pe_len)
        scale = torch.clamp_min(max_pe_len / ori_max_pe_len, 1.0)   # dynamic scale
        
        if self.custom_freqs == 'linear': # equal to position interpolation
            freqs = 1. / torch.einsum('..., f -> ... f', scale, theta ** (torch.arange(0, dim, 2).float() / dim))
        elif self.custom_freqs == 'ntk-aware' or self.custom_freqs == 'ntk-aware-pro1' or self.custom_freqs == 'ntk-aware-pro2':
            freqs = 1. / torch.pow(
                find_newbase_ntk(dim, theta, scale).view(-1, 1), 
                (torch.arange(0, dim, 2).to(scale).float() / dim)
            ).squeeze()
        elif self.custom_freqs == 'ntk-by-parts':
            #Interpolation constants found experimentally for LLaMA (might not be totally optimal though)
            #Do not change unless there is a good reason for doing so!
            beta_0 = 1.25
            beta_1 = 0.75
            gamma_0 = 16
            gamma_1 = 2
            ntk_factor = 1
            extrapolation_factor = 1

            #Three RoPE extrapolation/interpolation methods
            freqs_base = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
            freqs_linear = 1.0 / torch.einsum('..., f -> ... f', scale, (theta ** (torch.arange(0, dim, 2).to(scale).float() / dim)))
            freqs_ntk = 1. / torch.pow(
                find_newbase_ntk(dim, theta, scale).view(-1, 1), 
                (torch.arange(0, dim, 2).to(scale).float() / dim)
            ).squeeze()
            
            #Combine NTK and Linear
            low, high = find_correction_range(beta_0, beta_1, dim, theta, ori_max_pe_len)
            freqs_mask = (1 - linear_ramp_mask(low, high, dim // 2).to(scale)) * ntk_factor
            freqs = freqs_linear * (1 - freqs_mask) + freqs_ntk * freqs_mask
            
            #Combine Extrapolation and NTK and Linear
            low, high = find_correction_range(gamma_0, gamma_1, dim, theta, ori_max_pe_len)
            freqs_mask = (1 - linear_ramp_mask(low, high, dim // 2).to(scale)) * extrapolation_factor
            freqs = freqs * (1 - freqs_mask) + freqs_base * freqs_mask
            
        elif self.custom_freqs == 'yarn':
            #Interpolation constants found experimentally for LLaMA (might not be totally optimal though)
            #Do not change unless there is a good reason for doing so!
            beta_fast = 32
            beta_slow = 1
            extrapolation_factor = 1
            
            freqs_extrapolation = 1.0 / (theta ** (torch.arange(0, dim, 2).to(scale).float() / dim))
            freqs_interpolation = 1.0 / torch.einsum('..., f -> ... f', scale, (theta ** (torch.arange(0, dim, 2).to(scale).float() / dim)))

            low, high = find_correction_range(beta_fast, beta_slow, dim, theta, ori_max_pe_len)
            freqs_mask = (1 - linear_ramp_mask(low, high, dim // 2).to(scale).float()) * extrapolation_factor # Get n-d rotational scaling corrected for extrapolation
            freqs = freqs_interpolation * (1 - freqs_mask) + freqs_extrapolation * freqs_mask            
        else:
            raise ValueError(f'Unknown modality {self.custom_freqs}. Only support normal, linear, ntk-aware, ntk-by-parts, yarn!')
        return freqs


    def online_get_2d_rope_from_grid(self, grid, size):
        '''
        grid: (B, 2, N)
            N = H * W
            the first dimension represents width, and the second reprensents height
            e.g.,   [0. 1. 2. 3. 0. 1. 2. 3. 0. 1. 2. 3.]
                    [0. 0. 0. 0. 1. 1. 1. 1. 2. 2. 2. 2.]
        size: (B, 1, 2), h goes first and w goes last
        '''
        size = size.squeeze()   # (B, 1, 2) -> (B, 2)
        if self.decouple:
            size_h = size[:, 0]
            size_w = size[:, 1]
            freqs_h = self.get_1d_rope_freqs(self.theta, self.dim, size_h, self.ori_max_pe_len)
            freqs_w = self.get_1d_rope_freqs(self.theta, self.dim, size_w, self.ori_max_pe_len)
        else:
            size_max = torch.max(size[:, 0], size[:, 1])
            freqs_h = self.get_1d_rope_freqs(self.theta, self.dim, size_max, self.ori_max_pe_len)
            freqs_w = self.get_1d_rope_freqs(self.theta, self.dim, size_max, self.ori_max_pe_len)
        freqs_w = grid[:, 0][..., None] * freqs_w[:, None, :]
        freqs_w = repeat(freqs_w, '... n -> ... (n r)', r = 2)
        
        freqs_h = grid[:, 1][..., None] * freqs_h[:, None, :]
        freqs_h = repeat(freqs_h, '... n -> ... (n r)', r = 2)
        
        freqs = torch.cat([freqs_h, freqs_w], dim=-1)   # (B, N, D)
        
        if self.custom_freqs == 'yarn':
            freqs_cos = freqs.cos() * self.mscale[:, None, None]
            freqs_sin = freqs.sin() * self.mscale[:, None, None]
        elif self.custom_freqs == 'ntk-aware-pro1':
            freqs_cos = freqs.cos() * self.proportion1[:, None, None]
            freqs_sin = freqs.sin() * self.proportion1[:, None, None]
        elif self.custom_freqs == 'ntk-aware-pro2':
            freqs_cos = freqs.cos() * self.proportion2[:, None, None]
            freqs_sin = freqs.sin() * self.proportion2[:, None, None]
        else:
            freqs_cos = freqs.cos()
            freqs_sin = freqs.sin()
            
        return freqs_cos, freqs_sin  

    @lru_cache()
    def get_2d_rope_from_grid(self, grid):
        '''
        grid: (B, 2, N)
            N = H * W
            the first dimension represents width, and the second reprensents height
            e.g.,   [0. 1. 2. 3. 0. 1. 2. 3. 0. 1. 2. 3.]
                    [0. 0. 0. 0. 1. 1. 1. 1. 2. 2. 2. 2.]
        '''  
        freqs_h = torch.einsum('..., f -> ... f', grid[:, 0], self.freqs_h)
        freqs_h = repeat(freqs_h, '... n -> ... (n r)', r = 2)
        freqs_w = torch.einsum('..., f -> ... f', grid[:, 1], self.freqs_w)
        freqs_w = repeat(freqs_w, '... n -> ... (n r)', r = 2)
        
        freqs = torch.cat([freqs_h, freqs_w], dim=-1)   # (B, N, D)
        
        if self.custom_freqs == 'yarn':
            freqs_cos = freqs.cos() * self.mscale
            freqs_sin = freqs.sin() * self.mscale
        elif self.custom_freqs in ['ntk-aware-pro1', 'scale1']:
            freqs_cos = freqs.cos() * self.proportion1
            freqs_sin = freqs.sin() * self.proportion1
        elif self.custom_freqs in ['ntk-aware-pro2', 'scale2']:
            freqs_cos = freqs.cos() * self.proportion2
            freqs_sin = freqs.sin() * self.proportion2
        else:
            freqs_cos = freqs.cos()
            freqs_sin = freqs.sin()

        return freqs_cos, freqs_sin
    
    @lru_cache()
    def get_cached_2d_rope_from_grid(self, grid: torch.Tensor):
        '''
        grid: (B, 2, N)
            N = H * W
            the first dimension represents width, and the second reprensents height
            e.g.,   [0. 1. 2. 3. 0. 1. 2. 3. 0. 1. 2. 3.]
                    [0. 0. 0. 0. 1. 1. 1. 1. 2. 2. 2. 2.]
        '''  
        if len(grid.shape) == 3:    # (B, 2, N)
            freqs_h, freqs_w = self.freqs_h_cached[grid[:, 0]], self.freqs_w_cached[grid[:, 1]]
        elif len(grid.shape) == 2:  # (2, N)
            freqs_h, freqs_w = self.freqs_h_cached[grid[0]], self.freqs_w_cached[grid[1]]
        freqs = torch.cat([freqs_h, freqs_w], dim=-1)   # (B, N, D)
        
        if self.custom_freqs == 'yarn':
            freqs_cos = freqs.cos() * self.mscale
            freqs_sin = freqs.sin() * self.mscale
        elif self.custom_freqs in ['ntk-aware-pro1', 'scale1']:
            freqs_cos = freqs.cos() * self.proportion1
            freqs_sin = freqs.sin() * self.proportion1
        elif self.custom_freqs in ['ntk-aware-pro2', 'scale2']:
            freqs_cos = freqs.cos() * self.proportion2
            freqs_sin = freqs.sin() * self.proportion2
        else:
            freqs_cos = freqs.cos()
            freqs_sin = freqs.sin()
        
        return freqs_cos, freqs_sin


    def forward(self, x, grid): 
        '''
        x: (B, n_head, N, D)
        grid: (B, 2, N)
        '''
        # freqs_cos, freqs_sin = self.get_2d_rope_from_grid(grid)
        # freqs_cos, freqs_sin = freqs_cos.unsqueeze(1), freqs_sin.unsqueeze(1)
        # using cache to accelerate, this is the same with the above codes:
        freqs_cos, freqs_sin = self.get_cached_2d_rope_from_grid(grid)
        freqs_cos, freqs_sin = freqs_cos.unsqueeze(1), freqs_sin.unsqueeze(1)
        return  x * freqs_cos + rotate_half(x) * freqs_sin