"""MIRAI Model with exact original architecture for perfect weight loading.""" import torch import torch.nn as nn import torch.nn.functional as F from transformers import PreTrainedModel from transformers.modeling_outputs import BaseModelOutput import warnings import os import math from typing import Optional, Tuple, Dict, Any from dataclasses import dataclass try: from configuration_mirai import MiraiConfig except ImportError: from .configuration_mirai import MiraiConfig @dataclass class MiraiOutput(BaseModelOutput): """Output type for MIRAI model.""" logits: Optional[torch.FloatTensor] = None probabilities: Optional[torch.FloatTensor] = None hidden_states: Optional[torch.FloatTensor] = None class BasicBlock(nn.Module): """Basic ResNet block matching original MIRAI.""" expansion = 1 def __init__(self, inplanes, planes, stride=1, downsample=None): super().__init__() self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride, padding=1, bias=False) self.bn1 = nn.BatchNorm2d(planes) self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1, bias=False) self.bn2 = nn.BatchNorm2d(planes) self.relu = nn.ReLU(inplace=True) self.downsample = downsample self.stride = stride def forward(self, x): identity = x out = self.conv1(x) out = self.bn1(out) out = self.relu(out) out = self.conv2(out) out = self.bn2(out) if self.downsample is not None: identity = self.downsample(x) out += identity out = self.relu(out) return out class Downsampler(nn.Module): """Downsampling layers for ResNet exactly matching original.""" def __init__(self, inplanes, num_chan=3): super().__init__() self.conv1 = nn.Conv2d(num_chan, inplanes, kernel_size=7, stride=2, padding=3, bias=False) self.bn1 = nn.BatchNorm2d(inplanes) self.relu = nn.ReLU(inplace=True) self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) def forward(self, x): x = self.conv1(x) x = self.bn1(x) x = self.relu(x) x = self.maxpool(x) return x class MiraiEncoder(nn.Module): """Image encoder matching exact original MIRAI architecture.""" def __init__(self, config): super().__init__() self.config = config # Get config values num_risk_factors = config.risk_factors.get('num_risk_factors', 34) if hasattr(config, 'risk_factors') else 34 inplanes = 64 # Downsampler module - exactly as in original self.downsampler = Downsampler(inplanes, num_chan=3) # Layer 1 - using underscore notation as in original self.layer1_0 = BasicBlock(inplanes, 64) self.layer1_1 = BasicBlock(64, 64) # Layer 2 with downsample downsample2 = nn.Sequential( nn.Conv2d(64, 128, kernel_size=1, stride=2, bias=False), nn.BatchNorm2d(128) ) self.layer2_0 = BasicBlock(64, 128, stride=2, downsample=downsample2) self.layer2_1 = BasicBlock(128, 128) # Layer 3 with downsample downsample3 = nn.Sequential( nn.Conv2d(128, 256, kernel_size=1, stride=2, bias=False), nn.BatchNorm2d(256) ) self.layer3_0 = BasicBlock(128, 256, stride=2, downsample=downsample3) self.layer3_1 = BasicBlock(256, 256) # Layer 4 with downsample downsample4 = nn.Sequential( nn.Conv2d(256, 512, kernel_size=1, stride=2, bias=False), nn.BatchNorm2d(512) ) self.layer4_0 = BasicBlock(256, 512, stride=2, downsample=downsample4) self.layer4_1 = BasicBlock(512, 512) self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) # Risk factor pool with all the individual FCs self.pool = nn.Module() # All risk factor FCs with correct dimensions from saved weights risk_factor_dims = { 'density': (4, 512), 'binary_family_history': (1, 512), 'binary_biopsy_benign': (1, 512), 'binary_biopsy_LCIS': (1, 512), 'binary_biopsy_atypical_hyperplasia': (1, 512), 'age': (6, 512), 'menarche_age': (5, 512), 'menopause_age': (5, 512), 'first_pregnancy_age': (6, 512), 'prior_hist': (1, 512), 'race': (13, 512), 'parous': (1, 512), 'menopausal_status': (4, 512), 'weight': (7, 512), 'height': (7, 512), 'ovarian_cancer': (1, 512), 'ovarian_cancer_age': (6, 512), 'ashkenazi': (1, 512), 'brca': (4, 512), 'mom_bc_cancer_history': (1, 512), 'm_aunt_bc_cancer_history': (1, 512), 'p_aunt_bc_cancer_history': (1, 512), 'm_grandmother_bc_cancer_history': (1, 512), 'p_grantmother_bc_cancer_history': (1, 512), 'sister_bc_cancer_history': (1, 512), 'mom_oc_cancer_history': (1, 512), 'm_aunt_oc_cancer_history': (1, 512), 'p_aunt_oc_cancer_history': (1, 512), 'm_grandmother_oc_cancer_history': (1, 512), 'p_grantmother_oc_cancer_history': (1, 512), 'sister_oc_cancer_history': (1, 512), 'hrt_type': (3, 512), 'hrt_duration': (5, 512), 'hrt_years_ago_stopped': (5, 512) } for rf_name, (out_dim, in_dim) in risk_factor_dims.items(): setattr(self.pool, f'{rf_name}_fc', nn.Linear(in_dim, out_dim)) # Probability of failure layer self.prob_of_failure_layer = nn.Module() self.prob_of_failure_layer.hazard_fc = nn.Linear(612, 5) # 5 time points self.prob_of_failure_layer.base_hazard_fc = nn.Linear(612, 1) # Register upper_triagular_mask as buffer, not parameter upper_triagular_mask = torch.ones(5, 5).triu() self.prob_of_failure_layer.register_buffer('upper_triagular_mask', upper_triagular_mask) # Final FC self.fc = nn.Linear(612, 2) # Input is 512 image + 100 risk factors def forward(self, x): # Downsampling x = self.downsampler(x) # ResNet layers x = self.layer1_0(x) x = self.layer1_1(x) x = self.layer2_0(x) x = self.layer2_1(x) x = self.layer3_0(x) x = self.layer3_1(x) x = self.layer4_0(x) x = self.layer4_1(x) # Global pooling x = self.avgpool(x) x = torch.flatten(x, 1) return x class MultiHead_Attention(nn.Module): """Original MIRAI MultiHead Attention.""" def __init__(self, hidden_dim, num_heads=8, dropout=0.1): super().__init__() self.hidden_dim = hidden_dim self.num_heads = num_heads assert hidden_dim % num_heads == 0 self.query = nn.Linear(hidden_dim, hidden_dim) self.value = nn.Linear(hidden_dim, hidden_dim) self.key = nn.Linear(hidden_dim, hidden_dim) self.dropout = nn.Dropout(p=dropout) self.dim_per_head = hidden_dim // num_heads self.aggregate_fc = nn.Linear(hidden_dim, hidden_dim) def forward(self, x): B, N, D = x.size() H = self.num_heads D_h = self.dim_per_head q = self.query(x).view(B, N, H, D_h).transpose(1, 2) # B x H x N x D_h k = self.key(x).view(B, N, H, D_h).transpose(1, 2) # B x H x N x D_h v = self.value(x).view(B, N, H, D_h).transpose(1, 2) # B x H x N x D_h # Attention scores scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(D_h) attn = F.softmax(scores, dim=-1) attn = self.dropout(attn) # Apply attention context = torch.matmul(attn, v) # B x H x N x D_h context = context.transpose(1, 2).contiguous().view(B, N, D) output = self.aggregate_fc(context) return output class TransformerLayer(nn.Module): """Original MIRAI Transformer Layer.""" def __init__(self, hidden_dim, num_heads=8, dropout=0.1): super().__init__() self.multihead_attention = MultiHead_Attention(hidden_dim, num_heads, dropout) self.layernorm_attn = nn.LayerNorm(hidden_dim) self.fc1 = nn.Linear(hidden_dim, hidden_dim) self.relu = nn.ReLU() self.fc2 = nn.Linear(hidden_dim, hidden_dim) self.layernorm_fc = nn.LayerNorm(hidden_dim) def forward(self, x): h = self.multihead_attention(x) x = self.layernorm_attn(h + x) h = self.fc2(self.relu(self.fc1(x))) x = self.layernorm_fc(h + x) return x class MiraiTransformer(nn.Module): """Transformer module matching exact original MIRAI architecture.""" def __init__(self, config): super().__init__() self.config = config # Get config values hidden_dim = config.transformer_config.get('hidden_dim', 512) if hasattr(config, 'transformer_config') else 512 num_heads = config.transformer_config.get('num_heads', 8) if hasattr(config, 'transformer_config') else 8 dropout = config.transformer_config.get('dropout', 0.1) if hasattr(config, 'transformer_config') else 0.1 num_classes = config.num_classes if hasattr(config, 'num_classes') else 5 # Kept images vector kept_images_vec = torch.ones([1, 4, 1]) self.register_buffer('kept_images_vec', kept_images_vec) # Layers matching original self.projection_layer = nn.Linear(hidden_dim, hidden_dim) self.mask_embedding = nn.Embedding(2, hidden_dim) # Transformer embeddings - exactly as original EMBEDDING_DIM = 96 self.transformer = nn.Module() self.transformer.time_embed = nn.Embedding(11, EMBEDDING_DIM // 3) self.transformer.view_embed = nn.Embedding(3, EMBEDDING_DIM // 3) # 0=CC, 1=MLO, 2=PAD self.transformer.side_embed = nn.Embedding(3, EMBEDDING_DIM // 3) # 0=L, 1=R, 2=PAD self.transformer.embed_add_fc = nn.Linear(EMBEDDING_DIM, hidden_dim) self.transformer.embed_scale_fc = nn.Linear(EMBEDDING_DIM, hidden_dim) # Transformer layer 0 - exactly as original self.transformer.transformer_layer_0 = TransformerLayer(hidden_dim, num_heads, dropout) self.pred_masked_img_fc = nn.Linear(hidden_dim, hidden_dim) # Risk factor pool self.pool = nn.Module() self.pool.internal_pool = nn.Module() self.pool.internal_pool.attention_fc = nn.Linear(hidden_dim, 1) # All risk factor FCs with correct dimensions risk_factor_dims = { 'density': (4, 512), 'binary_family_history': (1, 512), 'binary_biopsy_benign': (1, 512), 'binary_biopsy_LCIS': (1, 512), 'binary_biopsy_atypical_hyperplasia': (1, 512), 'age': (6, 512), 'menarche_age': (5, 512), 'menopause_age': (5, 512), 'first_pregnancy_age': (6, 512), 'prior_hist': (1, 512), 'race': (13, 512), 'parous': (1, 512), 'menopausal_status': (4, 512), 'weight': (7, 512), 'height': (7, 512), 'ovarian_cancer': (1, 512), 'ovarian_cancer_age': (6, 512), 'ashkenazi': (1, 512), 'brca': (4, 512), 'mom_bc_cancer_history': (1, 512), 'm_aunt_bc_cancer_history': (1, 512), 'p_aunt_bc_cancer_history': (1, 512), 'm_grandmother_bc_cancer_history': (1, 512), 'p_grantmother_bc_cancer_history': (1, 512), 'sister_bc_cancer_history': (1, 512), 'mom_oc_cancer_history': (1, 512), 'm_aunt_oc_cancer_history': (1, 512), 'p_aunt_oc_cancer_history': (1, 512), 'm_grandmother_oc_cancer_history': (1, 512), 'p_grantmother_oc_cancer_history': (1, 512), 'sister_oc_cancer_history': (1, 512), 'hrt_type': (3, 512), 'hrt_duration': (5, 512), 'hrt_years_ago_stopped': (5, 512) } for rf_name, (out_dim, in_dim) in risk_factor_dims.items(): setattr(self.pool, f'{rf_name}_fc', nn.Linear(in_dim, out_dim)) self.fc = nn.Linear(612, 2) # Takes concatenated features + risk factors # Probability of failure layer self.prob_of_failure_layer = nn.Module() self.prob_of_failure_layer.hazard_fc = nn.Linear(612, 5) # 5 time points self.prob_of_failure_layer.base_hazard_fc = nn.Linear(612, 1) # Register as buffer, not parameter upper_triagular_mask = torch.ones(5, 5).triu() self.prob_of_failure_layer.register_buffer('upper_triagular_mask', upper_triagular_mask) def forward(self, hidden_states, batch_metadata=None, risk_factors=None): B, N, D = hidden_states.size() # Project hidden_states = self.projection_layer(hidden_states) # Get metadata if batch_metadata is not None: time_seq = batch_metadata.get('time_seq', torch.zeros(B, N).long()) view_seq = batch_metadata.get('view_seq', torch.zeros(B, N).long()) side_seq = batch_metadata.get('side_seq', torch.zeros(B, N).long()) # Add embeddings view = self.transformer.view_embed(view_seq) time = self.transformer.time_embed(time_seq) side = self.transformer.side_embed(side_seq) pos_embed = torch.cat([view, time, side], dim=-1) # Condition on positional embeddings hidden_states = (self.transformer.embed_scale_fc(pos_embed) * hidden_states + self.transformer.embed_add_fc(pos_embed)) # Transformer layer hidden_states = self.transformer.transformer_layer_0(hidden_states) # Pool pooled = hidden_states.mean(dim=1) # Concatenate risk factors if provided, otherwise use zeros # The FC layer expects 612 dims total: 512 image + 100 risk factors risk_factor_dims = 100 if risk_factors is not None: # Pad risk factors to 100 dimensions if needed current_rf_dims = risk_factors.shape[-1] if current_rf_dims < risk_factor_dims: # Pad with zeros to reach 100 dimensions padding = torch.zeros(B, risk_factor_dims - current_rf_dims, device=risk_factors.device, dtype=risk_factors.dtype) risk_factors = torch.cat([risk_factors, padding], dim=-1) elif current_rf_dims > risk_factor_dims: # Truncate if somehow we have more than 100 risk_factors = risk_factors[:, :risk_factor_dims] pooled = torch.cat([pooled, risk_factors], dim=-1) else: # Use zeros for risk factors if not provided (100 dims) dummy_risk_factors = torch.zeros(B, risk_factor_dims, device=pooled.device, dtype=pooled.dtype) pooled = torch.cat([pooled, dummy_risk_factors], dim=-1) # Final FC - now pooled is 612-dimensional (512 + 100) logits = self.fc(pooled) return logits, pooled class MiraiModel(PreTrainedModel): """Main MIRAI model with exact original architecture.""" config_class = MiraiConfig base_model_prefix = "mirai" def __init__(self, config): super().__init__(config) self.config = config self.encoder = MiraiEncoder(config) self.transformer = MiraiTransformer(config) # Initialize weights self.post_init() def forward( self, images: torch.FloatTensor, risk_factors: Optional[torch.FloatTensor] = None, batch_metadata: Optional[Dict[str, torch.Tensor]] = None, return_dict: bool = True ): batch_size, num_views, channels, height, width = images.shape # Reshape to process all views at once images_flat = images.view(batch_size * num_views, channels, height, width) # Encode all views encoded = self.encoder(images_flat) # Reshape back to [batch_size, num_views, hidden_size] hidden_states = encoded.view(batch_size, num_views, -1) # Apply transformer - pass risk_factors through logits, pooled = self.transformer(hidden_states, batch_metadata, risk_factors) # Apply sigmoid for probabilities probabilities = torch.sigmoid(logits) if return_dict: return MiraiOutput( logits=logits, probabilities=probabilities, hidden_states=pooled ) return (logits, probabilities, pooled) @classmethod def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs): """Load pretrained weights.""" config = kwargs.pop('config', None) if config is None: config = MiraiConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) # Initialize model model = cls(config) # Try to load weights try: # Load pytorch_model.bin model_path = os.path.join(pretrained_model_name_or_path, "pytorch_model.bin") if os.path.exists(model_path): state_dict = torch.load(model_path, map_location='cpu') # The weights should now match exactly missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False) if not missing_keys and not unexpected_keys: print(f"Successfully loaded all {len(state_dict)} weights from {pretrained_model_name_or_path}") else: loaded = len(state_dict) - len(unexpected_keys) if missing_keys: print(f"Missing {len(missing_keys)} keys") print(f"First few missing: {missing_keys[:5]}") if unexpected_keys: print(f"Unexpected {len(unexpected_keys)} keys") print(f"First few unexpected: {unexpected_keys[:5]}") print(f"Loaded {loaded}/{len(model.state_dict())} weights") else: warnings.warn(f"No weights found at {pretrained_model_name_or_path}. Using random initialization.") except Exception as e: warnings.warn(f"Error loading weights: {e}. Using random initialization.") return model def save_pretrained(self, save_directory: str, **kwargs): """Save the model to a directory.""" super().save_pretrained(save_directory, **kwargs)