alignedthreeattn / utils /
huzey's picture
""" OpenAI pretrained model functions
Adapted from Originally MIT License, Copyright (c) 2021 OpenAI.
import os
import warnings
from typing import List, Optional, Union
import torch
from utils.constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
from utils.model import build_model_from_openai_state_dict, get_cast_dtype
from utils.pretrained import *
__all__ = ["list_openai_models", "load_openai_model"]
def list_openai_models() -> List[str]:
"""Returns the names of available CLIP models"""
return list_pretrained_models_by_tag('openai')
def load_openai_model(
name: str,
precision: Optional[str] = None,
device: Optional[Union[str, torch.device]] = None,
cache_dir: Optional[str] = None,
"""Load a CLIP model
name : str
A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
precision: str
Model precision, if None defaults to 'fp32' if device == 'cpu' else 'fp16'.
device : Union[str, torch.device]
The device to put the loaded model
cache_dir : Optional[str]
The directory to cache the downloaded model weights
model : torch.nn.Module
The CLIP model
preprocess : Callable[[PIL.Image], torch.Tensor]
A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
if precision is None:
precision = 'fp32' if device == 'cpu' else 'fp16'
if get_pretrained_url(name, 'openai'):
model_path = download_pretrained_from_url(get_pretrained_url(name, 'openai'), cache_dir=cache_dir)
elif os.path.isfile(name):
model_path = name
raise RuntimeError(f"Model {name} not found; available models = {list_openai_models()}")
# loading JIT archive
model = torch.jit.load(model_path, map_location="cpu").eval()
state_dict = None
except RuntimeError:
# loading saved state dict
state_dict = torch.load(model_path, map_location="cpu")
# Build a non-jit model from the OpenAI jitted model state dict
cast_dtype = get_cast_dtype(precision)
model = build_model_from_openai_state_dict(state_dict or model.state_dict(), cast_dtype=cast_dtype)
except KeyError:
sd = {k[7:]: v for k, v in state_dict["state_dict"].items()}
model = build_model_from_openai_state_dict(sd, cast_dtype=cast_dtype)
# model from OpenAI state dict is in manually cast fp16 mode, must be converted for AMP/fp32/bf16 use
model =
# FIXME support pure fp16/bf16 precision modes
if precision != 'fp16':
if precision == 'bf16':
# for bf16, convert back to low-precision
convert_weights_to_lp(model, dtype=torch.bfloat16)
# add mean / std attributes for consistency with OpenCLIP models
model.visual.image_mean = OPENAI_DATASET_MEAN
model.visual.image_std = OPENAI_DATASET_STD
return model