AutoTune Models

Trained models from the AutoTune hyperparameter optimization study for astronomical transient classification.

Models

Model Architecture Checkpoint Val AUC
autotune_btsbot_optuna_asha DeiT3 DeiT3-epoch=14-val_auc=0.9999.ckpt 0.9999
autotune_btsbot_optuna_fifo DeiT3 DeiT3-epoch=15-val_auc=0.9995.ckpt 0.9995
autotune_btsbot_optuna_hyperband DeiT3 DeiT3-epoch=17-val_auc=0.9999.ckpt 0.9999
autotune_btsbot_optuna_median DeiT3 DeiT3-epoch=19-val_auc=0.9996.ckpt 0.9996
autotune_btsbot_optuna_pb2 DeiT DeiT-epoch=19-val_auc=0.9692.ckpt 0.9692
autotune_btsbot_optuna_pbt CaiT CaiT-epoch=19-val_auc=0.9954.ckpt 0.9954
autotune_btsbot_random_asha DeiT DeiT-epoch=14-val_auc=0.9995.ckpt 0.9995
autotune_btsbot_random_fifo DeiT3 DeiT3-epoch=13-val_auc=0.9998.ckpt 0.9998
autotune_btsbot_random_hyperband CaiT CaiT-epoch=14-val_auc=0.9997.ckpt 0.9997
autotune_btsbot_random_median DeiT DeiT-epoch=14-val_auc=0.9963.ckpt 0.9963

Usage

Load from Hugging Face Hub

from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
import timm

# Download model weights
model_path = hf_hub_download(
    repo_id="parlange/autotune-models",
    filename="autotune_btsbot_optuna_asha/model.safetensors"
)

# Load weights
state_dict = load_file(model_path)

# Create model architecture (DeiT3 example)
model = timm.create_model("deit3_base_patch16_224", pretrained=False, num_classes=2)
model.load_state_dict(state_dict, strict=False)
model.eval()

Google Cloud / Colab

!pip install huggingface_hub safetensors timm torch torchvision

from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
import timm
import torch
from torchvision import transforms
from PIL import Image

# Download model
model_path = hf_hub_download(
    repo_id="parlange/autotune-models",
    filename="autotune_btsbot_optuna_asha/model.safetensors"
)

# Load model
state_dict = load_file(model_path)
model = timm.create_model("deit3_base_patch16_224", pretrained=False, num_classes=2)
model.load_state_dict(state_dict, strict=False)
model.eval()

# Inference
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Load your triplet image (3-channel: science, reference, difference)
# image = Image.open("triplet.png").convert("RGB")
# input_tensor = transform(image).unsqueeze(0)
# with torch.no_grad():
#     output = model(input_tensor)
#     prediction = torch.softmax(output, dim=1)
#     print(f"Real probability: {prediction[0, 1]:.4f}")

Load Lightning Checkpoint

import torch

checkpoint = torch.load("checkpoint.ckpt", map_location="cpu")
state_dict = checkpoint["state_dict"]

# Remove 'model.' prefix if present
state_dict = {k.replace("model.", ""): v for k, v in state_dict.items()}

HPO Study Details

These models were trained using Ray Tune with 10 HPO strategies (8 search algorithm + scheduler combinations, plus 2 population-based methods):

Search Algorithms + Schedulers:

  • Optuna + ASHA: Bayesian optimization with aggressive early stopping
  • Optuna + FIFO: Bayesian optimization, all trials run to completion
  • Optuna + HyperBand: Bayesian optimization with HyperBand scheduling
  • Optuna + Median: Bayesian optimization with median stopping rule
  • Random + ASHA: Random search with ASHA early stopping
  • Random + FIFO: Random search, all trials run to completion
  • Random + HyperBand: Random search with HyperBand scheduling
  • Random + Median: Random search with median stopping rule

Population-Based Methods:

  • PBT: Population Based Training
  • PB2: Population Based Bandits

Datasets

Models are trained on two astronomical transient classification datasets:

  • BTSBot: Real/bogus classification of transients from the Zwicky Transient Facility (ZTF)
  • DES-SN: Supernova classification from the Dark Energy Survey Supernova Program (DES-SN)

Citation

If you use these models, please cite:

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support