File size: 9,947 Bytes
a7ab59e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models import efficientnet_v2_l, EfficientNet_V2_L_Weights

class MultiheadAttentionNoFlash(nn.Module):
    """Custom multi-head attention module (replaces FlashAttention) using ONNX-friendly ops."""
    def __init__(self, dim, num_heads=8, dropout=0.0):
        super().__init__()
        assert dim % num_heads == 0, "Embedding dim must be divisible by num_heads"
        self.dim = dim
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        self.scale = self.head_dim ** -0.5  # scaling factor for dot-product attention

        # Define separate projections for query, key, value, and output (no biases to match FlashAttention)
        self.q_proj = nn.Linear(dim, dim, bias=False)
        self.k_proj = nn.Linear(dim, dim, bias=False)
        self.v_proj = nn.Linear(dim, dim, bias=False)
        self.out_proj = nn.Linear(dim, dim, bias=False)
        # (Note: We omit dropout in attention computation for ONNX simplicity; model should be set to eval mode anyway.)

    def forward(self, query, key=None, value=None):
        # Allow usage as self-attention if key/value not provided
        if key is None: 
            key = query
        if value is None:
            value = key

        # Linear projections
        Q = self.q_proj(query)   # [B, S_q, dim]
        K = self.k_proj(key)     # [B, S_k, dim]
        V = self.v_proj(value)   # [B, S_v, dim]

        # Reshape into (B, num_heads, S, head_dim) for computing attention per head
        B, S_q, _ = Q.shape
        _, S_k, _ = K.shape
        Q = Q.view(B, S_q, self.num_heads, self.head_dim).transpose(1, 2)  # [B, heads, S_q, head_dim]
        K = K.view(B, S_k, self.num_heads, self.head_dim).transpose(1, 2)  # [B, heads, S_k, head_dim]
        V = V.view(B, S_k, self.num_heads, self.head_dim).transpose(1, 2)  # [B, heads, S_k, head_dim]

        # Scaled dot-product attention: compute attention weights
        attn_weights = torch.matmul(Q, K.transpose(2, 3))  # [B, heads, S_q, S_k]
        attn_weights = attn_weights * self.scale
        attn_probs = F.softmax(attn_weights, dim=-1)       # softmax over S_k (key length)

        # Apply attention weights to values
        attn_output = torch.matmul(attn_probs, V)  # [B, heads, S_q, head_dim]

        # Reshape back to [B, S_q, dim]
        attn_output = attn_output.transpose(1, 2).contiguous().view(B, S_q, self.dim)
        # Output projection
        output = self.out_proj(attn_output)  # [B, S_q, dim]
        return output

