|
import torch |
|
import torchvision |
|
from torchvision.models import ResNet50_Weights |
|
import swanlab |
|
from torch.utils.data import DataLoader |
|
from load_datasets import DatasetLoader |
|
import os |
|
|
|
|
|
|
|
def train(model, device, train_dataloader, optimizer, criterion, epoch): |
|
model.train() |
|
for iter, (inputs, labels) in enumerate(train_dataloader): |
|
inputs, labels = inputs.to(device), labels.to(device) |
|
optimizer.zero_grad() |
|
outputs = model(inputs) |
|
loss = criterion(outputs, labels) |
|
loss.backward() |
|
optimizer.step() |
|
print('Epoch [{}/{}], Iteration [{}/{}], Loss: {:.4f}'.format(epoch, num_epochs, iter + 1, len(TrainDataLoader), |
|
loss.item())) |
|
swanlab.log({"train_loss": loss.item()}) |
|
|
|
|
|
|
|
def test(model, device, test_dataloader, epoch): |
|
model.eval() |
|
correct = 0 |
|
total = 0 |
|
with torch.no_grad(): |
|
for inputs, labels in test_dataloader: |
|
inputs, labels = inputs.to(device), labels.to(device) |
|
outputs = model(inputs) |
|
_, predicted = torch.max(outputs.data, 1) |
|
total += labels.size(0) |
|
correct += (predicted == labels).sum().item() |
|
accuracy = correct / total * 100 |
|
print('Accuracy: {:.2f}%'.format(accuracy)) |
|
swanlab.log({"test_acc": accuracy}) |
|
|
|
|
|
if __name__ == "__main__": |
|
num_epochs = 20 |
|
lr = 1e-4 |
|
batch_size = 16 |
|
num_classes = 2 |
|
|
|
try: |
|
use_mps = torch.backends.mps.is_available() |
|
except AttributeError: |
|
use_mps = False |
|
|
|
if torch.cuda.is_available(): |
|
device = "cuda" |
|
elif use_mps: |
|
device = "mps" |
|
else: |
|
device = "cpu" |
|
|
|
|
|
swanlab.init( |
|
experiment_name="ResNet50", |
|
description="Train ResNet50 for cat and dog classification.", |
|
config={ |
|
"model": "resnet50", |
|
"optim": "Adam", |
|
"lr": lr, |
|
"batch_size": batch_size, |
|
"num_epochs": num_epochs, |
|
"num_class": num_classes, |
|
"device": device, |
|
} |
|
) |
|
|
|
TrainDataset = DatasetLoader("datasets/train.csv") |
|
ValDataset = DatasetLoader("datasets/val.csv") |
|
TrainDataLoader = DataLoader(TrainDataset, batch_size=batch_size, shuffle=True) |
|
ValDataLoader = DataLoader(ValDataset, batch_size=batch_size, shuffle=False) |
|
|
|
|
|
model = torchvision.models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V2) |
|
|
|
|
|
in_features = model.fc.in_features |
|
model.fc = torch.nn.Linear(in_features, num_classes) |
|
|
|
|
|
model.to(torch.device(device)) |
|
criterion = torch.nn.CrossEntropyLoss() |
|
optimizer = torch.optim.Adam(model.parameters(), lr=lr) |
|
|
|
|
|
for epoch in range(1, num_epochs + 1): |
|
train(model, device, TrainDataLoader, optimizer, criterion, epoch) |
|
|
|
if epoch % 4 == 0: |
|
accuracy = test(model, device, ValDataLoader, epoch) |
|
|
|
if not os.path.exists("checkpoint"): |
|
os.makedirs("checkpoint") |
|
torch.save(model.state_dict(), 'checkpoint/latest_checkpoint.pth') |
|
print("Training complete") |