"""Hugging Face pipeline implementation for MIRAI model.""" import torch import numpy as np from typing import Dict, List, Optional, Union, Any from transformers import Pipeline from preprocessor import MiraiPreprocessor import warnings class MiraiPipeline(Pipeline): """ Custom pipeline for MIRAI breast cancer risk prediction. This pipeline handles the complete inference workflow including image preprocessing, model inference, and result interpretation. """ def __init__(self, model, tokenizer=None, **kwargs): """Initialize the MIRAI pipeline.""" super().__init__(model=model, tokenizer=tokenizer, **kwargs) self.preprocessor = MiraiPreprocessor(device=self.device) def _sanitize_parameters(self, **kwargs): """Sanitize and split parameters.""" preprocess_params = {} forward_params = {} postprocess_params = {} # Preprocessing parameters if 'target_size' in kwargs: preprocess_params['target_size'] = kwargs['target_size'] if 'normalize' in kwargs: preprocess_params['normalize'] = kwargs['normalize'] # Forward parameters if 'risk_factors' in kwargs: forward_params['risk_factors'] = kwargs['risk_factors'] if 'return_all_years' in kwargs: forward_params['return_all_years'] = kwargs['return_all_years'] # Postprocessing parameters if 'include_recommendations' in kwargs: postprocess_params['include_recommendations'] = kwargs['include_recommendations'] if 'risk_threshold' in kwargs: postprocess_params['risk_threshold'] = kwargs['risk_threshold'] return preprocess_params, forward_params, postprocess_params def preprocess(self, inputs, **preprocess_params): """ Preprocess inputs for the model. Args: inputs: Can be: - Dict with 'images' key containing paths to mammogram views - Dict with 'exam_paths' and optional 'risk_factors' - Path to directory containing mammogram images Returns: Preprocessed tensors ready for model """ if isinstance(inputs, str): # Assume it's a directory path inputs = self._load_exam_from_directory(inputs) if isinstance(inputs, dict): # Handle different input formats if 'images' in inputs: exam_paths = inputs['images'] elif 'exam_paths' in inputs: exam_paths = inputs['exam_paths'] else: exam_paths = inputs # Load mammogram exam image_tensor = self.preprocessor.load_mammogram_exam(exam_paths) # Prepare risk factors if provided risk_factors = None if 'risk_factors' in inputs: risk_factors = self.preprocessor.prepare_risk_factors( inputs['risk_factors'] ) # Create batch metadata batch_metadata = { 'time_seq': torch.zeros(1, 4).long().to(self.device), 'view_seq': torch.tensor([[0, 1, 2, 3]]).to(self.device), 'side_seq': torch.tensor([[0, 0, 1, 1]]).to(self.device), } return { 'images': image_tensor.unsqueeze(0).to(self.device), 'risk_factors': risk_factors.unsqueeze(0).to(self.device) if risk_factors is not None else None, 'batch_metadata': batch_metadata } else: raise ValueError(f"Unsupported input type: {type(inputs)}") def _forward(self, model_inputs, **forward_params): """ Run model inference. Args: model_inputs: Preprocessed inputs **forward_params: Additional forward parameters Returns: Model outputs """ with torch.no_grad(): outputs = self.model( images=model_inputs['images'], risk_factors=model_inputs['risk_factors'], batch_metadata=model_inputs['batch_metadata'], return_dict=True, output_hidden_states=False ) return { 'probabilities': outputs.probabilities, 'logits': outputs.logits, 'risk_scores': outputs.risk_scores } def postprocess(self, model_outputs, **postprocess_params): """ Postprocess model outputs into interpretable results. Args: model_outputs: Raw model outputs **postprocess_params: Postprocessing parameters Returns: Formatted prediction results """ include_recommendations = postprocess_params.get('include_recommendations', True) risk_threshold = postprocess_params.get('risk_threshold', 5.0) # Extract probabilities probabilities = model_outputs['probabilities'][0].cpu().numpy() risk_scores = model_outputs.get('risk_scores', {}) # Format results results = { 'predictions': {}, 'risk_assessment': {}, 'metadata': { 'model_version': '1.0.0', 'prediction_type': 'breast_cancer_risk' } } # Year-by-year predictions for i in range(len(probabilities)): year = i + 1 risk_pct = probabilities[i] * 100 results['predictions'][f'year_{year}'] = { 'risk_percentage': float(risk_pct), 'risk_probability': float(probabilities[i]) } # 5-year risk assessment if len(probabilities) >= 5: five_year_risk = probabilities[4] * 100 results['risk_assessment']['five_year_risk'] = float(five_year_risk) # Risk categorization if five_year_risk < 1.67: category = 'Low Risk' recommendation = 'Standard annual screening' elif five_year_risk < 3.0: category = 'Average Risk' recommendation = 'Continue annual mammography' elif five_year_risk < risk_threshold: category = 'Moderate Risk' recommendation = 'Consider supplemental screening' else: category = 'High Risk' recommendation = 'Discuss risk reduction strategies with physician' results['risk_assessment']['category'] = category results['risk_assessment']['primary_recommendation'] = recommendation # Add detailed recommendations if requested if include_recommendations: results['recommendations'] = self._generate_recommendations( five_year_risk, category ) # Add confidence metrics results['confidence'] = self._calculate_confidence(model_outputs['logits'][0]) return results def _generate_recommendations(self, five_year_risk: float, category: str) -> List[str]: """Generate detailed clinical recommendations.""" recommendations = [] if category == 'Low Risk': recommendations.extend([ "Continue standard annual screening mammography", "Maintain healthy lifestyle habits", "Perform regular breast self-examinations" ]) elif category == 'Average Risk': recommendations.extend([ "Continue annual mammography screening", "Discuss family history with healthcare provider", "Consider lifestyle modifications to reduce risk" ]) elif category == 'Moderate Risk': recommendations.extend([ "Consider supplemental screening (MRI or ultrasound)", "Discuss risk reduction strategies with physician", "Consider genetic counseling if family history present", "Evaluate lifestyle factors that may be modified" ]) else: # High Risk recommendations.extend([ "Strongly consider supplemental screening modalities", "Discuss chemoprevention options with oncologist", "Consider genetic testing for BRCA mutations", "Evaluate eligibility for high-risk screening programs", "Consider consultation with breast specialist" ]) return recommendations def _calculate_confidence(self, logits: torch.Tensor) -> float: """Calculate model confidence score.""" probs = torch.sigmoid(logits.cpu()) # Use entropy-based confidence entropy = -torch.sum(probs * torch.log(probs + 1e-8)) max_entropy = -len(logits) * torch.log(torch.tensor(0.5)) confidence = 1.0 - (entropy / max_entropy).item() return float(max(0.0, min(1.0, confidence))) def _load_exam_from_directory(self, directory_path: str) -> Dict[str, str]: """Load mammogram exam from directory structure.""" from pathlib import Path dir_path = Path(directory_path) exam_paths = {} # Look for standard view files view_patterns = { 'L-CC': ['*LCC*', '*L_CC*', '*left_cc*'], 'L-MLO': ['*LMLO*', '*L_MLO*', '*left_mlo*'], 'R-CC': ['*RCC*', '*R_CC*', '*right_cc*'], 'R-MLO': ['*RMLO*', '*R_MLO*', '*right_mlo*'] } for view, patterns in view_patterns.items(): for pattern in patterns: files = list(dir_path.glob(pattern)) if files: exam_paths[view] = str(files[0]) break if not exam_paths: warnings.warn(f"No mammogram files found in {directory_path}") return exam_paths def __call__(self, inputs, **kwargs): """ Run the complete pipeline. Args: inputs: Input data (images paths or processed tensors) **kwargs: Additional parameters for pipeline stages Returns: Risk prediction results """ return super().__call__(inputs, **kwargs) def create_mirai_pipeline( model_name_or_path: str = "Lab-Rasool/Mirai", device: Optional[str] = None, **kwargs ) -> MiraiPipeline: """ Create a MIRAI pipeline for easy inference. Args: model_name_or_path: Model identifier or path device: Device for inference ('cuda', 'cpu', or None for auto) **kwargs: Additional pipeline parameters Returns: Configured MiraiPipeline instance """ from modeling_mirai import MiraiModel from configuration_mirai import MiraiConfig # Determine device if device is None: device = 'cuda' if torch.cuda.is_available() else 'cpu' # Load model try: model = MiraiModel.from_pretrained(model_name_or_path) except: # Fallback to creating new model with config config = MiraiConfig.from_pretrained(model_name_or_path) model = MiraiModel(config) warnings.warn("Could not load pretrained weights, using random initialization") # Move model to device model = model.to(device) model.eval() # Create pipeline pipeline = MiraiPipeline( model=model, device=device, **kwargs ) return pipeline