File size: 612 Bytes
44d3940
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28

import torch
import torchvision

from torch import nn

def create_vit_b16(num_classes: int):
  """
  Creates a ViT model and return the model with its transforms
  """

  vit_weights = torchvision.models.ViT_B_16_Weights.DEFAULT

  vit_b16_model = torchvision.models.vit_b_16(weights=vit_weights)

  vit_transforms = vit_weights.transforms()

  # freeze all the layers
  for param in vit_b16_model.parameters():
    param.requires_grad = False
  
  # changing the head
  vit_b16_model.heads = nn.Sequential(
      nn.Linear(in_features=768, out_features=num_classes)
  )

  return vit_b16_model, vit_transforms