Aakash-Tripathi commited on
Commit
aaaccd9
·
verified ·
1 Parent(s): 320b436

Update modeling_sybil_wrapper.py

Browse files
Files changed (1) hide show
  1. modeling_sybil_wrapper.py +253 -95
modeling_sybil_wrapper.py CHANGED
@@ -1,140 +1,298 @@
1
  """
2
- Simplified Hugging Face wrapper for original Sybil model
3
- This ensures full compatibility with the original implementation
4
  """
5
 
6
  import os
7
- import sys
8
  import json
 
9
  import torch
10
- import torch.nn as nn
11
- from typing import Optional, List, Dict
12
- from transformers import PreTrainedModel
13
  from dataclasses import dataclass
14
  from transformers.modeling_outputs import BaseModelOutput
 
15
 
16
- # Add original Sybil to path
17
- sys.path.append('/mnt/f/Projects/hfsybil/Sybil')
18
- from sybil import Sybil as OriginalSybil
19
- from sybil import Serie
20
 
21
  try:
22
  from .configuration_sybil import SybilConfig
 
 
23
  except ImportError:
24
  from configuration_sybil import SybilConfig
 
 
25
 
26
 
27
  @dataclass
28
  class SybilOutput(BaseModelOutput):
29
  """
30
- Output class for Sybil model.
 
 
 
 
31
  """
32
  risk_scores: torch.FloatTensor = None
33
  attentions: Optional[Dict] = None
34
 
35
 
36
- class SybilHFWrapper(PreTrainedModel):
37
  """
38
- Hugging Face wrapper around the original Sybil model.
39
- This ensures complete compatibility while providing HF interface.
40
  """
41
- config_class = SybilConfig
42
- base_model_prefix = "sybil"
43
-
44
- def __init__(self, config: SybilConfig):
45
- super().__init__(config)
46
- self.config = config
47
-
48
- # Load the original Sybil model with ensemble
49
- checkpoint_dir = "/mnt/f/Projects/hfsybil/checkpoints"
50
-
51
- # Copy checkpoints to ~/.sybil if needed
52
- cache_dir = os.path.expanduser("~/.sybil")
53
- os.makedirs(cache_dir, exist_ok=True)
54
-
55
- # Map of checkpoint files
56
- checkpoint_files = {
57
- "28a7cd44f5bcd3e6cc760b65c7e0d54d.ckpt": "sybil_1",
58
- "56ce1a7d241dc342982f5466c4a9d7ef.ckpt": "sybil_2",
59
- "624407ef8e3a2a009f9fa51f9846fe9a.ckpt": "sybil_3",
60
- "64a91b25f84141d32852e75a3aec7305.ckpt": "sybil_4",
61
- "65fd1f04cb4c5847d86a9ed8ba31ac1a.ckpt": "sybil_5",
62
- "sybil_ensemble_simple_calibrator.json": "ensemble_calibrator"
63
- }
64
-
65
- # Copy checkpoint files
66
- for filename in checkpoint_files.keys():
67
- src = os.path.join(checkpoint_dir, filename)
68
- dst = os.path.join(cache_dir, filename)
69
- if os.path.exists(src) and not os.path.exists(dst):
70
- import shutil
71
- shutil.copy2(src, dst)
72
-
73
- # Initialize the original model
74
- self.sybil_model = OriginalSybil("sybil_ensemble")
75
-
76
- def forward(
77
- self,
78
- pixel_values: torch.FloatTensor = None,
79
- dicom_paths: List[str] = None,
80
- return_attentions: bool = False,
81
- **kwargs
82
- ) -> SybilOutput:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
  """
84
- Forward pass using original Sybil model.
85
 
86
  Args:
87
- pixel_values: Pre-processed tensor (not used directly, for compatibility)
88
- dicom_paths: List of DICOM file paths
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
  return_attentions: Whether to return attention maps
90
 
91
  Returns:
92
- SybilOutput with risk scores and optional attentions
93
  """
 
 
94
 
95
- if dicom_paths is None:
96
- raise ValueError("dicom_paths must be provided")
 
97
 
98
- # Create Serie object
99
- serie = Serie(dicom_paths)
 
 
 
 
100
 
101
- # Run prediction
102
- prediction = self.sybil_model.predict([serie], return_attentions=return_attentions)
 
 
 
