|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dependencies = ["torch", "timm", "einops"] |
|
|
|
import os |
|
from typing import Dict, Any, Optional, Union, List |
|
import warnings |
|
|
|
import torch |
|
from torch.hub import load_state_dict_from_url |
|
|
|
from timm.models import clean_state_dict |
|
|
|
from .radio.adaptor_registry import adaptor_registry |
|
from .radio.common import DEFAULT_VERSION, RadioResource, RESOURCE_MAP |
|
from .radio.enable_damp import configure_damp_from_args |
|
from .radio.enable_spectral_reparam import disable_spectral_reparam, configure_spectral_reparam_from_args |
|
from .radio.feature_normalizer import FeatureNormalizer, IntermediateFeatureNormalizer |
|
from .radio.radio_model import RADIOModel, create_model_from_args |
|
from .radio.input_conditioner import get_default_conditioner |
|
from .radio.vitdet import apply_vitdet_arch, VitDetArgs |
|
|
|
|
|
def radio_model( |
|
version: str = "", |
|
progress: bool = True, |
|
adaptor_names: Union[str, List[str]] = None, |
|
vitdet_window_size: Optional[int] = None, |
|
return_checkpoint: bool = False, |
|
support_packing: bool=False, |
|
**kwargs, |
|
) -> RADIOModel: |
|
if not version: |
|
version = DEFAULT_VERSION |
|
|
|
if os.path.isfile(version): |
|
chk = torch.load(version, map_location="cpu", weights_only=False) |
|
resource = RadioResource(version, patch_size=None, max_resolution=None, preferred_resolution=None) |
|
else: |
|
resource = RESOURCE_MAP[version] |
|
chk = load_state_dict_from_url( |
|
resource.url, progress=progress, map_location="cpu", weights_only=False, |
|
) |
|
|
|
if "state_dict_ema" in chk: |
|
state_dict = chk["state_dict_ema"] |
|
chk['args'].spectral_reparam = False |
|
else: |
|
state_dict = chk["state_dict"] |
|
|
|
args = chk["args"] |
|
args.support_packing = support_packing |
|
mod = create_model_from_args(args) |
|
|
|
mod_state_dict = get_prefix_state_dict(state_dict, "base_model.") |
|
|
|
if args.spectral_reparam: |
|
configure_spectral_reparam_from_args(mod, args, state_dict_guidance=mod_state_dict) |
|
|
|
if getattr(args, 'damp', None): |
|
configure_damp_from_args(mod, args) |
|
|
|
state_dict = clean_state_dict(state_dict) |
|
|
|
key_warn = mod.load_state_dict(mod_state_dict, strict=False) |
|
if key_warn.missing_keys: |
|
warnings.warn(f'Missing keys in state dict: {key_warn.missing_keys}') |
|
if key_warn.unexpected_keys: |
|
warnings.warn(f'Unexpected keys in state dict: {key_warn.unexpected_keys}') |
|
|
|
if chk['args'].spectral_reparam: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
disable_spectral_reparam(mod) |
|
chk['args'].spectral_reparam = False |
|
|
|
conditioner = get_default_conditioner() |
|
conditioner.load_state_dict(get_prefix_state_dict(state_dict, "input_conditioner.")) |
|
|
|
dtype = getattr(chk['args'], 'dtype', torch.float32) |
|
mod.to(dtype=dtype) |
|
conditioner.dtype = dtype |
|
|
|
cls_token_per_teacher = getattr(chk['args'], 'cls_token_per_teacher', True) |
|
if cls_token_per_teacher: |
|
name_to_idx_map = dict() |
|
for i, t in enumerate(chk['args'].teachers): |
|
if t.get('use_summary', True): |
|
name = t['name'] |
|
if name not in name_to_idx_map: |
|
name_to_idx_map[name] = i |
|
summary_idxs = torch.tensor(sorted(name_to_idx_map.values()), dtype=torch.int64) |
|
else: |
|
summary_idxs = torch.tensor([0], dtype=torch.int64) |
|
|
|
if adaptor_names is None: |
|
adaptor_names = [] |
|
elif isinstance(adaptor_names, str): |
|
adaptor_names = [adaptor_names] |
|
|
|
teachers = chk["args"].teachers |
|
adaptors = dict() |
|
for adaptor_name in adaptor_names: |
|
for tidx, tconf in enumerate(teachers): |
|
if tconf["name"] == adaptor_name: |
|
break |
|
else: |
|
raise ValueError(f'Unable to find the specified adaptor name. Known names: {list(t["name"] for t in teachers)}') |
|
|
|
ttype = tconf["type"] |
|
|
|
pf_idx_head = f'_heads.{tidx}' |
|
pf_name_head = f'_heads.{adaptor_name}' |
|
pf_idx_feat = f'_feature_projections.{tidx}' |
|
pf_name_feat = f'_feature_projections.{adaptor_name}' |
|
|
|
adaptor_state = dict() |
|
for k, v in state_dict.items(): |
|
if k.startswith(pf_idx_head): |
|
adaptor_state['summary' + k[len(pf_idx_head):]] = v |
|
elif k.startswith(pf_name_head): |
|
adaptor_state['summary' + k[len(pf_name_head):]] = v |
|
elif k.startswith(pf_idx_feat): |
|
adaptor_state['feature' + k[len(pf_idx_feat):]] = v |
|
elif k.startswith(pf_name_feat): |
|
adaptor_state['feature' + k[len(pf_name_feat):]] = v |
|
|
|
adaptor = adaptor_registry.create_adaptor(ttype, chk["args"], tconf, adaptor_state) |
|
adaptor.head_idx = tidx if cls_token_per_teacher else 0 |
|
adaptors[adaptor_name] = adaptor |
|
|
|
feat_norm_sd = get_prefix_state_dict(state_dict, '_feature_normalizer.') |
|
feature_normalizer = None |
|
if feat_norm_sd: |
|
feature_normalizer = FeatureNormalizer(feat_norm_sd['mean'].shape[0], dtype=dtype) |
|
feature_normalizer.load_state_dict(feat_norm_sd) |
|
|
|
inter_feat_norm_sd = get_prefix_state_dict(state_dict, '_intermediate_feature_normalizer.') |
|
inter_feature_normalizer = None |
|
if inter_feat_norm_sd: |
|
inter_feature_normalizer = IntermediateFeatureNormalizer( |
|
*inter_feat_norm_sd['means'].shape[:2], |
|
rot_per_layer=inter_feat_norm_sd['rotation'].ndim == 3, |
|
dtype=dtype |
|
) |
|
inter_feature_normalizer.load_state_dict(inter_feat_norm_sd) |
|
|
|
radio = RADIOModel( |
|
mod, |
|
conditioner, |
|
summary_idxs=summary_idxs, |
|
patch_size=resource.patch_size, |
|
max_resolution=resource.max_resolution, |
|
window_size=vitdet_window_size, |
|
preferred_resolution=resource.preferred_resolution, |
|
adaptors=adaptors, |
|
feature_normalizer=feature_normalizer, |
|
inter_feature_normalizer=inter_feature_normalizer, |
|
) |
|
|
|
if vitdet_window_size is not None: |
|
apply_vitdet_arch( |
|
mod, |
|
VitDetArgs( |
|
vitdet_window_size, |
|
radio.num_summary_tokens, |
|
num_windowed=resource.vitdet_num_windowed, |
|
num_global=resource.vitdet_num_global, |
|
), |
|
) |
|
|
|
if return_checkpoint: |
|
return radio, chk |
|
return radio |
|
|
|
|
|
def get_prefix_state_dict(state_dict: Dict[str, Any], prefix: str): |
|
mod_state_dict = { |
|
k[len(prefix) :]: v for k, v in state_dict.items() if k.startswith(prefix) |
|
} |
|
return mod_state_dict |
|
|