BootsofLagrangian's picture
Upload OrthoViTForImageClassification
77cd05d verified
import torch
def _identity(x: torch.Tensor, f_x: torch.Tensor) -> torch.Tensor:
"""
linear residual connection
x : residual stream
f_x : attention/MLP/conv(if channel-wise) output
"""
return f_x
def _orthogonal_channel(x: torch.Tensor, f_x: torch.Tensor, dim: int, eps: float) -> torch.Tensor:
"""
orthogonal residual connection
x : residual stream
f_x : attention/MLP/conv(if channel-wise) output
"""
dot = (x * f_x).sum(dim, keepdim=True)
norm_x2 = (x * x ).sum(dim, keepdim=True) + eps
scale = (dot / norm_x2).to(x.dtype) # for amp
proj_out = scale * x
return f_x - proj_out
def _orthogonal_global(x: torch.Tensor, f_x: torch.Tensor, dim: int, eps: float) -> torch.Tensor:
"""
orthogonal residual connection
x : residual stream
f_x : conv output
"""
original_shape = x.shape
positive_dim = dim if dim >= 0 else len(original_shape) + dim
x_view = x.flatten(dim) # [B, CHW...]
f_view = f_x.flatten(dim) # same
dot = (x_view * f_view).sum(dim=dim, keepdim=True) # [B,1]
norm_sq = (x_view * x_view).sum(dim=dim, keepdim=True) + eps
scale = (dot / norm_sq).to(x.dtype) # for amp
unsqueeze_times = len(original_shape) - positive_dim - 1
for _ in range(unsqueeze_times):
scale = scale.unsqueeze(-1)
proj_out = scale * x # broadcast
return f_x - proj_out
def connect(x, f_x, *,
method="linear", orthogonal_method="global",
dim=-1, eps=1e-6, perturbation=None):
if perturbation is not None:
raise NotImplementedError("perturbation is not implemented yet")
if method == "linear":
return x + _identity(x, f_x)
elif method == "orthogonal":
if orthogonal_method == "global":
return x + _orthogonal_global(x, f_x, dim, eps)
elif orthogonal_method == "channel":
return x + _orthogonal_channel(x, f_x, dim, eps)
else:
raise ValueError(f"unknown orthogonal method: {method}")
else:
raise ValueError(f"unknown connect method: {method}")