#!/usr/bin/env python3 """ Example inference script for MIRAI breast cancer risk prediction model. This script demonstrates how to use the MIRAI model from Hugging Face Hub for breast cancer risk prediction using mammography images. """ import torch import numpy as np from pathlib import Path from typing import Dict, Optional import warnings # Import model components from modeling_mirai import MiraiModel from configuration_mirai import MiraiConfig from preprocessor import MiraiPreprocessor from pipeline import create_mirai_pipeline def example_basic_inference(): """Basic example of running MIRAI model inference.""" print("\n" + "=" * 80) print("EXAMPLE 1: Basic Inference") print("=" * 80) # Load model print("\nLoading MIRAI model...") model = MiraiModel.from_pretrained("Lab-Rasool/Mirai") model.eval() # Create preprocessor preprocessor = MiraiPreprocessor() # Example: Create dummy mammogram data # In practice, load actual mammogram images print("\nCreating example mammogram data...") dummy_exam = { 'L-CC': None, # Path to left CC view 'L-MLO': None, # Path to left MLO view 'R-CC': None, # Path to right CC view 'R-MLO': None # Path to right MLO view } # For demonstration, create random tensor images = torch.randn(1, 3, 4, 1664, 2048) # Optional: Add risk factors risk_factors = preprocessor.prepare_risk_factors({ 'age': 50, 'density': 2, 'family_history': False }).unsqueeze(0) # Run inference print("\nRunning model inference...") with torch.no_grad(): outputs = model( images=images, risk_factors=risk_factors, return_dict=True ) # Display results print("\nResults:") print("-" * 40) probabilities = outputs.probabilities[0].numpy() for year in range(len(probabilities)): risk_pct = probabilities[year] * 100 print(f"Year {year + 1} risk: {risk_pct:.2f}%") if outputs.risk_scores: print(f"\n5-Year Risk Category: {outputs.risk_scores.get('risk_category', 'Unknown')}") def example_pipeline_inference(): """Example using the high-level pipeline interface.""" print("\n" + "=" * 80) print("EXAMPLE 2: Pipeline Interface") print("=" * 80) # Create pipeline print("\nCreating MIRAI pipeline...") pipeline = create_mirai_pipeline("Lab-Rasool/Mirai") # Prepare input # In practice, provide actual file paths input_data = { 'exam_paths': { 'L-CC': 'path/to/left_cc.png', 'L-MLO': 'path/to/left_mlo.png', 'R-CC': 'path/to/right_cc.png', 'R-MLO': 'path/to/right_mlo.png' }, 'risk_factors': { 'age': 55, 'density': 3, 'family_history': True, 'weight': 70, 'height': 165 } } # For demonstration with dummy data print("\nRunning pipeline with example data...") try: # Create dummy input for demonstration dummy_input = { 'images': { 'L-CC': None, 'L-MLO': None, 'R-CC': None, 'R-MLO': None }, 'risk_factors': input_data['risk_factors'] } # Note: This will use dummy data since paths are None with warnings.catch_warnings(): warnings.simplefilter("ignore") results = pipeline(dummy_input) # Display results print("\nPipeline Results:") print("-" * 40) # Year-by-year predictions if 'predictions' in results: for year_key, year_data in results['predictions'].items(): print(f"{year_key}: {year_data['risk_percentage']:.2f}%") # Risk assessment if 'risk_assessment' in results: assessment = results['risk_assessment'] print(f"\n5-Year Risk: {assessment.get('five_year_risk', 0):.2f}%") print(f"Risk Category: {assessment.get('category', 'Unknown')}") print(f"Recommendation: {assessment.get('primary_recommendation', 'N/A')}") # Detailed recommendations if 'recommendations' in results: print("\nDetailed Recommendations:") for i, rec in enumerate(results['recommendations'], 1): print(f" {i}. {rec}") except Exception as e: print(f"Pipeline execution note: {e}") print("This is expected when running without actual image files.") def example_batch_processing(): """Example of processing multiple patients in batch.""" print("\n" + "=" * 80) print("EXAMPLE 3: Batch Processing") print("=" * 80) # Load model model = MiraiModel.from_pretrained("Lab-Rasool/Mirai") model.eval() # Create preprocessor preprocessor = MiraiPreprocessor() # Prepare batch of patients batch_size = 3 print(f"\nProcessing batch of {batch_size} patients...") # Create batch data (dummy for demonstration) batch_images = torch.randn(batch_size, 3, 4, 1664, 2048) # Different risk factors for each patient risk_factors_list = [ {'age': 45, 'density': 2, 'family_history': False}, {'age': 60, 'density': 3, 'family_history': True}, {'age': 55, 'density': 1, 'family_history': False} ] # Prepare risk factors tensor batch_risk_factors = [] for rf_dict in risk_factors_list: rf_tensor = preprocessor.prepare_risk_factors(rf_dict) batch_risk_factors.append(rf_tensor.unsqueeze(0)) batch_risk_factors = torch.cat(batch_risk_factors, dim=0) # Run batch inference print("\nRunning batch inference...") with torch.no_grad(): outputs = model( images=batch_images, risk_factors=batch_risk_factors, return_dict=True ) # Display results for each patient probabilities = outputs.probabilities.numpy() for patient_idx in range(batch_size): print(f"\nPatient {patient_idx + 1}:") print("-" * 20) patient_probs = probabilities[patient_idx] # 5-year risk five_year_risk = patient_probs[4] * 100 if len(patient_probs) > 4 else 0 print(f" 5-Year Risk: {five_year_risk:.2f}%") # Risk category if five_year_risk < 1.67: category = "Low Risk" elif five_year_risk < 3.0: category = "Average Risk" elif five_year_risk < 5.0: category = "Moderate Risk" else: category = "High Risk" print(f" Category: {category}") def example_risk_report_generation(): """Example of generating comprehensive risk reports.""" print("\n" + "=" * 80) print("EXAMPLE 4: Risk Report Generation") print("=" * 80) # Load model model = MiraiModel.from_pretrained("Lab-Rasool/Mirai") model.eval() # Prepare patient data patient_id = "DEMO_001" print(f"\nGenerating risk report for patient: {patient_id}") # Create dummy mammogram data images = torch.randn(1, 3, 4, 1664, 2048) # Patient risk factors risk_factors = torch.tensor([ 0.25, # density (BI-RADS 2) 0.0, # no family history 0.0, # no benign biopsy 0.0, # no LCIS 0.0, # no atypical hyperplasia 0.55, # age (55 years, normalized) 0.12, # menarche age 0.0, # menopause age (pre-menopausal) 0.0, # first pregnancy age 0.0, # no prior history 1.0, # race (categorical) 1.0, # parous 0.0, # menopausal status 0.70, # weight (normalized) 0.65, # height (normalized) 0.0, # no ovarian cancer 0.0, # ovarian cancer age 0.0, # not Ashkenazi 0.0, # no BRCA mutation 0.0, # mother BC history 0.0, # maternal aunt BC 0.0, # paternal aunt BC 0.0, # maternal grandmother BC 0.0, # paternal grandmother BC 0.0, # sister BC 0.0, # mother OC history 0.0, # maternal aunt OC 0.0, # paternal aunt OC 0.0, # maternal grandmother OC 0.0, # paternal grandmother OC 0.0, # sister OC 0.0, # HRT type 0.0, # HRT duration 0.0 # HRT years ago stopped ]).unsqueeze(0) # Generate report print("\nGenerating comprehensive risk assessment...") report = model.generate_risk_report( images=images, risk_factors=risk_factors, patient_id=patient_id ) # Display report print("\n" + "=" * 60) print(" BREAST CANCER RISK ASSESSMENT REPORT") print("=" * 60) print(f"\nPatient ID: {report['patient_id']}") print("-" * 60) print("\nRisk Predictions by Year:") for year, risk in report['predictions'].items(): year_num = year.split('_')[1] bar_length = int(risk / 5) # Scale to max 20 chars bar = '█' * bar_length + '░' * (20 - bar_length) print(f" Year {year_num}: [{bar}] {risk:.2f}%") print(f"\nOverall Risk Category: {report['risk_category']}") print(f"Model Confidence: {report['confidence']:.2%}") print("\nClinical Recommendations:") for i, rec in enumerate(report['recommendations'], 1): print(f" {i}. {rec}") print("\n" + "=" * 60) def main(): """Run all examples.""" print("\n" + "=" * 80) print(" MIRAI MODEL - HUGGING FACE INFERENCE EXAMPLES") print("=" * 80) print("\nThese examples demonstrate how to use the MIRAI breast cancer") print("risk prediction model from Hugging Face Hub.") print("\nNote: Using random data for demonstration purposes.") print("In practice, use actual mammogram images and patient data.") # Run examples try: example_basic_inference() except Exception as e: print(f"\nExample 1 note: {e}") try: example_pipeline_inference() except Exception as e: print(f"\nExample 2 note: {e}") try: example_batch_processing() except Exception as e: print(f"\nExample 3 note: {e}") try: example_risk_report_generation() except Exception as e: print(f"\nExample 4 note: {e}") print("\n" + "=" * 80) print("Examples completed!") print("=" * 80) if __name__ == "__main__": main()