camie-tagger-onnxruntime / model_no_flash.py
AngelBottomless's picture
Upload 9 files
a7ab59e verified
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models import efficientnet_v2_l, EfficientNet_V2_L_Weights
class MultiheadAttentionNoFlash(nn.Module):
"""Custom multi-head attention module (replaces FlashAttention) using ONNX-friendly ops."""
def __init__(self, dim, num_heads=8, dropout=0.0):
super().__init__()
assert dim % num_heads == 0, "Embedding dim must be divisible by num_heads"
self.dim = dim
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.scale = self.head_dim ** -0.5 # scaling factor for dot-product attention
# Define separate projections for query, key, value, and output (no biases to match FlashAttention)
self.q_proj = nn.Linear(dim, dim, bias=False)
self.k_proj = nn.Linear(dim, dim, bias=False)
self.v_proj = nn.Linear(dim, dim, bias=False)
self.out_proj = nn.Linear(dim, dim, bias=False)
# (Note: We omit dropout in attention computation for ONNX simplicity; model should be set to eval mode anyway.)
def forward(self, query, key=None, value=None):
# Allow usage as self-attention if key/value not provided
if key is None:
key = query
if value is None:
value = key
# Linear projections
Q = self.q_proj(query) # [B, S_q, dim]
K = self.k_proj(key) # [B, S_k, dim]
V = self.v_proj(value) # [B, S_v, dim]
# Reshape into (B, num_heads, S, head_dim) for computing attention per head
B, S_q, _ = Q.shape
_, S_k, _ = K.shape
Q = Q.view(B, S_q, self.num_heads, self.head_dim).transpose(1, 2) # [B, heads, S_q, head_dim]
K = K.view(B, S_k, self.num_heads, self.head_dim).transpose(1, 2) # [B, heads, S_k, head_dim]
V = V.view(B, S_k, self.num_heads, self.head_dim).transpose(1, 2) # [B, heads, S_k, head_dim]
# Scaled dot-product attention: compute attention weights
attn_weights = torch.matmul(Q, K.transpose(2, 3)) # [B, heads, S_q, S_k]
attn_weights = attn_weights * self.scale
attn_probs = F.softmax(attn_weights, dim=-1) # softmax over S_k (key length)
# Apply attention weights to values
attn_output = torch.matmul(attn_probs, V) # [B, heads, S_q, head_dim]
# Reshape back to [B, S_q, dim]
attn_output = attn_output.transpose(1, 2).contiguous().view(B, S_q, self.dim)
# Output projection
output = self.out_proj(attn_output) # [B, S_q, dim]
return output
class ImageTaggerRefinedONNX(nn.Module):
"""
Refined CAMIE Image Tagger model without FlashAttention.
- EfficientNetV2 backbone
- Initial classifier for preliminary tag logits
- Multi-head self-attention on top predicted tag embeddings
- Multi-head cross-attention between image feature and tag embeddings
- Refined classifier for final tag logits
"""
def __init__(self, total_tags, tag_context_size=256, num_heads=16, dropout=0.1):
super().__init__()
self.tag_context_size = tag_context_size
self.embedding_dim = 1280 # EfficientNetV2-L feature dimension
# Backbone feature extractor (EfficientNetV2-L)
backbone = efficientnet_v2_l(weights=EfficientNet_V2_L_Weights.DEFAULT)
backbone.classifier = nn.Identity() # remove final classification head
self.backbone = backbone
# Spatial pooling to get a single feature vector per image (1x1 avg pool)
self.spatial_pool = nn.AdaptiveAvgPool2d((1, 1))
# Initial classifier (two-layer MLP) to predict tags from image feature
self.initial_classifier = nn.Sequential(
nn.Linear(self.embedding_dim, self.embedding_dim * 2),
nn.LayerNorm(self.embedding_dim * 2),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(self.embedding_dim * 2, self.embedding_dim),
nn.LayerNorm(self.embedding_dim),
nn.GELU(),
nn.Linear(self.embedding_dim, total_tags) # outputs raw logits for all tags
)
# Embedding for tags (each tag gets an embedding vector, used for attention)
self.tag_embedding = nn.Embedding(total_tags, self.embedding_dim)
# Self-attention over the selected tag embeddings (replaces FlashAttention)
self.tag_attention = MultiheadAttentionNoFlash(self.embedding_dim, num_heads=num_heads, dropout=dropout)
self.tag_norm = nn.LayerNorm(self.embedding_dim)
# Projection from image feature to query vector for cross-attention
self.cross_proj = nn.Sequential(
nn.Linear(self.embedding_dim, self.embedding_dim * 2),
nn.LayerNorm(self.embedding_dim * 2),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(self.embedding_dim * 2, self.embedding_dim)
)
# Cross-attention between image feature (as query) and tag features (as key/value)
self.cross_attention = MultiheadAttentionNoFlash(self.embedding_dim, num_heads=num_heads, dropout=dropout)
self.cross_norm = nn.LayerNorm(self.embedding_dim)
# Refined classifier (takes concatenated original & attended features)
self.refined_classifier = nn.Sequential(
nn.Linear(self.embedding_dim * 2, self.embedding_dim * 2),
nn.LayerNorm(self.embedding_dim * 2),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(self.embedding_dim * 2, self.embedding_dim),
nn.LayerNorm(self.embedding_dim),
nn.GELU(),
nn.Linear(self.embedding_dim, total_tags)
)
# Temperature parameter for scaling logits (to calibrate confidence)
self.temperature = nn.Parameter(torch.ones(1) * 1.5)
def forward(self, images):
# 1. Feature extraction
feats = self.backbone.features(images) # [B, 1280, H/32, W/32] features
feats = self.spatial_pool(feats).squeeze(-1).squeeze(-1) # [B, 1280] global feature vector per image
# 2. Initial tag prediction
initial_logits = self.initial_classifier(feats) # [B, total_tags]
# Scale by temperature and clamp (to stabilize extreme values, as in original)
initial_preds = torch.clamp(initial_logits / self.temperature, min=-15.0, max=15.0)
# 3. Select top-k predicted tags for context (tag_context_size)
probs = torch.sigmoid(initial_preds) # convert logits to probabilities
# Get indices of top `tag_context_size` tags for each sample
_, topk_indices = torch.topk(probs, k=self.tag_context_size, dim=1)
# 4. Embed selected tags
tag_embeds = self.tag_embedding(topk_indices) # [B, tag_context_size, embedding_dim]
# 5. Self-attention on tag embeddings (to refine tag representation)
attn_tags = self.tag_attention(tag_embeds) # [B, tag_context_size, embedding_dim]
attn_tags = self.tag_norm(attn_tags) # layer norm
# 6. Cross-attention between image feature and attended tags
# Expand image features to have one per tag position
feat_q = self.cross_proj(feats) # [B, embedding_dim]
# Repeat each image feature vector tag_context_size times to form a sequence
feat_q = feat_q.unsqueeze(1).expand(-1, self.tag_context_size, -1) # [B, tag_context_size, embedding_dim]
# Use image features as queries, tag embeddings as keys and values
cross_attn = self.cross_attention(feat_q, attn_tags, attn_tags) # [B, tag_context_size, embedding_dim]
cross_attn = self.cross_norm(cross_attn)
# 7. Fuse features: average the cross-attended tag outputs, and combine with original features
fused_feature = cross_attn.mean(dim=1) # [B, embedding_dim]
combined = torch.cat([feats, fused_feature], dim=1) # [B, embedding_dim*2]
# 8. Refined tag prediction
refined_logits = self.refined_classifier(combined) # [B, total_tags]
refined_preds = torch.clamp(refined_logits / self.temperature, min=-15.0, max=15.0)
return initial_preds, refined_preds
# --- Load the pretrained refined model weights ---
total_tags = 70527 # total number of tags in the dataset (Danbooru 2024)
from safetensors.torch import load_file
safetensors_path = 'model_refined.safetensors'
state_dict = load_file(safetensors_path, device='cpu') # Load the saved weights (should be an OrderedDict)
#state_dict = torch.load("model_refined.pt", map_location="cpu") # Load the saved weights (should be an OrderedDict)
# Initialize our model and load weights
model = ImageTaggerRefinedONNX(total_tags=total_tags)
model.load_state_dict(state_dict)
model.eval() # set to evaluation mode (disable dropout)
# (Optional) Cast to float32 if weights were in half precision
# model = model.float()
# --- Export to ONNX ---
dummy_input = torch.randn(1, 3, 512, 512, requires_grad=False) # dummy batch of 1 image (3x512x512)
output_onnx_file = "camie_refined_no_flash_v15.onnx"
torch.onnx.export(
model, dummy_input, output_onnx_file,
export_params=True, # store trained parameter weights inside the model file
opset_version=17, # ONNX opset version (ensure support for needed ops)
do_constant_folding=True, # optimize constant expressions
input_names=["image"],
output_names=["initial_tags", "refined_tags"],
dynamic_axes={ # set batch dimension to be dynamic
"image": {0: "batch"},
"initial_tags": {0: "batch"},
"refined_tags": {0: "batch"}
}
)
print(f"ONNX model exported to {output_onnx_file}")