Spaces:
Build error
Build error
| import torch | |
| import numpy as np | |
| from monai.transforms import Compose, LoadImage, EnsureChannelFirst, Lambda, Resize, NormalizeIntensity, GaussianSmooth, ScaleIntensity, AsDiscrete, KeepLargestConnectedComponent, Invert, Rotate90, SaveImage, Transform | |
| from monai.inferers import SlidingWindowInferer | |
| from monai.networks.nets import UNet | |
| class RgbaToGrayscale(Transform): | |
| def __call__(self, x): | |
| # squeeze last dimension, to ensure C, H, W format | |
| x = x.squeeze(-1) | |
| # Ensure the tensor is 3D (channels, height, width) | |
| if x.ndim != 3: | |
| raise ValueError(f"Input tensor must be 3D. Shape: {x.shape}") | |
| # Check the number of channels | |
| if x.shape[0] == 4: # Assuming RGBA | |
| rgb_weights = torch.tensor([0.2989, 0.5870, 0.1140], device=x.device) | |
| # Apply weights to RGB channels, output should retain one channel dimension | |
| grayscale = torch.einsum('cwh,c->wh', x[:3, :, :], rgb_weights).unsqueeze(0) | |
| elif x.shape[0] == 3: # Assuming RGB | |
| rgb_weights = torch.tensor([0.2989, 0.5870, 0.1140], device=x.device) | |
| grayscale = torch.einsum('cwh,c->wh', x, rgb_weights).unsqueeze(0) | |
| elif x.shape[0] == 1: # Already grayscale | |
| grayscale = x | |
| else: | |
| raise ValueError(f"Unsupported channel number: {x.shape[0]}") | |
| return grayscale | |
| def inverse(self, x): | |
| # Simply return the input as the output | |
| return x | |
| model = UNet( | |
| spatial_dims=2, | |
| in_channels=1, | |
| out_channels=4, | |
| channels=[64, 128, 256, 512], | |
| strides=[2, 2, 2], | |
| num_res_units=3 | |
| ) | |
| checkpoint_path = 'segmentation_model.pt' | |
| checkpoint = torch.load(checkpoint_path, map_location='cpu') | |
| assert model.state_dict().keys() == checkpoint['network'].keys(), "Model and checkpoint keys do not match" | |
| model.load_state_dict(checkpoint['network']) | |
| model.eval() | |
| # Define transforms for preprocessing | |
| pre_transforms = Compose([ | |
| LoadImage(image_only=True), | |
| EnsureChannelFirst(), | |
| RgbaToGrayscale(), # Convert RGBA to grayscale | |
| Resize(spatial_size=(768, 768)), | |
| Lambda(func=lambda x: x.squeeze(-1)), # Adjust if the input image has an extra unwanted dimension | |
| NormalizeIntensity(), | |
| GaussianSmooth(sigma=0.1), | |
| ScaleIntensity(minv=-1, maxv=1) | |
| ]) | |
| # Define transforms for postprocessing | |
| post_transforms = Compose([ | |
| AsDiscrete(argmax=True, to_onehot=4), | |
| KeepLargestConnectedComponent(), | |
| AsDiscrete(argmax=True), | |
| Invert(pre_transforms), | |
| #SaveImage(output_dir='./', output_postfix='seg', output_ext='.nii', resample=False) | |
| ]) | |
| def load_and_segment_image(input_image_path, device): | |
| image_tensor = pre_transforms(input_image_path) | |
| image_tensor = image_tensor.unsqueeze(0).to(device) | |
| # Inference using SlidingWindowInferer | |
| inferer = SlidingWindowInferer(roi_size=(512, 512), sw_batch_size=16, overlap=0.75) | |
| with torch.no_grad(): | |
| outputs = inferer(image_tensor, model.to(device)) | |
| outputs = outputs.squeeze(0) | |
| processed_outputs = post_transforms(outputs) | |
| # rotate | |
| rotate = Rotate90(spatial_axes=(0, 1), k=3) | |
| processed_outputs = rotate(processed_outputs).to('cpu') | |
| output_array = processed_outputs.squeeze().detach().numpy().astype(np.uint8) | |
| return output_array |