AutoTune Models
Trained models from the AutoTune hyperparameter optimization study for astronomical transient classification.
Models
| Model | Architecture | Checkpoint | Val AUC |
|---|---|---|---|
| autotune_btsbot_optuna_asha | DeiT3 | DeiT3-epoch=14-val_auc=0.9999.ckpt |
0.9999 |
| autotune_btsbot_optuna_fifo | DeiT3 | DeiT3-epoch=15-val_auc=0.9995.ckpt |
0.9995 |
| autotune_btsbot_optuna_hyperband | DeiT3 | DeiT3-epoch=17-val_auc=0.9999.ckpt |
0.9999 |
| autotune_btsbot_optuna_median | DeiT3 | DeiT3-epoch=19-val_auc=0.9996.ckpt |
0.9996 |
| autotune_btsbot_optuna_pb2 | DeiT | DeiT-epoch=19-val_auc=0.9692.ckpt |
0.9692 |
| autotune_btsbot_optuna_pbt | CaiT | CaiT-epoch=19-val_auc=0.9954.ckpt |
0.9954 |
| autotune_btsbot_random_asha | DeiT | DeiT-epoch=14-val_auc=0.9995.ckpt |
0.9995 |
| autotune_btsbot_random_fifo | DeiT3 | DeiT3-epoch=13-val_auc=0.9998.ckpt |
0.9998 |
| autotune_btsbot_random_hyperband | CaiT | CaiT-epoch=14-val_auc=0.9997.ckpt |
0.9997 |
| autotune_btsbot_random_median | DeiT | DeiT-epoch=14-val_auc=0.9963.ckpt |
0.9963 |
Usage
Load from Hugging Face Hub
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
import timm
# Download model weights
model_path = hf_hub_download(
repo_id="parlange/autotune-models",
filename="autotune_btsbot_optuna_asha/model.safetensors"
)
# Load weights
state_dict = load_file(model_path)
# Create model architecture (DeiT3 example)
model = timm.create_model("deit3_base_patch16_224", pretrained=False, num_classes=2)
model.load_state_dict(state_dict, strict=False)
model.eval()
Google Cloud / Colab
!pip install huggingface_hub safetensors timm torch torchvision
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
import timm
import torch
from torchvision import transforms
from PIL import Image
# Download model
model_path = hf_hub_download(
repo_id="parlange/autotune-models",
filename="autotune_btsbot_optuna_asha/model.safetensors"
)
# Load model
state_dict = load_file(model_path)
model = timm.create_model("deit3_base_patch16_224", pretrained=False, num_classes=2)
model.load_state_dict(state_dict, strict=False)
model.eval()
# Inference
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# Load your triplet image (3-channel: science, reference, difference)
# image = Image.open("triplet.png").convert("RGB")
# input_tensor = transform(image).unsqueeze(0)
# with torch.no_grad():
# output = model(input_tensor)
# prediction = torch.softmax(output, dim=1)
# print(f"Real probability: {prediction[0, 1]:.4f}")
Load Lightning Checkpoint
import torch
checkpoint = torch.load("checkpoint.ckpt", map_location="cpu")
state_dict = checkpoint["state_dict"]
# Remove 'model.' prefix if present
state_dict = {k.replace("model.", ""): v for k, v in state_dict.items()}
HPO Study Details
These models were trained using Ray Tune with 10 HPO strategies (8 search algorithm + scheduler combinations, plus 2 population-based methods):
Search Algorithms + Schedulers:
- Optuna + ASHA: Bayesian optimization with aggressive early stopping
- Optuna + FIFO: Bayesian optimization, all trials run to completion
- Optuna + HyperBand: Bayesian optimization with HyperBand scheduling
- Optuna + Median: Bayesian optimization with median stopping rule
- Random + ASHA: Random search with ASHA early stopping
- Random + FIFO: Random search, all trials run to completion
- Random + HyperBand: Random search with HyperBand scheduling
- Random + Median: Random search with median stopping rule
Population-Based Methods:
- PBT: Population Based Training
- PB2: Population Based Bandits
Datasets
Models are trained on two astronomical transient classification datasets:
- BTSBot: Real/bogus classification of transients from the Zwicky Transient Facility (ZTF)
- DES-SN: Supernova classification from the Dark Energy Survey Supernova Program (DES-SN)
Citation
If you use these models, please cite:
- AutoTune Repository: https://github.com/parlange/autotune
- BTSBot Dataset: MultimodalUniverse/btsbot
- DES-SN Dataset: parlange/dark-energy-survey-supernova