YAML Metadata Warning: empty or missing yaml metadata in repo card (https://huggingface.co/docs/hub/model-cards#model-card-metadata)
import torch
from PIL import Image
from torchvision import transforms
from transformers import ViTModel, ViTConfig
from safetensors.torch import load_file as safetensors_load_file

# Define a transform to convert PIL images to tensors
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

class ViTSalesModel(nn.Module):
    def __init__(self):
        super(ViTSalesModel, self).__init__()
        self.vit = ViTModel.from_pretrained('google/vit-base-patch16-224')
        self.classifier = nn.Linear(self.vit.config.hidden_size, 1)
    
    def forward(self, pixel_values, labels=None):
        outputs = self.vit(pixel_values=pixel_values)
        cls_output = outputs.last_hidden_state[:, 0, :]  # Take the [CLS] token
        sales = self.classifier(cls_output)
        loss = None
        if labels is not None:
            loss_fct = nn.MSELoss()
            loss = loss_fct(sales.view(-1), labels.view(-1))
        return (loss, sales) if loss is not None else sales

model = ViTSalesModel()

# Load the saved model checkpoint
checkpoint_path = "/content/results/checkpoint-940/model.safetensors"
state_dict = safetensors_load_file(checkpoint_path)
model.load_state_dict(state_dict)
model.eval()

# Maximum sales value for de-normalization (from training)
max_sales_value = 100000  # Replace with the actual max sales value used during training

def predict_sales(image_path):
    # Load and preprocess the image
    image = Image.open(image_path).convert('RGB')
    image = transform(image).unsqueeze(0)  # Add batch dimension
    
    with torch.no_grad():
        # Run the model
        prediction = model(image)
    
    print(prediction)
    # De-normalize the prediction
    sales_prediction = prediction.item() * max_sales_value
    return sales_prediction

# Example usage
image_path = "/content/0000.png"
predicted_sales = predict_sales(image_path)
print(f"Predicted sales: {predicted_sales}")
Downloads last month

-

Downloads are not tracked for this model. How to track
Safetensors
Model size
86.4M params
Tensor type
F32
ยท
Inference Providers NEW
This model isn't deployed by any Inference Provider. ๐Ÿ™‹ Ask for provider support