obichimav commited on
Commit
aa30a8e
·
verified ·
1 Parent(s): e7f857c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -10
app.py CHANGED
@@ -124,7 +124,6 @@
124
  # if __name__ == "__main__":
125
  # demo.launch(share=True, height=900)
126
 
127
-
128
  import gradio as gr
129
  import rasterio
130
  import os
@@ -135,10 +134,10 @@ from SemanticModel.model_core import SegmentationModel
135
  from SemanticModel.prediction import PredictionPipeline
136
 
137
  from segmentation_models_pytorch.decoders.unet.model import Unet
138
-
139
- # Add this before loading your model
140
  from segmentation_models_pytorch.encoders.timm_regnet import RegNetEncoder
141
- torch.serialization.add_safe_globals([Unet, RegNetEncoder])
 
 
142
 
143
  COLORS = [
144
  (0, 0, 0), # Background (black)
@@ -180,12 +179,14 @@ def predict_and_show(input_file):
180
  img = src.read()
181
  img = np.moveaxis(img, 0, 2)[:,:,:3]
182
  img_normalized = np.clip(img/img.max(), 0, 1)
183
- model = SegmentationModel(
184
- classes=['bg', 'cacao', 'matarraton', 'abarco'],
185
- architecture='unet',
186
- encoder='timm-regnety_120',
187
- weights='models/CacaoShadeClassification.pth'
188
- )
 
 
189
  predictor = PredictionPipeline(model)
190
  pred, _ = predictor.predict_raster(input_file.name, tile_size=1024)
191
  colored_pred = style_prediction(pred)
@@ -273,3 +274,4 @@ with gr.Blocks(theme=gr.themes.Default()) as demo:
273
  if __name__ == "__main__":
274
  demo.launch(share=True, height=900)
275
 
 
 
124
  # if __name__ == "__main__":
125
  # demo.launch(share=True, height=900)
126
 
 
127
  import gradio as gr
128
  import rasterio
129
  import os
 
134
  from SemanticModel.prediction import PredictionPipeline
135
 
136
  from segmentation_models_pytorch.decoders.unet.model import Unet
 
 
137
  from segmentation_models_pytorch.encoders.timm_regnet import RegNetEncoder
138
+
139
+ # Previously we added safe globals globally, now we'll use the context manager approach.
140
+ # torch.serialization.add_safe_globals([Unet, RegNetEncoder])
141
 
142
  COLORS = [
143
  (0, 0, 0), # Background (black)
 
179
  img = src.read()
180
  img = np.moveaxis(img, 0, 2)[:,:,:3]
181
  img_normalized = np.clip(img/img.max(), 0, 1)
182
+ # Use the safe_globals context manager when loading the model
183
+ with torch.serialization.safe_globals([RegNetEncoder]):
184
+ model = SegmentationModel(
185
+ classes=['bg', 'cacao', 'matarraton', 'abarco'],
186
+ architecture='unet',
187
+ encoder='timm-regnety_120',
188
+ weights='models/CacaoShadeClassification.pth'
189
+ )
190
  predictor = PredictionPipeline(model)
191
  pred, _ = predictor.predict_raster(input_file.name, tile_size=1024)
192
  colored_pred = style_prediction(pred)
 
274
  if __name__ == "__main__":
275
  demo.launch(share=True, height=900)
276
 
277
+