import torch import torchvision from torch import nn def create_vit(pretrained_weights: torchvision.models.Weights, model: torchvision.models, in_features: int, out_features: int, device: torch.device): """Creates a Vision Transformer (ViT) instance from torchvision and returns it. """ # Create a pretrained ViT model model = torchvision.models.vit_b_16(weights=pretrained_weights).to(device) transforms = pretrained_weights.transforms() # Freeze the feature extractor for param in model.parameters(): param.requires_grad = False # Change the head of the ViT model.heads = nn.Sequential( nn.Linear(in_features=in_features, out_features=out_features) ).to(device) return model, transforms