Model Card for OrthoViT-B ImageNet-1k

This model is a Vision Transformer (ViT-B) trained on ImageNet-1k, incorporating Orthogonal Residual Updates as proposed in the paper Revisiting Residual Connections: Orthogonal Updates for Stable and Efficient Deep Networks. The core idea is to decompose a module's output relative to the input stream and add only the component orthogonal to this stream, aiming for richer feature learning and more efficient training.

This specific checkpoint was trained for approximately 90,000 steps (roughly 270 epochs out of a planned 300).

Model Details

Evaluation

Note: Validation accuracy below is measured on checkpoint at step 90k (not the final model); results may differ slightly from those reported in the paper.

Steps Connection Top-1 Accuracy (%) Top-5 Accuracy (%) Link
90k Orthogonal 74.62 92.26 here
90k Linear 71.23 90.29 link

Abstract

Residual connections are pivotal for deep neural networks, enabling greater depth by mitigating vanishing gradients. However, in standard residual updates, the module's output is directly added to the input stream. This can lead to updates that predominantly reinforce or modulate the existing stream direction, potentially underutilizing the module's capacity for learning entirely novel features. In this work, we introduce Orthogonal Residual Update: we decompose the module's output relative to the input stream and add only the component orthogonal to this stream. This design aims to guide modules to contribute primarily new representational directions, fostering richer feature learning while promoting more efficient training. We demonstrate that our orthogonal update strategy improves generalization accuracy and training stability across diverse architectures (ResNetV2, Vision Transformers) and datasets (CIFARs, TinyImageNet, ImageNet-1k), achieving, for instance, a +4.3%p top-1 accuracy gain for ViT-B on ImageNet-1k.

Method Overview

Our core idea is to modify the standard residual update $x_{n+1} = x_n + f(\sigma(x_n))$ by projecting out the component of $f(\sigma(x_n))$ that is parallel to $x_n$. The update then becomes $x_{n+1} = x_n + f_{\perp}(x_n)$, where $f_{\perp}(x_n)$ is the component of $f(\sigma(x_n))$ orthogonal to $x_n$.

Figure 1: Intuition behind Orthogonal Residual Update Figure 1: (Left) Standard residual update. (Right) Our Orthogonal Residual Update, which discards the parallel component $f_{||}$ and adds only the orthogonal component $f_{\perp}$.

This approach aims to ensure that each module primarily contributes new information to the residual stream, enhancing representational diversity and mitigating potential interference from updates that merely rescale or oppose the existing stream.

Key Results: Stable and Efficient Learning

Our Orthogonal Residual Update strategy leads to more stable training dynamics and improved learning efficiency. For example, models trained with our method often exhibit faster convergence to better generalization performance, as illustrated by comparative training curves.

Figure 2: Training Dynamics and Efficiency Comparison Figure 2: Example comparison (e.g., ViT-B on ImageNet-1k) showing Orthogonal Residual Update (blue) achieving lower training loss and higher validation accuracy in less wall-clock time compared to linear residual updates (red).

Model Sources

Evaluation

import torch
import torchvision.transforms as transforms
from datasets import load_dataset
from torch.utils.data import DataLoader
from transformers import AutoModelForImageClassification
from tqdm import tqdm
import argparse
from typing import Tuple, List

def accuracy_counts(
    logits: torch.Tensor,
    target: torch.Tensor,
    topk: Tuple[int, ...] = (1, 5),
) -> List[int]:
    """
    Given model outputs and targets, return a list of correct-counts
    for each k in topk.
    """
    maxk = max(topk)
    _, pred = logits.topk(maxk, dim=1, largest=True, sorted=True)
    pred = pred.t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))

    res = []
    for k in topk:
        correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
        res.append(correct_k.item())
    return res

def evaluate_model():
    device = torch.device("cuda" if torch.cuda.is_available() and not args.cpu else "cpu")
    print(f"Using device: {device}")

    model = AutoModelForImageClassification.from_pretrained(
        "BootsofLagrangian/ortho-vit-b-imagenet1k-hf",
        trust_remote_code=True
    )
    model.to(device)
    model.eval()

    img_size = 224
    mean = [0.485, 0.456, 0.406]
    std = [0.229, 0.224, 0.225]

    transform_eval = transforms.Compose([
        transforms.Lambda(lambda img: img.convert("RGB")),
        transforms.Resize(img_size, interpolation=transforms.InterpolationMode.BICUBIC),
        transforms.CenterCrop(img_size),
        transforms.ToTensor(),
        transforms.Normalize(mean, std),
    ])
    val_dataset = load_dataset("timm/imagenet-1k-wds", split="validation")

    def collate_fn(batch):
        images = torch.stack([transform_eval(item['jpg']) for item in batch])
        labels = torch.tensor([item['cls'] for item in batch])
        return images, labels

    val_loader = DataLoader(
        val_dataset,
        batch_size=32,
        shuffle=False,
        num_workers=4,
        collate_fn=collate_fn,
        pin_memory=True
    )
    total_samples, correct_top1, correct_top5 = 0, 0, 0

    with torch.no_grad():
        for images, labels in tqdm(val_loader, desc="Evaluating"):
            images = images.to(device)
            labels = labels.to(device)

            outputs = model(pixel_values=images)
            logits = outputs.logits

            counts = accuracy_counts(logits, labels, topk=(1, 5))
            correct_top1 += counts[0]
            correct_top5 += counts[1]
            total_samples += images.size(0)

    top1_accuracy = (correct_top1 / total_samples) * 100
    top5_accuracy = (correct_top5 / total_samples) * 100

    print("\n--- Evaluation Results ---")
    print(f"Total samples evaluated: {total_samples}")
    print(f"Top-1 Accuracy: {top1_accuracy:.2f}%")
    print(f"Top-5 Accuracy: {top5_accuracy:.2f}%")

Citation

@article{oh2025revisitingresidualconnectionsorthogonal,
      title={Revisiting Residual Connections: Orthogonal Updates for Stable and Efficient Deep Networks}, 
      author={Giyeong Oh and Woohyun Cho and Siyeol Kim and Suhwan Choi and Younjae Yu},
      year={2025},
      journal={arXiv preprint arXiv:2505.11881},
      eprint={2505.11881},
      archivePrefix={arXiv},
      primaryClass={cs.CV},
      url={https://arxiv.org/abs/2505.11881}
}
Downloads last month
20
Safetensors
Model size
86.5M params
Tensor type
F32
ยท
Inference Providers NEW
This model isn't deployed by any Inference Provider. ๐Ÿ™‹ Ask for provider support