|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from torchvision.models import efficientnet_v2_l, EfficientNet_V2_L_Weights |
|
from PIL import Image |
|
from typing import Optional |
|
import torchvision.transforms as transforms |
|
import os |
|
import json |
|
|
|
class InitialOnlyImageTagger(nn.Module): |
|
""" |
|
A lightweight version of ImageTagger that only includes the backbone and initial classifier. |
|
This model uses significantly less VRAM than the full model. |
|
""" |
|
def __init__(self, total_tags, dataset, model_name='efficientnet_v2_l', |
|
dropout=0.1, pretrained=True): |
|
super().__init__() |
|
|
|
self._flags = { |
|
'debug': False, |
|
'model_stats': False |
|
} |
|
|
|
|
|
self.dataset = dataset |
|
self.embedding_dim = 1280 |
|
|
|
|
|
if model_name == 'efficientnet_v2_l': |
|
weights = EfficientNet_V2_L_Weights.DEFAULT if pretrained else None |
|
self.backbone = efficientnet_v2_l(weights=weights) |
|
self.backbone.classifier = nn.Identity() |
|
|
|
|
|
self.spatial_pool = nn.AdaptiveAvgPool2d((1, 1)) |
|
|
|
|
|
self.initial_classifier = nn.Sequential( |
|
nn.Linear(self.embedding_dim, self.embedding_dim * 2), |
|
nn.LayerNorm(self.embedding_dim * 2), |
|
nn.GELU(), |
|
nn.Dropout(dropout), |
|
nn.Linear(self.embedding_dim * 2, self.embedding_dim), |
|
nn.LayerNorm(self.embedding_dim), |
|
nn.GELU(), |
|
nn.Linear(self.embedding_dim, total_tags) |
|
) |
|
|
|
|
|
self.temperature = nn.Parameter(torch.ones(1) * 1.5) |
|
|
|
@property |
|
def debug(self): |
|
return self._flags['debug'] |
|
|
|
@debug.setter |
|
def debug(self, value): |
|
self._flags['debug'] = value |
|
|
|
@property |
|
def model_stats(self): |
|
return self._flags['model_stats'] |
|
|
|
@model_stats.setter |
|
def model_stats(self, value): |
|
self._flags['model_stats'] = value |
|
|
|
def preprocess_image(self, image_path, image_size=512): |
|
"""Process an image for inference using same preprocessing as training""" |
|
if not os.path.exists(image_path): |
|
raise ValueError(f"Image not found at path: {image_path}") |
|
|
|
|
|
transform = transforms.Compose([ |
|
transforms.ToTensor(), |
|
]) |
|
|
|
try: |
|
with Image.open(image_path) as img: |
|
|
|
if img.mode in ('RGBA', 'P'): |
|
img = img.convert('RGB') |
|
|
|
|
|
width, height = img.size |
|
aspect_ratio = width / height |
|
|
|
|
|
if aspect_ratio > 1: |
|
new_width = image_size |
|
new_height = int(new_width / aspect_ratio) |
|
else: |
|
new_height = image_size |
|
new_width = int(new_height * aspect_ratio) |
|
|
|
|
|
img = img.resize((new_width, new_height), Image.Resampling.LANCZOS) |
|
|
|
|
|
new_image = Image.new('RGB', (image_size, image_size), (0, 0, 0)) |
|
paste_x = (image_size - new_width) // 2 |
|
paste_y = (image_size - new_height) // 2 |
|
new_image.paste(img, (paste_x, paste_y)) |
|
|
|
|
|
img_tensor = transform(new_image) |
|
return img_tensor |
|
except Exception as e: |
|
raise Exception(f"Error processing {image_path}: {str(e)}") |
|
|
|
def forward(self, x): |
|
"""Forward pass with only the initial predictions""" |
|
|
|
features = self.backbone.features(x) |
|
features = self.spatial_pool(features).squeeze(-1).squeeze(-1) |
|
|
|
|
|
initial_logits = self.initial_classifier(features) |
|
initial_preds = torch.clamp(initial_logits / self.temperature, min=-15.0, max=15.0) |
|
|
|
|
|
return initial_preds, initial_preds |
|
|
|
def predict(self, image_path, threshold=0.325, category_thresholds=None): |
|
""" |
|
Run inference on an image with support for category-specific thresholds. |
|
""" |
|
|
|
img_tensor = self.preprocess_image(image_path).unsqueeze(0) |
|
|
|
|
|
device = next(self.parameters()).device |
|
dtype = next(self.parameters()).dtype |
|
img_tensor = img_tensor.to(device, dtype=dtype) |
|
|
|
|
|
with torch.no_grad(): |
|
initial_preds, _ = self.forward(img_tensor) |
|
|
|
|
|
initial_probs = torch.sigmoid(initial_preds) |
|
|
|
|
|
if category_thresholds: |
|
|
|
initial_binary = torch.zeros_like(initial_probs) |
|
|
|
|
|
for category, cat_threshold in category_thresholds.items(): |
|
|
|
category_mask = torch.zeros_like(initial_probs, dtype=torch.bool) |
|
|
|
|
|
for tag_idx in range(initial_probs.size(-1)): |
|
try: |
|
_, tag_category = self.dataset.get_tag_info(tag_idx) |
|
if tag_category == category: |
|
category_mask[:, tag_idx] = True |
|
except: |
|
continue |
|
|
|
|
|
cat_threshold_tensor = torch.tensor(cat_threshold, device=device, dtype=dtype) |
|
initial_binary[category_mask] = (initial_probs[category_mask] >= cat_threshold_tensor).to(dtype) |
|
|
|
predictions = initial_binary |
|
else: |
|
|
|
threshold_tensor = torch.tensor(threshold, device=device, dtype=dtype) |
|
predictions = (initial_probs >= threshold_tensor).to(dtype) |
|
|
|
|
|
return { |
|
'initial_probabilities': initial_probs, |
|
'refined_probabilities': initial_probs, |
|
'predictions': predictions |
|
} |
|
|
|
def get_tags_from_predictions(self, predictions, include_probabilities=True): |
|
""" |
|
Convert model predictions to human-readable tags grouped by category. |
|
""" |
|
|
|
if predictions.dim() > 1: |
|
predictions = predictions[0] |
|
|
|
|
|
indices = torch.where(predictions > 0)[0].cpu().tolist() |
|
|
|
|
|
result = {} |
|
for idx in indices: |
|
tag_name, category = self.dataset.get_tag_info(idx) |
|
|
|
if category not in result: |
|
result[category] = [] |
|
|
|
if include_probabilities: |
|
prob = predictions[idx].item() |
|
result[category].append((tag_name, prob)) |
|
else: |
|
result[category].append(tag_name) |
|
|
|
|
|
if include_probabilities: |
|
for category in result: |
|
result[category] = sorted(result[category], key=lambda x: x[1], reverse=True) |
|
|
|
return result |
|
|
|
class FlashAttention(nn.Module): |
|
def __init__(self, dim, num_heads=8, dropout=0.1, batch_first=True): |
|
super().__init__() |
|
self.dim = dim |
|
self.num_heads = num_heads |
|
self.dropout = dropout |
|
self.batch_first = batch_first |
|
self.head_dim = dim // num_heads |
|
assert self.head_dim * num_heads == dim, "dim must be divisible by num_heads" |
|
|
|
self.q_proj = nn.Linear(dim, dim, bias=False) |
|
self.k_proj = nn.Linear(dim, dim, bias=False) |
|
self.v_proj = nn.Linear(dim, dim, bias=False) |
|
self.out_proj = nn.Linear(dim, dim, bias=False) |
|
|
|
for proj in [self.q_proj, self.k_proj, self.v_proj, self.out_proj]: |
|
nn.init.xavier_uniform_(proj.weight, gain=0.1) |
|
|
|
self.scale = self.head_dim ** -0.5 |
|
self.debug = False |
|
|
|
def _debug_print(self, name, tensor): |
|
"""Debug helper""" |
|
if self.debug: |
|
print(f"\n{name}:") |
|
print(f"Shape: {tensor.shape}") |
|
print(f"Device: {tensor.device}") |
|
print(f"Dtype: {tensor.dtype}") |
|
if tensor.is_floating_point(): |
|
with torch.no_grad(): |
|
print(f"Range: [{tensor.min().item():.3f}, {tensor.max().item():.3f}]") |
|
print(f"Mean: {tensor.mean().item():.3f}") |
|
print(f"Std: {tensor.std().item():.3f}") |
|
|
|
def _reshape_for_flash(self, x: torch.Tensor) -> torch.Tensor: |
|
"""Reshape input tensor for flash attention format""" |
|
batch_size, seq_len, _ = x.size() |
|
x = x.view(batch_size, seq_len, self.num_heads, self.head_dim) |
|
x = x.transpose(1, 2) |
|
return x.contiguous() |
|
|
|
def forward(self, query: torch.Tensor, key: Optional[torch.Tensor] = None, |
|
value: Optional[torch.Tensor] = None, |
|
mask: Optional[torch.Tensor] = None) -> torch.Tensor: |
|
"""Forward pass with flash attention""" |
|
if self.debug: |
|
print("\nFlashAttention Forward Pass") |
|
|
|
batch_size = query.size(0) |
|
|
|
|
|
key = query if key is None else key |
|
value = query if value is None else value |
|
|
|
|
|
q = self.q_proj(query) |
|
k = self.k_proj(key) |
|
v = self.v_proj(value) |
|
|
|
if self.debug: |
|
self._debug_print("Query before reshape", q) |
|
|
|
|
|
q = self._reshape_for_flash(q) |
|
k = self._reshape_for_flash(k) |
|
v = self._reshape_for_flash(v) |
|
|
|
if self.debug: |
|
self._debug_print("Query after reshape", q) |
|
|
|
|
|
if mask is not None: |
|
|
|
if mask.dim() == 2: |
|
mask = mask.view(batch_size, 1, -1, 1) |
|
elif mask.dim() == 3: |
|
mask = mask.view(batch_size, 1, mask.size(1), mask.size(2)) |
|
elif mask.dim() == 5: |
|
mask = mask.squeeze(1).view(batch_size, 1, mask.size(2), mask.size(3)) |
|
|
|
|
|
mask = mask.to(q.dtype) |
|
|
|
if self.debug: |
|
self._debug_print("Prepared mask", mask) |
|
print(f"q shape: {q.shape}, mask shape: {mask.shape}") |
|
|
|
|
|
seq_len = q.size(2) |
|
if mask.size(-1) != seq_len: |
|
|
|
new_mask = torch.zeros(batch_size, 1, seq_len, seq_len, |
|
device=mask.device, dtype=mask.dtype) |
|
min_len = min(seq_len, mask.size(-1)) |
|
new_mask[..., :min_len, :min_len] = mask[..., :min_len, :min_len] |
|
mask = new_mask |
|
|
|
|
|
key_padding_mask = mask.squeeze(1).sum(-1) > 0 |
|
key_padding_mask = key_padding_mask.view(batch_size, 1, -1, 1) |
|
|
|
|
|
k = k * key_padding_mask |
|
v = v * key_padding_mask |
|
|
|
if self.debug: |
|
self._debug_print("Query before attention", q) |
|
self._debug_print("Key before attention", k) |
|
self._debug_print("Value before attention", v) |
|
|
|
|
|
dropout_p = self.dropout if self.training else 0.0 |
|
output = flash_attn_func( |
|
q, k, v, |
|
dropout_p=dropout_p, |
|
softmax_scale=self.scale, |
|
causal=False |
|
) |
|
|
|
if self.debug: |
|
self._debug_print("Output after attention", output) |
|
|
|
|
|
output = output.transpose(1, 2).contiguous() |
|
output = output.view(batch_size, -1, self.dim) |
|
|
|
|
|
output = self.out_proj(output) |
|
|
|
if self.debug: |
|
self._debug_print("Final output", output) |
|
|
|
return output |
|
|
|
class OptimizedTagEmbedding(nn.Module): |
|
def __init__(self, num_tags, embedding_dim, num_heads=8, dropout=0.1): |
|
super().__init__() |
|
|
|
self.embedding = nn.Embedding(num_tags, embedding_dim) |
|
self.attention = FlashAttention(embedding_dim, num_heads, dropout) |
|
self.norm1 = nn.LayerNorm(embedding_dim) |
|
self.norm2 = nn.LayerNorm(embedding_dim) |
|
|
|
|
|
self.tag_importance = nn.Parameter(torch.ones(num_tags) * 0.1) |
|
|
|
|
|
self.context_proj = nn.Sequential( |
|
nn.Linear(embedding_dim, embedding_dim * 2), |
|
nn.LayerNorm(embedding_dim * 2), |
|
nn.GELU(), |
|
nn.Dropout(dropout), |
|
nn.Linear(embedding_dim * 2, embedding_dim), |
|
nn.LayerNorm(embedding_dim) |
|
) |
|
|
|
self.importance_scale = nn.Parameter(torch.tensor(0.1)) |
|
self.context_scale = nn.Parameter(torch.tensor(1.0)) |
|
self.debug = False |
|
|
|
def _debug_print(self, name, tensor, extra_info=None): |
|
"""Memory efficient debug printing with type handling""" |
|
if self.debug: |
|
print(f"\n{name}:") |
|
print(f"- Shape: {tensor.shape}") |
|
if isinstance(tensor, torch.Tensor): |
|
with torch.no_grad(): |
|
print(f"- Device: {tensor.device}") |
|
print(f"- Dtype: {tensor.dtype}") |
|
|
|
|
|
if tensor.dtype not in [torch.float16, torch.float32, torch.float64]: |
|
calc_tensor = tensor.float() |
|
else: |
|
calc_tensor = tensor |
|
|
|
try: |
|
min_val = calc_tensor.min().item() |
|
max_val = calc_tensor.max().item() |
|
mean_val = calc_tensor.mean().item() |
|
std_val = calc_tensor.std().item() |
|
norm_val = torch.norm(calc_tensor).item() |
|
|
|
print(f"- Value range: [{min_val:.3f}, {max_val:.3f}]") |
|
print(f"- Mean: {mean_val:.3f}") |
|
print(f"- Std: {std_val:.3f}") |
|
print(f"- L2 Norm: {norm_val:.3f}") |
|
|
|
if extra_info: |
|
print(f"- Additional info: {extra_info}") |
|
except Exception as e: |
|
print(f"- Could not compute statistics: {str(e)}") |
|
|
|
def _debug_tensor(self, name, tensor): |
|
"""Debug helper with dtype-specific analysis""" |
|
if self.debug and isinstance(tensor, torch.Tensor): |
|
print(f"\n{name}:") |
|
print(f"- Shape: {tensor.shape}") |
|
print(f"- Device: {tensor.device}") |
|
print(f"- Dtype: {tensor.dtype}") |
|
with torch.no_grad(): |
|
has_nan = torch.isnan(tensor).any().item() if tensor.is_floating_point() else False |
|
has_inf = torch.isinf(tensor).any().item() if tensor.is_floating_point() else False |
|
print(f"- Contains NaN: {has_nan}") |
|
print(f"- Contains Inf: {has_inf}") |
|
|
|
|
|
if tensor.is_floating_point(): |
|
print(f"- Range: [{tensor.min().item():.3f}, {tensor.max().item():.3f}]") |
|
print(f"- Mean: {tensor.mean().item():.3f}") |
|
print(f"- Std: {tensor.std().item():.3f}") |
|
else: |
|
|
|
print(f"- Range: [{tensor.min().item()}, {tensor.max().item()}]") |
|
print(f"- Unique values: {tensor.unique().numel()}") |
|
|
|
def _process_category(self, indices, masks): |
|
"""Process a single category of tags""" |
|
|
|
embeddings = self.embedding(indices) |
|
|
|
if self.debug: |
|
self._debug_tensor("Category embeddings", embeddings) |
|
|
|
|
|
importance = torch.sigmoid(self.tag_importance) * self.importance_scale |
|
importance = torch.clamp(importance, min=0.01, max=10.0) |
|
importance_weights = importance[indices].unsqueeze(-1) |
|
|
|
|
|
embeddings = embeddings * importance_weights |
|
embeddings = self.norm1(embeddings) |
|
|
|
|
|
if embeddings.size(1) > 1: |
|
if masks is not None: |
|
attention_mask = torch.einsum('bi,bj->bij', masks, masks) |
|
attended = self.attention(embeddings, mask=attention_mask) |
|
else: |
|
attended = self.attention(embeddings) |
|
embeddings = self.norm2(attended) |
|
|
|
|
|
if masks is not None: |
|
masked_embeddings = embeddings * masks.unsqueeze(-1) |
|
pooled = masked_embeddings.sum(dim=1) / masks.sum(dim=1, keepdim=True).clamp(min=1.0) |
|
else: |
|
pooled = embeddings.mean(dim=1) |
|
|
|
return pooled, embeddings |
|
|
|
def forward(self, tag_indices_dict, tag_masks_dict=None): |
|
""" |
|
Process all tags in a unified embedding space |
|
Args: |
|
tag_indices_dict: dict of {category: tensor of indices} |
|
tag_masks_dict: dict of {category: tensor of masks} |
|
""" |
|
if self.debug: |
|
print("\nOptimizedTagEmbedding Forward Pass") |
|
|
|
|
|
all_indices = [] |
|
all_masks = [] |
|
batch_size = None |
|
|
|
for category, indices in tag_indices_dict.items(): |
|
if batch_size is None: |
|
batch_size = indices.size(0) |
|
all_indices.append(indices) |
|
if tag_masks_dict: |
|
all_masks.append(tag_masks_dict[category]) |
|
|
|
|
|
combined_indices = torch.cat(all_indices, dim=1) |
|
if tag_masks_dict: |
|
combined_masks = torch.cat(all_masks, dim=1) |
|
|
|
if self.debug: |
|
self._debug_tensor("Combined indices", combined_indices) |
|
if tag_masks_dict: |
|
self._debug_tensor("Combined masks", combined_masks) |
|
|
|
|
|
embeddings = self.embedding(combined_indices) |
|
|
|
if self.debug: |
|
self._debug_tensor("Base embeddings", embeddings) |
|
|
|
|
|
importance = torch.sigmoid(self.tag_importance) * self.importance_scale |
|
importance = torch.clamp(importance, min=0.01, max=10.0) |
|
importance_weights = importance[combined_indices].unsqueeze(-1) |
|
|
|
|
|
embeddings = embeddings * importance_weights |
|
embeddings = self.norm1(embeddings) |
|
|
|
if self.debug: |
|
self._debug_tensor("Weighted embeddings", embeddings) |
|
|
|
|
|
if tag_masks_dict: |
|
attention_mask = torch.einsum('bi,bj->bij', combined_masks, combined_masks) |
|
attended = self.attention(embeddings, mask=attention_mask) |
|
else: |
|
attended = self.attention(embeddings) |
|
|
|
attended = self.norm2(attended) |
|
|
|
if self.debug: |
|
self._debug_tensor("Attended embeddings", attended) |
|
|
|
|
|
if tag_masks_dict: |
|
masked_embeddings = attended * combined_masks.unsqueeze(-1) |
|
tag_context = masked_embeddings.sum(dim=1) / combined_masks.sum(dim=1, keepdim=True).clamp(min=1.0) |
|
else: |
|
tag_context = attended.mean(dim=1) |
|
|
|
|
|
tag_context = self.context_proj(tag_context) |
|
context_scale = torch.clamp(self.context_scale, min=0.1, max=10.0) |
|
tag_context = tag_context * context_scale |
|
|
|
if self.debug: |
|
self._debug_tensor("Final tag context", tag_context) |
|
|
|
return tag_context, attended |
|
|
|
class TagDataset: |
|
"""Lightweight dataset wrapper for inference only""" |
|
def __init__(self, total_tags, idx_to_tag, tag_to_category): |
|
self.total_tags = total_tags |
|
self.idx_to_tag = idx_to_tag if isinstance(idx_to_tag, dict) else {int(k): v for k, v in idx_to_tag.items()} |
|
self.tag_to_category = tag_to_category |
|
|
|
def get_tag_info(self, idx): |
|
"""Get tag name and category for a given index""" |
|
tag_name = self.idx_to_tag.get(idx, f"unknown-{idx}") |
|
category = self.tag_to_category.get(tag_name, "general") |
|
return tag_name, category |
|
|
|
class ImageTagger(nn.Module): |
|
def __init__(self, total_tags, dataset, model_name='efficientnet_v2_l', |
|
num_heads=16, dropout=0.1, pretrained=True, |
|
tag_context_size=256): |
|
super().__init__() |
|
|
|
self._flags = { |
|
'debug': False, |
|
'model_stats': False |
|
} |
|
|
|
|
|
self.dataset = dataset |
|
self.tag_context_size = tag_context_size |
|
self.embedding_dim = 1280 |
|
|
|
|
|
if model_name == 'efficientnet_v2_l': |
|
weights = EfficientNet_V2_L_Weights.DEFAULT if pretrained else None |
|
self.backbone = efficientnet_v2_l(weights=weights) |
|
self.backbone.classifier = nn.Identity() |
|
|
|
|
|
self.spatial_pool = nn.AdaptiveAvgPool2d((1, 1)) |
|
|
|
|
|
self.initial_classifier = nn.Sequential( |
|
nn.Linear(self.embedding_dim, self.embedding_dim * 2), |
|
nn.LayerNorm(self.embedding_dim * 2), |
|
nn.GELU(), |
|
nn.Dropout(dropout), |
|
nn.Linear(self.embedding_dim * 2, self.embedding_dim), |
|
nn.LayerNorm(self.embedding_dim), |
|
nn.GELU(), |
|
nn.Linear(self.embedding_dim, total_tags) |
|
) |
|
|
|
|
|
self.tag_embedding = nn.Embedding(total_tags, self.embedding_dim) |
|
self.tag_attention = FlashAttention(self.embedding_dim, num_heads, dropout) |
|
self.tag_norm = nn.LayerNorm(self.embedding_dim) |
|
|
|
|
|
self.cross_proj = nn.Sequential( |
|
nn.Linear(self.embedding_dim, self.embedding_dim * 2), |
|
nn.LayerNorm(self.embedding_dim * 2), |
|
nn.GELU(), |
|
nn.Dropout(dropout), |
|
nn.Linear(self.embedding_dim * 2, self.embedding_dim) |
|
) |
|
|
|
|
|
self.cross_attention = FlashAttention(self.embedding_dim, num_heads, dropout) |
|
self.cross_norm = nn.LayerNorm(self.embedding_dim) |
|
|
|
|
|
self.refined_classifier = nn.Sequential( |
|
nn.Linear(self.embedding_dim * 2, self.embedding_dim * 2), |
|
nn.LayerNorm(self.embedding_dim * 2), |
|
nn.GELU(), |
|
nn.Dropout(dropout), |
|
nn.Linear(self.embedding_dim * 2, self.embedding_dim), |
|
nn.LayerNorm(self.embedding_dim), |
|
nn.GELU(), |
|
nn.Linear(self.embedding_dim, total_tags) |
|
) |
|
|
|
|
|
self.temperature = nn.Parameter(torch.ones(1) * 1.5) |
|
|
|
def _get_selected_tags(self, logits): |
|
"""Select top-K tags based on prediction confidence""" |
|
|
|
probs = torch.sigmoid(logits) |
|
|
|
|
|
batch_size = logits.size(0) |
|
topk_values, topk_indices = torch.topk( |
|
probs, k=self.tag_context_size, dim=1, largest=True, sorted=True |
|
) |
|
|
|
return topk_indices, topk_values |
|
|
|
@property |
|
def debug(self): |
|
return self._flags['debug'] |
|
|
|
@debug.setter |
|
def debug(self, value): |
|
self._flags['debug'] = value |
|
|
|
@property |
|
def model_stats(self): |
|
return self._flags['model_stats'] |
|
|
|
@model_stats.setter |
|
def model_stats(self, value): |
|
self._flags['model_stats'] = value |
|
|
|
def preprocess_image(self, image_path, image_size=512): |
|
"""Process an image for inference using same preprocessing as training""" |
|
if not os.path.exists(image_path): |
|
raise ValueError(f"Image not found at path: {image_path}") |
|
|
|
|
|
transform = transforms.Compose([ |
|
transforms.ToTensor(), |
|
]) |
|
|
|
try: |
|
with Image.open(image_path) as img: |
|
|
|
if img.mode in ('RGBA', 'P'): |
|
img = img.convert('RGB') |
|
|
|
|
|
width, height = img.size |
|
aspect_ratio = width / height |
|
|
|
|
|
if aspect_ratio > 1: |
|
new_width = image_size |
|
new_height = int(new_width / aspect_ratio) |
|
else: |
|
new_height = image_size |
|
new_width = int(new_height * aspect_ratio) |
|
|
|
|
|
img = img.resize((new_width, new_height), Image.Resampling.LANCZOS) |
|
|
|
|
|
new_image = Image.new('RGB', (image_size, image_size), (0, 0, 0)) |
|
paste_x = (image_size - new_width) // 2 |
|
paste_y = (image_size - new_height) // 2 |
|
new_image.paste(img, (paste_x, paste_y)) |
|
|
|
|
|
img_tensor = transform(new_image) |
|
return img_tensor |
|
except Exception as e: |
|
raise Exception(f"Error processing {image_path}: {str(e)}") |
|
|
|
def forward(self, x): |
|
"""Forward pass with simplified feature handling""" |
|
|
|
model_stats = {} if self.model_stats else {} |
|
debug_tensors = {} if self.debug else None |
|
|
|
|
|
features = self.backbone.features(x) |
|
features = self.spatial_pool(features).squeeze(-1).squeeze(-1) |
|
|
|
|
|
initial_logits = self.initial_classifier(features) |
|
initial_preds = torch.clamp(initial_logits / self.temperature, min=-15.0, max=15.0) |
|
|
|
|
|
pred_tag_indices, _ = self._get_selected_tags(initial_preds) |
|
tag_embeddings = self.tag_embedding(pred_tag_indices) |
|
|
|
|
|
attended_tags = self.tag_attention(tag_embeddings) |
|
attended_tags = self.tag_norm(attended_tags) |
|
|
|
|
|
features_proj = self.cross_proj(features) |
|
features_expanded = features_proj.unsqueeze(1).expand(-1, self.tag_context_size, -1) |
|
|
|
cross_attended = self.cross_attention(features_expanded, attended_tags) |
|
cross_attended = self.cross_norm(cross_attended) |
|
|
|
|
|
fused_features = cross_attended.mean(dim=1) |
|
|
|
combined_features = torch.cat([features, fused_features], dim=-1) |
|
|
|
|
|
refined_logits = self.refined_classifier(combined_features) |
|
refined_preds = torch.clamp(refined_logits / self.temperature, min=-15.0, max=15.0) |
|
|
|
|
|
return initial_preds, refined_preds |
|
|
|
def predict(self, image_path, threshold=0.325, category_thresholds=None): |
|
""" |
|
Run inference on an image with support for category-specific thresholds. |
|
""" |
|
|
|
img_tensor = self.preprocess_image(image_path).unsqueeze(0) |
|
|
|
|
|
device = next(self.parameters()).device |
|
dtype = next(self.parameters()).dtype |
|
img_tensor = img_tensor.to(device, dtype=dtype) |
|
|
|
|
|
with torch.no_grad(): |
|
initial_preds, refined_preds = self.forward(img_tensor) |
|
|
|
|
|
initial_probs = torch.sigmoid(initial_preds) |
|
refined_probs = torch.sigmoid(refined_preds) |
|
|
|
|
|
if category_thresholds: |
|
|
|
refined_binary = torch.zeros_like(refined_probs) |
|
|
|
|
|
for category, cat_threshold in category_thresholds.items(): |
|
|
|
category_mask = torch.zeros_like(refined_probs, dtype=torch.bool) |
|
|
|
|
|
for tag_idx in range(refined_probs.size(-1)): |
|
try: |
|
_, tag_category = self.dataset.get_tag_info(tag_idx) |
|
if tag_category == category: |
|
category_mask[:, tag_idx] = True |
|
except: |
|
continue |
|
|
|
|
|
cat_threshold_tensor = torch.tensor(cat_threshold, device=device, dtype=dtype) |
|
refined_binary[category_mask] = (refined_probs[category_mask] >= cat_threshold_tensor).to(dtype) |
|
|
|
predictions = refined_binary |
|
else: |
|
|
|
threshold_tensor = torch.tensor(threshold, device=device, dtype=dtype) |
|
predictions = (refined_probs >= threshold_tensor).to(dtype) |
|
|
|
|
|
return { |
|
'initial_probabilities': initial_probs, |
|
'refined_probabilities': refined_probs, |
|
'predictions': predictions |
|
} |
|
|
|
def get_tags_from_predictions(self, predictions, include_probabilities=True): |
|
""" |
|
Convert model predictions to human-readable tags grouped by category. |
|
""" |
|
|
|
if predictions.dim() > 1: |
|
predictions = predictions[0] |
|
|
|
|
|
indices = torch.where(predictions > 0)[0].cpu().tolist() |
|
|
|
|
|
result = {} |
|
for idx in indices: |
|
tag_name, category = self.dataset.get_tag_info(idx) |
|
|
|
if category not in result: |
|
result[category] = [] |
|
|
|
if include_probabilities: |
|
prob = predictions[idx].item() |
|
result[category].append((tag_name, prob)) |
|
else: |
|
result[category].append(tag_name) |
|
|
|
|
|
if include_probabilities: |
|
for category in result: |
|
result[category] = sorted(result[category], key=lambda x: x[1], reverse=True) |
|
|
|
return result |
|
|
|
def load_model(model_dir, device='cuda'): |
|
"""Load model with better error handling and warnings""" |
|
print(f"Loading model from {model_dir}") |
|
|
|
try: |
|
|
|
metadata_path = os.path.join(model_dir, "metadata.json") |
|
if not os.path.exists(metadata_path): |
|
raise FileNotFoundError(f"Metadata file not found at {metadata_path}") |
|
|
|
with open(metadata_path, 'r') as f: |
|
metadata = json.load(f) |
|
|
|
|
|
model_info_path = os.path.join(model_dir, "model_info.json") |
|
if os.path.exists(model_info_path): |
|
with open(model_info_path, 'r') as f: |
|
model_info = json.load(f) |
|
else: |
|
print("WARNING: Model info file not found, using default settings") |
|
model_info = { |
|
"tag_context_size": 256, |
|
"num_heads": 16, |
|
"precision": "float16" |
|
} |
|
|
|
|
|
dataset = TagDataset( |
|
total_tags=metadata['total_tags'], |
|
idx_to_tag=metadata['idx_to_tag'], |
|
tag_to_category=metadata['tag_to_category'] |
|
) |
|
|
|
|
|
model = ImageTagger( |
|
total_tags=metadata['total_tags'], |
|
dataset=dataset, |
|
num_heads=model_info.get('num_heads', 16), |
|
tag_context_size=model_info.get('tag_context_size', 256), |
|
pretrained=False |
|
) |
|
|
|
|
|
state_dict_path = os.path.join(model_dir, "model.pt") |
|
if not os.path.exists(state_dict_path): |
|
raise FileNotFoundError(f"Model state dict not found at {state_dict_path}") |
|
|
|
state_dict = torch.load(state_dict_path, map_location=device) |
|
|
|
|
|
try: |
|
model.load_state_dict(state_dict, strict=True) |
|
print("✓ Model state dict loaded with strict=True successfully") |
|
except Exception as e: |
|
print(f"! Strict loading failed: {str(e)}") |
|
print("Attempting non-strict loading...") |
|
|
|
|
|
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False) |
|
|
|
print(f"Non-strict loading completed with:") |
|
print(f"- {len(missing_keys)} missing keys") |
|
print(f"- {len(unexpected_keys)} unexpected keys") |
|
|
|
if len(missing_keys) > 0: |
|
print(f"Sample missing keys: {missing_keys[:5]}") |
|
if len(unexpected_keys) > 0: |
|
print(f"Sample unexpected keys: {unexpected_keys[:5]}") |
|
|
|
|
|
model = model.to(device) |
|
|
|
|
|
if model_info.get('precision') == 'float16': |
|
model = model.half() |
|
print("✓ Model converted to half precision") |
|
|
|
|
|
model.eval() |
|
print("✓ Model set to evaluation mode") |
|
|
|
|
|
param_dtype = next(model.parameters()).dtype |
|
print(f"✓ Model loaded with precision: {param_dtype}") |
|
|
|
return model, dataset |
|
|
|
except Exception as e: |
|
print(f"ERROR loading model: {str(e)}") |
|
import traceback |
|
traceback.print_exc() |
|
raise |
|
|
|
|
|
if __name__ == "__main__": |
|
import sys |
|
|
|
|
|
model_dir = sys.argv[1] if len(sys.argv) > 1 else "./exported_model" |
|
|
|
|
|
model, dataset, thresholds = load_model(model_dir) |
|
|
|
|
|
print(f"\nModel information:") |
|
print(f" Total tags: {dataset.total_tags}") |
|
print(f" Device: {next(model.parameters()).device}") |
|
print(f" Precision: {next(model.parameters()).dtype}") |
|
|
|
|
|
if len(sys.argv) > 2: |
|
image_path = sys.argv[2] |
|
print(f"\nRunning inference on {image_path}") |
|
|
|
|
|
if thresholds and 'categories' in thresholds: |
|
category_thresholds = {cat: opt['balanced']['threshold'] |
|
for cat, opt in thresholds['categories'].items()} |
|
results = model.predict(image_path, category_thresholds=category_thresholds) |
|
else: |
|
results = model.predict(image_path) |
|
|
|
|
|
tags = model.get_tags_from_predictions(results['predictions']) |
|
|
|
|
|
print("\nPredicted tags:") |
|
for category, category_tags in tags.items(): |
|
print(f"\n{category.capitalize()}:") |
|
for tag, prob in category_tags: |
|
print(f" {tag}: {prob:.3f}") |