|
"""PyTorch Sybil model for lung cancer risk prediction""" |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torchvision |
|
from transformers import PreTrainedModel |
|
from transformers.modeling_outputs import BaseModelOutput |
|
from typing import Optional, Dict, List, Tuple |
|
import numpy as np |
|
from dataclasses import dataclass |
|
|
|
try: |
|
from .configuration_sybil import SybilConfig |
|
except ImportError: |
|
from configuration_sybil import SybilConfig |
|
|
|
|
|
@dataclass |
|
class SybilOutput(BaseModelOutput): |
|
""" |
|
Base class for Sybil model outputs. |
|
|
|
Args: |
|
risk_scores: (`torch.FloatTensor` of shape `(batch_size, max_followup)`): |
|
Predicted risk scores for each year up to max_followup. |
|
image_attention: (`torch.FloatTensor` of shape `(batch_size, num_slices, height, width)`, *optional*): |
|
Attention weights over image pixels. |
|
volume_attention: (`torch.FloatTensor` of shape `(batch_size, num_slices)`, *optional*): |
|
Attention weights over CT scan slices. |
|
hidden_states: (`torch.FloatTensor` of shape `(batch_size, hidden_dim)`, *optional*): |
|
Hidden states from the pooling layer. |
|
""" |
|
risk_scores: torch.FloatTensor = None |
|
image_attention: Optional[torch.FloatTensor] = None |
|
volume_attention: Optional[torch.FloatTensor] = None |
|
hidden_states: Optional[torch.FloatTensor] = None |
|
|
|
|
|
class CumulativeProbabilityLayer(nn.Module): |
|
"""Cumulative probability layer for survival prediction""" |
|
|
|
def __init__(self, hidden_dim: int, max_followup: int = 6): |
|
super().__init__() |
|
self.max_followup = max_followup |
|
self.fc = nn.Linear(hidden_dim, max_followup) |
|
|
|
def forward(self, x): |
|
logits = self.fc(x) |
|
|
|
cumsum = torch.cumsum(torch.sigmoid(logits), dim=-1) |
|
|
|
return cumsum / self.max_followup |
|
|
|
|
|
class MultiAttentionPool(nn.Module): |
|
"""Multi-attention pooling layer for CT scan aggregation""" |
|
|
|
def __init__(self, channels: int = 512): |
|
super().__init__() |
|
self.channels = channels |
|
|
|
|
|
self.volume_attention = nn.Sequential( |
|
nn.Conv3d(channels, 128, kernel_size=1), |
|
nn.ReLU(), |
|
nn.Conv3d(128, 1, kernel_size=1) |
|
) |
|
|
|
|
|
self.image_attention = nn.Sequential( |
|
nn.Conv3d(channels, 128, kernel_size=1), |
|
nn.ReLU(), |
|
nn.Conv3d(128, 1, kernel_size=1) |
|
) |
|
|
|
def forward(self, x): |
|
batch_size = x.shape[0] |
|
|
|
|
|
volume_att = self.volume_attention(x) |
|
image_att = self.image_attention(x) |
|
|
|
|
|
volume_att_flat = volume_att.view(batch_size, -1) |
|
volume_att_weights = torch.softmax(volume_att_flat, dim=-1) |
|
volume_att_weights = volume_att_weights.view_as(volume_att) |
|
|
|
image_att_2d = image_att.squeeze(1) |
|
for i in range(image_att_2d.shape[1]): |
|
slice_att = image_att_2d[:, i, :, :].contiguous() |
|
slice_att_flat = slice_att.view(batch_size, -1) |
|
slice_att_weights = torch.softmax(slice_att_flat, dim=-1) |
|
image_att_2d[:, i, :, :] = slice_att_weights.view_as(slice_att) |
|
image_att = image_att_2d.unsqueeze(1) |
|
|
|
|
|
attended = x * volume_att_weights * image_att |
|
hidden = attended.mean(dim=[2, 3, 4]) |
|
|
|
return { |
|
'hidden': hidden, |
|
'volume_attention_1': volume_att_weights.squeeze(1), |
|
'image_attention_1': image_att.squeeze(1) |
|
} |
|
|
|
|
|
class SybilPreTrainedModel(PreTrainedModel): |
|
""" |
|
An abstract class to handle weights initialization and a simple interface |
|
for downloading and loading pretrained models. |
|
""" |
|
config_class = SybilConfig |
|
base_model_prefix = "sybil" |
|
supports_gradient_checkpointing = False |
|
|
|
def _init_weights(self, module): |
|
"""Initialize the weights""" |
|
if isinstance(module, nn.Linear): |
|
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) |
|
if module.bias is not None: |
|
module.bias.data.zero_() |
|
elif isinstance(module, nn.Conv3d): |
|
nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu') |
|
if module.bias is not None: |
|
module.bias.data.zero_() |
|
|
|
|
|
class SybilForRiskPrediction(SybilPreTrainedModel): |
|
""" |
|
Sybil model for lung cancer risk prediction from CT scans. |
|
|
|
This model takes 3D CT scan volumes as input and predicts cancer risk scores |
|
for multiple future time points (typically 1-6 years). |
|
""" |
|
|
|
def __init__(self, config: SybilConfig): |
|
super().__init__(config) |
|
self.config = config |
|
|
|
|
|
encoder = torchvision.models.video.r3d_18(pretrained=True) |
|
self.image_encoder = nn.Sequential(*list(encoder.children())[:-2]) |
|
|
|
|
|
self.pool = MultiAttentionPool(channels=512) |
|
|
|
|
|
self.relu = nn.ReLU(inplace=False) |
|
self.dropout = nn.Dropout(p=config.dropout) |
|
|
|
|
|
self.prob_of_failure_layer = CumulativeProbabilityLayer( |
|
config.hidden_dim, |
|
max_followup=config.max_followup |
|
) |
|
|
|
|
|
self.calibrator = None |
|
if config.calibrator_data: |
|
self.set_calibrator(config.calibrator_data) |
|
|
|
|
|
self.post_init() |
|
|
|
def set_calibrator(self, calibrator_data: Dict): |
|
"""Set calibration data for risk score adjustment""" |
|
self.calibrator = calibrator_data |
|
|
|
def _calibrate_scores(self, scores: torch.Tensor) -> torch.Tensor: |
|
"""Apply calibration to raw risk scores""" |
|
if self.calibrator is None: |
|
return scores |
|
|
|
|
|
scores_np = scores.detach().cpu().numpy() |
|
calibrated = np.zeros_like(scores_np) |
|
|
|
|
|
for year in range(scores_np.shape[1]): |
|
year_key = f"Year{year + 1}" |
|
if year_key in self.calibrator: |
|
|
|
calibrated[:, year] = self._apply_calibration( |
|
scores_np[:, year], |
|
self.calibrator[year_key] |
|
) |
|
else: |
|
calibrated[:, year] = scores_np[:, year] |
|
|
|
return torch.from_numpy(calibrated).to(scores.device) |
|
|
|
def _apply_calibration(self, scores: np.ndarray, calibrator_params: Dict) -> np.ndarray: |
|
"""Apply specific calibration transformation""" |
|
|
|
|
|
return scores |
|
|
|
def forward( |
|
self, |
|
pixel_values: torch.FloatTensor, |
|
return_attentions: bool = False, |
|
return_dict: bool = True, |
|
) -> SybilOutput: |
|
""" |
|
Forward pass of the Sybil model. |
|
|
|
Args: |
|
pixel_values: (`torch.FloatTensor` of shape `(batch_size, channels, depth, height, width)`): |
|
Pixel values of CT scan volumes. |
|
return_attentions: (`bool`, *optional*, defaults to `False`): |
|
Whether to return attention weights. |
|
return_dict: (`bool`, *optional*, defaults to `True`): |
|
Whether to return a `SybilOutput` instead of a plain tuple. |
|
|
|
Returns: |
|
`SybilOutput` or tuple |
|
""" |
|
|
|
features = self.image_encoder(pixel_values) |
|
|
|
|
|
pool_output = self.pool(features) |
|
|
|
|
|
hidden = self.relu(pool_output['hidden']) |
|
hidden = self.dropout(hidden) |
|
|
|
|
|
risk_logits = self.prob_of_failure_layer(hidden) |
|
risk_scores = torch.sigmoid(risk_logits) |
|
|
|
|
|
risk_scores = self._calibrate_scores(risk_scores) |
|
|
|
if not return_dict: |
|
outputs = (risk_scores,) |
|
if return_attentions: |
|
outputs = outputs + (pool_output.get('image_attention_1'), |
|
pool_output.get('volume_attention_1')) |
|
return outputs |
|
|
|
return SybilOutput( |
|
risk_scores=risk_scores, |
|
image_attention=pool_output.get('image_attention_1') if return_attentions else None, |
|
volume_attention=pool_output.get('volume_attention_1') if return_attentions else None, |
|
hidden_states=hidden if return_attentions else None |
|
) |
|
|
|
@classmethod |
|
def from_pretrained_ensemble( |
|
cls, |
|
pretrained_model_name_or_path, |
|
checkpoint_paths: List[str], |
|
calibrator_path: Optional[str] = None, |
|
**kwargs |
|
): |
|
""" |
|
Load an ensemble of Sybil models from checkpoints. |
|
|
|
Args: |
|
pretrained_model_name_or_path: Path to the pretrained model or model identifier. |
|
checkpoint_paths: List of paths to individual model checkpoints. |
|
calibrator_path: Path to calibration data. |
|
**kwargs: Additional keyword arguments for model initialization. |
|
|
|
Returns: |
|
SybilEnsemble: An ensemble of Sybil models. |
|
""" |
|
config = kwargs.pop("config", None) |
|
if config is None: |
|
config = SybilConfig.from_pretrained(pretrained_model_name_or_path) |
|
|
|
|
|
calibrator_data = None |
|
if calibrator_path: |
|
import json |
|
with open(calibrator_path, 'r') as f: |
|
calibrator_data = json.load(f) |
|
config.calibrator_data = calibrator_data |
|
|
|
|
|
models = [] |
|
for checkpoint_path in checkpoint_paths: |
|
model = cls(config) |
|
|
|
checkpoint = torch.load(checkpoint_path, map_location='cpu') |
|
|
|
state_dict = {} |
|
for k, v in checkpoint['state_dict'].items(): |
|
if k.startswith('model.'): |
|
state_dict[k[6:]] = v |
|
else: |
|
state_dict[k] = v |
|
|
|
|
|
mapped_state_dict = model._map_checkpoint_weights(state_dict) |
|
model.load_state_dict(mapped_state_dict, strict=False) |
|
models.append(model) |
|
|
|
return SybilEnsemble(models, config) |
|
|
|
def _map_checkpoint_weights(self, state_dict: Dict) -> Dict: |
|
"""Map original Sybil checkpoint weights to new structure""" |
|
mapped = {} |
|
|
|
|
|
for k, v in state_dict.items(): |
|
if k.startswith('image_encoder'): |
|
mapped[k] = v |
|
elif k.startswith('pool'): |
|
|
|
mapped[k] = v |
|
elif k.startswith('prob_of_failure_layer'): |
|
|
|
mapped[k] = v |
|
|
|
return mapped |
|
|
|
|
|
class SybilEnsemble: |
|
"""Ensemble of Sybil models for improved predictions""" |
|
|
|
def __init__(self, models: List[SybilForRiskPrediction], config: SybilConfig): |
|
self.models = models |
|
self.config = config |
|
self.device = None |
|
|
|
def to(self, device): |
|
"""Move all models to device""" |
|
self.device = device |
|
for model in self.models: |
|
model.to(device) |
|
return self |
|
|
|
def eval(self): |
|
"""Set all models to evaluation mode""" |
|
for model in self.models: |
|
model.eval() |
|
|
|
def __call__( |
|
self, |
|
pixel_values: torch.FloatTensor, |
|
return_attentions: bool = False, |
|
) -> SybilOutput: |
|
""" |
|
Run inference with ensemble voting. |
|
|
|
Args: |
|
pixel_values: Input CT scan volumes. |
|
return_attentions: Whether to return attention maps. |
|
|
|
Returns: |
|
SybilOutput with averaged predictions from all models. |
|
""" |
|
all_risk_scores = [] |
|
all_image_attentions = [] |
|
all_volume_attentions = [] |
|
|
|
with torch.no_grad(): |
|
for model in self.models: |
|
output = model( |
|
pixel_values=pixel_values, |
|
return_attentions=return_attentions |
|
) |
|
all_risk_scores.append(output.risk_scores) |
|
|
|
if return_attentions: |
|
all_image_attentions.append(output.image_attention) |
|
all_volume_attentions.append(output.volume_attention) |
|
|
|
|
|
risk_scores = torch.stack(all_risk_scores).mean(dim=0) |
|
|
|
|
|
image_attention = None |
|
volume_attention = None |
|
if return_attentions: |
|
image_attention = torch.stack(all_image_attentions).mean(dim=0) |
|
volume_attention = torch.stack(all_volume_attentions).mean(dim=0) |
|
|
|
return SybilOutput( |
|
risk_scores=risk_scores, |
|
image_attention=image_attention, |
|
volume_attention=volume_attention |
|
) |