#!/usr/bin/env python3 """ Enhanced Essence Generator for Tag Collector Game """ import torch import torch.nn as nn import torch.nn.functional as F from torchvision.transforms.functional import to_pil_image from PIL import Image import numpy as np import os import re import math import json import streamlit as st from tqdm import tqdm from scipy.ndimage import gaussian_filter from functools import wraps import time import tag_storage # Import for saving game state from game_constants import RARITY_LEVELS, ENKEPHALIN_CURRENCY_NAME, ENKEPHALIN_ICON from tag_categories import TAG_CATEGORIES # Define essence quality levels with thresholds and styles ESSENCE_QUALITY_LEVELS = { "ZAYIN": {"threshold": 0.0, "color": "#1CFC00", "description": "Basic representation with minimal details."}, "TETH": {"threshold": 3.0, "color": "#389DDF", "description": "Clear representation with recognizable features."}, "HE": {"threshold": 5.0, "color": "#FEF900", "description": "Refined representation with distinctive elements."}, "WAW": {"threshold": 10.0, "color": "#7930F1", "description": "Advanced representation with precise details."}, "ALEPH": {"threshold": 12.0, "color": "#FF0000", "description": "Perfect representation with extraordinary precision."} } # Essence generation costs in enkephalin based on tag rarity ESSENCE_COSTS = { "Special": 0, "Canard": 100, # Common tags "Urban Myth": 125, # Uncommon tags "Urban Legend": 150, # Rare tags "Urban Plague": 200, # Very rare tags "Urban Nightmare": 250, # Extremely rare tags "Star of the City": 300, # Nearly mythical tags "Impuritas Civitas": 400 # Legendary tags } # Default essence generation settings DEFAULT_ESSENCE_SETTINGS = { "scales": 1, # Number of scales for multiscale optimization "iterations": 256, # Iterations per scale "image_size": 512, # Always use 512x512 resolution "lr": 0.1, # Learning rate "layer_emphasis": "auto" # Default to auto-detection } def initialize_essence_settings(): """Initialize essence generator settings if not already present""" if 'essence_custom_settings' not in st.session_state: # Try to load from storage first loaded_state = tag_storage.load_essence_state() if loaded_state and 'essence_custom_settings' in loaded_state: st.session_state.essence_custom_settings = loaded_state['essence_custom_settings'] else: st.session_state.essence_custom_settings = DEFAULT_ESSENCE_SETTINGS.copy() # Replace initialize_manual_tags with: def initialize_manual_tags(): """Initialize manual tags if not already present""" if 'manual_tags' not in st.session_state: # Try to load from storage first loaded_state = tag_storage.load_essence_state() if loaded_state and 'manual_tags' in loaded_state: st.session_state.manual_tags = loaded_state['manual_tags'] else: st.session_state.manual_tags = { "hatsune_miku": {"rarity": "Special", "description": "Popular virtual singer with long teal twin-tails"}, } def timeout(seconds, fallback_value=None): """ Simple timeout utility for functions. Warns if a function takes longer than expected but doesn't interrupt it. Args: seconds: Expected maximum seconds the function should take fallback_value: Not used, just for API compatibility """ def decorator(func): @wraps(func) def wrapper(*args, **kwargs): start_time = time.time() result = func(*args, **kwargs) elapsed = time.time() - start_time if elapsed > seconds: print(f"WARNING: Function {func.__name__} took {elapsed:.2f} seconds (expected max {seconds}s)") return result return wrapper return decorator # Core Classes for Essence Generation class LayerHook: """Helper class to store the outputs of a layer via forward hook.""" def __init__(self, layer): self.layer = layer self.features = None self.hook = layer.register_forward_hook(self.hook_fn) def hook_fn(self, module, input, output): self.features = output def close(self): self.hook.remove() class FullModelHook: """Hook all layers in a model and track their responses to inputs.""" def __init__(self, model): self.model = model self.hooks = {} self.activations = {} self.layer_scores = {} # Recursively register hooks for all eligible layers self._register_hooks(model) print(f"FullModelHook initialized with {len(self.hooks)} hooks") def _register_hooks(self, module, prefix=''): """Recursively register hooks on all suitable layers.""" for name, child in module.named_children(): layer_name = f"{prefix}.{name}" if prefix else name # Only hook layers that produce activations # Avoid hooking containers like Sequential if isinstance(child, (torch.nn.Conv2d, torch.nn.Linear, torch.nn.BatchNorm2d, torch.nn.LayerNorm)): self.hooks[layer_name] = child.register_forward_hook( lambda m, inp, out, layer=layer_name: self._hook_fn(layer, out) ) # Recurse into children self._register_hooks(child, layer_name) def _hook_fn(self, layer_name, output): """Store activations for each layer.""" # For convolutional layers, compute channel-wise mean activations if len(output.shape) == 4: # [batch, channels, height, width] # Store mean activation per channel self.activations[layer_name] = output.mean(dim=[2, 3]).detach() else: # For other layers, store as is self.activations[layer_name] = output.detach() class EssenceGenerator: """ Enhanced Essence Generator optimized for anime characters. Includes improvements for more vibrant colors and recognizable features. """ def __init__( self, model, tag_to_name=None, iterations=256, scales=3, learning_rate=0.03, # Lower learning rate for better convergence decay_power=1.5, # Stronger emphasis on low frequencies tv_weight=5e-4, # Stronger total variation for clearer structures layers_to_hook=None, layer_weights=None, color_boost=1.5 # Color boosting factor ): """Initialize the Enhanced Essence Generator""" self.model = model self.tag_to_name = tag_to_name self.iterations = iterations self.scales = scales self.lr = learning_rate self.decay_power = decay_power self.tv_weight = tv_weight self.layers_to_hook = layers_to_hook self.layer_weights = layer_weights self.color_boost = color_boost # Set device self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.model.eval().to(self.device) # Initialize hooks self.hooks = {} # Enhanced color correlation matrix for anime-style colors # More saturated colors with stronger correlations self.color_correlation_matrix = torch.tensor([ [1.0000, 0.9522, 0.9156], [0.9522, 1.0000, 0.9708], [0.9156, 0.9708, 1.0000]], device=self.device) # Setup hooks if specified if self.layers_to_hook: self.setup_hooks(self.layers_to_hook) def setup_hooks(self, layers_to_hook): """Setup hooks on the specified layers.""" # Close any existing hooks self.close_hooks() # Create new hooks for layer_name in layers_to_hook: try: # Try to get layer by navigating the model hierarchy parts = layer_name.split('.') layer = self.model for part in parts: layer = getattr(layer, part) self.hooks[layer_name] = LayerHook(layer) print(f"Setup hook for layer: {layer_name}") except Exception as e: print(f"Failed to setup hook for {layer_name}: {e}") def setup_auto_hooks(self, tag_idx): """ Automatically detect the most responsive layers for a specific tag. This simplified version selects a few key layers based on model architecture. """ # Close any existing hooks self.close_hooks() # If we already have layer weights from initialization, use those if self.layers_to_hook and self.layer_weights: for layer_name in self.layers_to_hook: try: # Try to get layer by navigating the model hierarchy parts = layer_name.split('.') layer = self.model for part in parts: layer = getattr(layer, part) self.hooks[layer_name] = LayerHook(layer) print(f"Setup hook for layer: {layer_name}") except Exception as e: print(f"Failed to setup hook for {layer_name}: {e}") return self.layer_weights # Otherwise, detect layers automatically # Get all named modules all_layers = [] for name, module in self.model.named_modules(): if not name: # Skip empty name (the model itself) continue # Only consider certain layer types that typically have meaningful features if isinstance(module, (nn.Conv2d, nn.Linear, nn.BatchNorm2d)): all_layers.append((name, module)) # If the model is too large, select strategic layers selected_layers = [] layer_weights = {} if len(all_layers) > 30: # For large models, select a subset of layers # 1. Try to find classifier/final layer classifier_layers = [(name, module) for name, module in all_layers if any(x in name.lower() for x in ["classifier", "fc", "linear", "output", "logits"])] if classifier_layers: selected_layers.append(classifier_layers[-1]) layer_weights[classifier_layers[-1][0]] = 1.0 # Highest weight # 2. Find some mid to late convolutional layers conv_layers = [(name, module) for name, module in all_layers if isinstance(module, nn.Conv2d)] if conv_layers: # Take some layers from the second half half_idx = len(conv_layers) // 2 selected_idx = [half_idx, 3*len(conv_layers)//4, -1] # middle, 3/4, and last for idx in selected_idx: if idx < len(conv_layers) and conv_layers[idx] not in selected_layers: selected_layers.append(conv_layers[idx]) # Later layers get higher weights pos = selected_idx.index(idx) layer_weights[conv_layers[idx][0]] = 0.5 + 0.5 * (pos / max(1, len(selected_idx) - 1)) else: # For smaller models, use more layers # Take a sample across the network depth step = max(1, len(all_layers) // 5) indices = list(range(0, len(all_layers), step)) if len(all_layers) - 1 not in indices: indices.append(len(all_layers) - 1) # Always include the last layer for idx in indices: selected_layers.append(all_layers[idx]) # Later layers get higher weights layer_weights[all_layers[idx][0]] = 0.5 + 0.5 * (idx / max(1, len(all_layers) - 1)) # Create hooks for selected layers print(f"Setting up {len(selected_layers)} auto-detected layers:") for name, module in selected_layers: self.hooks[name] = LayerHook(module) print(f" - {name} (weight: {layer_weights.get(name, 0.5):.2f})") return layer_weights def close_hooks(self): """Clean up hooks to avoid memory leaks.""" for hook in self.hooks.values(): hook.close() self.hooks.clear() def total_variation_loss(self, img): """ Total variation loss for smoother images but preserving edges. Modified version that better preserves strong edges. """ diff_y = torch.abs(img[:, :, 1:, :] - img[:, :, :-1, :]) diff_x = torch.abs(img[:, :, :, 1:] - img[:, :, :, :-1]) # Use a more gentle version of total variation that allows for edges # but still penalizes noise (using square root reduces the penalty on large differences) tv = torch.mean(torch.sqrt(diff_y + 1e-8)) + torch.mean(torch.sqrt(diff_x + 1e-8)) return tv def create_fft_spectrum_initializer(self, size, batch_size=1): """Enhanced frequency domain initialization for better essence generation with more color and coverage""" fft_size = size // 2 + 1 # Initialize frequency components with a natural image prior # This biases toward more natural-looking essences spectrum_scale = torch.zeros(batch_size, 3, size, fft_size, 2, device=self.device) # Use 1/f spectrum characteristic of natural images (pink noise) for h in range(size): for w in range(fft_size): # Calculate distance from DC component dist = np.sqrt((h/size)**2 + (w/fft_size)**2) + 1e-5 # Pink noise falls off as 1/f, but with higher amplitude for better initial coverage weight = 1.0 / dist # Add random phase but weighted amplitude - increased amplitude for better initial values spectrum_scale[:, :, h, w, 0] = torch.randn(batch_size, 3, device=self.device) * weight * 0.15 # Increased from 0.05 spectrum_scale[:, :, h, w, 1] = torch.randn(batch_size, 3, device=self.device) * weight * 0.15 # Increased from 0.05 # Initialize DC component (average color) with higher values for better color saturation # Give distinct colors to each channel for more vibrant initialization spectrum_scale[:, 0, 0, 0, 0] = 0.5 # Red channel spectrum_scale[:, 1, 0, 0, 0] = 0.4 # Green channel spectrum_scale[:, 2, 0, 0, 0] = 0.6 # Blue channel spectrum_scale[:, :, 0, 0, 1] = 0 spectrum_scale.requires_grad_(True) # Phase component - increased for more variation spectrum_shift = torch.randn(batch_size, 3, size, fft_size, 2, device=self.device) * 0.05 # Increased from 0.02 spectrum_shift.requires_grad_(True) return spectrum_scale, spectrum_shift def create_spectrum_weights(self, size, decay_power=1.0): """Create weights for the spectrum that emphasize lower frequencies but preserve more details.""" freqs_x = torch.fft.rfftfreq(size).view(1, -1).to(self.device) freqs_y = torch.fft.fftfreq(size).view(-1, 1).to(self.device) dist_from_center = torch.sqrt(freqs_x**2 + freqs_y**2) # Modified weight calculation that allows for more mid-frequency details # and preserves more high frequencies for better detail weights = 1.0 / (dist_from_center + 1e-8) ** decay_power # Significantly increase mid-frequency weights for more details and coverage mid_freq_mask = (dist_from_center > 0.05) & (dist_from_center < 0.3) weights = weights * (1.0 + 1.0 * mid_freq_mask.float()) # Increased from 0.5 # Add some weight to high frequencies for texture details high_freq_mask = (dist_from_center >= 0.3) & (dist_from_center < 0.7) weights = weights * (1.0 + 0.3 * high_freq_mask.float()) # New addition weights = weights / weights.max() weights[0, 0] = 0.8 # Higher DC component for better color coherence (increased from 0.7) return weights def fft_to_rgb(self, spectrum_scale, spectrum_shift, size, spectrum_weight=None): """Convert FFT spectrum parameters to an RGB image.""" batch_size = spectrum_scale.shape[0] if spectrum_weight is not None: spectrum_scale = spectrum_scale * spectrum_weight.unsqueeze(0).unsqueeze(0).unsqueeze(-1) image = torch.zeros(batch_size, 3, size, size, device=self.device) spectrum_complex = torch.complex( spectrum_scale[..., 0], spectrum_scale[..., 1] ) phase_shift = torch.complex( torch.cos(spectrum_shift[..., 0]), torch.sin(spectrum_shift[..., 1]) ) spectrum_complex = spectrum_complex * phase_shift for b in range(batch_size): for c in range(3): channel_spectrum = spectrum_complex[b, c] channel_image = torch.fft.irfft2(channel_spectrum, s=(size, size)) channel_min = channel_image.min() channel_max = channel_image.max() if channel_max > channel_min: channel_image = (channel_image - channel_min) / (channel_max - channel_min) else: channel_image = torch.zeros_like(channel_image) image[b, c] = channel_image return image def apply_color_correlation(self, image): """ Apply color correlation to produce more vibrant, colorful images. """ batch_size, _, height, width = image.shape # 1. Apply basic color correlation flat_image = image.view(batch_size, 3, -1) correlated = torch.matmul(self.color_correlation_matrix, flat_image) correlated_image = correlated.view(batch_size, 3, height, width) # 2. Apply stronger color boost (increase saturation) # Calculate luminance (0.3R + 0.59G + 0.11B) luminance = 0.3 * image[:, 0:1] + 0.59 * image[:, 1:2] + 0.11 * image[:, 2:3] # Boost colors away from luminance (enhances saturation) # Increase the boost factor for more vibrant colors boosted_image = luminance + (correlated_image - luminance) * self.color_boost * 1.5 # Apply additional 1.5x boost # 3. Apply a gentle S-curve for better contrast # This helps make the colors "pop" more boosted_image = 0.5 + torch.tanh((boosted_image - 0.5) * 2) * 0.5 # Ensure values are in [0, 1] range boosted_image = torch.clamp(boosted_image, 0, 1) return boosted_image def apply_transforms(self, img): """Apply random transformations to the image for robustness.""" batch_size, c, h, w = img.shape # 1. Padding pad = 16 padded = F.pad(img, (pad, pad, pad, pad), mode='reflect') # 2. Random jitter jitter = 16 h_jitter = torch.randint(-jitter, jitter + 1, (batch_size,), device=self.device) w_jitter = torch.randint(-jitter, jitter + 1, (batch_size,), device=self.device) # Create sampling grid rows = torch.arange(h, device=self.device).view(1, 1, -1, 1).repeat(batch_size, 1, 1, w) cols = torch.arange(w, device=self.device).view(1, 1, 1, -1).repeat(batch_size, 1, h, 1) rows = rows + h_jitter.view(-1, 1, 1, 1) + pad cols = cols + w_jitter.view(-1, 1, 1, 1) + pad # Get transformed image (simplified implementation) grid_h = torch.clamp(rows, 0, padded.shape[2] - 1).long() grid_w = torch.clamp(cols, 0, padded.shape[3] - 1).long() # Apply second jitter for more randomness jitter2 = 8 h_jitter2 = torch.randint(-jitter2, jitter2 + 1, (batch_size,), device=self.device) w_jitter2 = torch.randint(-jitter2, jitter2 + 1, (batch_size,), device=self.device) grid_h = torch.clamp(grid_h + h_jitter2.view(-1, 1, 1, 1), 0, padded.shape[2] - 1).long() grid_w = torch.clamp(grid_w + w_jitter2.view(-1, 1, 1, 1), 0, padded.shape[3] - 1).long() # Gather values transformed = torch.zeros_like(img) for b in range(batch_size): transformed[b] = padded[b, :, grid_h[b, 0], grid_w[b, 0]] return transformed def add_spatial_prior(self, img, strength=0.05): """ Add a spatial prior to encourage character-like structures with better composition and fuller image coverage. """ batch_size, c, h, w = img.shape # Create normalized coordinate grids y_indices = torch.arange(h, device=self.device).float() x_indices = torch.arange(w, device=self.device).float() y = (2.0 * y_indices / h) - 1.0 # Normalize to [-1, 1] x = (2.0 * x_indices / w) - 1.0 # Normalize to [-1, 1] # Expand to 2D grid y_grid = y.view(-1, 1).repeat(1, w) x_grid = x.view(1, -1).repeat(h, 1) # Center bias (gentler in the middle to allow more coverage) center_dist = torch.sqrt(x_grid.pow(2) + y_grid.pow(2)) # Wider center bias for better coverage center_value = torch.exp(-0.8 * center_dist) # Reduced from -1.5 for wider coverage # Full-image utilization bias (higher values further from edge) edge_dist_x = torch.min(torch.abs(x_grid - 1.0), torch.abs(x_grid + 1.0)) edge_dist_y = torch.min(torch.abs(y_grid - 1.0), torch.abs(y_grid + 1.0)) edge_dist = torch.min(edge_dist_x, edge_dist_y) edge_value = torch.clamp(edge_dist * 5.0, 0.2, 1.0) # Higher values away from edges # Rule of thirds with wider peaks (subtle enhancement at thirds points) thirds_x = torch.exp(-10 * (x_grid - 1/3).pow(2)) + torch.exp(-10 * (x_grid + 1/3).pow(2)) # Reduced from -30 thirds_y = torch.exp(-10 * (y_grid - 1/3).pow(2)) + torch.exp(-10 * (y_grid + 1/3).pow(2)) # Reduced from -30 thirds_value = (thirds_x + thirds_y) / 2 # Combine the different priors with rebalanced weights to favor coverage prior = 0.4 * center_value + 0.4 * edge_value + 0.2 * thirds_value # More weight on edge_value for better coverage # Normalize the prior prior = prior / prior.max() # Expand to match the input dimensions prior = prior.unsqueeze(0).unsqueeze(0) prior = prior.repeat(batch_size, c, 1, 1) # Apply the prior with increased strength result = img * (1.0 - strength*1.5) + prior * strength*1.5 # Increased strength by 50% return result def get_layer_activations(self, tag_idx, layer_weights): """ Get activations from all hooked layers for the target tag. Returns a weighted sum of activations based on layer weights. """ activation_sum = 0.0 for layer_name, hook in self.hooks.items(): if hook.features is None: continue # Get weight for this layer weight = layer_weights.get(layer_name, 0.5) # Handle different layer types if len(hook.features.shape) <= 2: # For fully connected layers, focus on the target class logit if hook.features.size(1) > tag_idx: activation = hook.features[0, tag_idx].item() activation_sum += weight * activation else: # For convolutional layers, focus on overall activation strength channel_means = hook.features.mean(dim=[2, 3]) # Make sure we don't request more channels than exist num_channels = min(5, channel_means.size(1)) _, top_indices = torch.topk(channel_means, num_channels) # Process each top channel individually for idx in range(min(3, len(top_indices[0]))): # Use up to 3 top channels channel_idx = top_indices[0, idx] channel_activation = hook.features[:, channel_idx].mean().item() # Weight decreases for less important channels channel_weight = 1.0 if idx == 0 else (0.5 if idx == 1 else 0.25) activation_sum += weight * channel_activation * channel_weight return activation_sum def generate_essence(self, tag_idx, image_size=512, return_score=True, progress_callback=None): """Generate an essence visualization using enhanced techniques.""" # Get tag name for logging (if available) tag_name = self.tag_to_name.get(tag_idx, f"Tag {tag_idx}") if self.tag_to_name else f"Tag {tag_idx}" print(f"Generating enhanced essence for '{tag_name}'...") # Auto-detect and set up hooks for responsive layers if not already set up layer_weights = self.layer_weights or self.setup_auto_hooks(tag_idx) # Determine scale sizes scale_sizes = [] for s in range(self.scales): # Start small and progressively increase size scale_size = max(32, image_size // (2 ** (self.scales - s - 1))) scale_sizes.append(scale_size) print(f"Processing scales: {scale_sizes}") # Create frequency spectrum weights spectrum_weights = {} for size in scale_sizes: spectrum_weights[size] = self.create_spectrum_weights(size, decay_power=self.decay_power) # Track best result best_score = -float('inf') best_img = None # Process each scale independently for scale_idx, size in enumerate(scale_sizes): # Initialize parameters for this scale spectrum_scale, spectrum_shift = self.create_fft_spectrum_initializer(size) # Create optimizer optimizer = torch.optim.Adam([spectrum_scale, spectrum_shift], lr=self.lr) # Use learning rate scheduler for better convergence scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( optimizer, T_max=self.iterations, eta_min=self.lr * 0.1 ) # Current scale's spectrum weights current_weights = spectrum_weights[size] # Iterations for this scale iterations = self.iterations # Epoch tracking for early stopping no_improvement_streak = 0 plateau_threshold = 128 # Increased from 64 - give more iterations before stopping scale_best_score = -float('inf') scale_best_img = None for i in range(iterations): # Clear gradients optimizer.zero_grad() # Convert FFT parameters to RGB image img = self.fft_to_rgb(spectrum_scale, spectrum_shift, size, current_weights) # # Apply color correlation with boosted colors for anime-style # img = self.apply_color_correlation(img) # Add spatial prior to encourage character-like patterns with better coverage if size >= 32: # Apply stronger spatial prior for better coverage img = self.add_spatial_prior(img, strength=0.25) # Increased from 0.15 # Apply transformations for robustness if size >= 32: img = self.apply_transforms(img) # Reset hooks for hook in self.hooks.values(): hook.features = None # Forward pass outputs = self.model(img) # Get target tag activation in final layer if isinstance(outputs, (list, tuple)): predictions = outputs[0] else: predictions = outputs # Get tag activation from final layer tag_activation = predictions[0, tag_idx] # Get activations from earlier layers layer_activation = self.get_layer_activations(tag_idx, layer_weights) # Combined loss to maximize activations # Increase weight on layer_activation for better feature emphasis activation_loss = -(tag_activation + 1.5 + layer_activation * 2.0) # Double layer activation weight # Add regularization term - total variation for smoothness # Reduce TV weight slightly to allow more details and coverage tv_loss = self.total_variation_loss(img) * (self.tv_weight * 0.7) # Reduce to 70% of original # Total loss total_loss = activation_loss + tv_loss # Backpropagation total_loss.backward() # Update parameters optimizer.step() scheduler.step() # Track best result for this scale current_score = tag_activation.item() if current_score > scale_best_score + 1e-4: scale_best_score = current_score scale_best_img = img.detach().clone() no_improvement_streak = 0 else: no_improvement_streak += 1 # More patient early stopping if no_improvement_streak >= plateau_threshold: print(f"Early stopping at iteration {i}/{iterations} due to plateau") break # Report progress if progress_callback and i % max(1, iterations // 10) == 0: progress_callback( scale_idx=scale_idx, scale_count=len(scale_sizes), iter_idx=i, iter_count=iterations, score=current_score ) print(f"Scale {scale_idx+1}/{len(scale_sizes)} completed. Score: {scale_best_score:.4f}") # Update overall best if this scale improved the score if scale_best_score > best_score: best_score = scale_best_score # For the final scale, keep the full resolution image if scale_idx == len(scale_sizes) - 1: best_img = scale_best_img # Otherwise, upscale to the final size for the return value else: with torch.no_grad(): best_img = F.interpolate(scale_best_img, size=(image_size, image_size), mode='bilinear', align_corners=False) # In case all scales failed, create an empty image if best_img is None: final_img = torch.zeros((1, 3, image_size, image_size), device=self.device) else: final_img = best_img # Convert to PIL image pil_img = to_pil_image(final_img[0].cpu()) # Clean up hooks self.close_hooks() if return_score: return pil_img, best_score else: return pil_img # Utility Functions for Model Analysis and Layer Selection def get_model_layers(model): """Utility function to get all available layers in a model.""" layers = [] for name, _ in model.named_modules(): if name: # Skip empty name (the model itself) layers.append(name) return layers def get_key_layers(model, max_layers=15): """ Get a curated list of the most relevant layers for visualization. """ all_layers = get_model_layers(model) # For models with hundreds of layers, we need to be selective if len(all_layers) > 30: # Extract patterns to identify layer types block_patterns = {} # Find common patterns in layer names for layer in all_layers: # Extract the main component (e.g., "backbone.features") parts = layer.split(".") if len(parts) >= 2: prefix = ".".join(parts[:2]) if prefix not in block_patterns: block_patterns[prefix] = [] block_patterns[prefix].append(layer) # Now select representative layers from each major block key_layers = { "early": [], "middle": [], "late": [] } # For each major block, select layers at strategic positions for prefix, layers in block_patterns.items(): if len(layers) > 3: # Only process significant blocks # Sort by natural depth (assuming numerical components indicate depth) layers.sort(key=lambda x: [int(s) if s.isdigit() else s for s in re.findall(r'\d+|\D+', x)]) # Get layers at strategic positions early = layers[0] middle = layers[len(layers) // 2] late = layers[-1] key_layers["early"].append(early) key_layers["middle"].append(middle) key_layers["late"].append(late) # Ensure we don't have too many layers # If we need to reduce further, prioritize middle and late layers flattened = [] for _, group_layers in key_layers.items(): flattened.extend(group_layers) if len(flattened) > max_layers: # Calculate how many to keep from each group total = len(flattened) # Prioritize keeping late layers (for character recognition) late_count = min(len(key_layers["late"]), max_layers // 3) # Allocate remaining slots between early and middle remaining = max_layers - late_count middle_count = min(len(key_layers["middle"]), remaining // 2) early_count = min(len(key_layers["early"]), remaining - middle_count) # Take only the needed number from each category key_layers["early"] = key_layers["early"][:early_count] key_layers["middle"] = key_layers["middle"][:middle_count] key_layers["late"] = key_layers["late"][:late_count] else: # For simpler models, use standard distribution n = len(all_layers) key_layers = { "early": all_layers[:n//3][:3], # First few layers "middle": all_layers[n//3:2*n//3][:4], # Middle layers "late": all_layers[2*n//3:][:3] # Last few layers } # Try to identify the classifier/final layer classifier_layers = [layer for layer in all_layers if any(x in layer.lower() for x in ["classifier", "fc", "linear", "output", "logits", "head"])] if classifier_layers: key_layers["classifier"] = [classifier_layers[-1]] return key_layers def get_suggested_layers(model, layer_type="balanced"): """ Get suggested layers based on the desired feature type. """ key_layers = get_key_layers(model) # Flatten all layers for reference all_key_layers = [] for layers in key_layers.values(): all_key_layers.extend(layers) # Choose layers based on the requested emphasis if layer_type == "low": # Focus on early visual features (textures, patterns, colors) selected = key_layers.get("early", []) # Add one middle layer for stability if "middle" in key_layers and key_layers["middle"]: selected.append(key_layers["middle"][0]) elif layer_type == "mid": # Focus on mid-level features (parts, components) selected = key_layers.get("middle", []) # Add one early layer for context if "early" in key_layers and key_layers["early"]: selected.append(key_layers["early"][-1]) elif layer_type == "high": # Focus on high-level semantic features (objects, characters) selected = key_layers.get("late", []) selected.extend(key_layers.get("classifier", [])) # Add one middle layer for context if "middle" in key_layers and key_layers["middle"]: selected.append(key_layers["middle"][-1]) else: # balanced # Use a mix of early, middle and late layers selected = [] for category in ["early", "middle", "late", "classifier"]: if category in key_layers and key_layers[category]: # Take one from each category selected.append(key_layers[category][0]) # For middle and late, also take the last one if different if category in ["middle", "late"] and len(key_layers[category]) > 1: selected.append(key_layers[category][-1]) # Ensure we have at least one layer if not selected and all_key_layers: selected = [all_key_layers[-1]] # Use the last layer as fallback return selected def get_quality_level(score): """ Determine the quality level of an essence based on its score """ for level in reversed(list(ESSENCE_QUALITY_LEVELS.keys())): if score >= ESSENCE_QUALITY_LEVELS[level]["threshold"]: return level return "ZAYIN" # Default to lowest level def get_essence_cost(rarity): """ Calculate the cost to generate an essence image based on tag rarity """ return ESSENCE_COSTS.get(rarity, 100) # Default to 100 if rarity unknown # Game UI and Integration Functions def initialize_essence_settings(): """Initialize essence generator settings if not already present""" if 'essence_custom_settings' not in st.session_state: st.session_state.essence_custom_settings = DEFAULT_ESSENCE_SETTINGS.copy() def save_essence_to_game_folder(image, tag, score, quality_level): """ Save the generated essence image to a persistent game folder Args: image: PIL Image of the essence tag: The tag name score: The generation score quality_level: The quality classification (ZAYIN, TETH, etc.) Returns: Path to the saved image """ # Create game folder if it doesn't exist game_folder = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "game_data") essence_folder = os.path.join(game_folder, "essences") # Make directories if they don't exist os.makedirs(essence_folder, exist_ok=True) # Create filename with quality level and score safe_tag = tag.replace('/', '_').replace('\\', '_').replace(' ', '_') timestamp = time.strftime("%Y%m%d_%H%M%S") filename = f"{safe_tag}_{quality_level}_{score:.2f}_{timestamp}.png" filepath = os.path.join(essence_folder, filename) # Save the image image.save(filepath) return filepath def generate_essence_for_tag(tag, model, dataset, custom_settings=None): """ Generate an essence image for a specific tag using the improved generator Args: tag: The tag name or index model: The model to use dataset: The dataset containing tag information custom_settings: Optional dictionary with custom generation settings Returns: PIL Image of the generated essence, score, quality level """ print(f"\n=== Starting essence generation for tag '{tag}' ===") # Check if tag is discovered or a manual tag is_manual_tag = hasattr(st.session_state, 'manual_tags') and tag in st.session_state.manual_tags is_discovered = hasattr(st.session_state, 'discovered_tags') and tag in st.session_state.discovered_tags if not is_discovered and not is_manual_tag: st.error(f"Tag '{tag}' has not been discovered yet.") return None, 0, None # Get tag rarity and calculate cost if is_discovered: rarity = st.session_state.discovered_tags[tag].get("rarity", "Canard") elif is_manual_tag: rarity = st.session_state.manual_tags[tag].get("rarity", "Canard") else: rarity = "Canard" # Calculate cost based on rarity cost = get_essence_cost(rarity) # Check if player has enough Enkephalin if st.session_state.enkephalin < cost: st.error(f"Not enough {ENKEPHALIN_CURRENCY_NAME} to generate this essence. You need {cost} {ENKEPHALIN_ICON} but have {st.session_state.enkephalin} {ENKEPHALIN_ICON}.") return None, 0, None # Use provided settings or defaults settings = custom_settings or DEFAULT_ESSENCE_SETTINGS.copy() print(f"Using settings: {settings}") # Extract settings iterations = settings.get("iterations", 256) scales = settings.get("scales", 5) layer_emphasis = settings.get("layer_emphasis", "auto") # UI containers for progress preview_container = st.empty() progress_container = st.empty() message_container = st.empty() # If multiple layer emphasis types are requested, show tabs if layer_emphasis == "compare": message_container.info("Generating essences with different layer emphasis types...") tabs_container = st.empty() tabs = None tab_images = {} best_score = -float('inf') best_image = None best_emphasis = None try: # Show generation information if layer_emphasis != "compare": message_container.info(f"Generating essence for '{tag}' with {layer_emphasis} layer emphasis...") # Progress callback function for essence generation def progress_callback(scale_idx, scale_count, iter_idx, iter_count, score): # Update progress bar progress = ((scale_idx * iter_count) + iter_idx) / (scale_count * iter_count) progress_container.progress(progress, f"Scale {scale_idx+1}/{scale_count}, Iteration {iter_idx}/{iter_count}") message_container.info(f"Current score: {score:.4f}") # Print status to console too if iter_idx % 20 == 0: print(f"Progress: Scale {scale_idx+1}/{scale_count}, Iteration {iter_idx}/{iter_count}, Score: {score:.4f}") # Convert tag name to index if needed tag_idx = None # Try to find tag in various places if isinstance(tag, str): print(f"Converting tag name '{tag}' to index...") # Standard lookup methods if hasattr(dataset, 'tag_to_idx') and tag in dataset.tag_to_idx: tag_idx = dataset.tag_to_idx[tag] print(f"Found tag index from dataset.tag_to_idx: {tag_idx}") # Session state metadata lookup if tag_idx is None and hasattr(st.session_state, 'metadata') and 'tag_to_idx' in st.session_state.metadata: tag_idx = st.session_state.metadata['tag_to_idx'].get(tag) if tag_idx is not None: print(f"Found tag index from session_state.metadata['tag_to_idx']: {tag_idx}") # Lookup from idx_to_tag if tag_idx is None and hasattr(st.session_state, 'metadata') and 'idx_to_tag' in st.session_state.metadata: idx_to_tag = st.session_state.metadata['idx_to_tag'] tag_to_idx = {v: int(k) for k, v in idx_to_tag.items()} tag_idx = tag_to_idx.get(tag) if tag_idx is not None: print(f"Found tag index from inverted idx_to_tag: {tag_idx}") # Try case-insensitive if tag_idx is None: tag_lower = tag.lower() for t, idx in tag_to_idx.items(): if t.lower() == tag_lower: tag_idx = idx print(f"Found tag index using case-insensitive match: {tag_idx}") break # For manual tags that aren't in the model's tag list, # we might need to find a semantically similar tag or use a generic index if tag_idx is None and is_manual_tag: # For demonstration, we could map manual tags to known similar tags manual_tag_mapping = { "hatsune_miku": "hatsune_miku", # Try to find this in the dataset "lamp": "lamp", # Try to find this in the dataset "blue_gloves": "gloves", # Fallback to a more generic tag } fallback_tag = manual_tag_mapping.get(tag) if fallback_tag: # Try to find the fallback tag if hasattr(dataset, 'tag_to_idx') and fallback_tag in dataset.tag_to_idx: tag_idx = dataset.tag_to_idx[fallback_tag] print(f"Using fallback tag '{fallback_tag}' with index: {tag_idx}") # Try session state metadata if tag_idx is None and hasattr(st.session_state, 'metadata') and 'tag_to_idx' in st.session_state.metadata: tag_idx = st.session_state.metadata['tag_to_idx'].get(fallback_tag) if tag_idx is not None: print(f"Using fallback tag '{fallback_tag}' with index: {tag_idx}") # If still not found, use a generic index (this is a last resort) if tag_idx is None: # Try to use a category-specific generic tag if "hair" in tag.lower(): generic_tag = "blue_hair" # A common tag that might be in the model elif "gloves" in tag.lower(): generic_tag = "gloves" elif "miku" in tag.lower(): generic_tag = "twintails" # A feature of Hatsune Miku else: generic_tag = "1girl" # A very common tag # Try to find this generic tag if hasattr(dataset, 'tag_to_idx') and generic_tag in dataset.tag_to_idx: tag_idx = dataset.tag_to_idx[generic_tag] print(f"Using generic tag '{generic_tag}' with index: {tag_idx}") elif hasattr(st.session_state, 'metadata') and 'tag_to_idx' in st.session_state.metadata: tag_idx = st.session_state.metadata['tag_to_idx'].get(generic_tag) if tag_idx is not None: print(f"Using generic tag '{generic_tag}' with index: {tag_idx}") # If still not found, show error if tag_idx is None: st.error(f"Tag '{tag}' index not found. Cannot generate essence.") print(f"ERROR: Tag '{tag}' index not found") return None, 0, None else: # Tag is already an index tag_idx = tag print(f"Using provided tag index: {tag_idx}") # Generate the essence - either one or multiple depending on settings if layer_emphasis == "compare": # Generate essences with different layer emphasis types results = try_different_layer_emphasis( model=model, tag_idx=tag_idx, tag_name=tag, image_size=512, # Always use 512x512 iterations=iterations, scales=scales, progress_callback=progress_callback ) # Create tabs to display the results tab_names = [] tab_contents = [] for emphasis_type, result in results.items(): image = result["image"] score = result["score"] # Store results tab_images[emphasis_type] = image # Track best score if score > best_score: best_score = score best_image = image best_emphasis = emphasis_type # Add to tabs tab_names.append(f"{emphasis_type.capitalize()} ({score:.2f})") tab_contents.append(image) # Show tabs with results tabs = tabs_container.tabs(tab_names) for i, tab in enumerate(tabs): with tab: st.image(tab_contents[i], caption=f"Essence with {list(results.keys())[i]} layer emphasis", use_container_width=True) # Use the best-scored image as the final result image = best_image score = best_score # Show which emphasis type worked best st.success(f"Best results achieved with {best_emphasis} layer emphasis (score: {best_score:.2f})") else: # Generate single essence with specified layer emphasis color_boost = 1.5 # Default color boost tv_weight = 5e-4 # Default TV weight # Adjust parameters based on layer emphasis if layer_emphasis == "low": color_boost = 1.3 tv_weight = 2e-4 elif layer_emphasis == "high": color_boost = 1.7 tv_weight = 8e-4 image, score = generate_essence_with_emphasis( model=model, tag_idx=tag_idx, tag_name=tag, image_size=512, # Always use 512x512 iterations=iterations, scales=scales, progress_callback=progress_callback, layer_emphasis=layer_emphasis, color_boost=color_boost, tv_weight=tv_weight ) # Determine quality level quality_level = get_quality_level(score) # Deduct enkephalin cost st.session_state.enkephalin -= cost st.session_state.game_stats["enkephalin_spent"] = st.session_state.game_stats.get("enkephalin_spent", 0) + cost # Increment essence counter st.session_state.game_stats["essences_generated"] = st.session_state.game_stats.get("essences_generated", 0) + 1 # Save to persistent location filepath = save_essence_to_game_folder(image, tag, score, quality_level) print(f"Saved essence to: {filepath}") # Update UI with final result if not showing comparison tabs if layer_emphasis != "compare": preview_container.image(image, caption=f"Essence of '{tag}' - Quality: {quality_level}", width=512) # Clear progress elements progress_container.empty() message_container.empty() # Store in session state if 'generated_essences' not in st.session_state: st.session_state.generated_essences = {} st.session_state.generated_essences[tag] = { "path": filepath, "score": score, "quality": quality_level, "rarity": rarity, "settings": settings, "generated_time": time.strftime("%Y-%m-%d %H:%M:%S") } # Show success message st.success(f"Successfully generated {quality_level} essence for '{tag}' with score {score:.4f}! Spent {cost} {ENKEPHALIN_ICON}") print(f"=== Essence generation complete for '{tag}' ===\n") # Add at the end of generate_essence_for_tag function, just before returning: tag_storage.save_essence_state(session_state=st.session_state) return image, score, quality_level except Exception as e: st.error(f"Error generating essence: {str(e)}") print(f"EXCEPTION in generate_essence_for_tag: {str(e)}") import traceback err_traceback = traceback.format_exc() print(err_traceback) st.code(err_traceback) return None, 0, None def display_essence_generator(): """ Display the essence generator interface """ # Initialize settings initialize_essence_settings() st.title("🎨 Tag Essence Generator") st.write("Generate visual representations of what the AI model recognizes for specific tags.") # Add detailed explanation of what essences are for with st.expander("What are Tag Essences & How to Use Them", expanded=True): st.markdown(""" ### 💡 Understanding Tag Essences Tag Essences are visual representations of what the AI model recognizes for specific tags. They can be extremely valuable for your tag collection strategy! **How to use Tag Essences:** 1. **Generate a high-quality essence** for a tag you want to collect more of (only available on tags discovered in the library) 2. **Save the essence image** to your computer 3. **Upload the essence image** back into the tagger 4. The tagger will **almost always detect the original tag** 5. It will often also **detect related rare tags** from the same category **Strategic Value:** - Character essences can help unlock other tags associated with that character - Category essences can help discover rare tags within that category - High-quality essences (WAW, ALEPH) have the strongest effect **This is why Enkephalin costs are high** - essences are powerful tools that can help you discover rare tags much more efficiently than random image scanning! """) # Check for model availability model_available = hasattr(st.session_state, 'model') if not model_available: st.warning("Model not available. You can browse your tags but cannot generate essences.") # Create tabs for the different sections tabs = st.tabs(["Generate Essence", "My Essences"]) with tabs[0]: # Check for pending generation from previous interaction if hasattr(st.session_state, 'selected_tag') and st.session_state.selected_tag: tag = st.session_state.selected_tag st.subheader(f"Generating Essence for '{tag}'") # Generate the essence image, score, quality = generate_essence_for_tag( tag, st.session_state.model, st.session_state.model.dataset, st.session_state.essence_custom_settings ) # Show usage tips if successful if image is not None: with st.expander("Essence Usage", expanded=True): st.markdown(""" 💡 **Tag Essence Usage Tips:** 1. Look for similar patterns, colors, and elements in real images 2. The essence reveals what features the AI model recognizes for this tag 3. Use this as inspiration when creating or finding images to get this tag """) else: st.error("Essence generation failed. Please check the error messages above and try again with different settings.") # Clear selected tag st.session_state.selected_tag = None else: # Show the interface to select a tag selected_tag = display_essence_generation_interface(model_available) # If a tag was selected, store it for the next run and rerun if selected_tag: st.session_state.selected_tag = selected_tag st.rerun() with tabs[1]: display_saved_essences() def save_essence_to_game_folder(image, tag, score, quality_level): """ Save the generated essence image to a persistent game folder Args: image: PIL Image of the essence tag: The tag name score: The generation score quality_level: The quality classification (ZAYIN, TETH, etc.) Returns: Path to the saved image """ # Create game folder paths with better structure base_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) game_data_dir = os.path.join(base_dir, "game_data") essence_folder = os.path.join(game_data_dir, "essences") # Make sure all parent directories exist os.makedirs(game_data_dir, exist_ok=True) os.makedirs(essence_folder, exist_ok=True) # Organize essences by quality level for easier browsing quality_folder = os.path.join(essence_folder, quality_level) os.makedirs(quality_folder, exist_ok=True) # Create filename with more details and better organization safe_tag = tag.replace('/', '_').replace('\\', '_').replace(' ', '_') timestamp = time.strftime("%Y%m%d_%H%M%S") filename = f"{safe_tag}_{score:.2f}_{timestamp}.png" filepath = os.path.join(quality_folder, filename) # Save the image image.save(filepath) print(f"Saved essence to: {filepath}") return filepath def essence_folder_path(): """Get the path to the essence folder, creating it if necessary""" base_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) game_data_dir = os.path.join(base_dir, "game_data") essence_folder = os.path.join(game_data_dir, "essences") # Make sure all directories exist os.makedirs(game_data_dir, exist_ok=True) os.makedirs(essence_folder, exist_ok=True) return essence_folder def display_saved_essences(): """Display the user's saved essence images""" st.subheader("My Generated Essences") if not hasattr(st.session_state, 'generated_essences') or not st.session_state.generated_essences: st.info("You haven't generated any essences yet. Go to the Generate tab to create some!") return # Add usage instructions at the top st.markdown(""" ### How to Use Your Essences 1. **Click on any essence image** to open it in full size 2. **Save the image** to your computer (right-click → Save image) 3. **Go to the Scan Images tab** and upload the saved essence 4. The tagger will likely detect the original tag and potentially related rare tags! Higher quality essences (WAW, ALEPH) generally produce the best results. """) # Get the essence folder path essence_dir = essence_folder_path() # Try to locate any missing files for tag, info in st.session_state.generated_essences.items(): if "path" in info and not os.path.exists(info["path"]): # Try to find the file in the essence directory quality = info.get("quality", "ZAYIN") quality_dir = os.path.join(essence_dir, quality) if os.path.exists(quality_dir): # Check for files with this tag name safe_tag = tag.replace('/', '_').replace('\\', '_').replace(' ', '_') matching_files = [f for f in os.listdir(quality_dir) if f.startswith(safe_tag)] if matching_files: # Use the most recent file if there are multiple matching_files.sort(reverse=True) info["path"] = os.path.join(quality_dir, matching_files[0]) print(f"Reconnected essence for {tag} to {info['path']}") # List essences by quality level essences_by_quality = {} for tag, info in st.session_state.generated_essences.items(): quality = info.get("quality", "ZAYIN") # Default to lowest if not set if quality not in essences_by_quality: essences_by_quality[quality] = [] essences_by_quality[quality].append((tag, info)) # Check if any essences exist on disk but are not tracked in session state try: untracked_essences = {} for quality in ESSENCE_QUALITY_LEVELS.keys(): quality_dir = os.path.join(essence_dir, quality) if os.path.exists(quality_dir): essence_files = os.listdir(quality_dir) # Filter to only show PNG files essence_files = [f for f in essence_files if f.lower().endswith('.png')] if essence_files: # Check if any of these files aren't in our tracked essences for filename in essence_files: # Extract tag name from filename parts = filename.split('_') if len(parts) >= 2: tag = parts[0].replace('_', ' ') # Check if file is already tracked is_tracked = False for tracked_tag, tracked_info in st.session_state.generated_essences.items(): if "path" in tracked_info and os.path.basename(tracked_info["path"]) == filename: is_tracked = True break if not is_tracked: if quality not in untracked_essences: untracked_essences[quality] = [] untracked_essences[quality].append((tag, { "path": os.path.join(quality_dir, filename), "quality": quality, "discovered_on_disk": True })) except Exception as e: print(f"Error checking for untracked essences: {e}") # Combine tracked and untracked essences for quality, essences in untracked_essences.items(): if quality not in essences_by_quality: essences_by_quality[quality] = [] for tag, info in essences: # Only add if we don't already have this tag in this quality level if not any(tracked_tag == tag for tracked_tag, _ in essences_by_quality[quality]): essences_by_quality[quality].append((tag, info)) # Show essences from highest to lowest quality for quality in list(ESSENCE_QUALITY_LEVELS.keys())[::-1]: if quality in essences_by_quality: essences = essences_by_quality[quality] color = ESSENCE_QUALITY_LEVELS[quality]["color"] with st.expander(f"{quality} Essences ({len(essences)})", expanded=quality in ["ALEPH", "WAW"]): # Create grid layout cols = st.columns(3) for i, (tag, info) in enumerate(sorted(essences, key=lambda x: x[1].get("score", 0), reverse=True)): col_idx = i % 3 with cols[col_idx]: try: # Try to load the image from path if "path" in info and os.path.exists(info["path"]): image = Image.open(info["path"]) rarity = info.get("rarity", "Canard") score = info.get("score", 0) # Get color for rarity rarity_color = RARITY_LEVELS.get(rarity, {}).get("color", "#AAAAAA") # Display the image with metadata st.image(image, caption=tag, use_container_width=True) # Use special styling for rare tags if rarity == "Impuritas Civitas": st.markdown(f""" {quality} | {rarity} | Score: {score:.2f} """, unsafe_allow_html=True) elif rarity == "Star of the City": st.markdown(f""" {quality} | {rarity} | Score: {score:.2f} """, unsafe_allow_html=True) elif rarity == "Urban Nightmare": st.markdown(f""" {quality} | {rarity} | Score: {score:.2f} """, unsafe_allow_html=True) elif rarity == "Urban Plague": st.markdown(f""" {quality} | {rarity} | Score: {score:.2f} """, unsafe_allow_html=True) else: st.markdown(f""" {quality} | {rarity} | Score: {score:.2f} """, unsafe_allow_html=True) # Add file info if "discovered_on_disk" in info and info["discovered_on_disk"]: st.info("Found on disk (not in session state)") # Add button to open folder if st.button(f"Open Folder", key=f"open_folder_{tag}_{quality}"): folder_path = os.path.dirname(info["path"]) try: # Try different methods to open folder based on platform if os.name == 'nt': # Windows os.startfile(folder_path) elif os.name == 'posix': # macOS or Linux import subprocess if 'darwin' in os.sys.platform: # macOS subprocess.call(['open', folder_path]) else: # Linux subprocess.call(['xdg-open', folder_path]) st.success(f"Opened folder: {folder_path}") except Exception as e: st.error(f"Could not open folder: {str(e)}") # Provide the path for manual navigation st.code(folder_path) else: # Could not find image st.warning(f"Image file not found: {info.get('path', 'No path available')}") # Show quality and tag name st.markdown(f""" {quality} | {tag} """, unsafe_allow_html=True) # Only add reconnect button if we have some metadata if "rarity" in info and "score" in info: if st.button(f"Reconnect File", key=f"reconnect_{tag}_{quality}"): # Update path in session state safe_tag = tag.replace('/', '_').replace('\\', '_').replace(' ', '_') score = info.get("score", 0) quality_dir = os.path.join(essence_dir, quality) # Create directory if it doesn't exist os.makedirs(quality_dir, exist_ok=True) # Set a path - user will need to manually add the image timestamp = time.strftime("%Y%m%d_%H%M%S") filename = f"{safe_tag}_{score:.2f}_{timestamp}.png" info["path"] = os.path.join(quality_dir, filename) st.info(f"Please save your image to this location: {info['path']}") st.session_state.generated_essences[tag] = info tag_storage.save_essence_state(session_state=st.session_state) st.rerun() except Exception as e: st.write(f"Error loading {tag}: {str(e)}") # Add option to clean up missing files st.divider() if st.button("Clean Up Missing Files", help="Remove entries for essences where the file no longer exists"): # Find all entries with missing files to_remove = [] for tag, info in st.session_state.generated_essences.items(): if "path" in info and not os.path.exists(info["path"]): to_remove.append(tag) # Remove them for tag in to_remove: del st.session_state.generated_essences[tag] # Save state tag_storage.save_essence_state(session_state=st.session_state) if to_remove: st.success(f"Removed {len(to_remove)} entries with missing files") else: st.success("No missing files found") st.rerun() def display_essence_generation_interface(model_available): """Display the interface for generating new essences""" # Initialize manual tags initialize_manual_tags() st.subheader("Generate Tag Essence") st.write("Select a tag to generate its essence. Higher quality essences can help unlock rare related tags when uploaded back into the tagger.") # Settings column col1, col2 = st.columns(2) with col1: # Simple settings st.write("Generation Settings:") # Basic settings scales = st.slider("Scales", 1, 5, DEFAULT_ESSENCE_SETTINGS["scales"], help="More scales produce more detailed essences") iterations = st.slider("Iterations", 64, 2048, DEFAULT_ESSENCE_SETTINGS["iterations"], 64, help="More iterations improve quality") # Layer emphasis selection - all options available including comparison layer_emphasis = st.selectbox( "Feature Targeting", options=["auto", "balanced", "high", "mid", "low", "compare", "custom"], index=0, # Default to auto format_func=lambda x: { "auto": "Auto-detect (best for each tag)", "balanced": "Balanced (mix of features)", "high": "High-level (characters, objects)", "mid": "Mid-level (parts, components)", "low": "Low-level (textures, patterns)", "compare": "Compare different approaches", "custom": "Custom layer selection" }.get(x, x), help="Controls which model features to emphasize in the essence" ) # Custom layer selection if needed custom_layers = [] if layer_emphasis == "custom" and model_available: st.write("Select Custom Layers:") # Get key layers (simplified approach) key_layers = get_key_layers(st.session_state.model, max_layers=15) # Show categories (early, middle, late, classifier) for category, layers in key_layers.items(): if layers: category_name = { "early": "Early Layers (textures, colors)", "middle": "Middle Layers (parts, components)", "late": "Late Layers (objects, characters)", "classifier": "Classifier (final recognition)" }.get(category, category.capitalize()) with st.expander(f"{category_name}", expanded=category in ["late", "classifier"]): select_all = st.checkbox(f"Select all {category} layers", key=f"select_all_{category}") for layer in layers: # Create a shortened display name parts = layer.split(".") display_name = f"...{parts[-2]}.{parts[-1]}" if len(parts) > 3 else layer if select_all or st.checkbox(display_name, key=f"layer_{layer}"): custom_layers.append(layer) # Show selected layers if custom_layers: st.success(f"Selected {len(custom_layers)} layers") else: st.warning("Please select at least one layer") # Save settings st.session_state.essence_custom_settings = { "scales": scales, "iterations": iterations, "image_size": 512, # Fixed "lr": 0.03, # Lower learning rate for better results "layer_emphasis": layer_emphasis, "custom_layers": custom_layers } with col2: # Show quality level descriptions st.write("Quality Levels:") for level, info in ESSENCE_QUALITY_LEVELS.items(): st.markdown(f"""
{level} ({info['threshold']:.0f} Score+): {info['description']}
""", unsafe_allow_html=True) # Feature targeting explanation st.write("Feature Targeting Explanation:") st.markdown(""" â„šī¸ **Feature targeting affects what the visualization emphasizes:** """) # Show current Enkephalin st.markdown(f"### Your {ENKEPHALIN_CURRENCY_NAME}: **{st.session_state.enkephalin}** {ENKEPHALIN_ICON}") st.divider() # Add CSS for animations matching tag collection display st.markdown(""" """, unsafe_allow_html=True) # ----- NEW TAG COLLECTION DISPLAY ----- # Gather all tags for essence generation all_tags = [] # Process discovered tags if hasattr(st.session_state, 'discovered_tags'): for tag, info in st.session_state.discovered_tags.items(): tag_info = { "tag": tag, "rarity": info.get("rarity", "Unknown"), "category": info.get("category", "unknown"), "source": "discovered", "library_floor": info.get("library_floor", ""), "discovery_time": info.get("discovery_time", "") } all_tags.append(tag_info) # Process manual tags if hasattr(st.session_state, 'manual_tags'): for tag, info in st.session_state.manual_tags.items(): tag_info = { "tag": tag, "rarity": info.get("rarity", "Special"), "category": info.get("category", "special"), "source": "manual", "description": info.get("description", "") } all_tags.append(tag_info) # Count tags by rarity rarity_counts = {} for info in all_tags: rarity = info["rarity"] if rarity not in rarity_counts: rarity_counts[rarity] = 0 rarity_counts[rarity] += 1 # Display rarity counts at the top st.subheader("Available Tags for Essence Generation") st.write(f"You have {len(all_tags)} tags available for essence generation. Collect more from the library!") # Display rarity distribution rarity_cols = st.columns(len(rarity_counts)) for i, (rarity, count) in enumerate(sorted(rarity_counts.items(), key=lambda x: list(RARITY_LEVELS.keys()).index(x[0]) if x[0] in RARITY_LEVELS else 999)): with rarity_cols[i]: # Get color with fallback color = RARITY_LEVELS.get(rarity, {}).get("color", "#888888") # Apply special styling based on rarity style = f"color:{color};font-weight:bold;" class_name = "" if rarity == "Impuritas Civitas": class_name = "grid-impuritas" elif rarity == "Star of the City": class_name = "grid-star" elif rarity == "Urban Nightmare": class_name = "grid-nightmare" elif rarity == "Urban Plague": class_name = "grid-plague" if class_name: st.markdown( f"
{rarity.capitalize()}
{count}
", unsafe_allow_html=True ) else: st.markdown( f"
{rarity.capitalize()}
{count}
", unsafe_allow_html=True ) # Search box for all tags search_term = st.text_input("Search tags", "", key="essence_search_tags") # Sort options sort_options = ["Category (rarest first)", "Rarity", "Discovery Time"] selected_sort = st.selectbox("Sort tags by:", sort_options, key="essence_tags_sort") # Filter tags by search term if provided if search_term: all_tags = [info for info in all_tags if search_term.lower() in info["tag"].lower()] selected_tag = None # Sort and group tags based on selection if selected_sort == "Category (rarest first)": # Group tags by category categories = {} for info in all_tags: category = info["category"] if category not in categories: categories[category] = [] categories[category].append(info) # Display tags by category in expanders for category, tags in sorted(categories.items()): # Get rarity order for sorting rarity_order = list(reversed(RARITY_LEVELS.keys())) # Sort tags by rarity (rarest first) def get_rarity_index(info): rarity = info["rarity"] if rarity in rarity_order: return len(rarity_order) - rarity_order.index(rarity) return 0 sorted_tags = sorted(tags, key=get_rarity_index, reverse=True) # Check if category has any rare tags has_rare_tags = any(info["rarity"] in ["Impuritas Civitas", "Star of the City"] for info in sorted_tags) # Get category info if available category_display = category.capitalize() if category in TAG_CATEGORIES: category_info = TAG_CATEGORIES[category] icon = category_info.get("icon", "") color = category_info.get("color", "#888888") category_display = f"{icon} {category.capitalize()}" # Create header with information about rare tags if present header = f"{category_display} ({len(tags)} tags)" if has_rare_tags: header += " ✨ Contains rare tags!" # Display category header and expander st.markdown(header, unsafe_allow_html=True) with st.expander("Show/Hide", expanded=has_rare_tags): # Create grid layout for tags cols = st.columns(3) for i, info in enumerate(sorted_tags): with cols[i % 3]: tag = info["tag"] rarity = info["rarity"] source = info["source"] # Get rarity color rarity_color = RARITY_LEVELS.get(rarity, {}).get("color", "#AAAAAA") # Check if this tag has an essence already has_essence = hasattr(st.session_state, 'generated_essences') and tag in st.session_state.generated_essences # Get cost for this tag cost = get_essence_cost(rarity) can_afford = st.session_state.enkephalin >= cost # Format tag display with special styling if rarity == "Impuritas Civitas": tag_display = f'{tag}' elif rarity == "Star of the City": tag_display = f'{tag}' elif rarity == "Urban Nightmare": tag_display = f'{tag}' elif rarity == "Urban Plague": tag_display = f'{tag}' else: tag_display = f'{tag}' # Show tag with rarity badge and cost st.markdown( f'{tag_display} {rarity.capitalize()} ({cost} {ENKEPHALIN_ICON})', unsafe_allow_html=True ) # Show discovery details if available if source == "discovered" and "library_floor" in info and info["library_floor"]: st.markdown(f'Found in: {info["library_floor"]}', unsafe_allow_html=True) elif source == "manual" and "description" in info and info["description"]: st.markdown(f'{info["description"]}', unsafe_allow_html=True) # Add generation button button_label = "Generate" if not has_essence else "Regenerate ✓" if st.button(button_label, key=f"gen_{tag}_{source}", disabled=not model_available or not can_afford): selected_tag = tag elif selected_sort == "Rarity": # Group tags by rarity rarity_groups = {} for info in all_tags: rarity = info["rarity"] if rarity not in rarity_groups: rarity_groups[rarity] = [] rarity_groups[rarity].append(info) # Get ordered rarities (rarest first) ordered_rarities = list(RARITY_LEVELS.keys()) ordered_rarities.reverse() # Reverse to show rarest first # Add any rarities not in RARITY_LEVELS for rarity in rarity_groups.keys(): if rarity not in ordered_rarities: ordered_rarities.append(rarity) # Display tags by rarity for rarity in ordered_rarities: if rarity in rarity_groups: tags = rarity_groups[rarity] color = RARITY_LEVELS.get(rarity, {}).get("color", "#AAAAAA") # Add special styling for rare rarities rarity_html = f"{rarity.capitalize()}" if rarity == "Impuritas Civitas": rarity_html = f"{rarity.capitalize()}" elif rarity == "Star of the City": rarity_html = f"{rarity.capitalize()}" elif rarity == "Urban Nightmare": rarity_html = f"{rarity.capitalize()}" # First create the title with HTML, then use it in the expander st.markdown(f"### {rarity_html} ({len(tags)} tags)", unsafe_allow_html=True) with st.expander("Show/Hide", expanded=rarity in ["Impuritas Civitas", "Star of the City"]): # Create grid layout for tags cols = st.columns(3) for i, info in enumerate(sorted(tags, key=lambda x: x["tag"])): with cols[i % 3]: tag = info["tag"] source = info["source"] # Check if this tag has an essence already has_essence = hasattr(st.session_state, 'generated_essences') and tag in st.session_state.generated_essences # Get cost for this tag cost = get_essence_cost(rarity) can_afford = st.session_state.enkephalin >= cost # Show tag with cost st.markdown(f"**{tag}** ({cost} {ENKEPHALIN_ICON})") # Show discovery details if available if source == "discovered" and "library_floor" in info and info["library_floor"]: st.markdown(f'Found in: {info["library_floor"]}', unsafe_allow_html=True) elif source == "manual" and "description" in info and info["description"]: st.markdown(f'{info["description"]}', unsafe_allow_html=True) # Add generation button button_label = "Generate" if not has_essence else "Regenerate ✓" if st.button(button_label, key=f"gen_{tag}_{source}", disabled=not model_available or not can_afford): selected_tag = tag elif selected_sort == "Discovery Time": # Filter to just discovered tags (manual tags don't have discovery time) discovered_tags = [info for info in all_tags if info["source"] == "discovered" and "discovery_time" in info] # Sort all tags by discovery time (newest first) sorted_tags = sorted(discovered_tags, key=lambda x: x["discovery_time"], reverse=True) # Group by date date_groups = {} for info in sorted_tags: time_str = info["discovery_time"] # Extract just the date part if timestamp has date and time date = time_str.split()[0] if " " in time_str else time_str if date not in date_groups: date_groups[date] = [] date_groups[date].append(info) # Display tags grouped by discovery date for date, tags in date_groups.items(): date_display = date if date else "Unknown date" st.markdown(f"### Discovered on {date_display} ({len(tags)} tags)") with st.expander("Show/Hide", expanded=date == list(date_groups.keys())[0]): # Expand most recent by default # Create grid layout for tags cols = st.columns(3) for i, info in enumerate(tags): with cols[i % 3]: tag = info["tag"] rarity = info["rarity"] # Get rarity color rarity_color = RARITY_LEVELS.get(rarity, {}).get("color", "#AAAAAA") # Check if this tag has an essence already has_essence = hasattr(st.session_state, 'generated_essences') and tag in st.session_state.generated_essences # Get cost for this tag cost = get_essence_cost(rarity) can_afford = st.session_state.enkephalin >= cost # Format tag display with special styling if rarity == "Impuritas Civitas": tag_display = f'{tag}' elif rarity == "Star of the City": tag_display = f'{tag}' elif rarity == "Urban Nightmare": tag_display = f'{tag}' elif rarity == "Urban Plague": tag_display = f'{tag}' else: tag_display = f'{tag}' # Show tag with rarity badge and cost st.markdown( f'{tag_display} {rarity.capitalize()} ({cost} {ENKEPHALIN_ICON})', unsafe_allow_html=True ) # Show discovery details if "library_floor" in info and info["library_floor"]: st.markdown(f'Found in: {info["library_floor"]}', unsafe_allow_html=True) # Add generation button button_label = "Generate" if not has_essence else "Regenerate ✓" if st.button(button_label, key=f"gen_{tag}_disc", disabled=not model_available or not can_afford): selected_tag = tag # Show manual tags separately if we have any manual_tags = [info for info in all_tags if info["source"] == "manual"] if manual_tags: st.markdown("### Manual Tags") with st.expander("Show/Hide"): # Create grid layout for tags cols = st.columns(3) for i, info in enumerate(manual_tags): with cols[i % 3]: tag = info["tag"] rarity = info["rarity"] # Get rarity color rarity_color = RARITY_LEVELS.get(rarity, {}).get("color", "#AAAAAA") # Check if this tag has an essence already has_essence = hasattr(st.session_state, 'generated_essences') and tag in st.session_state.generated_essences # Get cost for this tag cost = get_essence_cost(rarity) can_afford = st.session_state.enkephalin >= cost # Show tag with rarity badge and cost st.markdown(f"**{tag}** ({cost} {ENKEPHALIN_ICON})") # Show description if available if "description" in info and info["description"]: st.markdown(f'{info["description"]}', unsafe_allow_html=True) # Add generation button button_label = "Generate" if not has_essence else "Regenerate ✓" if st.button(button_label, key=f"gen_{tag}_manual", disabled=not model_available or not can_afford): selected_tag = tag return selected_tag def generate_essence_with_emphasis(model, tag_idx, tag_name=None, image_size=512, iterations=256, scales=3, progress_callback=None, layer_emphasis="mid", color_boost=1.5, tv_weight=5e-4): """ Generate an essence visualization with specific layer emphasis and enhancements. Args: model: Neural network model to visualize tag_idx: Index of the tag to visualize tag_name: Optional name of the tag (for logging) image_size: Size of output image (default: 512) iterations: Number of iterations per scale (default: 256) scales: Number of scales to use (default: 5) progress_callback: Optional callback for progress updates layer_emphasis: Type of layers to use ("auto", "balanced", "high", "mid", "low") color_boost: Factor for boosting color saturation (default: 1.5) tv_weight: Total variation weight (default: 5e-4) Returns: PIL Image of the generated essence and the activation score """ # Create a tag-to-name mapping with the provided name tag_to_name = {tag_idx: tag_name} if tag_name else None # Determine layers to use if not auto layers_to_hook = None layer_weights = None if layer_emphasis != "auto": # Get layers based on the emphasis type layers_to_hook = get_suggested_layers(model, layer_emphasis) # Set layer weights based on position layer_weights = {} for i, layer in enumerate(layers_to_hook): # Base weight from position weight = 0.5 + 0.5 * (i / max(1, len(layers_to_hook) - 1)) # Boost weight for classifier layers if any(x in layer.lower() for x in ["classifier", "fc", "linear", "output", "logits"]): weight *= 1.5 layer_weights[layer] = weight print(f"Using {len(layers_to_hook)} {layer_emphasis}-level layers") else: print("Using auto layer detection") # Create instance of the improved generator generator = EssenceGenerator( model=model, tag_to_name=tag_to_name, iterations=iterations, scales=scales, learning_rate=0.05, # Lower for better convergence decay_power=1.0, # Stronger decay power for cleaner images tv_weight=tv_weight, # Customizable TV weight layers_to_hook=layers_to_hook, layer_weights=layer_weights, color_boost=color_boost # Customizable color boost ) # Generate the essence print(f"Generating essence for tag {tag_name or tag_idx} with {layer_emphasis} emphasis...") image, score = generator.generate_essence( tag_idx=tag_idx, image_size=image_size, return_score=True, progress_callback=progress_callback ) print(f"Essence generation complete. Score: {score:.4f}") return image, score def try_different_layer_emphasis(model, tag_idx, tag_name=None, image_size=512, iterations=256, scales=4, progress_callback=None): """ Generate multiple essences with different layer emphasis types and return them all. Args: model: Neural network model to visualize tag_idx: Index of the tag to visualize tag_name: Optional name of the tag (for logging) image_size: Size of output image (default: 512) iterations: Number of iterations per scale (default: 256) scales: Number of scales to use (default: 4) progress_callback: Optional callback for progress updates Returns: Dictionary of PIL Images and scores for each layer emphasis type """ emphasis_types = [ {"name": "low", "color_boost": 1.3, "tv_weight": 2e-4}, # Low-level features (textures, colors) {"name": "mid", "color_boost": 1.5, "tv_weight": 5e-4}, # Mid-level features (parts, components) {"name": "high", "color_boost": 1.7, "tv_weight": 8e-4}, # High-level features (characters, objects) ] results = {} for emphasis in emphasis_types: print(f"\n=== Trying {emphasis['name']} layer emphasis ===") image, score = generate_essence_with_emphasis( model=model, tag_idx=tag_idx, tag_name=tag_name, image_size=image_size, iterations=iterations, scales=scales, progress_callback=progress_callback, layer_emphasis=emphasis["name"], color_boost=emphasis["color_boost"], tv_weight=emphasis["tv_weight"] ) results[emphasis["name"]] = { "image": image, "score": score } print(f"=== Completed {emphasis['name']} layer emphasis with score {score:.4f} ===") return results