ResNet18 Fine-Tuned on CIFAR-10

This model is a fine-tuned version of ResNet18 (originally pretrained on ImageNet) on the CIFAR-10 dataset. It achieves the following results on the validation/test set:

  • Validation Accuracy: 88.60%

Model description

  • Architecture: ResNet18 with the final fully-connected layer replaced by a 10-class output layer for CIFAR-10 (airplane, automobile, bird, cat, deer, dog, frog, horse, ship, truck).
  • Pretrained Weights: ImageNet1K
  • Fine-Tuning: The model was fine-tuned on CIFAR-10 images resized to 128×128 pixels.
  • Data Augmentation: Random horizontal flip, random rotation, normalization to mean=0.5 and std=0.5.

Intended uses & limitations

  • Intended use: Educational/demo purposes or as a starting point for further fine-tuning on similar image classification tasks.
  • Not intended for: Production-critical tasks without further evaluation, as CIFAR-10 is relatively small-scale, and the model may not generalize to non-CIFAR data without additional fine-tuning.

Training procedure

Hyperparameters (approximate):

  • optimizer: Adam
  • learning_rate: 1e-3
  • batch_size: 32
  • num_epochs: 15

GPU/CPU:

  • This model was trained on a single GPU (torch.device("cuda")) if available, otherwise CPU.

Training logs (for each epoch on the training set):

Epoch Training Loss Training Accuracy Validation Accuracy
1 0.7013 76.52% -
2 0.4248 85.64% -
3 0.3185 89.07% -
4 0.2341 92.06% -
5 0.1762 93.86% -
6 0.1302 95.55% -
7 0.1085 96.31% -
8 0.0925 96.82% -
9 0.0765 97.37% -
10 0.0683 97.68% -
11 0.0655 97.83% -
12 0.0548 98.18% -
13 0.0513 98.27% -
14 0.0461 98.49% -
15 0.0470 98.41% 88.60%

Note: Validation accuracy was computed at the end of training (final epoch).


Usage

Below is a sample usage snippet in Python. Replace username/model_repo_name with the actual model repo id on Hugging Face.

import torch
import torch.nn as nn
from torchvision import models, transforms
from huggingface_hub import hf_hub_download
from PIL import Image

# Download the weights from the Hugging Face Hub
ckpt_path = hf_hub_download(repo_id="username/model_repo_name", filename="cnn_model.pth")

# Define the same model architecture
model = models.resnet18(pretrained=False)
model.fc = nn.Linear(model.fc.in_features, 10)  # for CIFAR-10
model.load_state_dict(torch.load(ckpt_path, map_location="cpu"))
model.eval()

# Define transforms
transform = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])

# Example inference
image = Image.open("your_image.jpg").convert("RGB")
input_tensor = transform(image).unsqueeze(0)  # add batch dimension
with torch.no_grad():
    logits = model(input_tensor)
    predicted_class = logits.argmax(dim=1).item()

print("Predicted class ID:", predicted_class)
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Model tree for teguhteja/ttm_cnn_model

Finetuned
(27)
this model

Dataset used to train teguhteja/ttm_cnn_model