Model Card for OrthoViT-B ImageNet-1k
This model is a Vision Transformer (ViT-B) trained on ImageNet-1k, incorporating Orthogonal Residual Updates as proposed in the paper Revisiting Residual Connections: Orthogonal Updates for Stable and Efficient Deep Networks. The core idea is to decompose a module's output relative to the input stream and add only the component orthogonal to this stream, aiming for richer feature learning and more efficient training.
This specific checkpoint was trained for approximately 90,000 steps (roughly 270 epochs out of a planned 300).
Model Details
Evaluation
Note: Validation accuracy below is measured on checkpoint at step 90k (not the final model); results may differ slightly from those reported in the paper.
Steps | Connection | Top-1 Accuracy (%) | Top-5 Accuracy (%) | Link |
---|---|---|---|---|
90k | Orthogonal | 74.62 | 92.26 | here |
90k | Linear | 71.23 | 90.29 | link |
Abstract
Residual connections are pivotal for deep neural networks, enabling greater depth by mitigating vanishing gradients. However, in standard residual updates, the module's output is directly added to the input stream. This can lead to updates that predominantly reinforce or modulate the existing stream direction, potentially underutilizing the module's capacity for learning entirely novel features. In this work, we introduce Orthogonal Residual Update: we decompose the module's output relative to the input stream and add only the component orthogonal to this stream. This design aims to guide modules to contribute primarily new representational directions, fostering richer feature learning while promoting more efficient training. We demonstrate that our orthogonal update strategy improves generalization accuracy and training stability across diverse architectures (ResNetV2, Vision Transformers) and datasets (CIFARs, TinyImageNet, ImageNet-1k), achieving, for instance, a +4.3%p top-1 accuracy gain for ViT-B on ImageNet-1k.
Method Overview
Our core idea is to modify the standard residual update $x_{n+1} = x_n + f(\sigma(x_n))$ by projecting out the component of $f(\sigma(x_n))$ that is parallel to $x_n$. The update then becomes $x_{n+1} = x_n + f_{\perp}(x_n)$, where $f_{\perp}(x_n)$ is the component of $f(\sigma(x_n))$ orthogonal to $x_n$.
Figure 1: (Left) Standard residual update. (Right) Our Orthogonal Residual Update, which discards the parallel component $f_{||}$ and adds only the orthogonal component $f_{\perp}$.
This approach aims to ensure that each module primarily contributes new information to the residual stream, enhancing representational diversity and mitigating potential interference from updates that merely rescale or oppose the existing stream.
Key Results: Stable and Efficient Learning
Our Orthogonal Residual Update strategy leads to more stable training dynamics and improved learning efficiency. For example, models trained with our method often exhibit faster convergence to better generalization performance, as illustrated by comparative training curves.
Figure 2: Example comparison (e.g., ViT-B on ImageNet-1k) showing Orthogonal Residual Update (blue) achieving lower training loss and higher validation accuracy in less wall-clock time compared to linear residual updates (red).
Model Sources
- Repository (Original Implementation): https://github.com/BootsofLagrangian/ortho-residual
- Paper: Revisiting Residual Connections: Orthogonal Updates for Stable and Efficient Deep Networks (arXiv:2505.11881)
Evaluation
import torch
import torchvision.transforms as transforms
from datasets import load_dataset
from torch.utils.data import DataLoader
from transformers import AutoModelForImageClassification
from tqdm import tqdm
import argparse
from typing import Tuple, List
def accuracy_counts(
logits: torch.Tensor,
target: torch.Tensor,
topk: Tuple[int, ...] = (1, 5),
) -> List[int]:
"""
Given model outputs and targets, return a list of correct-counts
for each k in topk.
"""
maxk = max(topk)
_, pred = logits.topk(maxk, dim=1, largest=True, sorted=True)
pred = pred.t()
correct = pred.eq(target.view(1, -1).expand_as(pred))
res = []
for k in topk:
correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
res.append(correct_k.item())
return res
def evaluate_model():
device = torch.device("cuda" if torch.cuda.is_available() and not args.cpu else "cpu")
print(f"Using device: {device}")
model = AutoModelForImageClassification.from_pretrained(
"BootsofLagrangian/ortho-vit-b-imagenet1k-hf",
trust_remote_code=True
)
model.to(device)
model.eval()
img_size = 224
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
transform_eval = transforms.Compose([
transforms.Lambda(lambda img: img.convert("RGB")),
transforms.Resize(img_size, interpolation=transforms.InterpolationMode.BICUBIC),
transforms.CenterCrop(img_size),
transforms.ToTensor(),
transforms.Normalize(mean, std),
])
val_dataset = load_dataset("timm/imagenet-1k-wds", split="validation")
def collate_fn(batch):
images = torch.stack([transform_eval(item['jpg']) for item in batch])
labels = torch.tensor([item['cls'] for item in batch])
return images, labels
val_loader = DataLoader(
val_dataset,
batch_size=32,
shuffle=False,
num_workers=4,
collate_fn=collate_fn,
pin_memory=True
)
total_samples, correct_top1, correct_top5 = 0, 0, 0
with torch.no_grad():
for images, labels in tqdm(val_loader, desc="Evaluating"):
images = images.to(device)
labels = labels.to(device)
outputs = model(pixel_values=images)
logits = outputs.logits
counts = accuracy_counts(logits, labels, topk=(1, 5))
correct_top1 += counts[0]
correct_top5 += counts[1]
total_samples += images.size(0)
top1_accuracy = (correct_top1 / total_samples) * 100
top5_accuracy = (correct_top5 / total_samples) * 100
print("\n--- Evaluation Results ---")
print(f"Total samples evaluated: {total_samples}")
print(f"Top-1 Accuracy: {top1_accuracy:.2f}%")
print(f"Top-5 Accuracy: {top5_accuracy:.2f}%")
Citation
@article{oh2025revisitingresidualconnectionsorthogonal,
title={Revisiting Residual Connections: Orthogonal Updates for Stable and Efficient Deep Networks},
author={Giyeong Oh and Woohyun Cho and Siyeol Kim and Suhwan Choi and Younjae Yu},
year={2025},
journal={arXiv preprint arXiv:2505.11881},
eprint={2505.11881},
archivePrefix={arXiv},
primaryClass={cs.CV},
url={https://arxiv.org/abs/2505.11881}
}
- Downloads last month
- 20