dance-classifier / models /training_environment.py
waidhoferj's picture
Refactor config style and reorganize files
557fb53
raw
history blame
2.93 kB
from models.utils import calculate_metrics
import pytorch_lightning as pl
import torch
import torch.nn as nn
class TrainingEnvironment(pl.LightningModule):
def __init__(
self,
model: nn.Module,
criterion: nn.Module,
config: dict,
learning_rate=1e-4,
log_spectrograms=False,
*args,
**kwargs,
):
super().__init__(*args, **kwargs)
self.model = model
self.criterion = criterion
self.learning_rate = learning_rate
self.log_spectrograms = log_spectrograms
self.config = config
self.has_multi_label_predictions = (
not type(criterion).__name__ == "CrossEntropyLoss"
)
self.save_hyperparameters(
{
"model": type(model).__name__,
"loss": type(criterion).__name__,
"config": config,
**kwargs,
}
)
def training_step(
self, batch: tuple[torch.Tensor, torch.TensorType], batch_index: int
) -> torch.Tensor:
features, labels = batch
outputs = self.model(features)
loss = self.criterion(outputs, labels)
metrics = calculate_metrics(
outputs,
labels,
prefix="train/",
multi_label=self.has_multi_label_predictions,
)
self.log_dict(metrics, prog_bar=True)
# Log spectrograms
if self.log_spectrograms and batch_index % 100 == 0:
tensorboard = self.logger.experiment
img_index = torch.randint(0, len(features), (1,)).item()
img = features[img_index][0]
img = (img - img.min()) / (img.max() - img.min())
tensorboard.add_image(
f"batch: {batch_index}, element: {img_index}", img, 0, dataformats="HW"
)
return loss
def validation_step(
self, batch: tuple[torch.Tensor, torch.TensorType], batch_index: int
):
x, y = batch
preds = self.model(x)
metrics = calculate_metrics(
preds, y, prefix="val/", multi_label=self.has_multi_label_predictions
)
metrics["val/loss"] = self.criterion(preds, y)
self.log_dict(metrics, prog_bar=True)
def test_step(self, batch: tuple[torch.Tensor, torch.TensorType], batch_index: int):
x, y = batch
preds = self.model(x)
self.log_dict(
calculate_metrics(
preds, y, prefix="test/", multi_label=self.has_multi_label_predictions
),
prog_bar=True,
)
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, "min")
return {
"optimizer": optimizer,
"lr_scheduler": scheduler,
"monitor": "val/loss",
}