zaydzuhri's picture
Add files using upload-large-folder tool
183cbc0 verified
raw
history blame
661 Bytes
# -*- coding: utf-8 -*-
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
from typing import Optional
import torch
def naive_parallel_rebased(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
scale: Optional[float] = None,
use_norm: bool = True,
) -> torch.Tensor:
if scale is None:
scale = q.shape[-1] ** -0.5
q = q * scale
attn = q @ k.transpose(-2, -1)
attn = attn ** 2
attn.masked_fill_(~torch.tril(torch.ones(q.shape[-2], q.shape[-2], dtype=torch.bool, device=q.device)), 0)
o = attn @ v
if use_norm:
z = attn.sum(-1)
return o / (z[..., None] + 1e-6)
else:
return o