"""Preprocessing utilities for MIRAI model.""" import torch import numpy as np from PIL import Image import cv2 from typing import Dict, List, Optional, Tuple, Union import pydicom import warnings from pathlib import Path class MiraiPreprocessor: """ Preprocessor for MIRAI mammography images. Handles loading, conversion, and normalization of mammogram images for the MIRAI model. """ # Standard mammogram views STANDARD_VIEWS = ['L-CC', 'L-MLO', 'R-CC', 'R-MLO'] VIEW_INDICES = {'L-CC': 0, 'L-MLO': 1, 'R-CC': 2, 'R-MLO': 3} def __init__( self, target_size: Tuple[int, int] = (1664, 2048), normalize: bool = True, device: str = 'cpu' ): """ Initialize the preprocessor. Args: target_size: Target image size (height, width) normalize: Whether to apply normalization device: Device for tensor operations """ self.target_size = target_size self.normalize = normalize self.device = device # ImageNet normalization parameters self.imagenet_mean = np.array([0.485, 0.456, 0.406]) self.imagenet_std = np.array([0.229, 0.224, 0.225]) # MIRAI-specific normalization (for raw pixel values) self.mirai_mean = 7047.99 self.mirai_std = 12005.5 def load_dicom(self, dicom_path: str) -> np.ndarray: """ Load and convert DICOM file to numpy array. Args: dicom_path: Path to DICOM file Returns: Numpy array of the image """ try: # Load DICOM dcm = pydicom.dcmread(dicom_path) # Get pixel array img = dcm.pixel_array.astype(float) # Apply DICOM windowing if available if hasattr(dcm, 'WindowCenter') and hasattr(dcm, 'WindowWidth'): center = float(dcm.WindowCenter) width = float(dcm.WindowWidth) img = self._apply_windowing(img, center, width) # Apply rescale if available if hasattr(dcm, 'RescaleSlope') and hasattr(dcm, 'RescaleIntercept'): img = img * float(dcm.RescaleSlope) + float(dcm.RescaleIntercept) # Normalize to 0-1 range img = (img - img.min()) / (img.max() - img.min() + 1e-8) return img except Exception as e: warnings.warn(f"Could not load DICOM file {dicom_path}: {e}") return None def load_image(self, image_path: str) -> np.ndarray: """ Load image from file (PNG, JPEG, or DICOM). Args: image_path: Path to image file Returns: Numpy array of the image """ path = Path(image_path) if path.suffix.lower() in ['.dcm', '.dicom']: return self.load_dicom(image_path) elif path.suffix.lower() in ['.png', '.jpg', '.jpeg', '.tif', '.tiff']: # Load with PIL img = Image.open(image_path) # Convert to grayscale if needed if img.mode != 'L': img = img.convert('L') # Convert to numpy img = np.array(img, dtype=np.float32) # Normalize to 0-1 range if img.max() > 1: img = img / img.max() return img else: raise ValueError(f"Unsupported file format: {path.suffix}") def _apply_windowing(self, img: np.ndarray, center: float, width: float) -> np.ndarray: """Apply DICOM windowing to image.""" lower = center - width / 2 upper = center + width / 2 img = np.clip(img, lower, upper) img = (img - lower) / (upper - lower) return img def resize_image(self, img: np.ndarray, target_size: Optional[Tuple[int, int]] = None) -> np.ndarray: """ Resize image to target size. Args: img: Input image array target_size: Target size (height, width) Returns: Resized image """ if target_size is None: target_size = self.target_size # Use cv2 for resizing resized = cv2.resize(img, (target_size[1], target_size[0]), interpolation=cv2.INTER_LINEAR) return resized def preprocess_single_image(self, img: np.ndarray) -> torch.Tensor: """ Preprocess a single mammogram image. Args: img: Input image as numpy array Returns: Preprocessed tensor """ # Resize if needed if img.shape[:2] != self.target_size: img = self.resize_image(img, self.target_size) # Convert to RGB (replicate grayscale to 3 channels) if len(img.shape) == 2: img = np.stack([img, img, img], axis=-1) # Convert to tensor img_tensor = torch.from_numpy(img).float() # Rearrange dimensions: HWC -> CHW img_tensor = img_tensor.permute(2, 0, 1) # Normalize if requested if self.normalize: # Normalize with ImageNet statistics for c in range(3): img_tensor[c] = (img_tensor[c] - self.imagenet_mean[c]) / self.imagenet_std[c] return img_tensor def load_mammogram_exam( self, paths: Dict[str, str], return_missing_views: bool = False ) -> Union[torch.Tensor, Tuple[torch.Tensor, List[str]]]: """ Load a complete mammogram exam with all 4 standard views. Args: paths: Dictionary mapping view names to file paths return_missing_views: Whether to return list of missing views Returns: Tensor of shape (3, 4, H, W) or tuple with tensor and missing views """ # Initialize output tensor exam_tensor = torch.zeros( 3, 4, self.target_size[0], self.target_size[1], dtype=torch.float32 ) missing_views = [] loaded_views = [] # Process each view for view_name in self.STANDARD_VIEWS: if view_name in paths and paths[view_name]: try: # Load and preprocess image img = self.load_image(paths[view_name]) if img is not None: img_tensor = self.preprocess_single_image(img) view_idx = self.VIEW_INDICES[view_name] exam_tensor[:, view_idx, :, :] = img_tensor loaded_views.append(view_name) else: missing_views.append(view_name) except Exception as e: warnings.warn(f"Could not load view {view_name}: {e}") missing_views.append(view_name) else: missing_views.append(view_name) # Warn if views are missing if missing_views: warnings.warn( f"Missing mammogram views: {missing_views}. " "Model performance may be degraded." ) if return_missing_views: return exam_tensor, missing_views return exam_tensor def prepare_risk_factors( self, risk_factor_dict: Optional[Dict[str, Union[float, int, bool]]] = None, use_defaults: bool = True ) -> torch.Tensor: """ Prepare risk factor tensor from dictionary. Args: risk_factor_dict: Dictionary of risk factors use_defaults: Whether to use default values for missing factors Returns: Risk factor tensor of shape (34,) """ # Default risk factor values (zeros for unknown) risk_factors = torch.zeros(34, dtype=torch.float32) if risk_factor_dict is None and use_defaults: # Set some reasonable defaults risk_factors[5] = 0.5 # Age (normalized) risk_factors[0] = 0.25 # Density (BI-RADS 2) return risk_factors if risk_factor_dict is not None: # Map risk factors to indices (simplified mapping) factor_mapping = { 'density': 0, 'family_history': 1, 'biopsy_benign': 2, 'biopsy_lcis': 3, 'biopsy_atypical': 4, 'age': 5, 'menarche_age': 6, 'menopause_age': 7, 'first_pregnancy_age': 8, 'prior_hist': 9, 'race': 10, 'parous': 11, 'menopausal_status': 12, 'weight': 13, 'height': 14, # ... additional mappings } for factor_name, idx in factor_mapping.items(): if factor_name in risk_factor_dict: value = risk_factor_dict[factor_name] # Normalize numerical values if factor_name == 'age': value = value / 100.0 # Normalize age to 0-1 elif factor_name in ['weight', 'height']: value = value / 200.0 # Normalize weight/height elif isinstance(value, bool): value = float(value) risk_factors[idx] = float(value) return risk_factors def create_batch( self, exam_list: List[Dict[str, str]], risk_factors_list: Optional[List[Dict]] = None ) -> Dict[str, torch.Tensor]: """ Create a batch of mammogram exams. Args: exam_list: List of exam path dictionaries risk_factors_list: Optional list of risk factor dictionaries Returns: Dictionary with 'images' and 'risk_factors' tensors """ batch_images = [] batch_risk_factors = [] for i, exam_paths in enumerate(exam_list): # Load exam exam_tensor = self.load_mammogram_exam(exam_paths) batch_images.append(exam_tensor.unsqueeze(0)) # Prepare risk factors rf_dict = risk_factors_list[i] if risk_factors_list else None rf_tensor = self.prepare_risk_factors(rf_dict) batch_risk_factors.append(rf_tensor.unsqueeze(0)) # Stack into batches images = torch.cat(batch_images, dim=0) risk_factors = torch.cat(batch_risk_factors, dim=0) # Create batch metadata batch_size = len(exam_list) batch_metadata = { 'time_seq': torch.zeros(batch_size, 4).long(), 'view_seq': torch.tensor([[0, 1, 2, 3]] * batch_size), 'side_seq': torch.tensor([[0, 0, 1, 1]] * batch_size), } return { 'images': images.to(self.device), 'risk_factors': risk_factors.to(self.device), 'batch_metadata': batch_metadata } def augment_image( self, img: np.ndarray, flip_horizontal: bool = False, flip_vertical: bool = False, rotate_angle: float = 0, brightness_factor: float = 1.0, contrast_factor: float = 1.0 ) -> np.ndarray: """ Apply data augmentation to mammogram image. Args: img: Input image flip_horizontal: Apply horizontal flip flip_vertical: Apply vertical flip rotate_angle: Rotation angle in degrees brightness_factor: Brightness adjustment factor contrast_factor: Contrast adjustment factor Returns: Augmented image """ img_aug = img.copy() # Flipping if flip_horizontal: img_aug = np.fliplr(img_aug) if flip_vertical: img_aug = np.flipud(img_aug) # Rotation if rotate_angle != 0: h, w = img_aug.shape[:2] center = (w // 2, h // 2) M = cv2.getRotationMatrix2D(center, rotate_angle, 1.0) img_aug = cv2.warpAffine(img_aug, M, (w, h)) # Brightness and contrast img_aug = img_aug * contrast_factor img_aug = img_aug + (brightness_factor - 1.0) img_aug = np.clip(img_aug, 0, 1) return img_aug