Paras Shah
commited on
Commit
·
0466118
1
Parent(s):
0d17b56
Add cache optimization
Browse files
app.py
CHANGED
|
@@ -14,11 +14,17 @@ from SingleTreePointCloudLoader import SingleTreePointCloudLoader
|
|
| 14 |
gc.enable()
|
| 15 |
|
| 16 |
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
|
| 23 |
side_bg = "static/sidebar.png"
|
| 24 |
side_bg_ext = "png"
|
|
@@ -160,6 +166,7 @@ if uploaded_file:
|
|
| 160 |
proceed = st.button("Run model")
|
| 161 |
except Exception as e:
|
| 162 |
st.error(f"An error occured: {str(e)}")
|
|
|
|
| 163 |
|
| 164 |
if proceed:
|
| 165 |
try:
|
|
@@ -259,6 +266,7 @@ if proceed:
|
|
| 259 |
st.write(f"**Height of tree: {height:.2f}m**")
|
| 260 |
st.write(f"**Canopy volume: {canopy_volume:.2f}m\u00b3**")
|
| 261 |
st.write(f"**DBH: {dbh:.2f}m**")
|
|
|
|
| 262 |
|
| 263 |
except Exception as e:
|
| 264 |
st.error(f"An error occured: {str(e)}")
|
|
|
|
| 14 |
gc.enable()
|
| 15 |
|
| 16 |
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 17 |
+
|
| 18 |
+
@st.cache_resource
|
| 19 |
+
def load_model():
|
| 20 |
+
with st.spinner("Loading PointNet++ model..."):
|
| 21 |
+
checkpoint = torch.load('checkpoints/best_model.pth', map_location=torch.device(device))
|
| 22 |
+
classifier = pn2.get_model(num_class=4, normal_channel=False)
|
| 23 |
+
classifier.load_state_dict(checkpoint['model_state_dict'])
|
| 24 |
+
classifier.eval()
|
| 25 |
+
return classifier
|
| 26 |
+
|
| 27 |
+
classifier = load_model()
|
| 28 |
|
| 29 |
side_bg = "static/sidebar.png"
|
| 30 |
side_bg_ext = "png"
|
|
|
|
| 166 |
proceed = st.button("Run model")
|
| 167 |
except Exception as e:
|
| 168 |
st.error(f"An error occured: {str(e)}")
|
| 169 |
+
gc.collect() # Optimize after file is loaded
|
| 170 |
|
| 171 |
if proceed:
|
| 172 |
try:
|
|
|
|
| 266 |
st.write(f"**Height of tree: {height:.2f}m**")
|
| 267 |
st.write(f"**Canopy volume: {canopy_volume:.2f}m\u00b3**")
|
| 268 |
st.write(f"**DBH: {dbh:.2f}m**")
|
| 269 |
+
gc.collect() # Optimize after inference is done
|
| 270 |
|
| 271 |
except Exception as e:
|
| 272 |
st.error(f"An error occured: {str(e)}")
|