|  | """Preprocessing utilities for MIRAI model.""" | 
					
						
						|  |  | 
					
						
						|  | import torch | 
					
						
						|  | import numpy as np | 
					
						
						|  | from PIL import Image | 
					
						
						|  | import cv2 | 
					
						
						|  | from typing import Dict, List, Optional, Tuple, Union | 
					
						
						|  | import pydicom | 
					
						
						|  | import warnings | 
					
						
						|  | from pathlib import Path | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class MiraiPreprocessor: | 
					
						
						|  | """ | 
					
						
						|  | Preprocessor for MIRAI mammography images. | 
					
						
						|  |  | 
					
						
						|  | Handles loading, conversion, and normalization of mammogram images | 
					
						
						|  | for the MIRAI model. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | STANDARD_VIEWS = ['L-CC', 'L-MLO', 'R-CC', 'R-MLO'] | 
					
						
						|  | VIEW_INDICES = {'L-CC': 0, 'L-MLO': 1, 'R-CC': 2, 'R-MLO': 3} | 
					
						
						|  |  | 
					
						
						|  | def __init__( | 
					
						
						|  | self, | 
					
						
						|  | target_size: Tuple[int, int] = (1664, 2048), | 
					
						
						|  | normalize: bool = True, | 
					
						
						|  | device: str = 'cpu' | 
					
						
						|  | ): | 
					
						
						|  | """ | 
					
						
						|  | Initialize the preprocessor. | 
					
						
						|  |  | 
					
						
						|  | Args: | 
					
						
						|  | target_size: Target image size (height, width) | 
					
						
						|  | normalize: Whether to apply normalization | 
					
						
						|  | device: Device for tensor operations | 
					
						
						|  | """ | 
					
						
						|  | self.target_size = target_size | 
					
						
						|  | self.normalize = normalize | 
					
						
						|  | self.device = device | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | self.imagenet_mean = np.array([0.485, 0.456, 0.406]) | 
					
						
						|  | self.imagenet_std = np.array([0.229, 0.224, 0.225]) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | self.mirai_mean = 7047.99 | 
					
						
						|  | self.mirai_std = 12005.5 | 
					
						
						|  |  | 
					
						
						|  | def load_dicom(self, dicom_path: str) -> np.ndarray: | 
					
						
						|  | """ | 
					
						
						|  | Load and convert DICOM file to numpy array. | 
					
						
						|  |  | 
					
						
						|  | Args: | 
					
						
						|  | dicom_path: Path to DICOM file | 
					
						
						|  |  | 
					
						
						|  | Returns: | 
					
						
						|  | Numpy array of the image | 
					
						
						|  | """ | 
					
						
						|  | try: | 
					
						
						|  |  | 
					
						
						|  | dcm = pydicom.dcmread(dicom_path) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | img = dcm.pixel_array.astype(float) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if hasattr(dcm, 'WindowCenter') and hasattr(dcm, 'WindowWidth'): | 
					
						
						|  | center = float(dcm.WindowCenter) | 
					
						
						|  | width = float(dcm.WindowWidth) | 
					
						
						|  | img = self._apply_windowing(img, center, width) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if hasattr(dcm, 'RescaleSlope') and hasattr(dcm, 'RescaleIntercept'): | 
					
						
						|  | img = img * float(dcm.RescaleSlope) + float(dcm.RescaleIntercept) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | img = (img - img.min()) / (img.max() - img.min() + 1e-8) | 
					
						
						|  |  | 
					
						
						|  | return img | 
					
						
						|  |  | 
					
						
						|  | except Exception as e: | 
					
						
						|  | warnings.warn(f"Could not load DICOM file {dicom_path}: {e}") | 
					
						
						|  | return None | 
					
						
						|  |  | 
					
						
						|  | def load_image(self, image_path: str) -> np.ndarray: | 
					
						
						|  | """ | 
					
						
						|  | Load image from file (PNG, JPEG, or DICOM). | 
					
						
						|  |  | 
					
						
						|  | Args: | 
					
						
						|  | image_path: Path to image file | 
					
						
						|  |  | 
					
						
						|  | Returns: | 
					
						
						|  | Numpy array of the image | 
					
						
						|  | """ | 
					
						
						|  | path = Path(image_path) | 
					
						
						|  |  | 
					
						
						|  | if path.suffix.lower() in ['.dcm', '.dicom']: | 
					
						
						|  | return self.load_dicom(image_path) | 
					
						
						|  | elif path.suffix.lower() in ['.png', '.jpg', '.jpeg', '.tif', '.tiff']: | 
					
						
						|  |  | 
					
						
						|  | img = Image.open(image_path) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if img.mode != 'L': | 
					
						
						|  | img = img.convert('L') | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | img = np.array(img, dtype=np.float32) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if img.max() > 1: | 
					
						
						|  | img = img / img.max() | 
					
						
						|  |  | 
					
						
						|  | return img | 
					
						
						|  | else: | 
					
						
						|  | raise ValueError(f"Unsupported file format: {path.suffix}") | 
					
						
						|  |  | 
					
						
						|  | def _apply_windowing(self, img: np.ndarray, center: float, width: float) -> np.ndarray: | 
					
						
						|  | """Apply DICOM windowing to image.""" | 
					
						
						|  | lower = center - width / 2 | 
					
						
						|  | upper = center + width / 2 | 
					
						
						|  | img = np.clip(img, lower, upper) | 
					
						
						|  | img = (img - lower) / (upper - lower) | 
					
						
						|  | return img | 
					
						
						|  |  | 
					
						
						|  | def resize_image(self, img: np.ndarray, target_size: Optional[Tuple[int, int]] = None) -> np.ndarray: | 
					
						
						|  | """ | 
					
						
						|  | Resize image to target size. | 
					
						
						|  |  | 
					
						
						|  | Args: | 
					
						
						|  | img: Input image array | 
					
						
						|  | target_size: Target size (height, width) | 
					
						
						|  |  | 
					
						
						|  | Returns: | 
					
						
						|  | Resized image | 
					
						
						|  | """ | 
					
						
						|  | if target_size is None: | 
					
						
						|  | target_size = self.target_size | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | resized = cv2.resize(img, (target_size[1], target_size[0]), | 
					
						
						|  | interpolation=cv2.INTER_LINEAR) | 
					
						
						|  |  | 
					
						
						|  | return resized | 
					
						
						|  |  | 
					
						
						|  | def preprocess_single_image(self, img: np.ndarray) -> torch.Tensor: | 
					
						
						|  | """ | 
					
						
						|  | Preprocess a single mammogram image. | 
					
						
						|  |  | 
					
						
						|  | Args: | 
					
						
						|  | img: Input image as numpy array | 
					
						
						|  |  | 
					
						
						|  | Returns: | 
					
						
						|  | Preprocessed tensor | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | if img.shape[:2] != self.target_size: | 
					
						
						|  | img = self.resize_image(img, self.target_size) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if len(img.shape) == 2: | 
					
						
						|  | img = np.stack([img, img, img], axis=-1) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | img_tensor = torch.from_numpy(img).float() | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | img_tensor = img_tensor.permute(2, 0, 1) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if self.normalize: | 
					
						
						|  |  | 
					
						
						|  | for c in range(3): | 
					
						
						|  | img_tensor[c] = (img_tensor[c] - self.imagenet_mean[c]) / self.imagenet_std[c] | 
					
						
						|  |  | 
					
						
						|  | return img_tensor | 
					
						
						|  |  | 
					
						
						|  | def load_mammogram_exam( | 
					
						
						|  | self, | 
					
						
						|  | paths: Dict[str, str], | 
					
						
						|  | return_missing_views: bool = False | 
					
						
						|  | ) -> Union[torch.Tensor, Tuple[torch.Tensor, List[str]]]: | 
					
						
						|  | """ | 
					
						
						|  | Load a complete mammogram exam with all 4 standard views. | 
					
						
						|  |  | 
					
						
						|  | Args: | 
					
						
						|  | paths: Dictionary mapping view names to file paths | 
					
						
						|  | return_missing_views: Whether to return list of missing views | 
					
						
						|  |  | 
					
						
						|  | Returns: | 
					
						
						|  | Tensor of shape (3, 4, H, W) or tuple with tensor and missing views | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | exam_tensor = torch.zeros( | 
					
						
						|  | 3, 4, self.target_size[0], self.target_size[1], | 
					
						
						|  | dtype=torch.float32 | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | missing_views = [] | 
					
						
						|  | loaded_views = [] | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | for view_name in self.STANDARD_VIEWS: | 
					
						
						|  | if view_name in paths and paths[view_name]: | 
					
						
						|  | try: | 
					
						
						|  |  | 
					
						
						|  | img = self.load_image(paths[view_name]) | 
					
						
						|  | if img is not None: | 
					
						
						|  | img_tensor = self.preprocess_single_image(img) | 
					
						
						|  | view_idx = self.VIEW_INDICES[view_name] | 
					
						
						|  | exam_tensor[:, view_idx, :, :] = img_tensor | 
					
						
						|  | loaded_views.append(view_name) | 
					
						
						|  | else: | 
					
						
						|  | missing_views.append(view_name) | 
					
						
						|  | except Exception as e: | 
					
						
						|  | warnings.warn(f"Could not load view {view_name}: {e}") | 
					
						
						|  | missing_views.append(view_name) | 
					
						
						|  | else: | 
					
						
						|  | missing_views.append(view_name) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if missing_views: | 
					
						
						|  | warnings.warn( | 
					
						
						|  | f"Missing mammogram views: {missing_views}. " | 
					
						
						|  | "Model performance may be degraded." | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | if return_missing_views: | 
					
						
						|  | return exam_tensor, missing_views | 
					
						
						|  | return exam_tensor | 
					
						
						|  |  | 
					
						
						|  | def prepare_risk_factors( | 
					
						
						|  | self, | 
					
						
						|  | risk_factor_dict: Optional[Dict[str, Union[float, int, bool]]] = None, | 
					
						
						|  | use_defaults: bool = True | 
					
						
						|  | ) -> torch.Tensor: | 
					
						
						|  | """ | 
					
						
						|  | Prepare risk factor tensor from dictionary. | 
					
						
						|  |  | 
					
						
						|  | Args: | 
					
						
						|  | risk_factor_dict: Dictionary of risk factors | 
					
						
						|  | use_defaults: Whether to use default values for missing factors | 
					
						
						|  |  | 
					
						
						|  | Returns: | 
					
						
						|  | Risk factor tensor of shape (34,) | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | risk_factors = torch.zeros(34, dtype=torch.float32) | 
					
						
						|  |  | 
					
						
						|  | if risk_factor_dict is None and use_defaults: | 
					
						
						|  |  | 
					
						
						|  | risk_factors[5] = 0.5 | 
					
						
						|  | risk_factors[0] = 0.25 | 
					
						
						|  | return risk_factors | 
					
						
						|  |  | 
					
						
						|  | if risk_factor_dict is not None: | 
					
						
						|  |  | 
					
						
						|  | factor_mapping = { | 
					
						
						|  | 'density': 0, | 
					
						
						|  | 'family_history': 1, | 
					
						
						|  | 'biopsy_benign': 2, | 
					
						
						|  | 'biopsy_lcis': 3, | 
					
						
						|  | 'biopsy_atypical': 4, | 
					
						
						|  | 'age': 5, | 
					
						
						|  | 'menarche_age': 6, | 
					
						
						|  | 'menopause_age': 7, | 
					
						
						|  | 'first_pregnancy_age': 8, | 
					
						
						|  | 'prior_hist': 9, | 
					
						
						|  | 'race': 10, | 
					
						
						|  | 'parous': 11, | 
					
						
						|  | 'menopausal_status': 12, | 
					
						
						|  | 'weight': 13, | 
					
						
						|  | 'height': 14, | 
					
						
						|  |  | 
					
						
						|  | } | 
					
						
						|  |  | 
					
						
						|  | for factor_name, idx in factor_mapping.items(): | 
					
						
						|  | if factor_name in risk_factor_dict: | 
					
						
						|  | value = risk_factor_dict[factor_name] | 
					
						
						|  |  | 
					
						
						|  | if factor_name == 'age': | 
					
						
						|  | value = value / 100.0 | 
					
						
						|  | elif factor_name in ['weight', 'height']: | 
					
						
						|  | value = value / 200.0 | 
					
						
						|  | elif isinstance(value, bool): | 
					
						
						|  | value = float(value) | 
					
						
						|  |  | 
					
						
						|  | risk_factors[idx] = float(value) | 
					
						
						|  |  | 
					
						
						|  | return risk_factors | 
					
						
						|  |  | 
					
						
						|  | def create_batch( | 
					
						
						|  | self, | 
					
						
						|  | exam_list: List[Dict[str, str]], | 
					
						
						|  | risk_factors_list: Optional[List[Dict]] = None | 
					
						
						|  | ) -> Dict[str, torch.Tensor]: | 
					
						
						|  | """ | 
					
						
						|  | Create a batch of mammogram exams. | 
					
						
						|  |  | 
					
						
						|  | Args: | 
					
						
						|  | exam_list: List of exam path dictionaries | 
					
						
						|  | risk_factors_list: Optional list of risk factor dictionaries | 
					
						
						|  |  | 
					
						
						|  | Returns: | 
					
						
						|  | Dictionary with 'images' and 'risk_factors' tensors | 
					
						
						|  | """ | 
					
						
						|  | batch_images = [] | 
					
						
						|  | batch_risk_factors = [] | 
					
						
						|  |  | 
					
						
						|  | for i, exam_paths in enumerate(exam_list): | 
					
						
						|  |  | 
					
						
						|  | exam_tensor = self.load_mammogram_exam(exam_paths) | 
					
						
						|  | batch_images.append(exam_tensor.unsqueeze(0)) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | rf_dict = risk_factors_list[i] if risk_factors_list else None | 
					
						
						|  | rf_tensor = self.prepare_risk_factors(rf_dict) | 
					
						
						|  | batch_risk_factors.append(rf_tensor.unsqueeze(0)) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | images = torch.cat(batch_images, dim=0) | 
					
						
						|  | risk_factors = torch.cat(batch_risk_factors, dim=0) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | batch_size = len(exam_list) | 
					
						
						|  | batch_metadata = { | 
					
						
						|  | 'time_seq': torch.zeros(batch_size, 4).long(), | 
					
						
						|  | 'view_seq': torch.tensor([[0, 1, 2, 3]] * batch_size), | 
					
						
						|  | 'side_seq': torch.tensor([[0, 0, 1, 1]] * batch_size), | 
					
						
						|  | } | 
					
						
						|  |  | 
					
						
						|  | return { | 
					
						
						|  | 'images': images.to(self.device), | 
					
						
						|  | 'risk_factors': risk_factors.to(self.device), | 
					
						
						|  | 'batch_metadata': batch_metadata | 
					
						
						|  | } | 
					
						
						|  |  | 
					
						
						|  | def augment_image( | 
					
						
						|  | self, | 
					
						
						|  | img: np.ndarray, | 
					
						
						|  | flip_horizontal: bool = False, | 
					
						
						|  | flip_vertical: bool = False, | 
					
						
						|  | rotate_angle: float = 0, | 
					
						
						|  | brightness_factor: float = 1.0, | 
					
						
						|  | contrast_factor: float = 1.0 | 
					
						
						|  | ) -> np.ndarray: | 
					
						
						|  | """ | 
					
						
						|  | Apply data augmentation to mammogram image. | 
					
						
						|  |  | 
					
						
						|  | Args: | 
					
						
						|  | img: Input image | 
					
						
						|  | flip_horizontal: Apply horizontal flip | 
					
						
						|  | flip_vertical: Apply vertical flip | 
					
						
						|  | rotate_angle: Rotation angle in degrees | 
					
						
						|  | brightness_factor: Brightness adjustment factor | 
					
						
						|  | contrast_factor: Contrast adjustment factor | 
					
						
						|  |  | 
					
						
						|  | Returns: | 
					
						
						|  | Augmented image | 
					
						
						|  | """ | 
					
						
						|  | img_aug = img.copy() | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if flip_horizontal: | 
					
						
						|  | img_aug = np.fliplr(img_aug) | 
					
						
						|  | if flip_vertical: | 
					
						
						|  | img_aug = np.flipud(img_aug) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if rotate_angle != 0: | 
					
						
						|  | h, w = img_aug.shape[:2] | 
					
						
						|  | center = (w // 2, h // 2) | 
					
						
						|  | M = cv2.getRotationMatrix2D(center, rotate_angle, 1.0) | 
					
						
						|  | img_aug = cv2.warpAffine(img_aug, M, (w, h)) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | img_aug = img_aug * contrast_factor | 
					
						
						|  | img_aug = img_aug + (brightness_factor - 1.0) | 
					
						
						|  | img_aug = np.clip(img_aug, 0, 1) | 
					
						
						|  |  | 
					
						
						|  | return img_aug |