103
 
104
- # Convert to torch tensors
105
- risk_scores = torch.tensor(prediction.scores[0])
106
 
107
- return SybilOutput(
108
- risk_scores=risk_scores,
109
- attentions=prediction.attentions[0] if return_attentions else None
110
- )
111
 
112
- @classmethod
113
- def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114
  """
115
- Load the model. Since we're using the original Sybil,
116
- we just need to ensure the checkpoints are available.
 
 
 
 
 
 
 
117
  """
118
- config = kwargs.pop("config", None)
119
- if config is None:
120
- config = SybilConfig.from_pretrained(pretrained_model_name_or_path)
 
 
 
121
 
122
- return cls(config)
 
 
 
 
 
123
 
124
- def save_pretrained(self, save_directory, **kwargs):
 
125
  """
126
- Save the model configuration.
127
- The actual model weights are handled by the original Sybil.
 
 
 
 
 
 
128
  """
129
- os.makedirs(save_directory, exist_ok=True)
130
- self.config.save_pretrained(save_directory)
131
-
132
- # Save info about checkpoint locations
133
- info = {
134
- "model_type": "sybil_wrapper",
135
- "checkpoint_dir": "/mnt/f/Projects/hfsybil/checkpoints",
136
- "note": "This model uses the original Sybil implementation"
137
- }
138
-
139
- with open(os.path.join(save_directory, "model_info.json"), "w") as f:
140
- json.dump(info, f, indent=2)
 
1
  """
2
+ Self-contained Hugging Face wrapper for Sybil lung cancer risk prediction model.
3
+ This version works directly from HF without requiring external Sybil package.
4
  """
5
 
6
  import os
 
7
  import json
8
+ import sys
9
  import torch
10
+ import numpy as np
11
+ from typing import List, Dict, Optional
 
12
  from dataclasses import dataclass
13
  from transformers.modeling_outputs import BaseModelOutput
14
+ from safetensors.torch import load_file
15
 
16
+ # Add model path to sys.path for imports
17
+ current_dir = os.path.dirname(os.path.abspath(__file__))
18
+ if current_dir not in sys.path:
19
+ sys.path.insert(0, current_dir)
20
 
21
  try:
22
  from .configuration_sybil import SybilConfig
23
+ from .modeling_sybil import SybilForRiskPrediction
24
+ from .image_processing_sybil import SybilImageProcessor
25
  except ImportError:
26
  from configuration_sybil import SybilConfig
27
+ from modeling_sybil import SybilForRiskPrediction
28
+ from image_processing_sybil import SybilImageProcessor
29
 
30
 
31
  @dataclass
32
  class SybilOutput(BaseModelOutput):
33
  """
34
+ Output class for Sybil model predictions.
35
+
36
+ Args:
37
+ risk_scores: Risk scores for each year (1-6 years by default)
38
+ attentions: Optional attention maps if requested
39
  """
40
  risk_scores: torch.FloatTensor = None
41
  attentions: Optional[Dict] = None
42
 
43
 
44
+ class SybilHFWrapper:
45
  """
46
+ Hugging Face wrapper for Sybil ensemble model.
47
+ Provides a simple interface for lung cancer risk prediction from CT scans.
48
  """
