File size: 15,462 Bytes
9ea7e0b 9e360cf 9ea7e0b 9e360cf 9ea7e0b 9e360cf 9ea7e0b 9e360cf 9ea7e0b 9e360cf 9ea7e0b 9e360cf 9ea7e0b 9e360cf 9ea7e0b 9e360cf 9ea7e0b 9e360cf 9ea7e0b 9e360cf 9ea7e0b 9e360cf 9ea7e0b 9e360cf 9ea7e0b 9e360cf 9ea7e0b 9e360cf 9ea7e0b 9e360cf 9ea7e0b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 |
---
license: gpl-3.0
datasets:
- p1atdev/danbooru-2024
metrics:
- f1
tags:
- art
- code
---
## Usage
After installation, run the application by executing `setup.bat`. This launches a web interface where you can:
- Upload your own images or select from example images
- Choose different threshold profiles
- Adjust category-specific thresholds
- View predictions organized by category
- Filter and sort tags based on confidence# Anime Image Tagger
An advanced deep learning model for automatically tagging anime/manga illustrations with relevant tags across multiple categories, achieving **61% F1 score** across 70,000+ possible tags on a test set of 20,116 samples.
## Key Highlights
- **Efficient Training**: Completed on just a single RTX 3060 GPU (12GB VRAM)
- **Fast Convergence**: Trained on 7,024,392 samples (3.52 epochs) in 1,756,098 batches
- **Comprehensive Coverage**: 70,000+ tags across 7 categories (general, character, copyright, artist, meta, rating, year)
- **Innovative Architecture**: Two-stage prediction model with cross-attention for tag context
- **User-Friendly Interface**: Easy-to-use application with customizable thresholds
*This project demonstrates that high-quality anime image tagging models can be trained on consumer hardware with the right optimization techniques.*
## Features
- **Multi-category tagging system**: Handles general tags, characters, copyright (series), artists, meta information, and content ratings
- **High performance**: 61% F1 score across 70,000+ possible tags
- **Dual-mode operation**: Full model for best quality or Initial-only mode for reduced VRAM usage
- **Windows compatibility**: Initial-only mode works on Windows without Flash Attention
- **Streamlit web interface**: User-friendly UI for uploading and analyzing images
- **Adjustable threshold profiles**: Overall, Weighted, Category-specific, High Precision, and High Recall profiles
- **Fine-grained control**: Per-category threshold adjustments for precision-recall tradeoffs
## Loss Function
The model employs a specialized `UnifiedFocalLoss` to address the extreme class imbalance inherent in multi-label tag prediction:
```python
class UnifiedFocalLoss(nn.Module):
def __init__(self, device=None, gamma=2.0, alpha=0.25, lambda_initial=0.4):
# Implementation details...
```
### Key Components
1. **Focal Loss Mechanism**:
- Down-weights well-classified examples (γ=2.0) to focus training on difficult tags
- Addresses the extreme imbalance between positive and negative examples (often 100:1 or worse)
- Uses α=0.25 to balance positive/negative examples across 70,000+ possible tags
2. **Two-stage Weighting**:
- Combines losses from both prediction stages (`initial_predictions` and `refined_predictions`)
- Uses λ=0.4 to weight the initial prediction loss, giving more importance (0.6) to refined predictions
- This encourages the model to improve predictions in the refinement stage while still maintaining strong initial predictions
3. **Per-sample Statistics**:
- Tracks separate metrics for positive and negative samples
- Provides detailed debugging information about prediction distributions
- Enables analysis of which tag categories are performing well/poorly
This loss function was essential for achieving high F1 scores across diverse tag categories despite the extreme classification challenge of 70,000+ possible tags.
## DeepSpeed Configuration
Microsoft DeepSpeed was crucial for training this model on consumer hardware. The project uses a carefully tuned configuration to maximize efficiency:
```python
def create_deepspeed_config(
config_path,
learning_rate=3e-4,
weight_decay=0.01,
num_train_samples=None,
micro_batch_size=4,
grad_accum_steps=8
):
# Implementation details...
```
### Key Optimizations
1. **Memory Efficiency**:
- **ZeRO Stage 2**: Partitions optimizer states and gradients, dramatically reducing memory requirements
- **Activation Checkpointing**: Trades computation for memory by recomputing activations during backpropagation
- **Contiguous Memory Optimization**: Reduces memory fragmentation
2. **Mixed Precision Training**:
- **FP16 Mode**: Uses half-precision (16-bit) for most calculations, with automatic loss scaling
- **Initial Scale Power**: Set to 16 for stable convergence with large batch sizes
3. **Gradient Accumulation**:
- Micro-batch size of 4 with 8 gradient accumulation steps
- Effective batch size of 32 while only requiring memory for 4 samples at once
4. **Learning Rate Schedule**:
- WarmupLR scheduler with gradual increase from 3e-6 to 3e-4
- Warmup over 1/4 of an epoch to stabilize early training
This configuration allowed the model to train efficiently with only 12GB of VRAM while maintaining numerical stability across millions of training examples with 70,000+ output dimensions.
## Dataset
The model was trained on a carefully filtered subset of the [Danbooru 2024 dataset](https://huggingface.co/datasets/p1atdev/danbooru-2024), which contains a vast collection of anime/manga illustrations with comprehensive tagging.
### Filtering Process
The dataset was filtered with the following constraints:
```python
# Minimum tags per category required for each image
min_tag_counts = {
'general': 25,
'character': 1,
'copyright': 1,
'artist': 0,
'meta': 0
}
# Minimum samples per tag required for tag to be included
min_tag_samples = {
'general': 20,
'character': 40,
'copyright': 50,
'artist': 200,
'meta': 50
}
```
This filtering process:
1. First removed low-sample tags (tags with fewer occurrences than specified in `min_tag_samples`)
2. Then removed images with insufficient tags per category (as specified in `min_tag_counts`)
### Training Data
- **Starting dataset size**: ~3,000,000 filtered images
- **Training subset**: 2,000,000 images (due to storage and time constraints)
- **Training duration**: 3.5 epochs
The model could potentially achieve even higher accuracy with more training epochs and the full dataset.
### Preprocessing
Images were preprocessed with minimal transformations:
- Tensor normalization (scaled to 0-1 range)
- Resized while maintaining original aspect ratio
- No additional augmentations were applied
## Model Architecture
The model uses a novel two-stage prediction approach that achieves superior performance compared to traditional single-stage models:
### Image Feature Extraction
- **Backbone**: EfficientNet V2-L extracts high-quality visual features from input images
- **Spatial Pooling**: Adaptive averaging converts spatial features to a compact 1280-dimensional embedding
### Initial Prediction Stage
- Direct classification from image features through a multi-layer classifier
- Bottleneck architecture with LayerNorm and GELU activations between linear layers
- Outputs initial tag probabilities across all 70,000+ possible tags
### Tag Context Mechanism
- Top predicted tags are embedded using a shared embedding space
- Self-attention layer allows tags to influence each other based on co-occurrence patterns
- Normalized tag embeddings represent a coherent "tag context" for the image
### Cross-Attention Refinement
- Image features and tag embeddings interact through cross-attention
- Each dimension of the image features attends to relevant dimensions in the tag space
- This creates a bidirectional flow of information between visual features and semantic tags
### Refined Predictions
- Fused features (original + cross-attended) feed into a final classifier
- Residual connection ensures initial predictions are preserved when beneficial
- Temperature scaling provides calibrated probability outputs
This dual-stage approach allows the model to leverage tag co-occurrence patterns and semantic relationships, improving accuracy without increasing the parameter count significantly.
## Installation
Simply run the included setup script to install all dependencies:
```
setup.bat
```
This will automatically set up all necessary packages for the application.
### Requirements
- **Python 3.11.9 specifically** (newer versions are incompatible)
- PyTorch 1.10+
- Streamlit
- PIL/Pillow
- NumPy
- Flash Attention (note: doesn't work properly on Windows)
### Running the Application
The application is located in the `app` folder and can be launched via the setup script:
1. Run `setup.bat` to install dependencies
2. The Streamlit interface will automatically open in your browser
3. If the browser doesn't open automatically, navigate to http://localhost:8501


## Model Details
### Tag Categories
The model recognizes tags across these categories:
- **General**: Visual elements, concepts, clothing, etc.
- **Character**: Individual characters appearing in the image
- **Copyright**: Source material (anime, manga, game)
- **Artist**: Creator of the artwork
- **Meta**: Meta information about the image
- **Rating**: Content rating
- **Year**: Year of upload
### Performance Notes
The full model with refined predictions outperforms the initial-only model, though the performance gap is surprisingly small given the same parameter count. This is an interesting architectural finding - the refined predictions layer adds significant value without substantial computational overhead.
This efficiency makes the initial-only model particularly valuable for Windows users or systems with limited VRAM, as they can still achieve near-optimal performance without requiring Flash Attention.
In benchmarks, the model achieved a 61% F1 score across all categories, which is remarkable considering the extreme multi-label classification challenge of 70,000+ possible tags. The model performs particularly well on general tags and character recognition.
### Threshold Profiles
- **Overall**: Single threshold applied to all categories
- **Weighted**: Threshold optimized for balanced performance across categories
- **Category-specific**: Different thresholds for each category
- **High Precision**: Higher thresholds for more confident predictions
- **High Recall**: Lower thresholds to capture more potential tags
## Windows Compatibility
The full model uses Flash Attention, which does not work properly on Windows. For Windows users:
- The application automatically defaults to the Initial-only model
- Performance difference is minimal (0.2% absolute F1 score reduction, from 61.6% to 61.4%)
- The Initial-only model still uses the same powerful EfficientNet backbone and initial classifier
## Web Interface Guide
The interface is divided into three main sections:
1. **Model Selection** (Sidebar)
- Choose between Full Model or Initial-only Model
- View model information and memory usage
2. **Image Upload** (Left Panel)
- Upload your own images or select from examples
- View the selected image
3. **Tagging Controls** (Right Panel)
- Select threshold profile
- Adjust thresholds for precision-recall tradeoff
- Configure display options
- View predictions organized by category
### Display Options
- **Show all tags**: Display all tags including those below threshold
- **Compact view**: Hide progress bars for cleaner display
- **Minimum confidence**: Filter out low-confidence predictions
- **Category selection**: Choose which categories to include in the summary
### Interface Screenshots


## Training Environment
The model was trained using surprisingly modest hardware:
- **GPU**: Single NVIDIA RTX 3060 (12GB VRAM)
- **RAM**: 64GB system memory
- **Platform**: Windows with WSL (Windows Subsystem for Linux)
- **Libraries**:
- Microsoft DeepSpeed for memory-efficient training
- PyTorch with CUDA acceleration
- Flash Attention for optimized attention computation
### Training Notebooks
The repository includes two main training notebooks:
1. **CAMIE Tagger.ipynb**
- Main training notebook
- Dataset loading and preprocessing
- Model initialization
- Initial training loop with DeepSpeed integration
- Tag selection optimization
- Metric tracking and visualization
2. **Camie Tagger Cont and Evals.ipynb**
- Continuation of training from checkpoints
- Comprehensive model evaluation
- Per-category performance metrics
- Threshold optimization
- Model conversion for deployment in the app
- Export functionality for the standalone application
### Training Monitor
The project includes a real-time training monitor accessible via browser at `localhost:5000` during training:

#### Performance Tips
⚠️ **Important**: For optimal training speed, keep VSCode minimized and the training monitor open in your browser. This can improve iteration speed by **3-5x** due to how the Windows/WSL graphics stack handles window focus and CUDA kernel execution.
#### Monitor Features
The training monitor provides three main views:
##### 1. Overview Tab

- **Training Progress**: Real-time metrics including epoch, batch, speed, and time estimates
- **Loss Chart**: Training and validation loss visualization
- **F1 Scores**: Initial and refined F1 metrics for both training and validation
##### 2. Predictions Tab

- **Image Preview**: Shows the current sample being analyzed
- **Prediction Controls**: Toggle between initial and refined predictions
- **Tag Analysis**:
- Color-coded tag results (correct, incorrect, missing)
- Confidence visualization with probability bars
- Category-based organization
- Filtering options for error analysis
##### 3. Selection Analysis Tab

- **Selection Metrics**: Statistics on tag selection quality
- Ground truth recall
- Average probability for ground truth vs. non-ground truth tags
- Unique tags selected
- **Selection Graph**: Trends in selection quality over time
- **Selected Tags Details**: Detailed view of model-selected tags with confidence scores
The monitor provides invaluable insights into how the two-stage prediction model is performing, particularly how the tag selection process is working between the initial and refined prediction stages.
### Training Notes
- Training notebooks require WSL and likely 32GB+ of RAM to handle the dataset
- Microsoft DeepSpeed was crucial for fitting the model and batches into the available VRAM
- Despite hardware limitations, the model achieves impressive results
- With more computational resources, the model could be trained longer on the full dataset
## Support:
I plan to move onto LLMs after this project as I have lots of ideas on how to improve upon them. I will update this model based on community attention.
If you'd like to support further training on the complete dataset or my future projects, consider [buying me a coffee](https://www.buymeacoffee.com/yourusername).
## Acknowledgments
- [Danbooru](https://danbooru.donmai.us/) for the incredible dataset of tagged anime images
- [p1atdev](https://huggingface.co/p1atdev) for the processed Danbooru 2024 dataset
- Microsoft for DeepSpeed, which made training possible on consumer hardware
- PyTorch and the open-source ML community |