""" Self-contained Hugging Face wrapper for Sybil lung cancer risk prediction model. This version works directly from HF without requiring external Sybil package. """ import os import json import sys import torch import numpy as np from typing import List, Dict, Optional from dataclasses import dataclass from transformers.modeling_outputs import BaseModelOutput from safetensors.torch import load_file # Add model path to sys.path for imports current_dir = os.path.dirname(os.path.abspath(__file__)) if current_dir not in sys.path: sys.path.insert(0, current_dir) try: from .configuration_sybil import SybilConfig from .modeling_sybil import SybilForRiskPrediction from .image_processing_sybil import SybilImageProcessor except ImportError: from configuration_sybil import SybilConfig from modeling_sybil import SybilForRiskPrediction from image_processing_sybil import SybilImageProcessor @dataclass class SybilOutput(BaseModelOutput): """ Output class for Sybil model predictions. Args: risk_scores: Risk scores for each year (1-6 years by default) attentions: Optional attention maps if requested """ risk_scores: torch.FloatTensor = None attentions: Optional[Dict] = None class SybilHFWrapper: """ Hugging Face wrapper for Sybil ensemble model. Provides a simple interface for lung cancer risk prediction from CT scans. """ def __init__(self, config: SybilConfig = None): """ Initialize the Sybil model ensemble. Args: config: Model configuration (will use default if not provided) """ self.config = config if config is not None else SybilConfig() self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Get the directory where this file is located self.model_dir = os.path.dirname(os.path.abspath(__file__)) # Initialize image processor self.image_processor = SybilImageProcessor() # Load calibrator self.calibrator = self._load_calibrator() # Load ensemble models self.models = self._load_ensemble_models() def _load_calibrator(self) -> Dict: """Load ensemble calibrator data""" calibrator_path = os.path.join(self.model_dir, "checkpoints", "sybil_ensemble_simple_calibrator.json") if os.path.exists(calibrator_path): with open(calibrator_path, 'r') as f: return json.load(f) else: # Try alternative location calibrator_path = os.path.join(self.model_dir, "calibrator_data.json") if os.path.exists(calibrator_path): with open(calibrator_path, 'r') as f: return json.load(f) return {} def _load_ensemble_models(self) -> List[torch.nn.Module]: """Load all models in the ensemble from safetensors files""" models = [] # Load each model in the ensemble (Sybil uses 5 models) for i in range(1, 6): model_subdir = os.path.join(self.model_dir, f"sybil_{i}") weights_path = os.path.join(model_subdir, "model.safetensors") if os.path.exists(weights_path): # Create model instance model = SybilForRiskPrediction(self.config) # Load weights from safetensors try: state_dict = load_file(weights_path) model.load_state_dict(state_dict, strict=False) except Exception as e: print(f"Warning: Could not load weights for sybil_{i}: {e}") continue model.to(self.device) model.eval() models.append(model) else: # Try loading from checkpoints directory checkpoint_path = os.path.join(self.model_dir, "checkpoints", f"sybil_{i}.ckpt") if os.path.exists(checkpoint_path): model = SybilForRiskPrediction(self.config) checkpoint = torch.load(checkpoint_path, map_location='cpu') # Extract state dict if 'state_dict' in checkpoint: state_dict = checkpoint['state_dict'] else: state_dict = checkpoint # Remove 'model.' prefix if present cleaned_state_dict = {} for k, v in state_dict.items(): if k.startswith('model.'): cleaned_state_dict[k[6:]] = v else: cleaned_state_dict[k] = v model.load_state_dict(cleaned_state_dict, strict=False) model.to(self.device) model.eval() models.append(model) if not models: raise ValueError("No models could be loaded from the ensemble. Please ensure model files are present.") print(f"Loaded {len(models)} models in ensemble") return models def _apply_calibration(self, scores: np.ndarray) -> np.ndarray: """ Apply calibration to raw model outputs. Args: scores: Raw risk scores from the model Returns: Calibrated risk scores """ if not self.calibrator: return scores calibrated = np.zeros_like(scores) for year in range(scores.shape[1]): year_key = f"Year{year + 1}" if year_key in self.calibrator: cal_data = self.calibrator[year_key] if isinstance(cal_data, list) and len(cal_data) > 0: cal_data = cal_data[0] # Apply linear calibration if available if isinstance(cal_data, dict) and "coef" in cal_data and "intercept" in cal_data: coef = cal_data["coef"][0][0] if isinstance(cal_data["coef"], list) else cal_data["coef"] intercept = cal_data["intercept"][0] if isinstance(cal_data["intercept"], list) else cal_data["intercept"] # Apply calibration calibrated[:, year] = scores[:, year] * coef + intercept calibrated[:, year] = 1 / (1 + np.exp(-calibrated[:, year])) # Sigmoid else: calibrated[:, year] = scores[:, year] else: calibrated[:, year] = scores[:, year] return calibrated def preprocess_dicom(self, dicom_paths: List[str]) -> torch.Tensor: """ Preprocess DICOM files for model input. Args: dicom_paths: List of paths to DICOM files Returns: Preprocessed tensor ready for model input """ # Use the image processor to handle DICOM files result = self.image_processor(dicom_paths, file_type="dicom", return_tensors="pt") pixel_values = result["pixel_values"] # Ensure we have 5D tensor (B, C, D, H, W) if pixel_values.ndim == 4: pixel_values = pixel_values.unsqueeze(0) # Add batch dimension return pixel_values.to(self.device) def predict(self, dicom_paths: List[str], return_attentions: bool = False) -> SybilOutput: """ Run prediction on a CT scan series. Args: dicom_paths: List of paths to DICOM files for a single CT series return_attentions: Whether to return attention maps Returns: SybilOutput with risk scores and optional attention maps """ # Preprocess the DICOM files pixel_values = self.preprocess_dicom(dicom_paths) # Run inference with ensemble all_predictions = [] all_attentions = [] with torch.no_grad(): for model in self.models: output = model( pixel_values=pixel_values, return_attentions=return_attentions ) # Extract risk scores if hasattr(output, 'risk_scores'): predictions = output.risk_scores else: predictions = output[0] if isinstance(output, tuple) else output all_predictions.append(predictions.cpu().numpy()) if return_attentions and hasattr(output, 'image_attention'): all_attentions.append(output.image_attention) # Average ensemble predictions ensemble_pred = np.mean(all_predictions, axis=0) # Apply calibration calibrated_pred = self._apply_calibration(ensemble_pred) # Convert back to torch tensor risk_scores = torch.from_numpy(calibrated_pred).float() # Average attentions if requested attentions = None if return_attentions and all_attentions: attentions = {"image_attention": torch.stack(all_attentions).mean(dim=0)} return SybilOutput(risk_scores=risk_scores, attentions=attentions) def __call__(self, dicom_paths: List[str] = None, dicom_series: List[List[str]] = None, **kwargs) -> SybilOutput: """ Convenience method for prediction. Args: dicom_paths: List of DICOM file paths for a single series dicom_series: List of lists of DICOM paths for batch processing **kwargs: Additional arguments passed to predict() Returns: SybilOutput with predictions """ if dicom_series is not None: # Batch processing all_outputs = [] for paths in dicom_series: output = self.predict(paths, **kwargs) all_outputs.append(output.risk_scores) risk_scores = torch.stack(all_outputs) return SybilOutput(risk_scores=risk_scores) elif dicom_paths is not None: return self.predict(dicom_paths, **kwargs) else: raise ValueError("Either dicom_paths or dicom_series must be provided") @classmethod def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs): """ Load model from Hugging Face hub or local path. Args: pretrained_model_name_or_path: HF model ID or local path **kwargs: Additional configuration arguments Returns: SybilHFWrapper instance """ # Load configuration config = kwargs.pop("config", None) if config is None: try: config = SybilConfig.from_pretrained(pretrained_model_name_or_path) except: config = SybilConfig() return cls(config=config)