49
+
50
+ def __init__(self, config: SybilConfig = None):
51
+ """
52
+ Initialize the Sybil model ensemble.
53
+
54
+ Args:
55
+ config: Model configuration (will use default if not provided)
56
+ """
57
+ self.config = config if config is not None else SybilConfig()
58
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
59
+
60
+ # Get the directory where this file is located
61
+ self.model_dir = os.path.dirname(os.path.abspath(__file__))
62
+
63
+ # Initialize image processor
64
+ self.image_processor = SybilImageProcessor()
65
+
66
+ # Load calibrator
67
+ self.calibrator = self._load_calibrator()
68
+
69
+ # Load ensemble models
70
+ self.models = self._load_ensemble_models()
71
+
72
+ def _load_calibrator(self) -> Dict:
73
+ """Load ensemble calibrator data"""
74
+ calibrator_path = os.path.join(self.model_dir, "checkpoints", "sybil_ensemble_simple_calibrator.json")
75
+
76
+ if os.path.exists(calibrator_path):
77
+ with open(calibrator_path, 'r') as f:
78
+ return json.load(f)
79
+ else:
80
+ # Try alternative location
81
+ calibrator_path = os.path.join(self.model_dir, "calibrator_data.json")
82
+ if os.path.exists(calibrator_path):
83
+ with open(calibrator_path, 'r') as f:
84
+ return json.load(f)
85
+ return {}
86
+
87
+ def _load_ensemble_models(self) -> List[torch.nn.Module]:
88
+ """Load all models in the ensemble from safetensors files"""
89
+ models = []
90
+
91
+ # Load each model in the ensemble (Sybil uses 5 models)
92
+ for i in range(1, 6):
93
+ model_subdir = os.path.join(self.model_dir, f"sybil_{i}")
94
+ weights_path = os.path.join(model_subdir, "model.safetensors")
95
+
96
+ if os.path.exists(weights_path):
97
+ # Create model instance
98
+ model = SybilForRiskPrediction(self.config)
99
+
100
+ # Load weights from safetensors
101
+ try:
102
+ state_dict = load_file(weights_path)
103
+ model.load_state_dict(state_dict, strict=False)
104
+ except Exception as e:
105
+ print(f"Warning: Could not load weights for sybil_{i}: {e}")
106
+ continue
107
+
108
+ model.to(self.device)
109
+ model.eval()
110
+ models.append(model)
111
+ else:
112
+ # Try loading from checkpoints directory
113
+ checkpoint_path = os.path.join(self.model_dir, "checkpoints", f"sybil_{i}.ckpt")
114
+ if os.path.exists(checkpoint_path):
115
+ model = SybilForRiskPrediction(self.config)
116
+ checkpoint = torch.load(checkpoint_path, map_location='cpu')
117
+
118
+ # Extract state dict
119
+ if 'state_dict' in checkpoint:
120
+ state_dict = checkpoint['state_dict']
121
+ else:
122
+ state_dict = checkpoint
123
+
124
+ # Remove 'model.' prefix if present
125
+ cleaned_state_dict = {}
126
+ for k, v in state_dict.items():
127
+ if k.startswith('model.'):
128
+ cleaned_state_dict[k[6:]] = v
129
+ else:
130
+ cleaned_state_dict[k] = v
131
+
132
+ model.load_state_dict(cleaned_state_dict, strict=False)
133
+ model.to(self.device)
134
+ model.eval()
135
+ models.append(model)
136
+
137
+ if not models:
138
+ raise ValueError("No models could be loaded from the ensemble. Please ensure model files are present.")
139
+
140
+ print(f"Loaded {len(models)} models in ensemble")
141
+ return models
142
+
143
+ def _apply_calibration(self, scores: np.ndarray) -> np.ndarray:
144
+ """
145
+ Apply calibration to raw model outputs.
146
+
147
+ Args:
148
+ scores: Raw risk scores from the model
149
+
150
+ Returns:
151
+ Calibrated risk scores
152
+ """
153
+ if not self.calibrator:
154
+ return scores
155
+
156
+ calibrated = np.zeros_like(scores)
157
+
158
+ for year in range(scores.shape[1]):
159
+ year_key = f"Year{year + 1}"
160
+ if year_key in self.calibrator:
161
+ cal_data = self.calibrator[year_key]
162
+ if isinstance(cal_data, list) and len(cal_data) > 0:
163
+ cal_data = cal_data[0]
164
+
165
+ # Apply linear calibration if available
166
+ if isinstance(cal_data, dict) and "coef" in cal_data and "intercept" in cal_data:
167
+ coef = cal_data["coef"][0][0] if isinstance(cal_data["coef"], list) else cal_data["coef"]
168
+ intercept = cal_data["intercept"][0] if isinstance(cal_data["intercept"], list) else cal_data["intercept"]
169
+
170
+ # Apply calibration
171
+ calibrated[:, year] = scores[:, year] * coef + intercept
172
+ calibrated[:, year] = 1 / (1 + np.exp(-calibrated[:, year])) # Sigmoid
173
+ else:
174
+ calibrated[:, year] = scores[:, year]
175
+ else:
176
+ calibrated[:, year] = scores[:, year]
177
+
178
+ return calibrated
179
+
180
+ def preprocess_dicom(self, dicom_paths: List[str]) -> torch.Tensor:
181
  """
182
+ Preprocess DICOM files for model input.
183
 
184
  Args:
185
+ dicom_paths: List of paths to DICOM files
186
+
187
+ Returns:
188
+ Preprocessed tensor ready for model input
189
+ """
190
+ # Use the image processor to handle DICOM files
191
+ result = self.image_processor(dicom_paths, file_type="dicom", return_tensors="pt")
192
+ pixel_values = result["pixel_values"]
193
+
194
+ # Ensure we have 5D tensor (B, C, D, H, W)
195
+ if pixel_values.ndim == 4:
196
+ pixel_values = pixel_values.unsqueeze(0) # Add batch dimension
197
+
198
+ return pixel_values.to(self.device)
199
+
200
+ def predict(self, dicom_paths: List[str], return_attentions: bool = False) -> SybilOutput:
201
+ """
202
+ Run prediction on a CT scan series.
203
+
204
+ Args:
205
+ dicom_paths: List of paths to DICOM files for a single CT series
206
  return_attentions: Whether to return attention maps
207
 
208
  Returns:
209
+ SybilOutput with risk scores and optional attention maps
210
  """
