|
--- |
|
license: apache-2.0 |
|
language: |
|
- en |
|
library_name: pytorch |
|
tags: |
|
- autonomous-driving |
|
- multi-modal |
|
- transformer |
|
- interfuser |
|
--- |
|
|
|
# InterFuser: Multi-modal Transformer for Autonomous Driving |
|
|
|
This is a Hugging Face repository for the InterFuser model, designed for end-to-end autonomous driving tasks. It processes multi-view camera images and LiDAR data to predict driving waypoints and perceive the surrounding traffic environment. |
|
|
|
This model was trained by [Your Name or Alias]. |
|
|
|
## Model Architecture |
|
|
|
The model, `Interfuser`, is a Transformer-based architecture that: |
|
- Uses a CNN backbone (ResNet-50 for RGB, ResNet-18 for LiDAR) to extract features. |
|
- A Transformer Encoder fuses these multi-modal features. |
|
- A Transformer Decoder predicts waypoints, junction status, and a Bird's-Eye-View (BEV) map of traffic. |
|
|
|
## How to Use |
|
|
|
First, make sure you have `timm` installed: `pip install timm`. |
|
|
|
You can then use the model with `AutoModel`. Remember to pass `trust_remote_code=True`. |
|
|
|
```python |
|
import torch |
|
from transformers import AutoModel, AutoConfig |
|
|
|
# === 1. Load Model from Hugging Face Hub === |
|
model_id = "your-username/interfuser-driving-model" # Replace with your REPO_ID |
|
config = AutoConfig.from_pretrained(model_id, trust_remote_code=True) |
|
model = AutoModel.from_pretrained(model_id, trust_remote_code=True) |
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
model.to(device) |
|
model.eval() |
|
|
|
# === 2. Prepare Dummy Input Data === |
|
batch_size = 1 |
|
dummy_inputs = { |
|
'rgb': torch.randn(batch_size, 3 * 4, 224, 224, device=device), |
|
'rgb_left': torch.randn(batch_size, 3, 224, 224, device=device), |
|
'rgb_right': torch.randn(batch_size, 3, 224, 224, device=device), |
|
'rgb_center': torch.randn(batch_size, 3, 224, 224, device=device), |
|
'lidar': torch.randn(batch_size, 3, 112, 112, device=device), |
|
'measurements': torch.randn(batch_size, 10, device=device), |
|
'target_point': torch.randn(batch_size, 2, device=device) |
|
} |
|
|
|
# === 3. Run Inference === |
|
with torch.no_grad(): |
|
outputs = model(**dummy_inputs) |
|
|
|
# === 4. Interpret the Outputs === |
|
traffic, waypoints, is_junc, light, stop, _ = outputs |
|
print("Inference successful!") |
|
print(f"Waypoints shape: {waypoints.shape}") |
|
print(f"Traffic BEV map shape: {traffic.shape}") |
|
|