blanchon's picture
Update
3ed0796
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
dependencies = ["torch", "timm", "einops"]
import os
from typing import Dict, Any, Optional, Union, List
import warnings
import torch
from torch.hub import load_state_dict_from_url
from timm.models import clean_state_dict
from .radio.adaptor_registry import adaptor_registry
from .radio.common import DEFAULT_VERSION, RadioResource, RESOURCE_MAP
from .radio.enable_damp import configure_damp_from_args
from .radio.enable_spectral_reparam import disable_spectral_reparam, configure_spectral_reparam_from_args
from .radio.feature_normalizer import FeatureNormalizer, IntermediateFeatureNormalizer
from .radio.radio_model import RADIOModel, create_model_from_args
from .radio.input_conditioner import get_default_conditioner
from .radio.vitdet import apply_vitdet_arch, VitDetArgs
def radio_model(
version: str = "",
progress: bool = True,
adaptor_names: Union[str, List[str]] = None,
vitdet_window_size: Optional[int] = None,
return_checkpoint: bool = False,
support_packing: bool=False,
**kwargs,
) -> RADIOModel:
if not version:
version = DEFAULT_VERSION
if os.path.isfile(version):
chk = torch.load(version, map_location="cpu", weights_only=False)
resource = RadioResource(version, patch_size=None, max_resolution=None, preferred_resolution=None)
else:
resource = RESOURCE_MAP[version]
chk = load_state_dict_from_url(
resource.url, progress=progress, map_location="cpu", weights_only=False,
)
if "state_dict_ema" in chk:
state_dict = chk["state_dict_ema"]
chk['args'].spectral_reparam = False
else:
state_dict = chk["state_dict"]
args = chk["args"]
args.support_packing = support_packing
mod = create_model_from_args(args)
mod_state_dict = get_prefix_state_dict(state_dict, "base_model.")
if args.spectral_reparam:
configure_spectral_reparam_from_args(mod, args, state_dict_guidance=mod_state_dict)
if getattr(args, 'damp', None):
configure_damp_from_args(mod, args)
state_dict = clean_state_dict(state_dict)
key_warn = mod.load_state_dict(mod_state_dict, strict=False)
if key_warn.missing_keys:
warnings.warn(f'Missing keys in state dict: {key_warn.missing_keys}')
if key_warn.unexpected_keys:
warnings.warn(f'Unexpected keys in state dict: {key_warn.unexpected_keys}')
if chk['args'].spectral_reparam:
# Spectral reparametrization uses PyTorch's "parametrizations" API. The idea behind
# the method is that instead of there being a `weight` tensor for certain Linear layers
# in the model, we make it a dynamically computed function. During training, this
# helps stabilize the model. However, for downstream use cases, it shouldn't be necessary.
# Disabling it in this context means that instead of having `w' = f(w)`, we just compute `w' = f(w)`
# once, during this function call, and replace the parametrization with the realized weights.
# This makes the model run faster, and also use less memory.
disable_spectral_reparam(mod)
chk['args'].spectral_reparam = False
conditioner = get_default_conditioner()
conditioner.load_state_dict(get_prefix_state_dict(state_dict, "input_conditioner."))
dtype = getattr(chk['args'], 'dtype', torch.float32)
mod.to(dtype=dtype)
conditioner.dtype = dtype
cls_token_per_teacher = getattr(chk['args'], 'cls_token_per_teacher', True)
if cls_token_per_teacher:
name_to_idx_map = dict()
for i, t in enumerate(chk['args'].teachers):
if t.get('use_summary', True):
name = t['name']
if name not in name_to_idx_map:
name_to_idx_map[name] = i
summary_idxs = torch.tensor(sorted(name_to_idx_map.values()), dtype=torch.int64)
else:
summary_idxs = torch.tensor([0], dtype=torch.int64)
if adaptor_names is None:
adaptor_names = []
elif isinstance(adaptor_names, str):
adaptor_names = [adaptor_names]
teachers = chk["args"].teachers
adaptors = dict()
for adaptor_name in adaptor_names:
for tidx, tconf in enumerate(teachers):
if tconf["name"] == adaptor_name:
break
else:
raise ValueError(f'Unable to find the specified adaptor name. Known names: {list(t["name"] for t in teachers)}')
ttype = tconf["type"]
pf_idx_head = f'_heads.{tidx}'
pf_name_head = f'_heads.{adaptor_name}'
pf_idx_feat = f'_feature_projections.{tidx}'
pf_name_feat = f'_feature_projections.{adaptor_name}'
adaptor_state = dict()
for k, v in state_dict.items():
if k.startswith(pf_idx_head):
adaptor_state['summary' + k[len(pf_idx_head):]] = v
elif k.startswith(pf_name_head):
adaptor_state['summary' + k[len(pf_name_head):]] = v
elif k.startswith(pf_idx_feat):
adaptor_state['feature' + k[len(pf_idx_feat):]] = v
elif k.startswith(pf_name_feat):
adaptor_state['feature' + k[len(pf_name_feat):]] = v
adaptor = adaptor_registry.create_adaptor(ttype, chk["args"], tconf, adaptor_state)
adaptor.head_idx = tidx if cls_token_per_teacher else 0
adaptors[adaptor_name] = adaptor
feat_norm_sd = get_prefix_state_dict(state_dict, '_feature_normalizer.')
feature_normalizer = None
if feat_norm_sd:
feature_normalizer = FeatureNormalizer(feat_norm_sd['mean'].shape[0], dtype=dtype)
feature_normalizer.load_state_dict(feat_norm_sd)
inter_feat_norm_sd = get_prefix_state_dict(state_dict, '_intermediate_feature_normalizer.')
inter_feature_normalizer = None
if inter_feat_norm_sd:
inter_feature_normalizer = IntermediateFeatureNormalizer(
*inter_feat_norm_sd['means'].shape[:2],
rot_per_layer=inter_feat_norm_sd['rotation'].ndim == 3,
dtype=dtype
)
inter_feature_normalizer.load_state_dict(inter_feat_norm_sd)
radio = RADIOModel(
mod,
conditioner,
summary_idxs=summary_idxs,
patch_size=resource.patch_size,
max_resolution=resource.max_resolution,
window_size=vitdet_window_size,
preferred_resolution=resource.preferred_resolution,
adaptors=adaptors,
feature_normalizer=feature_normalizer,
inter_feature_normalizer=inter_feature_normalizer,
)
if vitdet_window_size is not None:
apply_vitdet_arch(
mod,
VitDetArgs(
vitdet_window_size,
radio.num_summary_tokens,
num_windowed=resource.vitdet_num_windowed,
num_global=resource.vitdet_num_global,
),
)
if return_checkpoint:
return radio, chk
return radio
def get_prefix_state_dict(state_dict: Dict[str, Any], prefix: str):
mod_state_dict = {
k[len(prefix) :]: v for k, v in state_dict.items() if k.startswith(prefix)
}
return mod_state_dict