class ImageTaggerRefinedONNX(nn.Module):
    """

    Refined CAMIE Image Tagger model without FlashAttention.

    - EfficientNetV2 backbone

    - Initial classifier for preliminary tag logits

    - Multi-head self-attention on top predicted tag embeddings

    - Multi-head cross-attention between image feature and tag embeddings

    - Refined classifier for final tag logits

    """
    def __init__(self, total_tags, tag_context_size=256, num_heads=16, dropout=0.1):
        super().__init__()
        self.tag_context_size = tag_context_size
        self.embedding_dim = 1280  # EfficientNetV2-L feature dimension

        # Backbone feature extractor (EfficientNetV2-L)
        backbone = efficientnet_v2_l(weights=EfficientNet_V2_L_Weights.DEFAULT)
        backbone.classifier = nn.Identity()  # remove final classification head
        self.backbone = backbone

        # Spatial pooling to get a single feature vector per image (1x1 avg pool)
        self.spatial_pool = nn.AdaptiveAvgPool2d((1, 1))

        # Initial classifier (two-layer MLP) to predict tags from image feature
        self.initial_classifier = nn.Sequential(
            nn.Linear(self.embedding_dim, self.embedding_dim * 2),
            nn.LayerNorm(self.embedding_dim * 2),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(self.embedding_dim * 2, self.embedding_dim),
            nn.LayerNorm(self.embedding_dim),
            nn.GELU(),
            nn.Linear(self.embedding_dim, total_tags)  # outputs raw logits for all tags
        )

        # Embedding for tags (each tag gets an embedding vector, used for attention)
        self.tag_embedding = nn.Embedding(total_tags, self.embedding_dim)

        # Self-attention over the selected tag embeddings (replaces FlashAttention)
        self.tag_attention = MultiheadAttentionNoFlash(self.embedding_dim, num_heads=num_heads, dropout=dropout)
        self.tag_norm = nn.LayerNorm(self.embedding_dim)

        # Projection from image feature to query vector for cross-attention
        self.cross_proj = nn.Sequential(
            nn.Linear(self.embedding_dim, self.embedding_dim * 2),
            nn.LayerNorm(self.embedding_dim * 2),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(self.embedding_dim * 2, self.embedding_dim)
        )
        # Cross-attention between image feature (as query) and tag features (as key/value)
        self.cross_attention = MultiheadAttentionNoFlash(self.embedding_dim, num_heads=num_heads, dropout=dropout)
        self.cross_norm = nn.LayerNorm(self.embedding_dim)

        # Refined classifier (takes concatenated original & attended features)
        self.refined_classifier = nn.Sequential(
            nn.Linear(self.embedding_dim * 2, self.embedding_dim * 2),
            nn.LayerNorm(self.embedding_dim * 2),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(self.embedding_dim * 2, self.embedding_dim),
            nn.LayerNorm(self.embedding_dim),
            nn.GELU(),
            nn.Linear(self.embedding_dim, total_tags)
        )

        # Temperature parameter for scaling logits (to calibrate confidence)
        self.temperature = nn.Parameter(torch.ones(1) * 1.5)

    def forward(self, images):
        # 1. Feature extraction
        feats = self.backbone.features(images)               # [B, 1280, H/32, W/32] features
        feats = self.spatial_pool(feats).squeeze(-1).squeeze(-1)  # [B, 1280] global feature vector per image

        # 2. Initial tag prediction
        initial_logits = self.initial_classifier(feats)      # [B, total_tags]
        # Scale by temperature and clamp (to stabilize extreme values, as in original)
        initial_preds = torch.clamp(initial_logits / self.temperature, min=-15.0, max=15.0)

        # 3. Select top-k predicted tags for context (tag_context_size)
        probs = torch.sigmoid(initial_preds)                 # convert logits to probabilities
        # Get indices of top `tag_context_size` tags for each sample
        _, topk_indices = torch.topk(probs, k=self.tag_context_size, dim=1)
        # 4. Embed selected tags
        tag_embeds = self.tag_embedding(topk_indices)        # [B, tag_context_size, embedding_dim]

        # 5. Self-attention on tag embeddings (to refine tag representation)
        attn_tags = self.tag_attention(tag_embeds)           # [B, tag_context_size, embedding_dim]
        attn_tags = self.tag_norm(attn_tags)                 # layer norm

        # 6. Cross-attention between image feature and attended tags
        # Expand image features to have one per tag position
        feat_q = self.cross_proj(feats)                      # [B, embedding_dim]
        # Repeat each image feature vector tag_context_size times to form a sequence
        feat_q = feat_q.unsqueeze(1).expand(-1, self.tag_context_size, -1)  # [B, tag_context_size, embedding_dim]
        # Use image features as queries, tag embeddings as keys and values
        cross_attn = self.cross_attention(feat_q, attn_tags, attn_tags)  # [B, tag_context_size, embedding_dim]
        cross_attn = self.cross_norm(cross_attn)

        # 7. Fuse features: average the cross-attended tag outputs, and combine with original features
        fused_feature = cross_attn.mean(dim=1)               # [B, embedding_dim]
        combined = torch.cat([feats, fused_feature], dim=1)  # [B, embedding_dim*2]

        # 8. Refined tag prediction
        refined_logits = self.refined_classifier(combined)   # [B, total_tags]
        refined_preds = torch.clamp(refined_logits / self.temperature, min=-15.0, max=15.0)

        return initial_preds, refined_preds

# --- Load the pretrained refined model weights ---
total_tags = 70527  # total number of tags in the dataset (Danbooru 2024)
from safetensors.torch import load_file
safetensors_path = 'model_refined.safetensors'
state_dict = load_file(safetensors_path, device='cpu')  # Load the saved weights (should be an OrderedDict)
#state_dict = torch.load("model_refined.pt", map_location="cpu")  # Load the saved weights (should be an OrderedDict)

# Initialize our model and load weights
model = ImageTaggerRefinedONNX(total_tags=total_tags)
model.load_state_dict(state_dict)
model.eval()  # set to evaluation mode (disable dropout)

# (Optional) Cast to float32 if weights were in half precision
# model = model.float()

# --- Export to ONNX ---
dummy_input = torch.randn(1, 3, 512, 512, requires_grad=False)  # dummy batch of 1 image (3x512x512)
output_onnx_file = "camie_refined_no_flash_v15.onnx"
torch.onnx.export(
    model, dummy_input, output_onnx_file,
    export_params=True,        # store trained parameter weights inside the model file
    opset_version=17,          # ONNX opset version (ensure support for needed ops)
    do_constant_folding=True,  # optimize constant expressions
    input_names=["image"], 
    output_names=["initial_tags", "refined_tags"],
    dynamic_axes={             # set batch dimension to be dynamic
        "image": {0: "batch"},
        "initial_tags": {0: "batch"},
        "refined_tags": {0: "batch"}
    }
)
print(f"ONNX model exported to {output_onnx_file}")