--- license: apache-2.0 language: - en library_name: pytorch tags: - autonomous-driving - multi-modal - transformer - interfuser --- # InterFuser: Multi-modal Transformer for Autonomous Driving This is a Hugging Face repository for the InterFuser model, designed for end-to-end autonomous driving tasks. It processes multi-view camera images and LiDAR data to predict driving waypoints and perceive the surrounding traffic environment. This model was trained by [Your Name or Alias]. ## Model Architecture The model, `Interfuser`, is a Transformer-based architecture that: - Uses a CNN backbone (ResNet-50 for RGB, ResNet-18 for LiDAR) to extract features. - A Transformer Encoder fuses these multi-modal features. - A Transformer Decoder predicts waypoints, junction status, and a Bird's-Eye-View (BEV) map of traffic. ## How to Use First, make sure you have `timm` installed: `pip install timm`. You can then use the model with `AutoModel`. Remember to pass `trust_remote_code=True`. ```python import torch from transformers import AutoModel, AutoConfig # === 1. Load Model from Hugging Face Hub === model_id = "your-username/interfuser-driving-model" # Replace with your REPO_ID config = AutoConfig.from_pretrained(model_id, trust_remote_code=True) model = AutoModel.from_pretrained(model_id, trust_remote_code=True) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.to(device) model.eval() # === 2. Prepare Dummy Input Data === batch_size = 1 dummy_inputs = { 'rgb': torch.randn(batch_size, 3 * 4, 224, 224, device=device), 'rgb_left': torch.randn(batch_size, 3, 224, 224, device=device), 'rgb_right': torch.randn(batch_size, 3, 224, 224, device=device), 'rgb_center': torch.randn(batch_size, 3, 224, 224, device=device), 'lidar': torch.randn(batch_size, 3, 112, 112, device=device), 'measurements': torch.randn(batch_size, 10, device=device), 'target_point': torch.randn(batch_size, 2, device=device) } # === 3. Run Inference === with torch.no_grad(): outputs = model(**dummy_inputs) # === 4. Interpret the Outputs === traffic, waypoints, is_junc, light, stop, _ = outputs print("Inference successful!") print(f"Waypoints shape: {waypoints.shape}") print(f"Traffic BEV map shape: {traffic.shape}")