qhuang20 ryankeivanfar commited on
Commit
3203cde
·
verified ·
1 Parent(s): 279b7c3

create scripts/pl_models.py (#3)

Browse files

- create scripts/pl_models.py (7fcfad978dbd4198fe8e424df30b21f27fce070e)


Co-authored-by: Ryan Keivanfar <[email protected]>

Files changed (1) hide show
  1. scripts/pl_models.py +483 -0
scripts/pl_models.py ADDED
@@ -0,0 +1,483 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pl_models.py
2
+ """
3
+ This module defines PyTorch Lightning modules for the Tahoeformer project.
4
+ It includes a base model class (`LitBaseModel`) and the main experimental model
5
+ (`LitEnformerSMILES`) which combines an Enformer-based DNA sequence model with
6
+ drug information (SMILES string processed into Morgan Fingerprints) and dose information
7
+ to predict gene expression.
8
+
9
+ Key components:
10
+ - masked_mse: A utility loss function for Mean Squared Error that handles NaN targets.
11
+ - LitBaseModel: A base LightningModule providing common training, validation, test steps,
12
+ optimizer configuration, and basic metric logging hooks.
13
+ - LitEnformerSMILES: The primary model for predicting drug-induced gene expression changes,
14
+ using Enformer for DNA and Morgan fingerprints for drugs.
15
+ - MetricLogger: A PyTorch Lightning Callback for detailed logging of predictions.
16
+ """
17
+
18
+ import pandas as pd
19
+ import os
20
+ import torch
21
+ import torch.nn as nn
22
+ import lightning.pytorch as pl
23
+ from enformer_pytorch.finetune import HeadAdapterWrapper
24
+ from enformer_pytorch import Enformer
25
+ from torchmetrics.regression import PearsonCorrCoef, R2Score
26
+ from warnings import warn
27
+ import wandb
28
+ import numpy as np # Added for MetricLogger consistency
29
+
30
+ # --- Utility Functions ---
31
+ def masked_mse(y_hat, y):
32
+ """
33
+ Computes Mean Squared Error (MSE) while ignoring NaN values in the target tensor.
34
+
35
+ Args:
36
+ y_hat (torch.Tensor): The predicted values.
37
+ y (torch.Tensor): The target values, which may contain NaNs.
38
+
39
+ Returns:
40
+ torch.Tensor: A scalar tensor representing the masked MSE. Returns 0.0 if all targets are NaN.
41
+ """
42
+ mask = torch.isnan(y)
43
+ if mask.all(): # Handle case where all targets in batch are NaN
44
+ return torch.tensor(0.0, device=y_hat.device, requires_grad=True)
45
+ mse = torch.mean((y[~mask] - y_hat[~mask])**2)
46
+ return mse
47
+
48
+ # --- Base Lightning Module ---
49
+ class LitBaseModel(pl.LightningModule):
50
+ """
51
+ A base PyTorch Lightning module providing common boilerplate for training and evaluation.
52
+
53
+ This class implements a generic training/validation/test step, loss calculation using
54
+ `masked_mse`, optimizer configuration (AdamW), and hooks for accumulating outputs
55
+ for detailed metric logging via the `MetricLogger` callback.
56
+
57
+ Derived classes are expected to implement the `forward` method.
58
+
59
+ Hyperparameters:
60
+ learning_rate (float): The learning rate for the optimizer.
61
+ loss_alpha (float): A coefficient for the primary loss term (MSE). Useful if
62
+ additional loss terms were to be added.
63
+ weight_decay (float, optional): Weight decay for the AdamW optimizer. If None,
64
+ AdamW's internal default is used.
65
+ eval_gene_sets (dict, optional): A dictionary where keys are set names (e.g., 'oncogenes')
66
+ and values are lists of gene IDs. Used by `MetricLogger`
67
+ to compute metrics for specific gene subsets.
68
+ """
69
+ def __init__(self, learning_rate=5e-6, loss_alpha=1.0, weight_decay=None,
70
+ eval_gene_sets=None): # eval_gene_sets: dict {'train': [genes], 'valid': [genes], 'test': [genes]}
71
+ """
72
+ Initializes the LitBaseModel.
73
+
74
+ Args:
75
+ learning_rate (float, optional): Learning rate. Defaults to 5e-6.
76
+ loss_alpha (float, optional): Alpha for MSE loss. Defaults to 1.0.
77
+ weight_decay (float, optional): Weight decay for AdamW. If None, uses optimizer default.
78
+ Defaults to None.
79
+ eval_gene_sets (dict, optional): Dictionary of gene sets for targeted evaluation.
80
+ Keys are names, values are lists of gene IDs.
81
+ Defaults to None.
82
+ """
83
+ super().__init__()
84
+ self.save_hyperparameters()
85
+ self.learning_rate = learning_rate
86
+ self.loss_alpha = loss_alpha # alpha for mse vs. other terms (if any)
87
+ self.weight_decay = weight_decay
88
+ self.eval_gene_sets = eval_gene_sets if eval_gene_sets else {}
89
+
90
+ # Results accumulated per epoch for MetricLogger
91
+ self.epoch_outputs = []
92
+
93
+ def loss_fn(self, y_hat, y):
94
+ """
95
+ Calculates the loss for the model.
96
+
97
+ Currently uses `masked_mse` scaled by `self.loss_alpha`.
98
+
99
+ Args:
100
+ y_hat (torch.Tensor): Predicted values from the model.
101
+ y (torch.Tensor): Ground truth target values.
102
+
103
+ Returns:
104
+ torch.Tensor: The computed loss value.
105
+ """
106
+ mse_term = masked_mse(y_hat, y)
107
+ # Potentially: add other loss terms here, weighted by (1-loss_alpha) if desired
108
+ return self.loss_alpha * mse_term
109
+
110
+ def _common_step(self, batch, batch_idx, step_type):
111
+ """
112
+ A common step for training, validation, and testing.
113
+
114
+ This method unpacks the batch, performs a forward pass, calculates the loss,
115
+ logs the loss, and accumulates outputs for epoch-level metric calculation
116
+ (for validation and test steps).
117
+
118
+ Args:
119
+ batch: The batch of data from the DataLoader. Expected to be a tuple containing
120
+ DNA sequence, Morgan fingerprints, dose, target expression,
121
+ and metadata (gene_id, drug_id, cell_line).
122
+ batch_idx (int): The index of the current batch.
123
+ step_type (str): A string indicating the type of step ('train', 'val', or 'test').
124
+
125
+ Returns:
126
+ torch.Tensor: The loss for the current batch.
127
+ """
128
+ # Batch structure will change after dataset modification:
129
+ # (dna_seq, morgan_fingerprints, dose, target_expression, gene_id, drug_id, cell_line)
130
+ dna_seq, morgan_fingerprints, dose, target_expression, gene_id, drug_id, cell_line = batch
131
+
132
+ y_hat = self(dna_seq, morgan_fingerprints, dose) # Call forward method of derived class
133
+
134
+ loss = self.loss_fn(y_hat, target_expression)
135
+ self.log(f'{step_type}_loss', loss, batch_size=target_expression.shape[0], on_step=(step_type=='train' and False), on_epoch=True, prog_bar=(step_type!='train'))
136
+
137
+ if step_type != 'train':
138
+ # Prepare data for MetricLogger
139
+ batch_size = target_expression.shape[0]
140
+ for i in range(batch_size):
141
+ item_data = {
142
+ 'pred': y_hat[i].detach(),
143
+ 'target': target_expression[i].detach(),
144
+ 'gene_id': gene_id[i],
145
+ 'drug_id': drug_id[i],
146
+ 'cell_line': cell_line[i],
147
+ 'rank': self.trainer.global_rank
148
+ }
149
+ self.epoch_outputs.append(item_data)
150
+ return loss
151
+
152
+ def training_step(self, batch, batch_idx):
153
+ """PyTorch Lightning training step. Calls `_common_step`."""
154
+ return self._common_step(batch, batch_idx, 'train')
155
+
156
+ def validation_step(self, batch, batch_idx, dataloader_idx=0):
157
+ """PyTorch Lightning validation step. Calls `_common_step`."""
158
+ return self._common_step(batch, batch_idx, 'val')
159
+
160
+ def test_step(self, batch, batch_idx, dataloader_idx=0):
161
+ """PyTorch Lightning test step. Calls `_common_step`."""
162
+ return self._common_step(batch, batch_idx, 'test')
163
+
164
+ def on_validation_epoch_start(self):
165
+ """Clears accumulated outputs at the start of each validation epoch."""
166
+ self.epoch_outputs = []
167
+
168
+ def on_test_epoch_start(self):
169
+ """Clears accumulated outputs at the start of each test epoch."""
170
+ self.epoch_outputs = []
171
+
172
+ def configure_optimizers(self):
173
+ """
174
+ Configures the optimizer for the model.
175
+
176
+ Uses AdamW with the specified learning rate and weight decay.
177
+
178
+ Returns:
179
+ torch.optim.Optimizer: The configured AdamW optimizer.
180
+ """
181
+ if self.weight_decay is None:
182
+ optimizer = torch.optim.AdamW(self.parameters(), lr=self.learning_rate)
183
+ else:
184
+ optimizer = torch.optim.AdamW(self.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay)
185
+ return optimizer
186
+
187
+ # --- Enformer + Morgan Fingerprints Model ---
188
+ class LitEnformerSMILES(LitBaseModel): # Consider renaming to LitEnformerMorgan for clarity
189
+ """
190
+ A PyTorch Lightning module that combines genomic sequence information (via Enformer)
191
+ with drug chemical structure (represented by Morgan fingerprints) and drug dose
192
+ to predict gene expression changes.
193
+
194
+ The model architecture consists of three main branches:
195
+ 1. DNA Module: Uses a pre-trained Enformer model (with an adapted head) to extract
196
+ features from a one-hot encoded DNA sequence centered around a gene's TSS.
197
+ 2. Drug Module: Uses pre-computed Morgan fingerprints as the drug representation.
198
+ 3. Dose Module: Directly uses the numerical dose value.
199
+
200
+ Features from these three branches are concatenated and passed through a multi-layer
201
+ fusion head (MLP with ReLU, BatchNorm, Dropout) to produce the final prediction
202
+ of gene expression.
203
+
204
+ Inherits common training and evaluation logic from `LitBaseModel`.
205
+ """
206
+ def __init__(self,
207
+ enformer_model_name: str = 'EleutherAI/enformer-official-rough',
208
+ enformer_target_length: int = -1,
209
+ num_output_tracks_enformer_head: int = 1,
210
+ morgan_fingerprint_dim: int = 2048, # dim of the Morgan fingerprint vector
211
+ dose_input_dim: int = 1,
212
+ fusion_hidden_dim: int = 256,
213
+ final_output_tracks: int = 1,
214
+ learning_rate=5e-6,
215
+ loss_alpha=1.0,
216
+ weight_decay=None,
217
+ eval_gene_sets=None):
218
+ """
219
+ Initializes the LitEnformerSMILES (or LitEnformerMorgan) model.
220
+
221
+ Args:
222
+ enformer_model_name (str, optional): Name or path of the pre-trained Enformer model.
223
+ enformer_target_length (int, optional): Target length for Enformer's internal pooling.
224
+ num_output_tracks_enformer_head (int, optional): Output features from Enformer head.
225
+ morgan_fingerprint_dim (int, optional): Dimensionality of the Morgan fingerprint vector
226
+ (e.g., 2048 for ECFP4). Defaults to 2048.
227
+ dose_input_dim (int, optional): Dimensionality of the drug dose input. Defaults to 1.
228
+ fusion_hidden_dim (int, optional): Hidden dimension for the fusion MLP. Defaults to 256.
229
+ final_output_tracks (int, optional): Number of final output values. Defaults to 1.
230
+ learning_rate (float, optional): Learning rate. Defaults to 5e-6.
231
+ loss_alpha (float, optional): Weight for MSE loss. Defaults to 1.0.
232
+ weight_decay (float, optional): Weight decay. Defaults to None.
233
+ eval_gene_sets (dict, optional): Gene sets for targeted evaluation. Defaults to None.
234
+ """
235
+ super().__init__(learning_rate, loss_alpha, weight_decay, eval_gene_sets)
236
+ self.save_hyperparameters(
237
+ "enformer_model_name", "enformer_target_length",
238
+ "num_output_tracks_enformer_head", "morgan_fingerprint_dim",
239
+ "dose_input_dim", "fusion_hidden_dim", "final_output_tracks",
240
+ "learning_rate", "loss_alpha", "weight_decay"
241
+ )
242
+
243
+ # 1. DNA Module (Enformer with HeadAdapter)
244
+ enformer_pretrained = Enformer.from_pretrained(
245
+ self.hparams.enformer_model_name,
246
+ target_length=self.hparams.enformer_target_length
247
+ )
248
+ self.dna_module = HeadAdapterWrapper(
249
+ enformer=enformer_pretrained,
250
+ num_tracks=self.hparams.num_output_tracks_enformer_head,
251
+ post_transformer_embed=False,
252
+ output_activation=nn.Identity()
253
+ )
254
+
255
+ # 2. Drug Module (Morgan Fingerprints are provided as input directly)
256
+ # No layers needed here as fingerprints are pre-computed.
257
+ # The self.hparams.morgan_fingerprint_dim defines the expected input dimension.
258
+
259
+ # 3. Fusion Head
260
+ # Input dimension uses morgan_fingerprint_dim
261
+ fusion_input_dim = self.hparams.num_output_tracks_enformer_head + self.hparams.morgan_fingerprint_dim + self.hparams.dose_input_dim
262
+ self.fusion_head = nn.Sequential(
263
+ nn.Linear(fusion_input_dim, self.hparams.fusion_hidden_dim),
264
+ nn.ReLU(),
265
+ nn.BatchNorm1d(self.hparams.fusion_hidden_dim),
266
+ nn.Dropout(0.25),
267
+ nn.Linear(self.hparams.fusion_hidden_dim, self.hparams.fusion_hidden_dim // 2),
268
+ nn.ReLU(),
269
+ nn.BatchNorm1d(self.hparams.fusion_hidden_dim // 2),
270
+ nn.Dropout(0.1),
271
+ nn.Linear(self.hparams.fusion_hidden_dim // 2, self.hparams.final_output_tracks)
272
+ )
273
+
274
+ def forward(self, dna_seq, morgan_fingerprints, dose):
275
+ """
276
+ Defines the forward pass of the LitEnformerSMILES model using Morgan Fingerprints.
277
+
278
+ Args:
279
+ dna_seq (torch.Tensor): Batch of one-hot encoded DNA sequences.
280
+ Shape: (batch_size, sequence_length, 4).
281
+ morgan_fingerprints (torch.Tensor): Batch of pre-computed Morgan fingerprint vectors.
282
+ Shape: (batch_size, morgan_fingerprint_dim).
283
+ dose (torch.Tensor): Batch of drug dose values.
284
+ Shape: (batch_size, dose_input_dim).
285
+
286
+ Returns:
287
+ torch.Tensor: The model's prediction. Shape: (batch_size, final_output_tracks).
288
+ """
289
+ # --- DNA Processing ---
290
+ dna_out_intermediate = self.dna_module(dna_seq, freeze_enformer=False)
291
+ center_seq_idx = dna_out_intermediate.shape[1] // 2
292
+ dna_features = dna_out_intermediate[:, center_seq_idx, :]
293
+
294
+ # --- Drug Processing (Morgan Fingerprints) ---
295
+ # Morgan fingerprints are directly used as features.
296
+ smiles_features = morgan_fingerprints # Shape: (batch_size, morgan_fingerprint_dim)
297
+
298
+ # --- Dose Processing ---
299
+ if dose.ndim == 1:
300
+ dose = dose.unsqueeze(-1)
301
+
302
+ # --- Feature Combination & Final Prediction ---
303
+ combined_features = torch.cat([dna_features, smiles_features, dose], dim=1)
304
+ prediction = self.fusion_head(combined_features)
305
+ return prediction
306
+
307
+ # --- Metrics Logging Callback ---
308
+ class MetricLogger(pl.Callback):
309
+ """
310
+ A PyTorch Lightning Callback for comprehensive metric calculation and logging.
311
+
312
+ This callback accumulates predictions and targets during validation and test epochs.
313
+ At the end of these epochs, it:
314
+ 1. Processes the accumulated outputs into a pandas DataFrame.
315
+ 2. Saves the raw predictions and targets for the epoch to a CSV file.
316
+ 3. Logs a sample of these raw predictions as a W&B Table if WandbLogger is used.
317
+ 4. Calculates overall performance metrics (MSE, Pearson, R2) for the epoch.
318
+ 5. If `eval_gene_sets` are provided in the LightningModule, calculates metrics for these specific gene subsets.
319
+ 6. Calculates metrics per cell line if 'cell_line' information is available in the outputs.
320
+ 7. Logs all calculated metrics to the LightningModule's logger.
321
+
322
+ Attributes:
323
+ save_dir_prefix (str): Prefix for the directory where metric CSVs will be saved.
324
+ current_epoch_data (list): List to accumulate dictionaries of pred/target/metadata per item.
325
+ """
326
+ def __init__(self, save_dir_prefix="results"):
327
+ """
328
+ Initializes the MetricLogger callback.
329
+
330
+ Args:
331
+ save_dir_prefix (str, optional): Directory prefix for saving metrics files.
332
+ Defaults to "results".
333
+ """
334
+ super().__init__()
335
+ self.save_dir_prefix = save_dir_prefix
336
+ self.current_epoch_data = []
337
+
338
+ def _process_epoch_outputs(self, pl_module, stage):
339
+ """
340
+ Processes the raw outputs collected during an epoch into a pandas DataFrame.
341
+
342
+ Converts tensor data for 'pred' and 'target' columns to NumPy/Python native types.
343
+
344
+ Args:
345
+ pl_module (pl.LightningModule): The LightningModule instance.
346
+ stage (str): The current stage (e.g., "validation", "test").
347
+
348
+ Returns:
349
+ pd.DataFrame: A DataFrame containing the processed epoch outputs.
350
+ Returns an empty DataFrame if no outputs were collected.
351
+ """
352
+ if not hasattr(pl_module, 'epoch_outputs') or not pl_module.epoch_outputs:
353
+ warn(f"No outputs collected (pl_module.epoch_outputs is missing or empty) during {stage} epoch for MetricLogger.")
354
+ return pd.DataFrame()
355
+
356
+ df = pd.DataFrame(pl_module.epoch_outputs)
357
+
358
+ for col in ['pred', 'target']:
359
+ if col in df.columns and not df[col].empty:
360
+ if isinstance(df[col].iloc[0], torch.Tensor):
361
+ df[col] = df[col].apply(lambda x: x.cpu().float().numpy().item() if x.numel() == 1 else x.cpu().float().numpy())
362
+ return df
363
+
364
+ def on_validation_epoch_end(self, trainer, pl_module):
365
+ """Hook called at the end of the validation epoch."""
366
+ if hasattr(pl_module, 'epoch_outputs') and pl_module.epoch_outputs:
367
+ self.current_epoch_data = self._process_epoch_outputs(pl_module, "validation")
368
+ if not self.current_epoch_data.empty:
369
+ self._log_and_save_metrics(trainer, pl_module, "validation")
370
+ else:
371
+ warn("MetricLogger: pl_module.epoch_outputs not found or empty at on_validation_epoch_end.")
372
+
373
+ def on_test_epoch_end(self, trainer, pl_module):
374
+ """Hook called at the end of the test epoch."""
375
+ if hasattr(pl_module, 'epoch_outputs') and pl_module.epoch_outputs:
376
+ self.current_epoch_data = self._process_epoch_outputs(pl_module, "test")
377
+ if not self.current_epoch_data.empty:
378
+ self._log_and_save_metrics(trainer, pl_module, "test")
379
+ else:
380
+ warn("MetricLogger: pl_module.epoch_outputs not found or empty at on_test_epoch_end.")
381
+
382
+
383
+ def _log_and_save_metrics(self, trainer, pl_module, stage):
384
+ """
385
+ Calculates, logs, and saves metrics for the current stage and epoch.
386
+
387
+ Args:
388
+ trainer (pl.Trainer): The PyTorch Lightning Trainer instance.
389
+ pl_module (pl.LightningModule): The LightningModule instance.
390
+ stage (str): The current stage (e.g., "validation", "test").
391
+ """
392
+ epoch = trainer.current_epoch if trainer.current_epoch is not None else -1
393
+ save_dir = getattr(pl_module.hparams, 'save_dir',
394
+ os.path.join(self.save_dir_prefix, f"run_{trainer.logger.version if trainer.logger else 'local'}"))
395
+ os.makedirs(save_dir, exist_ok=True)
396
+
397
+ raw_preds_path = os.path.join(save_dir, f"{stage}_predictions_epoch_{epoch}.csv")
398
+ self.current_epoch_data.to_csv(raw_preds_path, index=False)
399
+
400
+ if trainer.logger and hasattr(trainer.logger, 'experiment') and isinstance(trainer.logger.experiment, wandb.sdk.wandb_run.Run):
401
+ try:
402
+ trainer.logger.experiment.log({f"{stage}_raw_predictions_epoch_{epoch}": wandb.Table(dataframe=self.current_epoch_data.head(1000))})
403
+ except Exception as e:
404
+ warn(f"MetricLogger: Failed to log raw predictions table to W&B: {e}")
405
+
406
+ overall_metrics = self._calculate_metrics_for_group(self.current_epoch_data, pl_module.device)
407
+ if overall_metrics:
408
+ pl_module.log_dict({f"{stage}_{k}_epoch": v for k, v in overall_metrics.items()}, sync_dist=True)
409
+
410
+ if hasattr(pl_module, 'eval_gene_sets') and pl_module.eval_gene_sets and isinstance(pl_module.eval_gene_sets, dict) and 'gene_id' in self.current_epoch_data.columns:
411
+ for split_name, gene_list in pl_module.eval_gene_sets.items():
412
+ if not gene_list: continue
413
+ split_df = self.current_epoch_data[self.current_epoch_data['gene_id'].isin(gene_list)]
414
+ if not split_df.empty:
415
+ split_metrics = self._calculate_metrics_for_group(split_df, pl_module.device)
416
+ if split_metrics:
417
+ pl_module.log_dict({f"{stage}_{split_name}_genes_{k}_epoch": v for k, v in split_metrics.items()}, sync_dist=True)
418
+
419
+ if 'cell_line' in self.current_epoch_data.columns:
420
+ for cell_line, group_df in self.current_epoch_data.groupby('cell_line'):
421
+ cl_metrics = self._calculate_metrics_for_group(group_df, pl_module.device)
422
+ if cl_metrics:
423
+ pl_module.log_dict({f"{stage}_{cell_line}_cell_line_{k}_epoch": v for k,v in cl_metrics.items()}, sync_dist=True)
424
+
425
+
426
+ def _calculate_metrics_for_group(self, df_group, device):
427
+ """
428
+ Calculates regression metrics (MSE, Pearson, R2) for a given group of predictions.
429
+
430
+ Args:
431
+ df_group (pd.DataFrame): DataFrame containing 'pred' and 'target' columns for the group.
432
+ device (torch.device): The device to perform calculations on.
433
+
434
+ Returns:
435
+ dict: A dictionary of calculated metrics (mse, pearson, r2). Returns empty if data is insufficient.
436
+ """
437
+ if df_group.empty or 'pred' not in df_group.columns or 'target' not in df_group.columns:
438
+ return {}
439
+
440
+ preds_np = np.array(df_group['pred'].tolist(), dtype=np.float32)
441
+ targets_np = np.array(df_group['target'].tolist(), dtype=np.float32)
442
+
443
+ preds = torch.tensor(preds_np).to(device)
444
+ targets = torch.tensor(targets_np).to(device)
445
+
446
+ if preds.ndim == 1:
447
+ preds = preds.squeeze()
448
+ targets = targets.squeeze()
449
+
450
+ if preds.numel() == 0 or targets.numel() == 0 or preds.shape != targets.shape :
451
+ warn(f"Skipping metrics calculation for a group due to mismatched or empty preds/targets. Pred shape: {preds.shape}, Target shape: {targets.shape}")
452
+ return {}
453
+
454
+ mse_val_tensor = masked_mse(preds.unsqueeze(-1) if preds.ndim==1 else preds,
455
+ targets.unsqueeze(-1) if targets.ndim==1 else targets)
456
+ calculated_metrics = {'mse': mse_val_tensor.item()}
457
+
458
+ if preds.numel() < 2:
459
+ warn(f"Skipping Pearson/R2 for a group with < 2 samples. Found {preds.numel()} samples. Only MSE will be reported.")
460
+ return calculated_metrics
461
+
462
+ preds_for_corr = preds.squeeze()
463
+ targets_for_corr = targets.squeeze()
464
+
465
+ if preds_for_corr.shape != targets_for_corr.shape or preds_for_corr.ndim > 1 and preds_for_corr.shape[1] >1:
466
+ warn(f"Skipping Pearson/R2 due to incompatible shapes after squeeze for correlation. Pred: {preds_for_corr.shape}, Target: {targets_for_corr.shape}")
467
+ return calculated_metrics
468
+
469
+ try:
470
+ pearson_fn = PearsonCorrCoef().to(device)
471
+ pearson_val = pearson_fn(preds_for_corr, targets_for_corr)
472
+ calculated_metrics['pearson'] = pearson_val.item()
473
+ except Exception as e:
474
+ warn(f"Could not compute Pearson Correlation: {e}. Preds shape: {preds_for_corr.shape}, Targets shape: {targets_for_corr.shape}")
475
+
476
+ try:
477
+ r2_fn = R2Score().to(device)
478
+ r2_val = r2_fn(preds_for_corr, targets_for_corr)
479
+ calculated_metrics['r2'] = r2_val.item()
480
+ except Exception as e:
481
+ warn(f"Could not compute R2 Score: {e}. Preds shape: {preds_for_corr.shape}, Targets shape: {targets_for_corr.shape}")
482
+
483
+ return calculated_metrics