create scripts/pl_models.py (#3)
Browse files- create scripts/pl_models.py (7fcfad978dbd4198fe8e424df30b21f27fce070e)
Co-authored-by: Ryan Keivanfar <[email protected]>
- scripts/pl_models.py +483 -0
scripts/pl_models.py
ADDED
@@ -0,0 +1,483 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# pl_models.py
|
2 |
+
"""
|
3 |
+
This module defines PyTorch Lightning modules for the Tahoeformer project.
|
4 |
+
It includes a base model class (`LitBaseModel`) and the main experimental model
|
5 |
+
(`LitEnformerSMILES`) which combines an Enformer-based DNA sequence model with
|
6 |
+
drug information (SMILES string processed into Morgan Fingerprints) and dose information
|
7 |
+
to predict gene expression.
|
8 |
+
|
9 |
+
Key components:
|
10 |
+
- masked_mse: A utility loss function for Mean Squared Error that handles NaN targets.
|
11 |
+
- LitBaseModel: A base LightningModule providing common training, validation, test steps,
|
12 |
+
optimizer configuration, and basic metric logging hooks.
|
13 |
+
- LitEnformerSMILES: The primary model for predicting drug-induced gene expression changes,
|
14 |
+
using Enformer for DNA and Morgan fingerprints for drugs.
|
15 |
+
- MetricLogger: A PyTorch Lightning Callback for detailed logging of predictions.
|
16 |
+
"""
|
17 |
+
|
18 |
+
import pandas as pd
|
19 |
+
import os
|
20 |
+
import torch
|
21 |
+
import torch.nn as nn
|
22 |
+
import lightning.pytorch as pl
|
23 |
+
from enformer_pytorch.finetune import HeadAdapterWrapper
|
24 |
+
from enformer_pytorch import Enformer
|
25 |
+
from torchmetrics.regression import PearsonCorrCoef, R2Score
|
26 |
+
from warnings import warn
|
27 |
+
import wandb
|
28 |
+
import numpy as np # Added for MetricLogger consistency
|
29 |
+
|
30 |
+
# --- Utility Functions ---
|
31 |
+
def masked_mse(y_hat, y):
|
32 |
+
"""
|
33 |
+
Computes Mean Squared Error (MSE) while ignoring NaN values in the target tensor.
|
34 |
+
|
35 |
+
Args:
|
36 |
+
y_hat (torch.Tensor): The predicted values.
|
37 |
+
y (torch.Tensor): The target values, which may contain NaNs.
|
38 |
+
|
39 |
+
Returns:
|
40 |
+
torch.Tensor: A scalar tensor representing the masked MSE. Returns 0.0 if all targets are NaN.
|
41 |
+
"""
|
42 |
+
mask = torch.isnan(y)
|
43 |
+
if mask.all(): # Handle case where all targets in batch are NaN
|
44 |
+
return torch.tensor(0.0, device=y_hat.device, requires_grad=True)
|
45 |
+
mse = torch.mean((y[~mask] - y_hat[~mask])**2)
|
46 |
+
return mse
|
47 |
+
|
48 |
+
# --- Base Lightning Module ---
|
49 |
+
class LitBaseModel(pl.LightningModule):
|
50 |
+
"""
|
51 |
+
A base PyTorch Lightning module providing common boilerplate for training and evaluation.
|
52 |
+
|
53 |
+
This class implements a generic training/validation/test step, loss calculation using
|
54 |
+
`masked_mse`, optimizer configuration (AdamW), and hooks for accumulating outputs
|
55 |
+
for detailed metric logging via the `MetricLogger` callback.
|
56 |
+
|
57 |
+
Derived classes are expected to implement the `forward` method.
|
58 |
+
|
59 |
+
Hyperparameters:
|
60 |
+
learning_rate (float): The learning rate for the optimizer.
|
61 |
+
loss_alpha (float): A coefficient for the primary loss term (MSE). Useful if
|
62 |
+
additional loss terms were to be added.
|
63 |
+
weight_decay (float, optional): Weight decay for the AdamW optimizer. If None,
|
64 |
+
AdamW's internal default is used.
|
65 |
+
eval_gene_sets (dict, optional): A dictionary where keys are set names (e.g., 'oncogenes')
|
66 |
+
and values are lists of gene IDs. Used by `MetricLogger`
|
67 |
+
to compute metrics for specific gene subsets.
|
68 |
+
"""
|
69 |
+
def __init__(self, learning_rate=5e-6, loss_alpha=1.0, weight_decay=None,
|
70 |
+
eval_gene_sets=None): # eval_gene_sets: dict {'train': [genes], 'valid': [genes], 'test': [genes]}
|
71 |
+
"""
|
72 |
+
Initializes the LitBaseModel.
|
73 |
+
|
74 |
+
Args:
|
75 |
+
learning_rate (float, optional): Learning rate. Defaults to 5e-6.
|
76 |
+
loss_alpha (float, optional): Alpha for MSE loss. Defaults to 1.0.
|
77 |
+
weight_decay (float, optional): Weight decay for AdamW. If None, uses optimizer default.
|
78 |
+
Defaults to None.
|
79 |
+
eval_gene_sets (dict, optional): Dictionary of gene sets for targeted evaluation.
|
80 |
+
Keys are names, values are lists of gene IDs.
|
81 |
+
Defaults to None.
|
82 |
+
"""
|
83 |
+
super().__init__()
|
84 |
+
self.save_hyperparameters()
|
85 |
+
self.learning_rate = learning_rate
|
86 |
+
self.loss_alpha = loss_alpha # alpha for mse vs. other terms (if any)
|
87 |
+
self.weight_decay = weight_decay
|
88 |
+
self.eval_gene_sets = eval_gene_sets if eval_gene_sets else {}
|
89 |
+
|
90 |
+
# Results accumulated per epoch for MetricLogger
|
91 |
+
self.epoch_outputs = []
|
92 |
+
|
93 |
+
def loss_fn(self, y_hat, y):
|
94 |
+
"""
|
95 |
+
Calculates the loss for the model.
|
96 |
+
|
97 |
+
Currently uses `masked_mse` scaled by `self.loss_alpha`.
|
98 |
+
|
99 |
+
Args:
|
100 |
+
y_hat (torch.Tensor): Predicted values from the model.
|
101 |
+
y (torch.Tensor): Ground truth target values.
|
102 |
+
|
103 |
+
Returns:
|
104 |
+
torch.Tensor: The computed loss value.
|
105 |
+
"""
|
106 |
+
mse_term = masked_mse(y_hat, y)
|
107 |
+
# Potentially: add other loss terms here, weighted by (1-loss_alpha) if desired
|
108 |
+
return self.loss_alpha * mse_term
|
109 |
+
|
110 |
+
def _common_step(self, batch, batch_idx, step_type):
|
111 |
+
"""
|
112 |
+
A common step for training, validation, and testing.
|
113 |
+
|
114 |
+
This method unpacks the batch, performs a forward pass, calculates the loss,
|
115 |
+
logs the loss, and accumulates outputs for epoch-level metric calculation
|
116 |
+
(for validation and test steps).
|
117 |
+
|
118 |
+
Args:
|
119 |
+
batch: The batch of data from the DataLoader. Expected to be a tuple containing
|
120 |
+
DNA sequence, Morgan fingerprints, dose, target expression,
|
121 |
+
and metadata (gene_id, drug_id, cell_line).
|
122 |
+
batch_idx (int): The index of the current batch.
|
123 |
+
step_type (str): A string indicating the type of step ('train', 'val', or 'test').
|
124 |
+
|
125 |
+
Returns:
|
126 |
+
torch.Tensor: The loss for the current batch.
|
127 |
+
"""
|
128 |
+
# Batch structure will change after dataset modification:
|
129 |
+
# (dna_seq, morgan_fingerprints, dose, target_expression, gene_id, drug_id, cell_line)
|
130 |
+
dna_seq, morgan_fingerprints, dose, target_expression, gene_id, drug_id, cell_line = batch
|
131 |
+
|
132 |
+
y_hat = self(dna_seq, morgan_fingerprints, dose) # Call forward method of derived class
|
133 |
+
|
134 |
+
loss = self.loss_fn(y_hat, target_expression)
|
135 |
+
self.log(f'{step_type}_loss', loss, batch_size=target_expression.shape[0], on_step=(step_type=='train' and False), on_epoch=True, prog_bar=(step_type!='train'))
|
136 |
+
|
137 |
+
if step_type != 'train':
|
138 |
+
# Prepare data for MetricLogger
|
139 |
+
batch_size = target_expression.shape[0]
|
140 |
+
for i in range(batch_size):
|
141 |
+
item_data = {
|
142 |
+
'pred': y_hat[i].detach(),
|
143 |
+
'target': target_expression[i].detach(),
|
144 |
+
'gene_id': gene_id[i],
|
145 |
+
'drug_id': drug_id[i],
|
146 |
+
'cell_line': cell_line[i],
|
147 |
+
'rank': self.trainer.global_rank
|
148 |
+
}
|
149 |
+
self.epoch_outputs.append(item_data)
|
150 |
+
return loss
|
151 |
+
|
152 |
+
def training_step(self, batch, batch_idx):
|
153 |
+
"""PyTorch Lightning training step. Calls `_common_step`."""
|
154 |
+
return self._common_step(batch, batch_idx, 'train')
|
155 |
+
|
156 |
+
def validation_step(self, batch, batch_idx, dataloader_idx=0):
|
157 |
+
"""PyTorch Lightning validation step. Calls `_common_step`."""
|
158 |
+
return self._common_step(batch, batch_idx, 'val')
|
159 |
+
|
160 |
+
def test_step(self, batch, batch_idx, dataloader_idx=0):
|
161 |
+
"""PyTorch Lightning test step. Calls `_common_step`."""
|
162 |
+
return self._common_step(batch, batch_idx, 'test')
|
163 |
+
|
164 |
+
def on_validation_epoch_start(self):
|
165 |
+
"""Clears accumulated outputs at the start of each validation epoch."""
|
166 |
+
self.epoch_outputs = []
|
167 |
+
|
168 |
+
def on_test_epoch_start(self):
|
169 |
+
"""Clears accumulated outputs at the start of each test epoch."""
|
170 |
+
self.epoch_outputs = []
|
171 |
+
|
172 |
+
def configure_optimizers(self):
|
173 |
+
"""
|
174 |
+
Configures the optimizer for the model.
|
175 |
+
|
176 |
+
Uses AdamW with the specified learning rate and weight decay.
|
177 |
+
|
178 |
+
Returns:
|
179 |
+
torch.optim.Optimizer: The configured AdamW optimizer.
|
180 |
+
"""
|
181 |
+
if self.weight_decay is None:
|
182 |
+
optimizer = torch.optim.AdamW(self.parameters(), lr=self.learning_rate)
|
183 |
+
else:
|
184 |
+
optimizer = torch.optim.AdamW(self.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay)
|
185 |
+
return optimizer
|
186 |
+
|
187 |
+
# --- Enformer + Morgan Fingerprints Model ---
|
188 |
+
class LitEnformerSMILES(LitBaseModel): # Consider renaming to LitEnformerMorgan for clarity
|
189 |
+
"""
|
190 |
+
A PyTorch Lightning module that combines genomic sequence information (via Enformer)
|
191 |
+
with drug chemical structure (represented by Morgan fingerprints) and drug dose
|
192 |
+
to predict gene expression changes.
|
193 |
+
|
194 |
+
The model architecture consists of three main branches:
|
195 |
+
1. DNA Module: Uses a pre-trained Enformer model (with an adapted head) to extract
|
196 |
+
features from a one-hot encoded DNA sequence centered around a gene's TSS.
|
197 |
+
2. Drug Module: Uses pre-computed Morgan fingerprints as the drug representation.
|
198 |
+
3. Dose Module: Directly uses the numerical dose value.
|
199 |
+
|
200 |
+
Features from these three branches are concatenated and passed through a multi-layer
|
201 |
+
fusion head (MLP with ReLU, BatchNorm, Dropout) to produce the final prediction
|
202 |
+
of gene expression.
|
203 |
+
|
204 |
+
Inherits common training and evaluation logic from `LitBaseModel`.
|
205 |
+
"""
|
206 |
+
def __init__(self,
|
207 |
+
enformer_model_name: str = 'EleutherAI/enformer-official-rough',
|
208 |
+
enformer_target_length: int = -1,
|
209 |
+
num_output_tracks_enformer_head: int = 1,
|
210 |
+
morgan_fingerprint_dim: int = 2048, # dim of the Morgan fingerprint vector
|
211 |
+
dose_input_dim: int = 1,
|
212 |
+
fusion_hidden_dim: int = 256,
|
213 |
+
final_output_tracks: int = 1,
|
214 |
+
learning_rate=5e-6,
|
215 |
+
loss_alpha=1.0,
|
216 |
+
weight_decay=None,
|
217 |
+
eval_gene_sets=None):
|
218 |
+
"""
|
219 |
+
Initializes the LitEnformerSMILES (or LitEnformerMorgan) model.
|
220 |
+
|
221 |
+
Args:
|
222 |
+
enformer_model_name (str, optional): Name or path of the pre-trained Enformer model.
|
223 |
+
enformer_target_length (int, optional): Target length for Enformer's internal pooling.
|
224 |
+
num_output_tracks_enformer_head (int, optional): Output features from Enformer head.
|
225 |
+
morgan_fingerprint_dim (int, optional): Dimensionality of the Morgan fingerprint vector
|
226 |
+
(e.g., 2048 for ECFP4). Defaults to 2048.
|
227 |
+
dose_input_dim (int, optional): Dimensionality of the drug dose input. Defaults to 1.
|
228 |
+
fusion_hidden_dim (int, optional): Hidden dimension for the fusion MLP. Defaults to 256.
|
229 |
+
final_output_tracks (int, optional): Number of final output values. Defaults to 1.
|
230 |
+
learning_rate (float, optional): Learning rate. Defaults to 5e-6.
|
231 |
+
loss_alpha (float, optional): Weight for MSE loss. Defaults to 1.0.
|
232 |
+
weight_decay (float, optional): Weight decay. Defaults to None.
|
233 |
+
eval_gene_sets (dict, optional): Gene sets for targeted evaluation. Defaults to None.
|
234 |
+
"""
|
235 |
+
super().__init__(learning_rate, loss_alpha, weight_decay, eval_gene_sets)
|
236 |
+
self.save_hyperparameters(
|
237 |
+
"enformer_model_name", "enformer_target_length",
|
238 |
+
"num_output_tracks_enformer_head", "morgan_fingerprint_dim",
|
239 |
+
"dose_input_dim", "fusion_hidden_dim", "final_output_tracks",
|
240 |
+
"learning_rate", "loss_alpha", "weight_decay"
|
241 |
+
)
|
242 |
+
|
243 |
+
# 1. DNA Module (Enformer with HeadAdapter)
|
244 |
+
enformer_pretrained = Enformer.from_pretrained(
|
245 |
+
self.hparams.enformer_model_name,
|
246 |
+
target_length=self.hparams.enformer_target_length
|
247 |
+
)
|
248 |
+
self.dna_module = HeadAdapterWrapper(
|
249 |
+
enformer=enformer_pretrained,
|
250 |
+
num_tracks=self.hparams.num_output_tracks_enformer_head,
|
251 |
+
post_transformer_embed=False,
|
252 |
+
output_activation=nn.Identity()
|
253 |
+
)
|
254 |
+
|
255 |
+
# 2. Drug Module (Morgan Fingerprints are provided as input directly)
|
256 |
+
# No layers needed here as fingerprints are pre-computed.
|
257 |
+
# The self.hparams.morgan_fingerprint_dim defines the expected input dimension.
|
258 |
+
|
259 |
+
# 3. Fusion Head
|
260 |
+
# Input dimension uses morgan_fingerprint_dim
|
261 |
+
fusion_input_dim = self.hparams.num_output_tracks_enformer_head + self.hparams.morgan_fingerprint_dim + self.hparams.dose_input_dim
|
262 |
+
self.fusion_head = nn.Sequential(
|
263 |
+
nn.Linear(fusion_input_dim, self.hparams.fusion_hidden_dim),
|
264 |
+
nn.ReLU(),
|
265 |
+
nn.BatchNorm1d(self.hparams.fusion_hidden_dim),
|
266 |
+
nn.Dropout(0.25),
|
267 |
+
nn.Linear(self.hparams.fusion_hidden_dim, self.hparams.fusion_hidden_dim // 2),
|
268 |
+
nn.ReLU(),
|
269 |
+
nn.BatchNorm1d(self.hparams.fusion_hidden_dim // 2),
|
270 |
+
nn.Dropout(0.1),
|
271 |
+
nn.Linear(self.hparams.fusion_hidden_dim // 2, self.hparams.final_output_tracks)
|
272 |
+
)
|
273 |
+
|
274 |
+
def forward(self, dna_seq, morgan_fingerprints, dose):
|
275 |
+
"""
|
276 |
+
Defines the forward pass of the LitEnformerSMILES model using Morgan Fingerprints.
|
277 |
+
|
278 |
+
Args:
|
279 |
+
dna_seq (torch.Tensor): Batch of one-hot encoded DNA sequences.
|
280 |
+
Shape: (batch_size, sequence_length, 4).
|
281 |
+
morgan_fingerprints (torch.Tensor): Batch of pre-computed Morgan fingerprint vectors.
|
282 |
+
Shape: (batch_size, morgan_fingerprint_dim).
|
283 |
+
dose (torch.Tensor): Batch of drug dose values.
|
284 |
+
Shape: (batch_size, dose_input_dim).
|
285 |
+
|
286 |
+
Returns:
|
287 |
+
torch.Tensor: The model's prediction. Shape: (batch_size, final_output_tracks).
|
288 |
+
"""
|
289 |
+
# --- DNA Processing ---
|
290 |
+
dna_out_intermediate = self.dna_module(dna_seq, freeze_enformer=False)
|
291 |
+
center_seq_idx = dna_out_intermediate.shape[1] // 2
|
292 |
+
dna_features = dna_out_intermediate[:, center_seq_idx, :]
|
293 |
+
|
294 |
+
# --- Drug Processing (Morgan Fingerprints) ---
|
295 |
+
# Morgan fingerprints are directly used as features.
|
296 |
+
smiles_features = morgan_fingerprints # Shape: (batch_size, morgan_fingerprint_dim)
|
297 |
+
|
298 |
+
# --- Dose Processing ---
|
299 |
+
if dose.ndim == 1:
|
300 |
+
dose = dose.unsqueeze(-1)
|
301 |
+
|
302 |
+
# --- Feature Combination & Final Prediction ---
|
303 |
+
combined_features = torch.cat([dna_features, smiles_features, dose], dim=1)
|
304 |
+
prediction = self.fusion_head(combined_features)
|
305 |
+
return prediction
|
306 |
+
|
307 |
+
# --- Metrics Logging Callback ---
|
308 |
+
class MetricLogger(pl.Callback):
|
309 |
+
"""
|
310 |
+
A PyTorch Lightning Callback for comprehensive metric calculation and logging.
|
311 |
+
|
312 |
+
This callback accumulates predictions and targets during validation and test epochs.
|
313 |
+
At the end of these epochs, it:
|
314 |
+
1. Processes the accumulated outputs into a pandas DataFrame.
|
315 |
+
2. Saves the raw predictions and targets for the epoch to a CSV file.
|
316 |
+
3. Logs a sample of these raw predictions as a W&B Table if WandbLogger is used.
|
317 |
+
4. Calculates overall performance metrics (MSE, Pearson, R2) for the epoch.
|
318 |
+
5. If `eval_gene_sets` are provided in the LightningModule, calculates metrics for these specific gene subsets.
|
319 |
+
6. Calculates metrics per cell line if 'cell_line' information is available in the outputs.
|
320 |
+
7. Logs all calculated metrics to the LightningModule's logger.
|
321 |
+
|
322 |
+
Attributes:
|
323 |
+
save_dir_prefix (str): Prefix for the directory where metric CSVs will be saved.
|
324 |
+
current_epoch_data (list): List to accumulate dictionaries of pred/target/metadata per item.
|
325 |
+
"""
|
326 |
+
def __init__(self, save_dir_prefix="results"):
|
327 |
+
"""
|
328 |
+
Initializes the MetricLogger callback.
|
329 |
+
|
330 |
+
Args:
|
331 |
+
save_dir_prefix (str, optional): Directory prefix for saving metrics files.
|
332 |
+
Defaults to "results".
|
333 |
+
"""
|
334 |
+
super().__init__()
|
335 |
+
self.save_dir_prefix = save_dir_prefix
|
336 |
+
self.current_epoch_data = []
|
337 |
+
|
338 |
+
def _process_epoch_outputs(self, pl_module, stage):
|
339 |
+
"""
|
340 |
+
Processes the raw outputs collected during an epoch into a pandas DataFrame.
|
341 |
+
|
342 |
+
Converts tensor data for 'pred' and 'target' columns to NumPy/Python native types.
|
343 |
+
|
344 |
+
Args:
|
345 |
+
pl_module (pl.LightningModule): The LightningModule instance.
|
346 |
+
stage (str): The current stage (e.g., "validation", "test").
|
347 |
+
|
348 |
+
Returns:
|
349 |
+
pd.DataFrame: A DataFrame containing the processed epoch outputs.
|
350 |
+
Returns an empty DataFrame if no outputs were collected.
|
351 |
+
"""
|
352 |
+
if not hasattr(pl_module, 'epoch_outputs') or not pl_module.epoch_outputs:
|
353 |
+
warn(f"No outputs collected (pl_module.epoch_outputs is missing or empty) during {stage} epoch for MetricLogger.")
|
354 |
+
return pd.DataFrame()
|
355 |
+
|
356 |
+
df = pd.DataFrame(pl_module.epoch_outputs)
|
357 |
+
|
358 |
+
for col in ['pred', 'target']:
|
359 |
+
if col in df.columns and not df[col].empty:
|
360 |
+
if isinstance(df[col].iloc[0], torch.Tensor):
|
361 |
+
df[col] = df[col].apply(lambda x: x.cpu().float().numpy().item() if x.numel() == 1 else x.cpu().float().numpy())
|
362 |
+
return df
|
363 |
+
|
364 |
+
def on_validation_epoch_end(self, trainer, pl_module):
|
365 |
+
"""Hook called at the end of the validation epoch."""
|
366 |
+
if hasattr(pl_module, 'epoch_outputs') and pl_module.epoch_outputs:
|
367 |
+
self.current_epoch_data = self._process_epoch_outputs(pl_module, "validation")
|
368 |
+
if not self.current_epoch_data.empty:
|
369 |
+
self._log_and_save_metrics(trainer, pl_module, "validation")
|
370 |
+
else:
|
371 |
+
warn("MetricLogger: pl_module.epoch_outputs not found or empty at on_validation_epoch_end.")
|
372 |
+
|
373 |
+
def on_test_epoch_end(self, trainer, pl_module):
|
374 |
+
"""Hook called at the end of the test epoch."""
|
375 |
+
if hasattr(pl_module, 'epoch_outputs') and pl_module.epoch_outputs:
|
376 |
+
self.current_epoch_data = self._process_epoch_outputs(pl_module, "test")
|
377 |
+
if not self.current_epoch_data.empty:
|
378 |
+
self._log_and_save_metrics(trainer, pl_module, "test")
|
379 |
+
else:
|
380 |
+
warn("MetricLogger: pl_module.epoch_outputs not found or empty at on_test_epoch_end.")
|
381 |
+
|
382 |
+
|
383 |
+
def _log_and_save_metrics(self, trainer, pl_module, stage):
|
384 |
+
"""
|
385 |
+
Calculates, logs, and saves metrics for the current stage and epoch.
|
386 |
+
|
387 |
+
Args:
|
388 |
+
trainer (pl.Trainer): The PyTorch Lightning Trainer instance.
|
389 |
+
pl_module (pl.LightningModule): The LightningModule instance.
|
390 |
+
stage (str): The current stage (e.g., "validation", "test").
|
391 |
+
"""
|
392 |
+
epoch = trainer.current_epoch if trainer.current_epoch is not None else -1
|
393 |
+
save_dir = getattr(pl_module.hparams, 'save_dir',
|
394 |
+
os.path.join(self.save_dir_prefix, f"run_{trainer.logger.version if trainer.logger else 'local'}"))
|
395 |
+
os.makedirs(save_dir, exist_ok=True)
|
396 |
+
|
397 |
+
raw_preds_path = os.path.join(save_dir, f"{stage}_predictions_epoch_{epoch}.csv")
|
398 |
+
self.current_epoch_data.to_csv(raw_preds_path, index=False)
|
399 |
+
|
400 |
+
if trainer.logger and hasattr(trainer.logger, 'experiment') and isinstance(trainer.logger.experiment, wandb.sdk.wandb_run.Run):
|
401 |
+
try:
|
402 |
+
trainer.logger.experiment.log({f"{stage}_raw_predictions_epoch_{epoch}": wandb.Table(dataframe=self.current_epoch_data.head(1000))})
|
403 |
+
except Exception as e:
|
404 |
+
warn(f"MetricLogger: Failed to log raw predictions table to W&B: {e}")
|
405 |
+
|
406 |
+
overall_metrics = self._calculate_metrics_for_group(self.current_epoch_data, pl_module.device)
|
407 |
+
if overall_metrics:
|
408 |
+
pl_module.log_dict({f"{stage}_{k}_epoch": v for k, v in overall_metrics.items()}, sync_dist=True)
|
409 |
+
|
410 |
+
if hasattr(pl_module, 'eval_gene_sets') and pl_module.eval_gene_sets and isinstance(pl_module.eval_gene_sets, dict) and 'gene_id' in self.current_epoch_data.columns:
|
411 |
+
for split_name, gene_list in pl_module.eval_gene_sets.items():
|
412 |
+
if not gene_list: continue
|
413 |
+
split_df = self.current_epoch_data[self.current_epoch_data['gene_id'].isin(gene_list)]
|
414 |
+
if not split_df.empty:
|
415 |
+
split_metrics = self._calculate_metrics_for_group(split_df, pl_module.device)
|
416 |
+
if split_metrics:
|
417 |
+
pl_module.log_dict({f"{stage}_{split_name}_genes_{k}_epoch": v for k, v in split_metrics.items()}, sync_dist=True)
|
418 |
+
|
419 |
+
if 'cell_line' in self.current_epoch_data.columns:
|
420 |
+
for cell_line, group_df in self.current_epoch_data.groupby('cell_line'):
|
421 |
+
cl_metrics = self._calculate_metrics_for_group(group_df, pl_module.device)
|
422 |
+
if cl_metrics:
|
423 |
+
pl_module.log_dict({f"{stage}_{cell_line}_cell_line_{k}_epoch": v for k,v in cl_metrics.items()}, sync_dist=True)
|
424 |
+
|
425 |
+
|
426 |
+
def _calculate_metrics_for_group(self, df_group, device):
|
427 |
+
"""
|
428 |
+
Calculates regression metrics (MSE, Pearson, R2) for a given group of predictions.
|
429 |
+
|
430 |
+
Args:
|
431 |
+
df_group (pd.DataFrame): DataFrame containing 'pred' and 'target' columns for the group.
|
432 |
+
device (torch.device): The device to perform calculations on.
|
433 |
+
|
434 |
+
Returns:
|
435 |
+
dict: A dictionary of calculated metrics (mse, pearson, r2). Returns empty if data is insufficient.
|
436 |
+
"""
|
437 |
+
if df_group.empty or 'pred' not in df_group.columns or 'target' not in df_group.columns:
|
438 |
+
return {}
|
439 |
+
|
440 |
+
preds_np = np.array(df_group['pred'].tolist(), dtype=np.float32)
|
441 |
+
targets_np = np.array(df_group['target'].tolist(), dtype=np.float32)
|
442 |
+
|
443 |
+
preds = torch.tensor(preds_np).to(device)
|
444 |
+
targets = torch.tensor(targets_np).to(device)
|
445 |
+
|
446 |
+
if preds.ndim == 1:
|
447 |
+
preds = preds.squeeze()
|
448 |
+
targets = targets.squeeze()
|
449 |
+
|
450 |
+
if preds.numel() == 0 or targets.numel() == 0 or preds.shape != targets.shape :
|
451 |
+
warn(f"Skipping metrics calculation for a group due to mismatched or empty preds/targets. Pred shape: {preds.shape}, Target shape: {targets.shape}")
|
452 |
+
return {}
|
453 |
+
|
454 |
+
mse_val_tensor = masked_mse(preds.unsqueeze(-1) if preds.ndim==1 else preds,
|
455 |
+
targets.unsqueeze(-1) if targets.ndim==1 else targets)
|
456 |
+
calculated_metrics = {'mse': mse_val_tensor.item()}
|
457 |
+
|
458 |
+
if preds.numel() < 2:
|
459 |
+
warn(f"Skipping Pearson/R2 for a group with < 2 samples. Found {preds.numel()} samples. Only MSE will be reported.")
|
460 |
+
return calculated_metrics
|
461 |
+
|
462 |
+
preds_for_corr = preds.squeeze()
|
463 |
+
targets_for_corr = targets.squeeze()
|
464 |
+
|
465 |
+
if preds_for_corr.shape != targets_for_corr.shape or preds_for_corr.ndim > 1 and preds_for_corr.shape[1] >1:
|
466 |
+
warn(f"Skipping Pearson/R2 due to incompatible shapes after squeeze for correlation. Pred: {preds_for_corr.shape}, Target: {targets_for_corr.shape}")
|
467 |
+
return calculated_metrics
|
468 |
+
|
469 |
+
try:
|
470 |
+
pearson_fn = PearsonCorrCoef().to(device)
|
471 |
+
pearson_val = pearson_fn(preds_for_corr, targets_for_corr)
|
472 |
+
calculated_metrics['pearson'] = pearson_val.item()
|
473 |
+
except Exception as e:
|
474 |
+
warn(f"Could not compute Pearson Correlation: {e}. Preds shape: {preds_for_corr.shape}, Targets shape: {targets_for_corr.shape}")
|
475 |
+
|
476 |
+
try:
|
477 |
+
r2_fn = R2Score().to(device)
|
478 |
+
r2_val = r2_fn(preds_for_corr, targets_for_corr)
|
479 |
+
calculated_metrics['r2'] = r2_val.item()
|
480 |
+
except Exception as e:
|
481 |
+
warn(f"Could not compute R2 Score: {e}. Preds shape: {preds_for_corr.shape}, Targets shape: {targets_for_corr.shape}")
|
482 |
+
|
483 |
+
return calculated_metrics
|