|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from torchvision.models import efficientnet_v2_l, EfficientNet_V2_L_Weights
|
|
|
|
class MultiheadAttentionNoFlash(nn.Module):
|
|
"""Custom multi-head attention module (replaces FlashAttention) using ONNX-friendly ops."""
|
|
def __init__(self, dim, num_heads=8, dropout=0.0):
|
|
super().__init__()
|
|
assert dim % num_heads == 0, "Embedding dim must be divisible by num_heads"
|
|
self.dim = dim
|
|
self.num_heads = num_heads
|
|
self.head_dim = dim // num_heads
|
|
self.scale = self.head_dim ** -0.5
|
|
|
|
|
|
self.q_proj = nn.Linear(dim, dim, bias=False)
|
|
self.k_proj = nn.Linear(dim, dim, bias=False)
|
|
self.v_proj = nn.Linear(dim, dim, bias=False)
|
|
self.out_proj = nn.Linear(dim, dim, bias=False)
|
|
|
|
|
|
def forward(self, query, key=None, value=None):
|
|
|
|
if key is None:
|
|
key = query
|
|
if value is None:
|
|
value = key
|
|
|
|
|
|
Q = self.q_proj(query)
|
|
K = self.k_proj(key)
|
|
V = self.v_proj(value)
|
|
|
|
|
|
B, S_q, _ = Q.shape
|
|
_, S_k, _ = K.shape
|
|
Q = Q.view(B, S_q, self.num_heads, self.head_dim).transpose(1, 2)
|
|
K = K.view(B, S_k, self.num_heads, self.head_dim).transpose(1, 2)
|
|
V = V.view(B, S_k, self.num_heads, self.head_dim).transpose(1, 2)
|
|
|
|
|
|
attn_weights = torch.matmul(Q, K.transpose(2, 3))
|
|
attn_weights = attn_weights * self.scale
|
|
attn_probs = F.softmax(attn_weights, dim=-1)
|
|
|
|
|
|
attn_output = torch.matmul(attn_probs, V)
|
|
|
|
|
|
attn_output = attn_output.transpose(1, 2).contiguous().view(B, S_q, self.dim)
|
|
|
|
output = self.out_proj(attn_output)
|
|
return output
|
|
|
|
class ImageTaggerRefinedONNX(nn.Module):
|
|
"""
|
|
Refined CAMIE Image Tagger model without FlashAttention.
|
|
- EfficientNetV2 backbone
|
|
- Initial classifier for preliminary tag logits
|
|
- Multi-head self-attention on top predicted tag embeddings
|
|
- Multi-head cross-attention between image feature and tag embeddings
|
|
- Refined classifier for final tag logits
|
|
"""
|
|
def __init__(self, total_tags, tag_context_size=256, num_heads=16, dropout=0.1):
|
|
super().__init__()
|
|
self.tag_context_size = tag_context_size
|
|
self.embedding_dim = 1280
|
|
|
|
|
|
backbone = efficientnet_v2_l(weights=EfficientNet_V2_L_Weights.DEFAULT)
|
|
backbone.classifier = nn.Identity()
|
|
self.backbone = backbone
|
|
|
|
|
|
self.spatial_pool = nn.AdaptiveAvgPool2d((1, 1))
|
|
|
|
|
|
self.initial_classifier = nn.Sequential(
|
|
nn.Linear(self.embedding_dim, self.embedding_dim * 2),
|
|
nn.LayerNorm(self.embedding_dim * 2),
|
|
nn.GELU(),
|
|
nn.Dropout(dropout),
|
|
nn.Linear(self.embedding_dim * 2, self.embedding_dim),
|
|
nn.LayerNorm(self.embedding_dim),
|
|
nn.GELU(),
|
|
nn.Linear(self.embedding_dim, total_tags)
|
|
)
|
|
|
|
|
|
self.tag_embedding = nn.Embedding(total_tags, self.embedding_dim)
|
|
|
|
|
|
self.tag_attention = MultiheadAttentionNoFlash(self.embedding_dim, num_heads=num_heads, dropout=dropout)
|
|
self.tag_norm = nn.LayerNorm(self.embedding_dim)
|
|
|
|
|
|
self.cross_proj = nn.Sequential(
|
|
nn.Linear(self.embedding_dim, self.embedding_dim * 2),
|
|
nn.LayerNorm(self.embedding_dim * 2),
|
|
nn.GELU(),
|
|
nn.Dropout(dropout),
|
|
nn.Linear(self.embedding_dim * 2, self.embedding_dim)
|
|
)
|
|
|
|
self.cross_attention = MultiheadAttentionNoFlash(self.embedding_dim, num_heads=num_heads, dropout=dropout)
|
|
self.cross_norm = nn.LayerNorm(self.embedding_dim)
|
|
|
|
|
|
self.refined_classifier = nn.Sequential(
|
|
nn.Linear(self.embedding_dim * 2, self.embedding_dim * 2),
|
|
nn.LayerNorm(self.embedding_dim * 2),
|
|
nn.GELU(),
|
|
nn.Dropout(dropout),
|
|
nn.Linear(self.embedding_dim * 2, self.embedding_dim),
|
|
nn.LayerNorm(self.embedding_dim),
|
|
nn.GELU(),
|
|
nn.Linear(self.embedding_dim, total_tags)
|
|
)
|
|
|
|
|
|
self.temperature = nn.Parameter(torch.ones(1) * 1.5)
|
|
|
|
def forward(self, images):
|
|
|
|
feats = self.backbone.features(images)
|
|
feats = self.spatial_pool(feats).squeeze(-1).squeeze(-1)
|
|
|
|
|
|
initial_logits = self.initial_classifier(feats)
|
|
|
|
initial_preds = torch.clamp(initial_logits / self.temperature, min=-15.0, max=15.0)
|
|
|
|
|
|
probs = torch.sigmoid(initial_preds)
|
|
|
|
_, topk_indices = torch.topk(probs, k=self.tag_context_size, dim=1)
|
|
|
|
tag_embeds = self.tag_embedding(topk_indices)
|
|
|
|
|
|
attn_tags = self.tag_attention(tag_embeds)
|
|
attn_tags = self.tag_norm(attn_tags)
|
|
|
|
|
|
|
|
feat_q = self.cross_proj(feats)
|
|
|
|
feat_q = feat_q.unsqueeze(1).expand(-1, self.tag_context_size, -1)
|
|
|
|
cross_attn = self.cross_attention(feat_q, attn_tags, attn_tags)
|
|
cross_attn = self.cross_norm(cross_attn)
|
|
|
|
|
|
fused_feature = cross_attn.mean(dim=1)
|
|
combined = torch.cat([feats, fused_feature], dim=1)
|
|
|
|
|
|
refined_logits = self.refined_classifier(combined)
|
|
refined_preds = torch.clamp(refined_logits / self.temperature, min=-15.0, max=15.0)
|
|
|
|
return initial_preds, refined_preds
|
|
|
|
|
|
total_tags = 70527
|
|
from safetensors.torch import load_file
|
|
safetensors_path = 'model_refined.safetensors'
|
|
state_dict = load_file(safetensors_path, device='cpu')
|
|
|
|
|
|
|
|
model = ImageTaggerRefinedONNX(total_tags=total_tags)
|
|
model.load_state_dict(state_dict)
|
|
model.eval()
|
|
|
|
|
|
|
|
|
|
|
|
dummy_input = torch.randn(1, 3, 512, 512, requires_grad=False)
|
|
output_onnx_file = "camie_refined_no_flash_v15.onnx"
|
|
torch.onnx.export(
|
|
model, dummy_input, output_onnx_file,
|
|
export_params=True,
|
|
opset_version=17,
|
|
do_constant_folding=True,
|
|
input_names=["image"],
|
|
output_names=["initial_tags", "refined_tags"],
|
|
dynamic_axes={
|
|
"image": {0: "batch"},
|
|
"initial_tags": {0: "batch"},
|
|
"refined_tags": {0: "batch"}
|
|
}
|
|
)
|
|
print(f"ONNX model exported to {output_onnx_file}")
|
|
|