# -*- coding: utf-8 -*- | |
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang | |
from typing import Optional | |
import torch | |
def naive_parallel_rebased( | |
q: torch.Tensor, | |
k: torch.Tensor, | |
v: torch.Tensor, | |
scale: Optional[float] = None, | |
use_norm: bool = True, | |
) -> torch.Tensor: | |
if scale is None: | |
scale = q.shape[-1] ** -0.5 | |
q = q * scale | |
attn = q @ k.transpose(-2, -1) | |
attn = attn ** 2 | |
attn.masked_fill_(~torch.tril(torch.ones(q.shape[-2], q.shape[-2], dtype=torch.bool, device=q.device)), 0) | |
o = attn @ v | |
if use_norm: | |
z = attn.sum(-1) | |
return o / (z[..., None] + 1e-6) | |
else: | |
return o | |