211
+ # Preprocess the DICOM files
212
+ pixel_values = self.preprocess_dicom(dicom_paths)
213
 
214
+ # Run inference with ensemble
215
+ all_predictions = []
216
+ all_attentions = []
217
 
218
+ with torch.no_grad():
219
+ for model in self.models:
220
+ output = model(
221
+ pixel_values=pixel_values,
222
+ return_attentions=return_attentions
223
+ )
224
 
225
+ # Extract risk scores
226
+ if hasattr(output, 'risk_scores'):
227
+ predictions = output.risk_scores
228
+ else:
229
+ predictions = output[0] if isinstance(output, tuple) else output
230
 
231
+ all_predictions.append(predictions.cpu().numpy())
 
232
 
233
+ if return_attentions and hasattr(output, 'image_attention'):
234
+ all_attentions.append(output.image_attention)
 
 
235
 
236
+ # Average ensemble predictions
237
+ ensemble_pred = np.mean(all_predictions, axis=0)
238
+
239
+ # Apply calibration
240
+ calibrated_pred = self._apply_calibration(ensemble_pred)
241
+
242
+ # Convert back to torch tensor
243
+ risk_scores = torch.from_numpy(calibrated_pred).float()
244
+
245
+ # Average attentions if requested
246
+ attentions = None
247
+ if return_attentions and all_attentions:
248
+ attentions = {"image_attention": torch.stack(all_attentions).mean(dim=0)}
249
+
250
+ return SybilOutput(risk_scores=risk_scores, attentions=attentions)
251
+
252
+ def __call__(self, dicom_paths: List[str] = None, dicom_series: List[List[str]] = None, **kwargs) -> SybilOutput:
253
  """
254
+ Convenience method for prediction.
255
+
256
+ Args:
257
+ dicom_paths: List of DICOM file paths for a single series
258
+ dicom_series: List of lists of DICOM paths for batch processing
259
+ **kwargs: Additional arguments passed to predict()
260
+
261
+ Returns:
262
+ SybilOutput with predictions
263
  """
264
+ if dicom_series is not None:
265
+ # Batch processing
266
+ all_outputs = []
267
+ for paths in dicom_series:
268
+ output = self.predict(paths, **kwargs)
269
+ all_outputs.append(output.risk_scores)
270
 
271
+ risk_scores = torch.stack(all_outputs)
272
+ return SybilOutput(risk_scores=risk_scores)
273
+ elif dicom_paths is not None:
274
+ return self.predict(dicom_paths, **kwargs)
275
+ else:
276
+ raise ValueError("Either dicom_paths or dicom_series must be provided")
277
 
278
+ @classmethod
279
+ def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs):
280
  """
281
+ Load model from Hugging Face hub or local path.
282
+
283
+ Args:
284
+ pretrained_model_name_or_path: HF model ID or local path
285
+ **kwargs: Additional configuration arguments
286
+
287
+ Returns:
288
+ SybilHFWrapper instance
289
  """
290
+ # Load configuration
291
+ config = kwargs.pop("config", None)
292
+ if config is None:
293
+ try:
294
+ config = SybilConfig.from_pretrained(pretrained_model_name_or_path)
295
+ except:
296
+ config = SybilConfig()
297
+
298
+ return cls(config=config)