|
import torch
|
|
import torch.nn as nn
|
|
import math
|
|
from ._ops import ops
|
|
|
|
|
|
def matmul_persistent(
|
|
a: torch.Tensor, b: torch.Tensor, bias: torch.Tensor = None
|
|
) -> torch.Tensor:
|
|
"""
|
|
Persistent matrix multiplication with optional bias.
|
|
|
|
Args:
|
|
a: Input tensor of shape (M, K)
|
|
b: Input tensor of shape (K, N)
|
|
bias: Optional bias tensor of shape (N,)
|
|
|
|
Returns:
|
|
Output tensor of shape (M, N)
|
|
"""
|
|
assert a.shape[1] == b.shape[0], "Incompatible dimensions"
|
|
assert a.dtype == b.dtype, "Incompatible dtypes"
|
|
assert bias is None or bias.dim() == 1, "Bias must be 1D"
|
|
|
|
M, K = a.shape
|
|
K, N = b.shape
|
|
|
|
c = torch.empty((M, N), device=a.device, dtype=a.dtype)
|
|
ops.matmul_persistent(a, b, c, bias)
|
|
|
|
return c
|
|
|
|
|
|
def log_softmax(input: torch.Tensor, dim: int = -1) -> torch.Tensor:
|
|
"""
|
|
Compute log_softmax using custom CUDA kernel.
|
|
|
|
Args:
|
|
input: Input tensor
|
|
dim: Dimension along which to compute log_softmax (only -1 supported)
|
|
|
|
Returns:
|
|
Tensor with log_softmax applied
|
|
"""
|
|
if dim != -1 and dim != input.ndim - 1:
|
|
raise ValueError(
|
|
"This implementation only supports log_softmax along the last dimension"
|
|
)
|
|
|
|
output = torch.empty_like(input)
|
|
ops.log_softmax(input, output)
|
|
|
|
return output
|
|
|
|
|
|
def mean_dim(
|
|
input: torch.Tensor, dim: int, keepdim: bool = False, dtype: torch.dtype = None
|
|
) -> torch.Tensor:
|
|
"""
|
|
Compute mean along a single dimension.
|
|
|
|
Args:
|
|
input: Input tensor
|
|
dim: Single dimension along which to compute mean
|
|
keepdim: Whether to keep the reduced dimension
|
|
dtype: Output dtype
|
|
|
|
Returns:
|
|
Tensor with mean values along specified dimension
|
|
"""
|
|
assert input.is_cuda, "Input must be a CUDA tensor"
|
|
assert -input.ndim <= dim < input.ndim, f"Invalid dimension {dim}"
|
|
|
|
if dim < 0:
|
|
dim = dim + input.ndim
|
|
|
|
if dtype is None:
|
|
if input.dtype in [torch.int8, torch.int16, torch.int32, torch.int64]:
|
|
dtype = torch.float32
|
|
else:
|
|
dtype = input.dtype
|
|
|
|
if input.dtype != dtype:
|
|
input = input.to(dtype)
|
|
|
|
shape = list(input.shape)
|
|
|
|
if keepdim:
|
|
output_shape = shape.copy()
|
|
output_shape[dim] = 1
|
|
else:
|
|
output_shape = shape[:dim] + shape[dim + 1 :]
|
|
|
|
output = torch.empty(output_shape, dtype=dtype, device=input.device)
|
|
ops.mean_dim(input, output, dim)
|
|
|
|
return output
|
|
|
|
|
|
|
|
def mm_batch_invariant(a, b):
|
|
return matmul_persistent(a, b)
|
|
|
|
|
|
def addmm_batch_invariant(bias, a, b):
|
|
return matmul_persistent(a, b, bias=bias)
|
|
|
|
|
|
def _log_softmax_batch_invariant(input, dim, _half_to_float):
|
|
assert not _half_to_float, "not implemented"
|
|
return log_softmax(input, dim=dim)
|
|
|
|
|
|
def mean_batch_invariant(input, dim, keepdim=False, dtype: torch.dtype = None):
|
|
if len(dim) == 1:
|
|
return mean_dim(input, dim[0], keepdim=keepdim, dtype=dtype)
|
|
else:
|
|
|
|
n_elems = 1
|
|
for d in dim:
|
|
n_elems *= input.shape[d]
|
|
return torch.sum(input, dim=dim, keepdim=keepdim, dtype=torch.float32) / n_elems
|
|
|
|
class BatchInvariantAttention(nn.Module):
|
|
"""
|
|
Batch invariant multi-head attention implementation.
|
|
Compatible with transformers library integration.
|
|
"""
|
|
|
|
def __init__(self, config):
|
|
super().__init__()
|
|
self.config = config
|
|
self.hidden_size = config.hidden_size
|
|
self.num_heads = config.num_attention_heads
|
|
self.head_dim = self.hidden_size // self.num_heads
|
|
self.max_position_embeddings = getattr(config, "max_position_embeddings", 2048)
|
|
|
|
if (self.head_dim * self.num_heads) != self.hidden_size:
|
|
raise ValueError(
|
|
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
|
|
f" and `num_heads`: {self.num_heads})."
|
|
)
|
|
|
|
|
|
self.q_proj = nn.Linear(
|
|
self.hidden_size, self.num_heads * self.head_dim, bias=False
|
|
)
|
|
self.k_proj = nn.Linear(
|
|
self.hidden_size, self.num_heads * self.head_dim, bias=False
|
|
)
|
|
self.v_proj = nn.Linear(
|
|
self.hidden_size, self.num_heads * self.head_dim, bias=False
|
|
)
|
|
self.o_proj = nn.Linear(
|
|
self.num_heads * self.head_dim, self.hidden_size, bias=False
|
|
)
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
attention_mask: torch.Tensor = None,
|
|
position_ids: torch.Tensor = None,
|
|
past_key_value=None,
|
|
output_attentions: bool = False,
|
|
use_cache: bool = False,
|
|
cache_position: torch.Tensor = None,
|
|
**kwargs,
|
|
):
|
|
batch_size, seq_len, _ = hidden_states.size()
|
|
|
|
|
|
query_states = self._batch_invariant_linear(hidden_states, self.q_proj.weight)
|
|
key_states = self._batch_invariant_linear(hidden_states, self.k_proj.weight)
|
|
value_states = self._batch_invariant_linear(hidden_states, self.v_proj.weight)
|
|
|
|
|
|
query_states = query_states.view(
|
|
batch_size, seq_len, self.num_heads, self.head_dim
|
|
).transpose(1, 2)
|
|
key_states = key_states.view(
|
|
batch_size, seq_len, self.num_heads, self.head_dim
|
|
).transpose(1, 2)
|
|
value_states = value_states.view(
|
|
batch_size, seq_len, self.num_heads, self.head_dim
|
|
).transpose(1, 2)
|
|
|
|
|
|
attn_weights = torch.matmul(
|
|
query_states, key_states.transpose(2, 3)
|
|
) / math.sqrt(self.head_dim)
|
|
|
|
|
|
if attention_mask is not None:
|
|
attn_weights = attn_weights + attention_mask
|
|
|
|
|
|
attn_weights_log = log_softmax(attn_weights, dim=-1)
|
|
attn_weights = torch.exp(attn_weights_log)
|
|
|
|
|
|
attn_output = torch.matmul(attn_weights, value_states)
|
|
|
|
|
|
attn_output = attn_output.transpose(1, 2).contiguous()
|
|
attn_output = attn_output.reshape(batch_size, seq_len, self.hidden_size)
|
|
attn_output = self._batch_invariant_linear(attn_output, self.o_proj.weight)
|
|
|
|
outputs = (attn_output,)
|
|
if output_attentions:
|
|
outputs += (attn_weights,)
|
|
if use_cache:
|
|
outputs += (past_key_value,)
|
|
|
|
return outputs
|
|
|
|
def _batch_invariant_linear(
|
|
self, input_tensor: torch.Tensor, weight: torch.Tensor
|
|
) -> torch.Tensor:
|
|
"""Apply linear transformation using batch invariant matrix multiplication"""
|
|
original_shape = input_tensor.shape
|
|
input_2d = input_tensor.view(-1, original_shape[-1])
|
|
output_2d = matmul_persistent(input_2d, weight.t())
|
|
return output_2d.view(*original_shape[:-1], -1)
|
|
|
|
|
|
class BatchInvariantMLP(nn.Module):
|
|
"""
|
|
Batch invariant MLP implementation.
|
|
"""
|
|
|
|
def __init__(self, config):
|
|
super().__init__()
|
|
self.config = config
|
|
self.hidden_size = config.hidden_size
|
|
self.intermediate_size = config.intermediate_size
|
|
|
|
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
|
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
|
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
|
self.act_fn = (
|
|
nn.SiLU()
|
|
)
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
|
|
gate = self._batch_invariant_linear(x, self.gate_proj.weight)
|
|
up = self._batch_invariant_linear(x, self.up_proj.weight)
|
|
|
|
|
|
intermediate = self.act_fn(gate) * up
|
|
|
|
|
|
output = self._batch_invariant_linear(intermediate, self.down_proj.weight)
|
|
return output
|
|
|
|
def _batch_invariant_linear(
|
|
self, input_tensor: torch.Tensor, weight: torch.Tensor
|
|
) -> torch.Tensor:
|
|
"""Apply linear transformation using batch invariant matrix multiplication"""
|
|
original_shape = input_tensor.shape
|
|
input_2d = input_tensor.view(-1, original_shape[-1])
|
|
output_2d = matmul_persistent(input_2d, weight.t())
|
|
return output_2d.view(*original_shape[:-1], -1)
|
|
|
|
|
|
class BatchInvariantRMSNorm(nn.Module):
|
|
"""
|
|
Batch invariant RMS normalization implementation.
|
|
"""
|
|
|
|
def __init__(self, hidden_size, eps=1e-6):
|
|
super().__init__()
|
|
self.weight = nn.Parameter(torch.ones(hidden_size))
|
|
self.variance_epsilon = eps
|
|
|
|
def forward(self, hidden_states):
|
|
input_dtype = hidden_states.dtype
|
|
hidden_states = hidden_states.to(torch.float32)
|
|
|
|
|
|
variance = mean_dim(hidden_states.pow(2), dim=-1, keepdim=True)
|
|
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
|
|
|
return self.weight * hidden_states.to(input_dtype)
|
|
|
|
|
|
|
|
__all__ += ["BatchInvariantAttention", "BatchInvariantMLP", "BatchInvariantRMSNorm"] |