"""Image processor for Sybil CT scan preprocessing""" import cv2 import numpy as np import torch from typing import Dict, List, Optional, Union, Tuple from transformers.image_processing_utils import BaseImageProcessor, BatchFeature from transformers.utils import TensorType import pydicom from PIL import Image import torchio as tio def order_slices(dicoms: List) -> List: """Order DICOM slices by their position""" # Sort by ImagePositionPatient if available try: dicoms = sorted(dicoms, key=lambda x: float(x.ImagePositionPatient[2])) except (AttributeError, TypeError): # Fall back to InstanceNumber if ImagePositionPatient not available try: dicoms = sorted(dicoms, key=lambda x: int(x.InstanceNumber)) except (AttributeError, TypeError): pass # Keep original order if neither attribute is available return dicoms class SybilImageProcessor(BaseImageProcessor): """ Constructs a Sybil image processor for preprocessing CT scans. Args: voxel_spacing (`List[float]`, *optional*, defaults to `[0.703125, 0.703125, 2.5]`): Target voxel spacing for resampling (row, column, slice thickness). img_size (`List[int]`, *optional*, defaults to `[512, 512]`): Target image size after resizing. num_images (`int`, *optional*, defaults to `208`): Number of slices to use from the CT scan. windowing (`Dict[str, float]`, *optional*): Windowing parameters for CT scan visualization. Default uses lung window: center=-600, width=1500. normalize (`bool`, *optional*, defaults to `True`): Whether to normalize pixel values to [0, 1]. **kwargs: Additional keyword arguments passed to the parent class. """ model_input_names = ["pixel_values"] def __init__( self, voxel_spacing: List[float] = None, img_size: List[int] = None, num_images: int = 208, windowing: Dict[str, float] = None, normalize: bool = True, **kwargs ): super().__init__(**kwargs) self.voxel_spacing = voxel_spacing if voxel_spacing is not None else [0.703125, 0.703125, 2.5] self.img_size = img_size if img_size is not None else [512, 512] self.num_images = num_images # Default lung window settings self.windowing = windowing if windowing is not None else { "center": -600, "width": 1500 } self.normalize = normalize # TorchIO transforms for standardization self.resample_transform = tio.transforms.Resample(target=self.voxel_spacing) # Note: Original Sybil uses 200 depth, 256x256 images self.default_depth = 200 self.default_size = [256, 256] # TorchIO uses (H, W, D) ordering for target_shape, matching original Sybil self.padding_transform = tio.transforms.CropOrPad( target_shape=tuple(self.default_size + [self.default_depth]), # (256, 256, 200) padding_mode=0 ) def load_dicom_series(self, paths: List[str]) -> Tuple[np.ndarray, Dict]: """ Load a series of DICOM files. Args: paths: List of paths to DICOM files. Returns: Tuple of (volume array, metadata dict) """ dicoms = [] for path in paths: try: dcm = pydicom.dcmread(path, stop_before_pixels=False) dicoms.append(dcm) except Exception as e: print(f"Error reading DICOM file {path}: {e}") continue if not dicoms: raise ValueError("No valid DICOM files found") # Order slices by position dicoms = order_slices(dicoms) # Extract pixel arrays volume = np.stack([dcm.pixel_array.astype(np.float32) for dcm in dicoms]) # Extract metadata metadata = { "slice_thickness": float(dicoms[0].SliceThickness) if hasattr(dicoms[0], 'SliceThickness') else None, "pixel_spacing": list(map(float, dicoms[0].PixelSpacing)) if hasattr(dicoms[0], 'PixelSpacing') else None, "manufacturer": str(dicoms[0].Manufacturer) if hasattr(dicoms[0], 'Manufacturer') else None, "num_slices": len(dicoms) } # Apply rescale if present if hasattr(dicoms[0], 'RescaleSlope') and hasattr(dicoms[0], 'RescaleIntercept'): slope = float(dicoms[0].RescaleSlope) intercept = float(dicoms[0].RescaleIntercept) volume = volume * slope + intercept return volume, metadata def load_png_series(self, paths: List[str]) -> np.ndarray: """ Load a series of PNG files. Args: paths: List of paths to PNG files (must be in anatomical order). Returns: 3D volume array """ images = [] for path in paths: img = Image.open(path).convert('L') # Convert to grayscale images.append(np.array(img, dtype=np.float32)) return np.stack(images) def resize_slices(self, volume: np.ndarray, target_size: List[int] = None) -> np.ndarray: """ Resize each slice in the volume to target size using OpenCV bilinear interpolation. This exactly matches the original Sybil's per-slice 2D resize operation. Args: volume: 3D volume array (D, H, W). target_size: Target size [H, W]. Defaults to [256, 256]. Returns: Resized volume. """ if target_size is None: target_size = self.default_size # [256, 256] # Resize each slice using OpenCV (matching original Sybil exactly) resized_slices = [] for i in range(volume.shape[0]): slice_2d = volume[i] # Shape: (H, W) # cv2.resize expects dsize=(width, height), not (height, width)! resized = cv2.resize( slice_2d, dsize=(target_size[1], target_size[0]), # (W, H) interpolation=cv2.INTER_LINEAR ) resized_slices.append(resized) # Stack back into volume return np.stack(resized_slices, axis=0) def apply_windowing(self, volume: np.ndarray) -> np.ndarray: """ Apply DICOM-standard windowing to CT scan, matching the original Sybil implementation. This implements the same windowing as the original Sybil: - Uses DICOM standard formula with center-0.5 and width-1 adjustments - Outputs to 16-bit range [0, 65535] then divides by 256 for 8-bit parity - Results in [0, 255] range that will be normalized later Args: volume: 3D CT volume in Hounsfield Units. Returns: Windowed volume in [0, 255] range. """ center = self.windowing["center"] # -600 width = self.windowing["width"] # 1500 # DICOM standard windowing formula (matching original Sybil) bit_size = 16 y_min = 0 y_max = 2 ** bit_size - 1 # 65535 y_range = y_max - y_min # DICOM standard adjustments c = center - 0.5 # -600.5 w = width - 1 # 1499 # Calculate window boundaries lower_bound = c - w / 2 # -1350 upper_bound = c + w / 2 # 149.5 # Apply windowing with three regions below = volume <= lower_bound above = volume > upper_bound between = np.logical_and(~below, ~above) # Create output array windowed = np.zeros_like(volume, dtype=np.float32) # Apply windowing windowed[below] = y_min # Values <= -1350 -> 0 windowed[above] = y_max # Values > 149.5 -> 65535 if between.any(): # Linear interpolation for values in window windowed[between] = ((volume[between] - c) / w + 0.5) * y_range + y_min # Divide by 256 for 8-bit parity (matching original Sybil) # This gives range [0, 255] instead of [0, 65535] windowed = windowed // 256 return windowed def resample_volume( self, volume: torch.Tensor, original_spacing: Optional[List[float]] = None ) -> torch.Tensor: """ Resample volume to target voxel spacing. Uses affine matrix approach matching original Sybil exactly. Args: volume: 3D or 4D volume tensor (D, H, W) or (C, D, H, W). original_spacing: Original voxel spacing [H_spacing, W_spacing, D_spacing]. Returns: Resampled volume with same number of dimensions. """ # Handle both 3D (D, H, W) and 4D (C, D, H, W) volumes if len(volume.shape) == 3: # Single channel: (D, H, W) -> (1, D, H, W) volume_4d = volume.unsqueeze(0) squeeze_output = True elif len(volume.shape) == 4: # Multi-channel: (C, D, H, W) - already has channel dim volume_4d = volume squeeze_output = False else: raise ValueError(f"Expected 3D or 4D volume, got shape {volume.shape}") # Permute to TorchIO format: (C, D, H, W) -> (C, H, W, D) volume_tio = volume_4d.permute(0, 2, 3, 1) # Create affine matrix like original Sybil # Original uses torch.diag(voxel_spacing) where voxel_spacing has 4 elements if original_spacing is not None: # Add 1.0 as 4th element like original Sybil voxel_spacing_4d = torch.tensor(original_spacing + [1.0], dtype=torch.float32) affine = torch.diag(voxel_spacing_4d) else: affine = None # Create TorchIO subject with affine (not spacing!) subject = tio.Subject( image=tio.ScalarImage(tensor=volume_tio, affine=affine) ) # Apply resampling resampled = self.resample_transform(subject) # Permute back: (C, H, W, D) -> (C, D, H, W) result = resampled['image'].data.permute(0, 3, 1, 2) # Return with original number of dimensions if squeeze_output: return result.squeeze(0) else: return result def pad_or_crop_volume(self, volume: torch.Tensor) -> torch.Tensor: """ Pad or crop volume to target shape. Args: volume: 3D or 4D volume tensor (D, H, W) or (C, D, H, W). Returns: Padded/cropped volume with same number of dimensions. """ # Handle both 3D (D, H, W) and 4D (C, D, H, W) volumes if len(volume.shape) == 3: # Single channel: (D, H, W) -> (1, D, H, W) volume_4d = volume.unsqueeze(0) squeeze_output = True elif len(volume.shape) == 4: # Multi-channel: (C, D, H, W) - already has channel dim volume_4d = volume squeeze_output = False else: raise ValueError(f"Expected 3D or 4D volume, got shape {volume.shape}") # Permute to TorchIO format: (C, D, H, W) -> (C, H, W, D) volume_tio = volume_4d.permute(0, 2, 3, 1) # Create TorchIO subject subject = tio.Subject( image=tio.ScalarImage(tensor=volume_tio) ) # Apply padding/cropping transformed = self.padding_transform(subject) # Permute back: (C, H, W, D) -> (C, D, H, W) result = transformed['image'].data.permute(0, 3, 1, 2) # Return with original number of dimensions if squeeze_output: return result.squeeze(0) else: return result def preprocess( self, images: Union[List[str], np.ndarray, torch.Tensor], file_type: str = "dicom", voxel_spacing: Optional[List[float]] = None, return_tensors: Optional[Union[str, TensorType]] = None, **kwargs ) -> BatchFeature: """ Preprocess CT scan images. Args: images: Either list of file paths or numpy/torch array of images. file_type: Type of input files ("dicom" or "png"). voxel_spacing: Original voxel spacing (required for PNG files). return_tensors: The type of tensors to return. Returns: BatchFeature with preprocessed images. """ # Load images if paths are provided if isinstance(images, list) and isinstance(images[0], str): if file_type == "dicom": volume, metadata = self.load_dicom_series(images) if voxel_spacing is None and metadata["pixel_spacing"]: voxel_spacing = metadata["pixel_spacing"] + [metadata["slice_thickness"]] elif file_type == "png": if voxel_spacing is None: raise ValueError("voxel_spacing must be provided for PNG files") volume = self.load_png_series(images) else: raise ValueError(f"Unknown file type: {file_type}") elif isinstance(images, (np.ndarray, torch.Tensor)): volume = images else: raise ValueError("Images must be file paths, numpy array, or torch tensor") # Ensure volume is numpy array for initial processing if isinstance(volume, torch.Tensor): volume_np = volume.numpy() else: volume_np = volume # Apply windowing volume_np = self.apply_windowing(volume_np) # Resize each slice to 256x256 (matching original Sybil's per-slice resize) volume_np = self.resize_slices(volume_np, target_size=self.default_size) # NOTE: Original Sybil uses the ORIGINAL voxel spacing from DICOM metadata # even after resizing slices. This is physically incorrect (spacing should be # adjusted for the resize factor), but we match the original behavior here. # The voxel_spacing remains unchanged from DICOM metadata. # Convert to torch tensor for remaining operations volume = torch.from_numpy(volume_np).float() # Apply normalization BEFORE resampling (to match original Sybil) # Original Sybil normalizes each slice before assembly and 3D resampling # This ensures 3D interpolation happens on normalized values, not [0, 255] values # These values come from the original Sybil implementation's computed mean/std # on 8-bit windowed images [0, 255] img_mean = 128.1722 img_std = 87.1849 volume = (volume - img_mean) / img_std # Replicate to 3 channels BEFORE resampling (to match original Sybil) # Original Sybil replicates channels per-slice, then assembles 3-channel volume # Shape: (D, H, W) -> (3, D, H, W) volume = volume.unsqueeze(0).repeat(3, 1, 1, 1) # Now (3, D, H, W) # Resample if spacing is provided (3D resampling for voxel spacing adjustment) # This happens on 3-channel volume, matching original Sybil if voxel_spacing is not None: volume = self.resample_volume(volume, voxel_spacing) # Pad or crop to target shape (on 3-channel volume) volume = self.pad_or_crop_volume(volume) # Add batch dimension to match original Sybil output shape [1, C, D, H, W] volume = volume.unsqueeze(0) # Now (1, 3, D, H, W) # Prepare output data = {"pixel_values": volume} # Convert to requested tensor type if return_tensors == "pt": return BatchFeature(data=data, tensor_type=TensorType.PYTORCH) elif return_tensors == "np": data = {k: v.numpy() for k, v in data.items()} return BatchFeature(data=data, tensor_type=TensorType.NUMPY) else: return BatchFeature(data=data) def __call__( self, images: Union[List[str], List[List[str]], np.ndarray, torch.Tensor], **kwargs ) -> BatchFeature: """ Main method to prepare images for the model. Args: images: Images to preprocess. Can be: - List of file paths for a single series - List of lists of file paths for multiple series - Numpy array or torch tensor Returns: BatchFeature with preprocessed images ready for model input. """ # Handle batch processing if isinstance(images, list) and images and isinstance(images[0], list): # Multiple series batch_volumes = [] for series_paths in images: result = self.preprocess(series_paths, **kwargs) batch_volumes.append(result["pixel_values"]) # Stack into batch (B, C, D, H, W) pixel_values = torch.stack(batch_volumes) return BatchFeature(data={"pixel_values": pixel_values}) else: # Single series return self.preprocess(images, **kwargs)