fbrynpk commited on
Commit
b418fb1
·
1 Parent(s): 40201f8

Initial Commit

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ pretrained_vit_foodvision.pth filter=lfs diff=lfs merge=lfs -text
__pycache__/model.cpython-311.pyc ADDED
Binary file (1.35 kB). View file
 
app.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ### Imports and class names setup ---------------------------------------------------- ###
2
+ import os
3
+ import torch
4
+ import torchvision
5
+ import gradio as gr
6
+
7
+ from model import create_vit
8
+ from timeit import default_timer as timer
9
+ from typing import Tuple, Dict
10
+
11
+ # Setup class names
12
+ class_names = ["pizza", "steak", "sushi"]
13
+
14
+ # Device agnostic code
15
+ if torch.backends.mps.is_available():
16
+ device = "mps"
17
+ elif torch.cuda.is_available():
18
+ device = "cuda"
19
+ else:
20
+ device = "cpu"
21
+
22
+ ### Model and transforms preparation ---------------------------------------------------- ###
23
+ vit_model, vit_transforms = create_vit(
24
+ pretrained_weights=torchvision.models.ViT_B_16_Weights.DEFAULT,
25
+ model=torchvision.models.vit_b_16,
26
+ in_features=768,
27
+ out_features=3,
28
+ device="cpu",
29
+ )
30
+
31
+ # Load save weights
32
+ vit_model.load_state_dict(
33
+ torch.load(f="pretrained_vit_foodvision.pth", map_location=torch.device("cpu"))
34
+ ) # load the model to the CPU
35
+
36
+
37
+ ### Predict function ---------------------------------------------------- ###
38
+ def predict(img) -> Tuple[Dict, float]:
39
+ # Start a timer
40
+ start_time = timer()
41
+ # Transform the input image for use with ViT Model
42
+ img = vit_transforms(img).unsqueeze(
43
+ 0
44
+ ) # unsqueeze = add batch dimension on 0th index (3, 224, 224) into (1, 3, 224, 224)
45
+ # Put model into eval mode, make prediction
46
+ vit_model.eval()
47
+ with torch.inference_mode():
48
+ # Pass transformed image through the model and turn the prediction logits into probabilities
49
+ pred_logits = vit_model(img)
50
+ pred_probs = torch.softmax(pred_logits, dim=1)
51
+ # Create a prediction label and prediction probability dictionary
52
+ pred_labels_and_probs = {
53
+ class_names[i]: float(pred_probs[0][i]) for i in range(len(class_names))
54
+ }
55
+
56
+ # Calculate pred time
57
+ end_timer = timer()
58
+ pred_time = round(end_timer - start_time, 4)
59
+
60
+ # Return pred dict and pred time
61
+ return pred_labels_and_probs, pred_time
62
+
63
+
64
+ ### Gradio interface and launch ------------------------------------------------------------------ ###
65
+
66
+ # Create title and description
67
+ title = "FoodVision Mini: ViT Model"
68
+ description = "A ViT model trained on 20% of the Food101 dataset to classify images of pizza, steak or sushi."
69
+
70
+ # Create example list
71
+ example_list = [["examples/" + example] for example in os.listdir("examples")]
72
+
73
+ # Create the Gradio demo
74
+ demo = gr.Interface(
75
+ fn=predict,
76
+ inputs=gr.Image(type="pil"),
77
+ outputs=[
78
+ gr.Label(num_top_classes=3, label="Predictions"),
79
+ gr.Number(label="Prediction time(s)"),
80
+ ],
81
+ title=title,
82
+ description=description,
83
+ examples=example_list,
84
+ )
85
+ demo.launch(
86
+ debug=False, share=True # print errors locally
87
+ ) # generate a publically shareable URL
examples/3177743.jpg ADDED
examples/61656.jpg ADDED
examples/730464.jpg ADDED
model.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ import torchvision
4
+
5
+ from torch import nn
6
+
7
+ def create_vit(pretrained_weights: torchvision.models.Weights,
8
+ model: torchvision.models,
9
+ in_features: int,
10
+ out_features: int,
11
+ device: torch.device):
12
+ """Creates a Vision Transformer (ViT) instance from torchvision
13
+ and returns it.
14
+ """
15
+ # Create a pretrained ViT model
16
+ model = torchvision.models.vit_b_16(weights=pretrained_weights).to(device)
17
+ transforms = pretrained_weights.transforms()
18
+
19
+ # Freeze the feature extractor
20
+ for param in model.parameters():
21
+ param.requires_grad = False
22
+
23
+ # Change the head of the ViT
24
+ model.heads = nn.Sequential(
25
+ nn.Linear(in_features=in_features, out_features=out_features)
26
+ ).to(device)
27
+
28
+ return model, transforms
pretrained_vit_foodvision.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0c9067188086ff537cdb76de31c205acb865cca5ee0a25ed2d7ffe6b05376fea
3
+ size 343264485
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+
2
+ torch==2.0.1
3
+ torchvision==0.15.2
4
+ gradio==3.23.0