Shriharsh commited on
Commit
4a739e5
ยท
verified ยท
1 Parent(s): 6a44a88

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +58 -46
app.py CHANGED
@@ -2,70 +2,84 @@
2
  import gradio as gr
3
  import os
4
  import torch
5
-
6
  from model import create_effnetb2_model
7
  from timeit import default_timer as timer
8
  from typing import Tuple, Dict
9
 
10
  # Setup class names
11
- with open("class_names.txt", "r") as f: # reading them in from class_names.txt
12
- class_names = [food_name.strip() for food_name in f.readlines()]
 
 
 
13
 
14
  ### 2. Model and transforms preparation ###
15
 
16
  # Create model
17
- effnetb2, effnetb2_transforms = create_effnetb2_model(
18
- num_classes=101, # could also use len(class_names)
19
- )
 
 
 
20
 
21
  # Load saved weights
22
- effnetb2.load_state_dict(
23
- torch.load(
24
- f="09_pretrained_effnetb2_feature_extractor_food101.pth",
25
- map_location=torch.device("cpu"), # load to CPU
 
 
26
  )
27
- )
 
 
 
28
 
29
  ### 3. Predict function ###
30
 
31
- # Create predict function
32
  def predict(img) -> Tuple[Dict, float]:
33
- """Transforms and performs a prediction on img and returns prediction and time taken.
34
- """
35
- # Start the timer
36
- start_time = timer()
37
-
38
- # Transform the target image and add a batch dimension
39
- img = effnetb2_transforms(img).unsqueeze(0)
40
-
41
- # Put model into evaluation mode and turn on inference mode
42
- effnetb2.eval()
43
- with torch.inference_mode():
44
- # Pass the transformed image through the model and turn the prediction logits into prediction probabilities
45
- pred_probs = torch.softmax(effnetb2(img), dim=1)
46
-
47
- # Create a prediction label and prediction probability dictionary for each prediction class (this is the required format for Gradio's output parameter)
48
- pred_labels_and_probs = {
49
- class_names[i]: float(pred_probs[0][i]) for i in range(len(class_names))
50
- }
51
-
52
- # Calculate the prediction time
53
- pred_time = round(timer() - start_time, 5)
54
-
55
- # Return the prediction dictionary and prediction time
56
- return pred_labels_and_probs, pred_time
57
-
 
 
 
58
 
59
  ### 4. Gradio app ###
60
 
61
- # Create title, description and article strings
62
  title = "FoodVision 101 ๐Ÿ”๐Ÿ‘"
63
- description = "An EfficientNetB2 feature extractor computer vision model to classify images of food into [101 different classes]"
64
- #(https://github.com/mrdbourke/pytorch-deep-learning/blob/main/extras/food101_class_names.txt)."
65
- #article = "Created at [09. PyTorch Model Deployment](https://www.learnpytorch.io/09_pytorch_model_deployment/)."
66
 
67
  # Create examples list from "examples/" directory
68
- example_list = [["examples/" + example] for example in os.listdir("examples")]
 
 
 
 
69
 
70
  # Create Gradio interface
71
  demo = gr.Interface(
@@ -78,9 +92,7 @@ demo = gr.Interface(
78
  examples=example_list,
79
  title=title,
80
  description=description,
81
- #article=article,
82
  )
83
 
84
- # Launch the app!
85
- demo.launch()
86
-
 
2
  import gradio as gr
3
  import os
4
  import torch
 
5
  from model import create_effnetb2_model
6
  from timeit import default_timer as timer
7
  from typing import Tuple, Dict
8
 
9
  # Setup class names
10
+ try:
11
+ with open("class_names.txt", "r") as f: # reading them in from class_names.txt
12
+ class_names = [food_name.strip() for food_name in f.readlines()]
13
+ except FileNotFoundError:
14
+ raise FileNotFoundError("class_names.txt not found. Ensure it exists in the root directory.")
15
 
16
  ### 2. Model and transforms preparation ###
17
 
18
  # Create model
19
+ try:
20
+ effnetb2, effnetb2_transforms = create_effnetb2_model(
21
+ num_classes=101, # could also use len(class_names)
22
+ )
23
+ except Exception as e:
24
+ raise Exception(f"Error creating model: {str(e)}")
25
 
26
  # Load saved weights
27
+ try:
28
+ effnetb2.load_state_dict(
29
+ torch.load(
30
+ f="09_pretrained_effnetb2_feature_extractor_food101.pth",
31
+ map_location=torch.device("cpu"), # load to CPU
32
+ )
33
  )
34
+ except FileNotFoundError:
35
+ raise FileNotFoundError("Model weights file '09_pretrained_effnetb2_feature_extractor_food101.pth' not found.")
36
+ except Exception as e:
37
+ raise Exception(f"Error loading model weights: {str(e)}")
38
 
39
  ### 3. Predict function ###
40
 
 
41
  def predict(img) -> Tuple[Dict, float]:
42
+ """Transforms and performs a prediction on img and returns prediction and time taken."""
43
+ try:
44
+ # Start the timer
45
+ start_time = timer()
46
+
47
+ # Transform the target image and add a batch dimension
48
+ if img is None:
49
+ raise ValueError("Input image is None. Please provide a valid image.")
50
+ img = effnetb2_transforms(img).unsqueeze(0)
51
+
52
+ # Put model into evaluation mode and turn on inference mode
53
+ effnetb2.eval()
54
+ with torch.inference_mode():
55
+ # Pass the transformed image through the model and turn the prediction logits into prediction probabilities
56
+ pred_probs = torch.softmax(effnetb2(img), dim=1)
57
+
58
+ # Create a prediction label and prediction probability dictionary for each prediction class
59
+ pred_labels_and_probs = {
60
+ class_names[i]: float(pred_probs[0][i]) for i in range(len(class_names))
61
+ }
62
+
63
+ # Calculate the prediction time
64
+ pred_time = round(timer() - start_time, 5)
65
+
66
+ # Return the prediction dictionary and prediction time
67
+ return pred_labels_and_probs, pred_time
68
+ except Exception as e:
69
+ return {"error": f"Prediction failed: {str(e)}"}, 0.0
70
 
71
  ### 4. Gradio app ###
72
 
73
+ # Create title, description
74
  title = "FoodVision 101 ๐Ÿ”๐Ÿ‘"
75
+ description = "An EfficientNetB2 feature extractor computer vision model to classify images of food into 101 different classes."
 
 
76
 
77
  # Create examples list from "examples/" directory
78
+ try:
79
+ example_list = [["examples/" + example] for example in os.listdir("examples")]
80
+ except FileNotFoundError:
81
+ example_list = []
82
+ print("Warning: 'examples/' directory not found. No example images will be loaded.")
83
 
84
  # Create Gradio interface
85
  demo = gr.Interface(
 
92
  examples=example_list,
93
  title=title,
94
  description=description,
 
95
  )
96
 
97
+ # Launch the app with share=True for Hugging Face Spaces
98
+ demo.launch(share=True)