Changed Code for CUDA
Browse files- Model_Class.py +8 -5
- Model_Seg.py +3 -5
- app.py +6 -3
Model_Class.py
CHANGED
|
@@ -59,14 +59,15 @@ val_transforms_416x628 = Compose(
|
|
| 59 |
]
|
| 60 |
)
|
| 61 |
|
| 62 |
-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 63 |
checkpoint = torch.load("classification_model.ckpt", map_location=torch.device('cpu'))
|
| 64 |
-
model = ResNet()
|
| 65 |
model.load_state_dict(checkpoint["state_dict"])
|
| 66 |
model.eval()
|
| 67 |
|
| 68 |
|
| 69 |
-
def load_and_classify_image(image_path):
|
|
|
|
|
|
|
| 70 |
image = val_transforms_416x628(image_path)
|
| 71 |
image = image.unsqueeze(0).to(device)
|
| 72 |
|
|
@@ -76,8 +77,10 @@ def load_and_classify_image(image_path):
|
|
| 76 |
return prediction.to('cpu'), image.to('cpu')
|
| 77 |
|
| 78 |
|
| 79 |
-
def make_GradCAM(image):
|
| 80 |
|
|
|
|
|
|
|
| 81 |
model.eval()
|
| 82 |
target_layers = [model.model.layer4[-1]]
|
| 83 |
|
|
@@ -90,7 +93,7 @@ def make_GradCAM(image):
|
|
| 90 |
aug_smooth=False,
|
| 91 |
eigen_smooth=True,
|
| 92 |
)
|
| 93 |
-
grayscale_cam = grayscale_cam.squeeze()
|
| 94 |
|
| 95 |
jet = plt.colormaps.get_cmap("inferno")
|
| 96 |
newcolors = jet(np.linspace(0, 1, 256))
|
|
|
|
| 59 |
]
|
| 60 |
)
|
| 61 |
|
|
|
|
| 62 |
checkpoint = torch.load("classification_model.ckpt", map_location=torch.device('cpu'))
|
| 63 |
+
model = ResNet()
|
| 64 |
model.load_state_dict(checkpoint["state_dict"])
|
| 65 |
model.eval()
|
| 66 |
|
| 67 |
|
| 68 |
+
def load_and_classify_image(image_path, device):
|
| 69 |
+
|
| 70 |
+
model = model.to(device)
|
| 71 |
image = val_transforms_416x628(image_path)
|
| 72 |
image = image.unsqueeze(0).to(device)
|
| 73 |
|
|
|
|
| 77 |
return prediction.to('cpu'), image.to('cpu')
|
| 78 |
|
| 79 |
|
| 80 |
+
def make_GradCAM(image, device):
|
| 81 |
|
| 82 |
+
model = model.to(device)
|
| 83 |
+
image = image.to(device)
|
| 84 |
model.eval()
|
| 85 |
target_layers = [model.model.layer4[-1]]
|
| 86 |
|
|
|
|
| 93 |
aug_smooth=False,
|
| 94 |
eigen_smooth=True,
|
| 95 |
)
|
| 96 |
+
grayscale_cam = grayscale_cam.to('cpu').squeeze()
|
| 97 |
|
| 98 |
jet = plt.colormaps.get_cmap("inferno")
|
| 99 |
newcolors = jet(np.linspace(0, 1, 256))
|
Model_Seg.py
CHANGED
|
@@ -39,10 +39,8 @@ model = UNet(
|
|
| 39 |
num_res_units=3
|
| 40 |
)
|
| 41 |
|
| 42 |
-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 43 |
-
|
| 44 |
checkpoint_path = 'segmentation_model.pt'
|
| 45 |
-
checkpoint = torch.load(checkpoint_path, map_location=
|
| 46 |
assert model.state_dict().keys() == checkpoint['network'].keys(), "Model and checkpoint keys do not match"
|
| 47 |
|
| 48 |
model.load_state_dict(checkpoint['network'])
|
|
@@ -73,9 +71,9 @@ post_transforms = Compose([
|
|
| 73 |
|
| 74 |
|
| 75 |
|
| 76 |
-
def load_and_segment_image(input_image_path):
|
| 77 |
|
| 78 |
-
|
| 79 |
image_tensor = pre_transforms(input_image_path)
|
| 80 |
image_tensor = image_tensor.unsqueeze(0).to(device)
|
| 81 |
|
|
|
|
| 39 |
num_res_units=3
|
| 40 |
)
|
| 41 |
|
|
|
|
|
|
|
| 42 |
checkpoint_path = 'segmentation_model.pt'
|
| 43 |
+
checkpoint = torch.load(checkpoint_path, map_location='cpu')
|
| 44 |
assert model.state_dict().keys() == checkpoint['network'].keys(), "Model and checkpoint keys do not match"
|
| 45 |
|
| 46 |
model.load_state_dict(checkpoint['network'])
|
|
|
|
| 71 |
|
| 72 |
|
| 73 |
|
| 74 |
+
def load_and_segment_image(input_image_path, device):
|
| 75 |
|
| 76 |
+
model = model.to(device)
|
| 77 |
image_tensor = pre_transforms(input_image_path)
|
| 78 |
image_tensor = image_tensor.unsqueeze(0).to(device)
|
| 79 |
|
app.py
CHANGED
|
@@ -7,6 +7,9 @@ import SimpleITK as sitk
|
|
| 7 |
import torch
|
| 8 |
from numpy import uint8
|
| 9 |
import spaces
|
|
|
|
|
|
|
|
|
|
| 10 |
image_base64 = utils.image_to_base64("anatomy_aware_pipeline.png")
|
| 11 |
article_html = f"<img src='data:image/png;base64,{image_base64}' alt='Anatomical pipeline illustration' style='width:100%;'>"
|
| 12 |
|
|
@@ -64,7 +67,7 @@ def predict_image(input_image, input_file):
|
|
| 64 |
else:
|
| 65 |
return None , None , "Please input an image before pressing run" , None , None
|
| 66 |
|
| 67 |
-
image_mask = Model_Seg.load_and_segment_image(image_path)
|
| 68 |
|
| 69 |
overlay_image_np, original_image_np = utils.overlay_mask(image_path, image_mask)
|
| 70 |
|
|
@@ -75,10 +78,10 @@ def predict_image(input_image, input_file):
|
|
| 75 |
cropped_boxed_array = sitk.GetArrayFromImage(cropped_boxed_im)
|
| 76 |
cropped_boxed_array_disp = cropped_boxed_array.squeeze()
|
| 77 |
cropped_boxed_tensor = torch.Tensor(cropped_boxed_array)
|
| 78 |
-
prediction, image_transformed = Model_Class.load_and_classify_image(cropped_boxed_tensor)
|
| 79 |
|
| 80 |
|
| 81 |
-
gradcam = Model_Class.make_GradCAM(image_transformed)
|
| 82 |
|
| 83 |
nr_axSpA_prob = float(prediction[0].item())
|
| 84 |
r_axSpA_prob = float(prediction[1].item())
|
|
|
|
| 7 |
import torch
|
| 8 |
from numpy import uint8
|
| 9 |
import spaces
|
| 10 |
+
|
| 11 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 12 |
+
|
| 13 |
image_base64 = utils.image_to_base64("anatomy_aware_pipeline.png")
|
| 14 |
article_html = f"<img src='data:image/png;base64,{image_base64}' alt='Anatomical pipeline illustration' style='width:100%;'>"
|
| 15 |
|
|
|
|
| 67 |
else:
|
| 68 |
return None , None , "Please input an image before pressing run" , None , None
|
| 69 |
|
| 70 |
+
image_mask = Model_Seg.load_and_segment_image(image_path, device)
|
| 71 |
|
| 72 |
overlay_image_np, original_image_np = utils.overlay_mask(image_path, image_mask)
|
| 73 |
|
|
|
|
| 78 |
cropped_boxed_array = sitk.GetArrayFromImage(cropped_boxed_im)
|
| 79 |
cropped_boxed_array_disp = cropped_boxed_array.squeeze()
|
| 80 |
cropped_boxed_tensor = torch.Tensor(cropped_boxed_array)
|
| 81 |
+
prediction, image_transformed = Model_Class.load_and_classify_image(cropped_boxed_tensor, device)
|
| 82 |
|
| 83 |
|
| 84 |
+
gradcam = Model_Class.make_GradCAM(image_transformed, device)
|
| 85 |
|
| 86 |
nr_axSpA_prob = float(prediction[0].item())
|
| 87 |
r_axSpA_prob = float(prediction[1].item())
|