InterFuser: Multi-modal Transformer for Autonomous Driving

This is a Hugging Face repository for the InterFuser model, designed for end-to-end autonomous driving tasks. It processes multi-view camera images and LiDAR data to predict driving waypoints and perceive the surrounding traffic environment.

This model was trained by [Your Name or Alias].

Model Architecture

The model, Interfuser, is a Transformer-based architecture that:

  • Uses a CNN backbone (ResNet-50 for RGB, ResNet-18 for LiDAR) to extract features.
  • A Transformer Encoder fuses these multi-modal features.
  • A Transformer Decoder predicts waypoints, junction status, and a Bird's-Eye-View (BEV) map of traffic.

How to Use

First, make sure you have timm installed: pip install timm.

You can then use the model with AutoModel. Remember to pass trust_remote_code=True.

import torch
from transformers import AutoModel, AutoConfig

# === 1. Load Model from Hugging Face Hub ===
model_id = "your-username/interfuser-driving-model" # Replace with your REPO_ID
config = AutoConfig.from_pretrained(model_id, trust_remote_code=True)
model = AutoModel.from_pretrained(model_id, trust_remote_code=True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.eval()

# === 2. Prepare Dummy Input Data ===
batch_size = 1
dummy_inputs = {
    'rgb': torch.randn(batch_size, 3 * 4, 224, 224, device=device),
    'rgb_left': torch.randn(batch_size, 3, 224, 224, device=device),
    'rgb_right': torch.randn(batch_size, 3, 224, 224, device=device),
    'rgb_center': torch.randn(batch_size, 3, 224, 224, device=device),
    'lidar': torch.randn(batch_size, 3, 112, 112, device=device),
    'measurements': torch.randn(batch_size, 10, device=device),
    'target_point': torch.randn(batch_size, 2, device=device)
}

# === 3. Run Inference ===
with torch.no_grad():
    outputs = model(**dummy_inputs)

# === 4. Interpret the Outputs ===
traffic, waypoints, is_junc, light, stop, _ = outputs
print("Inference successful!")
print(f"Waypoints shape: {waypoints.shape}")
print(f"Traffic BEV map shape: {traffic.shape}")
Downloads last month
9
Inference Providers NEW
This model isn't deployed by any Inference Provider. ๐Ÿ™‹ Ask for provider support