Edit model card

Model Card for Custom CNN Model for Garbage Classification

This model card provides information about a custom Convolutional Neural Network (CNN) designed for classifying images of garbage items into predefined categories.

Model Details

Model Description

The CNN architecture (CNNModel) consists of:

  • Four convolutional layers with batch normalization, ReLU activation, max pooling, and dropout for feature extraction.
  • Two fully connected layers for classification.

The model is trained using the Adam optimizer with cross-entropy loss and a ReduceLROnPlateau scheduler.

Model Source

Uses

Direct Use

This model can be used to classify images of garbage items into specific categories.

Downstream Use

Fine-tuning the model on a specific garbage classification dataset or integrating it into an application for waste management.

Out-of-Scope Use

This model is not suitable for general image classification tasks outside of garbage classification.

Bias, Risks, and Limitations

The model's performance may be affected by biases in the training data, such as underrepresentation of certain garbage types.

Recommendations

Users should be aware of the model's limitations and consider domain-specific data augmentation to improve performance.

How to Get Started with the Model

Example Usage

import torch
from torchvision import transforms
from PIL import Image

# Define your CNN model
class CNNModel(nn.Module):
    def __init__(self, num_classes):
        # Define layers as per your CNNModel definition

    def forward(self, x):
        # Define forward pass as per your CNNModel forward method

# Set device to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load the model
model = CNNModel(num_classes=num_classes).to(device)

# Load the best trained weights
model.load_state_dict(torch.load('best_model.pth'))
model.eval()

# Preprocess an image
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

img_path = 'path_to_your_image.jpg'  # Replace with your image path
img = Image.open(img_path)
input_tensor = transform(img)
input_batch = input_tensor.unsqueeze(0)

# Use the model for prediction
with torch.no_grad():
    output = model(input_batch)

# Get the predicted class
_, predicted = torch.max(output, 1)
predicted_class = train_dataset.classes[predicted.item()]
print(f'Predicted class: {predicted_class}')
Downloads last month
0
Inference Examples
Inference API (serverless) does not yet support flair models for this pipeline type.

Dataset used to train naamalia23/06_computer_vision_cnn