Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -124,7 +124,6 @@
|
|
124 |
# if __name__ == "__main__":
|
125 |
# demo.launch(share=True, height=900)
|
126 |
|
127 |
-
|
128 |
import gradio as gr
|
129 |
import rasterio
|
130 |
import os
|
@@ -135,10 +134,10 @@ from SemanticModel.model_core import SegmentationModel
|
|
135 |
from SemanticModel.prediction import PredictionPipeline
|
136 |
|
137 |
from segmentation_models_pytorch.decoders.unet.model import Unet
|
138 |
-
|
139 |
-
# Add this before loading your model
|
140 |
from segmentation_models_pytorch.encoders.timm_regnet import RegNetEncoder
|
141 |
-
|
|
|
|
|
142 |
|
143 |
COLORS = [
|
144 |
(0, 0, 0), # Background (black)
|
@@ -180,12 +179,14 @@ def predict_and_show(input_file):
|
|
180 |
img = src.read()
|
181 |
img = np.moveaxis(img, 0, 2)[:,:,:3]
|
182 |
img_normalized = np.clip(img/img.max(), 0, 1)
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
|
|
|
|
189 |
predictor = PredictionPipeline(model)
|
190 |
pred, _ = predictor.predict_raster(input_file.name, tile_size=1024)
|
191 |
colored_pred = style_prediction(pred)
|
@@ -273,3 +274,4 @@ with gr.Blocks(theme=gr.themes.Default()) as demo:
|
|
273 |
if __name__ == "__main__":
|
274 |
demo.launch(share=True, height=900)
|
275 |
|
|
|
|
124 |
# if __name__ == "__main__":
|
125 |
# demo.launch(share=True, height=900)
|
126 |
|
|
|
127 |
import gradio as gr
|
128 |
import rasterio
|
129 |
import os
|
|
|
134 |
from SemanticModel.prediction import PredictionPipeline
|
135 |
|
136 |
from segmentation_models_pytorch.decoders.unet.model import Unet
|
|
|
|
|
137 |
from segmentation_models_pytorch.encoders.timm_regnet import RegNetEncoder
|
138 |
+
|
139 |
+
# Previously we added safe globals globally, now we'll use the context manager approach.
|
140 |
+
# torch.serialization.add_safe_globals([Unet, RegNetEncoder])
|
141 |
|
142 |
COLORS = [
|
143 |
(0, 0, 0), # Background (black)
|
|
|
179 |
img = src.read()
|
180 |
img = np.moveaxis(img, 0, 2)[:,:,:3]
|
181 |
img_normalized = np.clip(img/img.max(), 0, 1)
|
182 |
+
# Use the safe_globals context manager when loading the model
|
183 |
+
with torch.serialization.safe_globals([RegNetEncoder]):
|
184 |
+
model = SegmentationModel(
|
185 |
+
classes=['bg', 'cacao', 'matarraton', 'abarco'],
|
186 |
+
architecture='unet',
|
187 |
+
encoder='timm-regnety_120',
|
188 |
+
weights='models/CacaoShadeClassification.pth'
|
189 |
+
)
|
190 |
predictor = PredictionPipeline(model)
|
191 |
pred, _ = predictor.predict_raster(input_file.name, tile_size=1024)
|
192 |
colored_pred = style_prediction(pred)
|
|
|
274 |
if __name__ == "__main__":
|
275 |
demo.launch(share=True, height=900)
|
276 |
|
277 |
+
|