|
import torch |
|
from typing import Optional |
|
from einops import rearrange |
|
|
|
def naive_rectified_attn( |
|
q: torch.Tensor, |
|
k: torch.Tensor, |
|
v: torch.Tensor, |
|
scale: Optional[float] = None, |
|
cu_seqlens: Optional[torch.LongTensor] = None, |
|
head_first: bool = False |
|
) -> torch.Tensor: |
|
head_dim = q.shape[-1] |
|
if scale is None: |
|
scale = 1.0 / (head_dim ** 0.5) |
|
if not head_first: |
|
q, k, v = map(lambda x: rearrange(x, 'b t h d -> b h t d'), (q, k, v)) |
|
q_len = q.shape[-2] |
|
k_len = k.shape[-2] |
|
mask = torch.tril(torch.ones(k_len, k_len, device=q.device)) |
|
wei = torch.matmul(q, k.transpose(2, 3)) |
|
wei = wei * scale |
|
wei = torch.where(wei >= 0, wei, float('-inf')) |
|
wei = wei.masked_fill(mask[k_len-q_len:k_len, :k_len] == 0, float('-inf')) |
|
wei = torch.softmax(wei.float(), dim=-1).to(q.dtype) |
|
wei = torch.nan_to_num(wei, nan=0.0) |
|
o = torch.matmul(wei, v) |
|
if not head_first: |
|
o = rearrange(o, 'b h t d -> b t h d') |
|
return o, wei |