Spaces:
Runtime error
Runtime error
from models.utils import calculate_metrics | |
import pytorch_lightning as pl | |
import torch | |
import torch.nn as nn | |
class TrainingEnvironment(pl.LightningModule): | |
def __init__( | |
self, | |
model: nn.Module, | |
criterion: nn.Module, | |
config: dict, | |
learning_rate=1e-4, | |
log_spectrograms=False, | |
*args, | |
**kwargs, | |
): | |
super().__init__(*args, **kwargs) | |
self.model = model | |
self.criterion = criterion | |
self.learning_rate = learning_rate | |
self.log_spectrograms = log_spectrograms | |
self.config = config | |
self.has_multi_label_predictions = ( | |
not type(criterion).__name__ == "CrossEntropyLoss" | |
) | |
self.save_hyperparameters( | |
{ | |
"model": type(model).__name__, | |
"loss": type(criterion).__name__, | |
"config": config, | |
**kwargs, | |
} | |
) | |
def training_step( | |
self, batch: tuple[torch.Tensor, torch.TensorType], batch_index: int | |
) -> torch.Tensor: | |
features, labels = batch | |
outputs = self.model(features) | |
loss = self.criterion(outputs, labels) | |
metrics = calculate_metrics( | |
outputs, | |
labels, | |
prefix="train/", | |
multi_label=self.has_multi_label_predictions, | |
) | |
self.log_dict(metrics, prog_bar=True) | |
# Log spectrograms | |
if self.log_spectrograms and batch_index % 100 == 0: | |
tensorboard = self.logger.experiment | |
img_index = torch.randint(0, len(features), (1,)).item() | |
img = features[img_index][0] | |
img = (img - img.min()) / (img.max() - img.min()) | |
tensorboard.add_image( | |
f"batch: {batch_index}, element: {img_index}", img, 0, dataformats="HW" | |
) | |
return loss | |
def validation_step( | |
self, batch: tuple[torch.Tensor, torch.TensorType], batch_index: int | |
): | |
x, y = batch | |
preds = self.model(x) | |
metrics = calculate_metrics( | |
preds, y, prefix="val/", multi_label=self.has_multi_label_predictions | |
) | |
metrics["val/loss"] = self.criterion(preds, y) | |
self.log_dict(metrics, prog_bar=True) | |
def test_step(self, batch: tuple[torch.Tensor, torch.TensorType], batch_index: int): | |
x, y = batch | |
preds = self.model(x) | |
self.log_dict( | |
calculate_metrics( | |
preds, y, prefix="test/", multi_label=self.has_multi_label_predictions | |
), | |
prog_bar=True, | |
) | |
def configure_optimizers(self): | |
optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate) | |
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, "min") | |
return { | |
"optimizer": optimizer, | |
"lr_scheduler": scheduler, | |
"monitor": "val/loss", | |
} | |