from torch import Tensor, nn, optim from torch.nn import functional as F from .base_model.classification import LightningClassification from .metrics.classification import classification_metrics from .modules.sample_torch_module import UselessLayer class UselessClassification(LightningClassification): def __init__(self, n_classes: int, lr: float, **kwargs) -> None: super(UselessClassification).__init__() self.save_hyperparameters() self.n_classes = n_classes self.lr = lr self.main = nn.Sequential(UselessLayer(), nn.GELU()) def forward(self, x: Tensor) -> Tensor: return self.main(x) def loss(self, input: Tensor, target: Tensor) -> Tensor: return F.mse_loss(input=input, target=target) def configure_optimizers(self): optimizer = optim.Adam(params=self.parameters(), lr=self.lr) return optimizer def training_step(self, batch, batch_idx): x, y = batch logits = self.forward(x) loss = self.loss(input=x, target=y) metrics = classification_metrics(preds=logits, target=y, num_classes=self.n_classes) self.train_batch_output.append({'loss': loss, **metrics}) return loss def validation_step(self, batch, batch_idx): x, y = batch logits = self.forward(x) loss = self.loss(input=x, target=y) metrics = classification_metrics(preds=logits, target=y, num_classes=self.n_classes) self.validation_batch_output.append({'loss': loss, **metrics}) return loss