camie-tagger / onnx_inference.py
Camais03's picture
V1.5
29b445b verified
import onnxruntime as ort
import torch
import json
import numpy as np
from PIL import Image
import torchvision.transforms as transforms
import os
import time
def preprocess_image(image_path, image_size=512):
"""Process an image for inference"""
if not os.path.exists(image_path):
raise ValueError(f"Image not found at path: {image_path}")
# Initialize transform
transform = transforms.Compose([
transforms.ToTensor(),
])
try:
with Image.open(image_path) as img:
# Convert RGBA or Palette images to RGB
if img.mode in ('RGBA', 'P'):
img = img.convert('RGB')
# Get original dimensions
width, height = img.size
aspect_ratio = width / height
# Calculate new dimensions to maintain aspect ratio
if aspect_ratio > 1:
new_width = image_size
new_height = int(new_width / aspect_ratio)
else:
new_height = image_size
new_width = int(new_height * aspect_ratio)
# Resize with LANCZOS filter
img = img.resize((new_width, new_height), Image.Resampling.LANCZOS)
# Create new image with padding
new_image = Image.new('RGB', (image_size, image_size), (0, 0, 0))
paste_x = (image_size - new_width) // 2
paste_y = (image_size - new_height) // 2
new_image.paste(img, (paste_x, paste_y))
# Apply transforms
img_tensor = transform(new_image)
return img_tensor
except Exception as e:
raise Exception(f"Error processing {image_path}: {str(e)}")
def test_onnx_model(model_path, metadata_path, image_path, threshold=0.325):
"""Test an ONNX model with a single image"""
# Load metadata
with open(metadata_path, 'r') as f:
metadata = json.load(f)
# Load ONNX model
print(f"Loading ONNX model from {model_path}")
try:
# Try with CUDA
session = ort.InferenceSession(
model_path,
providers=['CUDAExecutionProvider', 'CPUExecutionProvider']
)
print(f"Using providers: {session.get_providers()}")
except Exception as e:
print(f"CUDA not available, using CPU: {e}")
session = ort.InferenceSession(
model_path,
providers=['CPUExecutionProvider']
)
print(f"Using providers: {session.get_providers()}")
# Preprocess image
print(f"Processing image: {image_path}")
img_tensor = preprocess_image(image_path)
img_numpy = img_tensor.unsqueeze(0).numpy() # Add batch dimension and convert to numpy
# Get input name
input_name = session.get_inputs()[0].name
print(f"Input name: {input_name}")
# Run inference
print("Running inference...")
start_time = time.time()
outputs = session.run(None, {input_name: img_numpy})
inference_time = time.time() - start_time
print(f"Inference completed in {inference_time:.4f} seconds")
# Process outputs
initial_probs = 1.0 / (1.0 + np.exp(-outputs[0])) # Apply sigmoid
refined_probs = 1.0 / (1.0 + np.exp(-outputs[1])) if len(outputs) > 1 else initial_probs
# Apply threshold
predictions = (refined_probs >= threshold).astype(np.float32)
# Get top tags
indices = np.where(predictions[0] > 0)[0]
# Group by category
tags_by_category = {}
for idx in indices:
idx_str = str(idx)
tag_name = metadata['idx_to_tag'].get(idx_str, f"unknown-{idx}")
category = metadata['tag_to_category'].get(tag_name, "general")
if category not in tags_by_category:
tags_by_category[category] = []
prob = float(refined_probs[0, idx])
tags_by_category[category].append((tag_name, prob))
# Sort by probability
for category in tags_by_category:
tags_by_category[category] = sorted(tags_by_category[category], key=lambda x: x[1], reverse=True)
# Print results
print("\nPredicted tags:")
for category in sorted(tags_by_category.keys()):
print(f"\n{category.capitalize()}:")
for tag, prob in tags_by_category[category]:
print(f" {tag}: {prob:.3f}")
return tags_by_category
# Example usage:
test_onnx_model('model_initial.onnx', 'model_initial_metadata.json', 'test_image.jpg')