Mirai / preprocessor.py
Aakash-Tripathi's picture
Upload preprocessor.py with huggingface_hub
46525fb verified
raw
history blame
12.3 kB
"""Preprocessing utilities for MIRAI model."""
import torch
import numpy as np
from PIL import Image
import cv2
from typing import Dict, List, Optional, Tuple, Union
import pydicom
import warnings
from pathlib import Path
class MiraiPreprocessor:
"""
Preprocessor for MIRAI mammography images.
Handles loading, conversion, and normalization of mammogram images
for the MIRAI model.
"""
# Standard mammogram views
STANDARD_VIEWS = ['L-CC', 'L-MLO', 'R-CC', 'R-MLO']
VIEW_INDICES = {'L-CC': 0, 'L-MLO': 1, 'R-CC': 2, 'R-MLO': 3}
def __init__(
self,
target_size: Tuple[int, int] = (1664, 2048),
normalize: bool = True,
device: str = 'cpu'
):
"""
Initialize the preprocessor.
Args:
target_size: Target image size (height, width)
normalize: Whether to apply normalization
device: Device for tensor operations
"""
self.target_size = target_size
self.normalize = normalize
self.device = device
# ImageNet normalization parameters
self.imagenet_mean = np.array([0.485, 0.456, 0.406])
self.imagenet_std = np.array([0.229, 0.224, 0.225])
# MIRAI-specific normalization (for raw pixel values)
self.mirai_mean = 7047.99
self.mirai_std = 12005.5
def load_dicom(self, dicom_path: str) -> np.ndarray:
"""
Load and convert DICOM file to numpy array.
Args:
dicom_path: Path to DICOM file
Returns:
Numpy array of the image
"""
try:
# Load DICOM
dcm = pydicom.dcmread(dicom_path)
# Get pixel array
img = dcm.pixel_array.astype(float)
# Apply DICOM windowing if available
if hasattr(dcm, 'WindowCenter') and hasattr(dcm, 'WindowWidth'):
center = float(dcm.WindowCenter)
width = float(dcm.WindowWidth)
img = self._apply_windowing(img, center, width)
# Apply rescale if available
if hasattr(dcm, 'RescaleSlope') and hasattr(dcm, 'RescaleIntercept'):
img = img * float(dcm.RescaleSlope) + float(dcm.RescaleIntercept)
# Normalize to 0-1 range
img = (img - img.min()) / (img.max() - img.min() + 1e-8)
return img
except Exception as e:
warnings.warn(f"Could not load DICOM file {dicom_path}: {e}")
return None
def load_image(self, image_path: str) -> np.ndarray:
"""
Load image from file (PNG, JPEG, or DICOM).
Args:
image_path: Path to image file
Returns:
Numpy array of the image
"""
path = Path(image_path)
if path.suffix.lower() in ['.dcm', '.dicom']:
return self.load_dicom(image_path)
elif path.suffix.lower() in ['.png', '.jpg', '.jpeg', '.tif', '.tiff']:
# Load with PIL
img = Image.open(image_path)
# Convert to grayscale if needed
if img.mode != 'L':
img = img.convert('L')
# Convert to numpy
img = np.array(img, dtype=np.float32)
# Normalize to 0-1 range
if img.max() > 1:
img = img / img.max()
return img
else:
raise ValueError(f"Unsupported file format: {path.suffix}")
def _apply_windowing(self, img: np.ndarray, center: float, width: float) -> np.ndarray:
"""Apply DICOM windowing to image."""
lower = center - width / 2
upper = center + width / 2
img = np.clip(img, lower, upper)
img = (img - lower) / (upper - lower)
return img
def resize_image(self, img: np.ndarray, target_size: Optional[Tuple[int, int]] = None) -> np.ndarray:
"""
Resize image to target size.
Args:
img: Input image array
target_size: Target size (height, width)
Returns:
Resized image
"""
if target_size is None:
target_size = self.target_size
# Use cv2 for resizing
resized = cv2.resize(img, (target_size[1], target_size[0]),
interpolation=cv2.INTER_LINEAR)
return resized
def preprocess_single_image(self, img: np.ndarray) -> torch.Tensor:
"""
Preprocess a single mammogram image.
Args:
img: Input image as numpy array
Returns:
Preprocessed tensor
"""
# Resize if needed
if img.shape[:2] != self.target_size:
img = self.resize_image(img, self.target_size)
# Convert to RGB (replicate grayscale to 3 channels)
if len(img.shape) == 2:
img = np.stack([img, img, img], axis=-1)
# Convert to tensor
img_tensor = torch.from_numpy(img).float()
# Rearrange dimensions: HWC -> CHW
img_tensor = img_tensor.permute(2, 0, 1)
# Normalize if requested
if self.normalize:
# Normalize with ImageNet statistics
for c in range(3):
img_tensor[c] = (img_tensor[c] - self.imagenet_mean[c]) / self.imagenet_std[c]
return img_tensor
def load_mammogram_exam(
self,
paths: Dict[str, str],
return_missing_views: bool = False
) -> Union[torch.Tensor, Tuple[torch.Tensor, List[str]]]:
"""
Load a complete mammogram exam with all 4 standard views.
Args:
paths: Dictionary mapping view names to file paths
return_missing_views: Whether to return list of missing views
Returns:
Tensor of shape (3, 4, H, W) or tuple with tensor and missing views
"""
# Initialize output tensor
exam_tensor = torch.zeros(
3, 4, self.target_size[0], self.target_size[1],
dtype=torch.float32
)
missing_views = []
loaded_views = []
# Process each view
for view_name in self.STANDARD_VIEWS:
if view_name in paths and paths[view_name]:
try:
# Load and preprocess image
img = self.load_image(paths[view_name])
if img is not None:
img_tensor = self.preprocess_single_image(img)
view_idx = self.VIEW_INDICES[view_name]
exam_tensor[:, view_idx, :, :] = img_tensor
loaded_views.append(view_name)
else:
missing_views.append(view_name)
except Exception as e:
warnings.warn(f"Could not load view {view_name}: {e}")
missing_views.append(view_name)
else:
missing_views.append(view_name)
# Warn if views are missing
if missing_views:
warnings.warn(
f"Missing mammogram views: {missing_views}. "
"Model performance may be degraded."
)
if return_missing_views:
return exam_tensor, missing_views
return exam_tensor
def prepare_risk_factors(
self,
risk_factor_dict: Optional[Dict[str, Union[float, int, bool]]] = None,
use_defaults: bool = True
) -> torch.Tensor:
"""
Prepare risk factor tensor from dictionary.
Args:
risk_factor_dict: Dictionary of risk factors
use_defaults: Whether to use default values for missing factors
Returns:
Risk factor tensor of shape (34,)
"""
# Default risk factor values (zeros for unknown)
risk_factors = torch.zeros(34, dtype=torch.float32)
if risk_factor_dict is None and use_defaults:
# Set some reasonable defaults
risk_factors[5] = 0.5 # Age (normalized)
risk_factors[0] = 0.25 # Density (BI-RADS 2)
return risk_factors
if risk_factor_dict is not None:
# Map risk factors to indices (simplified mapping)
factor_mapping = {
'density': 0,
'family_history': 1,
'biopsy_benign': 2,
'biopsy_lcis': 3,
'biopsy_atypical': 4,
'age': 5,
'menarche_age': 6,
'menopause_age': 7,
'first_pregnancy_age': 8,
'prior_hist': 9,
'race': 10,
'parous': 11,
'menopausal_status': 12,
'weight': 13,
'height': 14,
# ... additional mappings
}
for factor_name, idx in factor_mapping.items():
if factor_name in risk_factor_dict:
value = risk_factor_dict[factor_name]
# Normalize numerical values
if factor_name == 'age':
value = value / 100.0 # Normalize age to 0-1
elif factor_name in ['weight', 'height']:
value = value / 200.0 # Normalize weight/height
elif isinstance(value, bool):
value = float(value)
risk_factors[idx] = float(value)
return risk_factors
def create_batch(
self,
exam_list: List[Dict[str, str]],
risk_factors_list: Optional[List[Dict]] = None
) -> Dict[str, torch.Tensor]:
"""
Create a batch of mammogram exams.
Args:
exam_list: List of exam path dictionaries
risk_factors_list: Optional list of risk factor dictionaries
Returns:
Dictionary with 'images' and 'risk_factors' tensors
"""
batch_images = []
batch_risk_factors = []
for i, exam_paths in enumerate(exam_list):
# Load exam
exam_tensor = self.load_mammogram_exam(exam_paths)
batch_images.append(exam_tensor.unsqueeze(0))
# Prepare risk factors
rf_dict = risk_factors_list[i] if risk_factors_list else None
rf_tensor = self.prepare_risk_factors(rf_dict)
batch_risk_factors.append(rf_tensor.unsqueeze(0))
# Stack into batches
images = torch.cat(batch_images, dim=0)
risk_factors = torch.cat(batch_risk_factors, dim=0)
# Create batch metadata
batch_size = len(exam_list)
batch_metadata = {
'time_seq': torch.zeros(batch_size, 4).long(),
'view_seq': torch.tensor([[0, 1, 2, 3]] * batch_size),
'side_seq': torch.tensor([[0, 0, 1, 1]] * batch_size),
}
return {
'images': images.to(self.device),
'risk_factors': risk_factors.to(self.device),
'batch_metadata': batch_metadata
}
def augment_image(
self,
img: np.ndarray,
flip_horizontal: bool = False,
flip_vertical: bool = False,
rotate_angle: float = 0,
brightness_factor: float = 1.0,
contrast_factor: float = 1.0
) -> np.ndarray:
"""
Apply data augmentation to mammogram image.
Args:
img: Input image
flip_horizontal: Apply horizontal flip
flip_vertical: Apply vertical flip
rotate_angle: Rotation angle in degrees
brightness_factor: Brightness adjustment factor
contrast_factor: Contrast adjustment factor
Returns:
Augmented image
"""
img_aug = img.copy()
# Flipping
if flip_horizontal:
img_aug = np.fliplr(img_aug)
if flip_vertical:
img_aug = np.flipud(img_aug)
# Rotation
if rotate_angle != 0:
h, w = img_aug.shape[:2]
center = (w // 2, h // 2)
M = cv2.getRotationMatrix2D(center, rotate_angle, 1.0)
img_aug = cv2.warpAffine(img_aug, M, (w, h))
# Brightness and contrast
img_aug = img_aug * contrast_factor
img_aug = img_aug + (brightness_factor - 1.0)
img_aug = np.clip(img_aug, 0, 1)
return img_aug