Update app.py
Browse files
app.py
CHANGED
|
@@ -25,6 +25,9 @@ def process_image(image, prompt, threshold, alpha_value, draw_rectangles):
|
|
| 25 |
|
| 26 |
pred = torch.sigmoid(preds)
|
| 27 |
mat = pred.cpu().numpy()
|
|
|
|
|
|
|
|
|
|
| 28 |
mask = Image.fromarray(np.uint8(mat * 255), "L")
|
| 29 |
mask = mask.convert("RGB")
|
| 30 |
mask = mask.resize(image.size)
|
|
@@ -37,7 +40,6 @@ def process_image(image, prompt, threshold, alpha_value, draw_rectangles):
|
|
| 37 |
|
| 38 |
# threshold the mask
|
| 39 |
bmask = mask > threshold
|
| 40 |
-
# zero out values below the threshold
|
| 41 |
mask[mask < threshold] = 0
|
| 42 |
|
| 43 |
fig, ax = plt.subplots()
|
|
@@ -74,6 +76,7 @@ def process_image(image, prompt, threshold, alpha_value, draw_rectangles):
|
|
| 74 |
|
| 75 |
return fig, result_mask, result_output
|
| 76 |
|
|
|
|
| 77 |
# Existing process_image function, copy it here
|
| 78 |
# ...
|
| 79 |
|
|
|
|
| 25 |
|
| 26 |
pred = torch.sigmoid(preds)
|
| 27 |
mat = pred.cpu().numpy()
|
| 28 |
+
|
| 29 |
+
# Ensure we are working with a single-channel 2D mask
|
| 30 |
+
mat = np.squeeze(mat, axis=0) # Remove batch dimension if it exists
|
| 31 |
mask = Image.fromarray(np.uint8(mat * 255), "L")
|
| 32 |
mask = mask.convert("RGB")
|
| 33 |
mask = mask.resize(image.size)
|
|
|
|
| 40 |
|
| 41 |
# threshold the mask
|
| 42 |
bmask = mask > threshold
|
|
|
|
| 43 |
mask[mask < threshold] = 0
|
| 44 |
|
| 45 |
fig, ax = plt.subplots()
|
|
|
|
| 76 |
|
| 77 |
return fig, result_mask, result_output
|
| 78 |
|
| 79 |
+
|
| 80 |
# Existing process_image function, copy it here
|
| 81 |
# ...
|
| 82 |
|