|
|
|
"""
|
|
Enhanced Essence Generator for Tag Collector Game
|
|
"""
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from torchvision.transforms.functional import to_pil_image
|
|
from PIL import Image
|
|
import numpy as np
|
|
import os
|
|
import re
|
|
import math
|
|
import json
|
|
import streamlit as st
|
|
from tqdm import tqdm
|
|
from scipy.ndimage import gaussian_filter
|
|
from functools import wraps
|
|
import time
|
|
import tag_storage
|
|
|
|
from game_constants import RARITY_LEVELS, ENKEPHALIN_CURRENCY_NAME, ENKEPHALIN_ICON
|
|
from tag_categories import TAG_CATEGORIES
|
|
|
|
|
|
ESSENCE_QUALITY_LEVELS = {
|
|
"ZAYIN": {"threshold": 0.0, "color": "#1CFC00", "description": "Basic representation with minimal details."},
|
|
"TETH": {"threshold": 3.0, "color": "#389DDF", "description": "Clear representation with recognizable features."},
|
|
"HE": {"threshold": 5.0, "color": "#FEF900", "description": "Refined representation with distinctive elements."},
|
|
"WAW": {"threshold": 10.0, "color": "#7930F1", "description": "Advanced representation with precise details."},
|
|
"ALEPH": {"threshold": 12.0, "color": "#FF0000", "description": "Perfect representation with extraordinary precision."}
|
|
}
|
|
|
|
|
|
ESSENCE_COSTS = {
|
|
"Special": 0,
|
|
"Canard": 100,
|
|
"Urban Myth": 125,
|
|
"Urban Legend": 150,
|
|
"Urban Plague": 200,
|
|
"Urban Nightmare": 250,
|
|
"Star of the City": 300,
|
|
"Impuritas Civitas": 400
|
|
}
|
|
|
|
|
|
DEFAULT_ESSENCE_SETTINGS = {
|
|
"scales": 1,
|
|
"iterations": 256,
|
|
"image_size": 512,
|
|
"lr": 0.1,
|
|
"layer_emphasis": "auto"
|
|
}
|
|
|
|
def initialize_essence_settings():
|
|
"""Initialize essence generator settings if not already present"""
|
|
if 'essence_custom_settings' not in st.session_state:
|
|
|
|
loaded_state = tag_storage.load_essence_state()
|
|
|
|
if loaded_state and 'essence_custom_settings' in loaded_state:
|
|
st.session_state.essence_custom_settings = loaded_state['essence_custom_settings']
|
|
else:
|
|
st.session_state.essence_custom_settings = DEFAULT_ESSENCE_SETTINGS.copy()
|
|
|
|
|
|
def initialize_manual_tags():
|
|
"""Initialize manual tags if not already present"""
|
|
if 'manual_tags' not in st.session_state:
|
|
|
|
loaded_state = tag_storage.load_essence_state()
|
|
|
|
if loaded_state and 'manual_tags' in loaded_state:
|
|
st.session_state.manual_tags = loaded_state['manual_tags']
|
|
else:
|
|
st.session_state.manual_tags = {
|
|
"hatsune_miku": {"rarity": "Special", "description": "Popular virtual singer with long teal twin-tails"},
|
|
}
|
|
|
|
|
|
def timeout(seconds, fallback_value=None):
|
|
"""
|
|
Simple timeout utility for functions.
|
|
Warns if a function takes longer than expected but doesn't interrupt it.
|
|
|
|
Args:
|
|
seconds: Expected maximum seconds the function should take
|
|
fallback_value: Not used, just for API compatibility
|
|
"""
|
|
def decorator(func):
|
|
@wraps(func)
|
|
def wrapper(*args, **kwargs):
|
|
start_time = time.time()
|
|
result = func(*args, **kwargs)
|
|
elapsed = time.time() - start_time
|
|
|
|
if elapsed > seconds:
|
|
print(f"WARNING: Function {func.__name__} took {elapsed:.2f} seconds (expected max {seconds}s)")
|
|
|
|
return result
|
|
return wrapper
|
|
return decorator
|
|
|
|
|
|
|
|
class LayerHook:
|
|
"""Helper class to store the outputs of a layer via forward hook."""
|
|
def __init__(self, layer):
|
|
self.layer = layer
|
|
self.features = None
|
|
self.hook = layer.register_forward_hook(self.hook_fn)
|
|
|
|
def hook_fn(self, module, input, output):
|
|
self.features = output
|
|
|
|
def close(self):
|
|
self.hook.remove()
|
|
|
|
class FullModelHook:
|
|
"""Hook all layers in a model and track their responses to inputs."""
|
|
|
|
def __init__(self, model):
|
|
self.model = model
|
|
self.hooks = {}
|
|
self.activations = {}
|
|
self.layer_scores = {}
|
|
|
|
|
|
self._register_hooks(model)
|
|
print(f"FullModelHook initialized with {len(self.hooks)} hooks")
|
|
|
|
def _register_hooks(self, module, prefix=''):
|
|
"""Recursively register hooks on all suitable layers."""
|
|
for name, child in module.named_children():
|
|
layer_name = f"{prefix}.{name}" if prefix else name
|
|
|
|
|
|
|
|
if isinstance(child, (torch.nn.Conv2d, torch.nn.Linear, torch.nn.BatchNorm2d, torch.nn.LayerNorm)):
|
|
self.hooks[layer_name] = child.register_forward_hook(
|
|
lambda m, inp, out, layer=layer_name: self._hook_fn(layer, out)
|
|
)
|
|
|
|
|
|
self._register_hooks(child, layer_name)
|
|
|
|
def _hook_fn(self, layer_name, output):
|
|
"""Store activations for each layer."""
|
|
|
|
if len(output.shape) == 4:
|
|
|
|
self.activations[layer_name] = output.mean(dim=[2, 3]).detach()
|
|
else:
|
|
|
|
self.activations[layer_name] = output.detach()
|
|
|
|
class EssenceGenerator:
|
|
"""
|
|
Enhanced Essence Generator optimized for anime characters.
|
|
Includes improvements for more vibrant colors and recognizable features.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
model,
|
|
tag_to_name=None,
|
|
iterations=256,
|
|
scales=3,
|
|
learning_rate=0.03,
|
|
decay_power=1.5,
|
|
tv_weight=5e-4,
|
|
layers_to_hook=None,
|
|
layer_weights=None,
|
|
color_boost=1.5
|
|
):
|
|
"""Initialize the Enhanced Essence Generator"""
|
|
self.model = model
|
|
self.tag_to_name = tag_to_name
|
|
self.iterations = iterations
|
|
self.scales = scales
|
|
self.lr = learning_rate
|
|
self.decay_power = decay_power
|
|
self.tv_weight = tv_weight
|
|
self.layers_to_hook = layers_to_hook
|
|
self.layer_weights = layer_weights
|
|
self.color_boost = color_boost
|
|
|
|
|
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
self.model.eval().to(self.device)
|
|
|
|
|
|
self.hooks = {}
|
|
|
|
|
|
|
|
self.color_correlation_matrix = torch.tensor([
|
|
[1.0000, 0.9522, 0.9156],
|
|
[0.9522, 1.0000, 0.9708],
|
|
[0.9156, 0.9708, 1.0000]], device=self.device)
|
|
|
|
|
|
if self.layers_to_hook:
|
|
self.setup_hooks(self.layers_to_hook)
|
|
|
|
def setup_hooks(self, layers_to_hook):
|
|
"""Setup hooks on the specified layers."""
|
|
|
|
self.close_hooks()
|
|
|
|
|
|
for layer_name in layers_to_hook:
|
|
try:
|
|
|
|
parts = layer_name.split('.')
|
|
layer = self.model
|
|
for part in parts:
|
|
layer = getattr(layer, part)
|
|
|
|
self.hooks[layer_name] = LayerHook(layer)
|
|
print(f"Setup hook for layer: {layer_name}")
|
|
except Exception as e:
|
|
print(f"Failed to setup hook for {layer_name}: {e}")
|
|
|
|
def setup_auto_hooks(self, tag_idx):
|
|
"""
|
|
Automatically detect the most responsive layers for a specific tag.
|
|
This simplified version selects a few key layers based on model architecture.
|
|
"""
|
|
|
|
self.close_hooks()
|
|
|
|
|
|
if self.layers_to_hook and self.layer_weights:
|
|
for layer_name in self.layers_to_hook:
|
|
try:
|
|
|
|
parts = layer_name.split('.')
|
|
layer = self.model
|
|
for part in parts:
|
|
layer = getattr(layer, part)
|
|
|
|
self.hooks[layer_name] = LayerHook(layer)
|
|
print(f"Setup hook for layer: {layer_name}")
|
|
except Exception as e:
|
|
print(f"Failed to setup hook for {layer_name}: {e}")
|
|
|
|
return self.layer_weights
|
|
|
|
|
|
|
|
all_layers = []
|
|
for name, module in self.model.named_modules():
|
|
if not name:
|
|
continue
|
|
|
|
|
|
if isinstance(module, (nn.Conv2d, nn.Linear, nn.BatchNorm2d)):
|
|
all_layers.append((name, module))
|
|
|
|
|
|
selected_layers = []
|
|
layer_weights = {}
|
|
|
|
if len(all_layers) > 30:
|
|
|
|
|
|
classifier_layers = [(name, module) for name, module in all_layers
|
|
if any(x in name.lower() for x in ["classifier", "fc", "linear", "output", "logits"])]
|
|
if classifier_layers:
|
|
selected_layers.append(classifier_layers[-1])
|
|
layer_weights[classifier_layers[-1][0]] = 1.0
|
|
|
|
|
|
conv_layers = [(name, module) for name, module in all_layers if isinstance(module, nn.Conv2d)]
|
|
if conv_layers:
|
|
|
|
half_idx = len(conv_layers) // 2
|
|
selected_idx = [half_idx, 3*len(conv_layers)//4, -1]
|
|
for idx in selected_idx:
|
|
if idx < len(conv_layers) and conv_layers[idx] not in selected_layers:
|
|
selected_layers.append(conv_layers[idx])
|
|
|
|
pos = selected_idx.index(idx)
|
|
layer_weights[conv_layers[idx][0]] = 0.5 + 0.5 * (pos / max(1, len(selected_idx) - 1))
|
|
else:
|
|
|
|
|
|
step = max(1, len(all_layers) // 5)
|
|
indices = list(range(0, len(all_layers), step))
|
|
if len(all_layers) - 1 not in indices:
|
|
indices.append(len(all_layers) - 1)
|
|
|
|
for idx in indices:
|
|
selected_layers.append(all_layers[idx])
|
|
|
|
layer_weights[all_layers[idx][0]] = 0.5 + 0.5 * (idx / max(1, len(all_layers) - 1))
|
|
|
|
|
|
print(f"Setting up {len(selected_layers)} auto-detected layers:")
|
|
for name, module in selected_layers:
|
|
self.hooks[name] = LayerHook(module)
|
|
print(f" - {name} (weight: {layer_weights.get(name, 0.5):.2f})")
|
|
|
|
return layer_weights
|
|
|
|
def close_hooks(self):
|
|
"""Clean up hooks to avoid memory leaks."""
|
|
for hook in self.hooks.values():
|
|
hook.close()
|
|
self.hooks.clear()
|
|
|
|
def total_variation_loss(self, img):
|
|
"""
|
|
Total variation loss for smoother images but preserving edges.
|
|
Modified version that better preserves strong edges.
|
|
"""
|
|
diff_y = torch.abs(img[:, :, 1:, :] - img[:, :, :-1, :])
|
|
diff_x = torch.abs(img[:, :, :, 1:] - img[:, :, :, :-1])
|
|
|
|
|
|
|
|
tv = torch.mean(torch.sqrt(diff_y + 1e-8)) + torch.mean(torch.sqrt(diff_x + 1e-8))
|
|
return tv
|
|
|
|
def create_fft_spectrum_initializer(self, size, batch_size=1):
|
|
"""Enhanced frequency domain initialization for better essence generation with more color and coverage"""
|
|
fft_size = size // 2 + 1
|
|
|
|
|
|
|
|
spectrum_scale = torch.zeros(batch_size, 3, size, fft_size, 2, device=self.device)
|
|
|
|
|
|
for h in range(size):
|
|
for w in range(fft_size):
|
|
|
|
dist = np.sqrt((h/size)**2 + (w/fft_size)**2) + 1e-5
|
|
|
|
weight = 1.0 / dist
|
|
|
|
spectrum_scale[:, :, h, w, 0] = torch.randn(batch_size, 3, device=self.device) * weight * 0.15
|
|
spectrum_scale[:, :, h, w, 1] = torch.randn(batch_size, 3, device=self.device) * weight * 0.15
|
|
|
|
|
|
|
|
spectrum_scale[:, 0, 0, 0, 0] = 0.5
|
|
spectrum_scale[:, 1, 0, 0, 0] = 0.4
|
|
spectrum_scale[:, 2, 0, 0, 0] = 0.6
|
|
spectrum_scale[:, :, 0, 0, 1] = 0
|
|
|
|
spectrum_scale.requires_grad_(True)
|
|
|
|
|
|
spectrum_shift = torch.randn(batch_size, 3, size, fft_size, 2, device=self.device) * 0.05
|
|
spectrum_shift.requires_grad_(True)
|
|
|
|
return spectrum_scale, spectrum_shift
|
|
|
|
def create_spectrum_weights(self, size, decay_power=1.0):
|
|
"""Create weights for the spectrum that emphasize lower frequencies but preserve more details."""
|
|
freqs_x = torch.fft.rfftfreq(size).view(1, -1).to(self.device)
|
|
freqs_y = torch.fft.fftfreq(size).view(-1, 1).to(self.device)
|
|
|
|
dist_from_center = torch.sqrt(freqs_x**2 + freqs_y**2)
|
|
|
|
|
|
|
|
weights = 1.0 / (dist_from_center + 1e-8) ** decay_power
|
|
|
|
|
|
mid_freq_mask = (dist_from_center > 0.05) & (dist_from_center < 0.3)
|
|
weights = weights * (1.0 + 1.0 * mid_freq_mask.float())
|
|
|
|
|
|
high_freq_mask = (dist_from_center >= 0.3) & (dist_from_center < 0.7)
|
|
weights = weights * (1.0 + 0.3 * high_freq_mask.float())
|
|
|
|
weights = weights / weights.max()
|
|
weights[0, 0] = 0.8
|
|
|
|
return weights
|
|
|
|
def fft_to_rgb(self, spectrum_scale, spectrum_shift, size, spectrum_weight=None):
|
|
"""Convert FFT spectrum parameters to an RGB image."""
|
|
batch_size = spectrum_scale.shape[0]
|
|
|
|
if spectrum_weight is not None:
|
|
spectrum_scale = spectrum_scale * spectrum_weight.unsqueeze(0).unsqueeze(0).unsqueeze(-1)
|
|
|
|
image = torch.zeros(batch_size, 3, size, size, device=self.device)
|
|
|
|
spectrum_complex = torch.complex(
|
|
spectrum_scale[..., 0],
|
|
spectrum_scale[..., 1]
|
|
)
|
|
|
|
phase_shift = torch.complex(
|
|
torch.cos(spectrum_shift[..., 0]),
|
|
torch.sin(spectrum_shift[..., 1])
|
|
)
|
|
spectrum_complex = spectrum_complex * phase_shift
|
|
|
|
for b in range(batch_size):
|
|
for c in range(3):
|
|
channel_spectrum = spectrum_complex[b, c]
|
|
channel_image = torch.fft.irfft2(channel_spectrum, s=(size, size))
|
|
|
|
channel_min = channel_image.min()
|
|
channel_max = channel_image.max()
|
|
if channel_max > channel_min:
|
|
channel_image = (channel_image - channel_min) / (channel_max - channel_min)
|
|
else:
|
|
channel_image = torch.zeros_like(channel_image)
|
|
|
|
image[b, c] = channel_image
|
|
|
|
return image
|
|
|
|
def apply_color_correlation(self, image):
|
|
"""
|
|
Apply color correlation to produce more vibrant, colorful images.
|
|
"""
|
|
batch_size, _, height, width = image.shape
|
|
|
|
|
|
flat_image = image.view(batch_size, 3, -1)
|
|
correlated = torch.matmul(self.color_correlation_matrix, flat_image)
|
|
correlated_image = correlated.view(batch_size, 3, height, width)
|
|
|
|
|
|
|
|
luminance = 0.3 * image[:, 0:1] + 0.59 * image[:, 1:2] + 0.11 * image[:, 2:3]
|
|
|
|
|
|
|
|
boosted_image = luminance + (correlated_image - luminance) * self.color_boost * 1.5
|
|
|
|
|
|
|
|
boosted_image = 0.5 + torch.tanh((boosted_image - 0.5) * 2) * 0.5
|
|
|
|
|
|
boosted_image = torch.clamp(boosted_image, 0, 1)
|
|
|
|
return boosted_image
|
|
|
|
def apply_transforms(self, img):
|
|
"""Apply random transformations to the image for robustness."""
|
|
batch_size, c, h, w = img.shape
|
|
|
|
|
|
pad = 16
|
|
padded = F.pad(img, (pad, pad, pad, pad), mode='reflect')
|
|
|
|
|
|
jitter = 16
|
|
h_jitter = torch.randint(-jitter, jitter + 1, (batch_size,), device=self.device)
|
|
w_jitter = torch.randint(-jitter, jitter + 1, (batch_size,), device=self.device)
|
|
|
|
|
|
rows = torch.arange(h, device=self.device).view(1, 1, -1, 1).repeat(batch_size, 1, 1, w)
|
|
cols = torch.arange(w, device=self.device).view(1, 1, 1, -1).repeat(batch_size, 1, h, 1)
|
|
|
|
rows = rows + h_jitter.view(-1, 1, 1, 1) + pad
|
|
cols = cols + w_jitter.view(-1, 1, 1, 1) + pad
|
|
|
|
|
|
grid_h = torch.clamp(rows, 0, padded.shape[2] - 1).long()
|
|
grid_w = torch.clamp(cols, 0, padded.shape[3] - 1).long()
|
|
|
|
|
|
jitter2 = 8
|
|
h_jitter2 = torch.randint(-jitter2, jitter2 + 1, (batch_size,), device=self.device)
|
|
w_jitter2 = torch.randint(-jitter2, jitter2 + 1, (batch_size,), device=self.device)
|
|
|
|
grid_h = torch.clamp(grid_h + h_jitter2.view(-1, 1, 1, 1), 0, padded.shape[2] - 1).long()
|
|
grid_w = torch.clamp(grid_w + w_jitter2.view(-1, 1, 1, 1), 0, padded.shape[3] - 1).long()
|
|
|
|
|
|
transformed = torch.zeros_like(img)
|
|
for b in range(batch_size):
|
|
transformed[b] = padded[b, :, grid_h[b, 0], grid_w[b, 0]]
|
|
|
|
return transformed
|
|
|
|
def add_spatial_prior(self, img, strength=0.05):
|
|
"""
|
|
Add a spatial prior to encourage character-like structures with better composition
|
|
and fuller image coverage.
|
|
"""
|
|
batch_size, c, h, w = img.shape
|
|
|
|
|
|
y_indices = torch.arange(h, device=self.device).float()
|
|
x_indices = torch.arange(w, device=self.device).float()
|
|
|
|
y = (2.0 * y_indices / h) - 1.0
|
|
x = (2.0 * x_indices / w) - 1.0
|
|
|
|
|
|
y_grid = y.view(-1, 1).repeat(1, w)
|
|
x_grid = x.view(1, -1).repeat(h, 1)
|
|
|
|
|
|
center_dist = torch.sqrt(x_grid.pow(2) + y_grid.pow(2))
|
|
|
|
center_value = torch.exp(-0.8 * center_dist)
|
|
|
|
|
|
edge_dist_x = torch.min(torch.abs(x_grid - 1.0), torch.abs(x_grid + 1.0))
|
|
edge_dist_y = torch.min(torch.abs(y_grid - 1.0), torch.abs(y_grid + 1.0))
|
|
edge_dist = torch.min(edge_dist_x, edge_dist_y)
|
|
edge_value = torch.clamp(edge_dist * 5.0, 0.2, 1.0)
|
|
|
|
|
|
thirds_x = torch.exp(-10 * (x_grid - 1/3).pow(2)) + torch.exp(-10 * (x_grid + 1/3).pow(2))
|
|
thirds_y = torch.exp(-10 * (y_grid - 1/3).pow(2)) + torch.exp(-10 * (y_grid + 1/3).pow(2))
|
|
thirds_value = (thirds_x + thirds_y) / 2
|
|
|
|
|
|
prior = 0.4 * center_value + 0.4 * edge_value + 0.2 * thirds_value
|
|
|
|
|
|
prior = prior / prior.max()
|
|
|
|
|
|
prior = prior.unsqueeze(0).unsqueeze(0)
|
|
prior = prior.repeat(batch_size, c, 1, 1)
|
|
|
|
|
|
result = img * (1.0 - strength*1.5) + prior * strength*1.5
|
|
|
|
return result
|
|
|
|
def get_layer_activations(self, tag_idx, layer_weights):
|
|
"""
|
|
Get activations from all hooked layers for the target tag.
|
|
Returns a weighted sum of activations based on layer weights.
|
|
"""
|
|
activation_sum = 0.0
|
|
|
|
for layer_name, hook in self.hooks.items():
|
|
if hook.features is None:
|
|
continue
|
|
|
|
|
|
weight = layer_weights.get(layer_name, 0.5)
|
|
|
|
|
|
if len(hook.features.shape) <= 2:
|
|
|
|
if hook.features.size(1) > tag_idx:
|
|
activation = hook.features[0, tag_idx].item()
|
|
activation_sum += weight * activation
|
|
else:
|
|
|
|
channel_means = hook.features.mean(dim=[2, 3])
|
|
|
|
|
|
num_channels = min(5, channel_means.size(1))
|
|
_, top_indices = torch.topk(channel_means, num_channels)
|
|
|
|
|
|
for idx in range(min(3, len(top_indices[0]))):
|
|
channel_idx = top_indices[0, idx]
|
|
channel_activation = hook.features[:, channel_idx].mean().item()
|
|
|
|
channel_weight = 1.0 if idx == 0 else (0.5 if idx == 1 else 0.25)
|
|
activation_sum += weight * channel_activation * channel_weight
|
|
|
|
return activation_sum
|
|
|
|
def generate_essence(self, tag_idx, image_size=512, return_score=True, progress_callback=None):
|
|
"""Generate an essence visualization using enhanced techniques."""
|
|
|
|
tag_name = self.tag_to_name.get(tag_idx, f"Tag {tag_idx}") if self.tag_to_name else f"Tag {tag_idx}"
|
|
print(f"Generating enhanced essence for '{tag_name}'...")
|
|
|
|
|
|
layer_weights = self.layer_weights or self.setup_auto_hooks(tag_idx)
|
|
|
|
|
|
scale_sizes = []
|
|
for s in range(self.scales):
|
|
|
|
scale_size = max(32, image_size // (2 ** (self.scales - s - 1)))
|
|
scale_sizes.append(scale_size)
|
|
|
|
print(f"Processing scales: {scale_sizes}")
|
|
|
|
|
|
spectrum_weights = {}
|
|
for size in scale_sizes:
|
|
spectrum_weights[size] = self.create_spectrum_weights(size, decay_power=self.decay_power)
|
|
|
|
|
|
best_score = -float('inf')
|
|
best_img = None
|
|
|
|
|
|
for scale_idx, size in enumerate(scale_sizes):
|
|
|
|
spectrum_scale, spectrum_shift = self.create_fft_spectrum_initializer(size)
|
|
|
|
|
|
optimizer = torch.optim.Adam([spectrum_scale, spectrum_shift], lr=self.lr)
|
|
|
|
|
|
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
|
|
optimizer,
|
|
T_max=self.iterations,
|
|
eta_min=self.lr * 0.1
|
|
)
|
|
|
|
|
|
current_weights = spectrum_weights[size]
|
|
|
|
|
|
iterations = self.iterations
|
|
|
|
|
|
no_improvement_streak = 0
|
|
plateau_threshold = 128
|
|
scale_best_score = -float('inf')
|
|
scale_best_img = None
|
|
|
|
for i in range(iterations):
|
|
|
|
optimizer.zero_grad()
|
|
|
|
|
|
img = self.fft_to_rgb(spectrum_scale, spectrum_shift, size, current_weights)
|
|
|
|
|
|
|
|
|
|
|
|
if size >= 32:
|
|
|
|
img = self.add_spatial_prior(img, strength=0.25)
|
|
|
|
|
|
if size >= 32:
|
|
img = self.apply_transforms(img)
|
|
|
|
|
|
for hook in self.hooks.values():
|
|
hook.features = None
|
|
|
|
|
|
outputs = self.model(img)
|
|
|
|
|
|
if isinstance(outputs, (list, tuple)):
|
|
predictions = outputs[0]
|
|
else:
|
|
predictions = outputs
|
|
|
|
|
|
tag_activation = predictions[0, tag_idx]
|
|
|
|
|
|
layer_activation = self.get_layer_activations(tag_idx, layer_weights)
|
|
|
|
|
|
|
|
activation_loss = -(tag_activation + 1.5 + layer_activation * 2.0)
|
|
|
|
|
|
|
|
tv_loss = self.total_variation_loss(img) * (self.tv_weight * 0.7)
|
|
|
|
|
|
total_loss = activation_loss + tv_loss
|
|
|
|
|
|
total_loss.backward()
|
|
|
|
|
|
optimizer.step()
|
|
scheduler.step()
|
|
|
|
|
|
current_score = tag_activation.item()
|
|
if current_score > scale_best_score + 1e-4:
|
|
scale_best_score = current_score
|
|
scale_best_img = img.detach().clone()
|
|
no_improvement_streak = 0
|
|
else:
|
|
no_improvement_streak += 1
|
|
|
|
|
|
if no_improvement_streak >= plateau_threshold:
|
|
print(f"Early stopping at iteration {i}/{iterations} due to plateau")
|
|
break
|
|
|
|
|
|
if progress_callback and i % max(1, iterations // 10) == 0:
|
|
progress_callback(
|
|
scale_idx=scale_idx,
|
|
scale_count=len(scale_sizes),
|
|
iter_idx=i,
|
|
iter_count=iterations,
|
|
score=current_score
|
|
)
|
|
|
|
print(f"Scale {scale_idx+1}/{len(scale_sizes)} completed. Score: {scale_best_score:.4f}")
|
|
|
|
|
|
if scale_best_score > best_score:
|
|
best_score = scale_best_score
|
|
|
|
if scale_idx == len(scale_sizes) - 1:
|
|
best_img = scale_best_img
|
|
|
|
else:
|
|
with torch.no_grad():
|
|
best_img = F.interpolate(scale_best_img, size=(image_size, image_size),
|
|
mode='bilinear', align_corners=False)
|
|
|
|
|
|
if best_img is None:
|
|
final_img = torch.zeros((1, 3, image_size, image_size), device=self.device)
|
|
else:
|
|
final_img = best_img
|
|
|
|
|
|
pil_img = to_pil_image(final_img[0].cpu())
|
|
|
|
|
|
self.close_hooks()
|
|
|
|
if return_score:
|
|
return pil_img, best_score
|
|
else:
|
|
return pil_img
|
|
|
|
|
|
|
|
def get_model_layers(model):
|
|
"""Utility function to get all available layers in a model."""
|
|
layers = []
|
|
for name, _ in model.named_modules():
|
|
if name:
|
|
layers.append(name)
|
|
return layers
|
|
|
|
def get_key_layers(model, max_layers=15):
|
|
"""
|
|
Get a curated list of the most relevant layers for visualization.
|
|
"""
|
|
all_layers = get_model_layers(model)
|
|
|
|
|
|
if len(all_layers) > 30:
|
|
|
|
block_patterns = {}
|
|
|
|
|
|
for layer in all_layers:
|
|
|
|
parts = layer.split(".")
|
|
if len(parts) >= 2:
|
|
prefix = ".".join(parts[:2])
|
|
if prefix not in block_patterns:
|
|
block_patterns[prefix] = []
|
|
block_patterns[prefix].append(layer)
|
|
|
|
|
|
key_layers = {
|
|
"early": [],
|
|
"middle": [],
|
|
"late": []
|
|
}
|
|
|
|
|
|
for prefix, layers in block_patterns.items():
|
|
if len(layers) > 3:
|
|
|
|
layers.sort(key=lambda x: [int(s) if s.isdigit() else s for s in re.findall(r'\d+|\D+', x)])
|
|
|
|
|
|
early = layers[0]
|
|
middle = layers[len(layers) // 2]
|
|
late = layers[-1]
|
|
|
|
key_layers["early"].append(early)
|
|
key_layers["middle"].append(middle)
|
|
key_layers["late"].append(late)
|
|
|
|
|
|
|
|
flattened = []
|
|
for _, group_layers in key_layers.items():
|
|
flattened.extend(group_layers)
|
|
|
|
if len(flattened) > max_layers:
|
|
|
|
total = len(flattened)
|
|
|
|
late_count = min(len(key_layers["late"]), max_layers // 3)
|
|
|
|
remaining = max_layers - late_count
|
|
middle_count = min(len(key_layers["middle"]), remaining // 2)
|
|
early_count = min(len(key_layers["early"]), remaining - middle_count)
|
|
|
|
|
|
key_layers["early"] = key_layers["early"][:early_count]
|
|
key_layers["middle"] = key_layers["middle"][:middle_count]
|
|
key_layers["late"] = key_layers["late"][:late_count]
|
|
else:
|
|
|
|
n = len(all_layers)
|
|
key_layers = {
|
|
"early": all_layers[:n//3][:3],
|
|
"middle": all_layers[n//3:2*n//3][:4],
|
|
"late": all_layers[2*n//3:][:3]
|
|
}
|
|
|
|
|
|
classifier_layers = [layer for layer in all_layers if any(x in layer.lower()
|
|
for x in ["classifier", "fc", "linear", "output", "logits", "head"])]
|
|
if classifier_layers:
|
|
key_layers["classifier"] = [classifier_layers[-1]]
|
|
|
|
return key_layers
|
|
|
|
def get_suggested_layers(model, layer_type="balanced"):
|
|
"""
|
|
Get suggested layers based on the desired feature type.
|
|
"""
|
|
key_layers = get_key_layers(model)
|
|
|
|
|
|
all_key_layers = []
|
|
for layers in key_layers.values():
|
|
all_key_layers.extend(layers)
|
|
|
|
|
|
if layer_type == "low":
|
|
|
|
selected = key_layers.get("early", [])
|
|
|
|
if "middle" in key_layers and key_layers["middle"]:
|
|
selected.append(key_layers["middle"][0])
|
|
|
|
elif layer_type == "mid":
|
|
|
|
selected = key_layers.get("middle", [])
|
|
|
|
if "early" in key_layers and key_layers["early"]:
|
|
selected.append(key_layers["early"][-1])
|
|
|
|
elif layer_type == "high":
|
|
|
|
selected = key_layers.get("late", [])
|
|
selected.extend(key_layers.get("classifier", []))
|
|
|
|
if "middle" in key_layers and key_layers["middle"]:
|
|
selected.append(key_layers["middle"][-1])
|
|
|
|
else:
|
|
|
|
selected = []
|
|
for category in ["early", "middle", "late", "classifier"]:
|
|
if category in key_layers and key_layers[category]:
|
|
|
|
selected.append(key_layers[category][0])
|
|
|
|
if category in ["middle", "late"] and len(key_layers[category]) > 1:
|
|
selected.append(key_layers[category][-1])
|
|
|
|
|
|
if not selected and all_key_layers:
|
|
selected = [all_key_layers[-1]]
|
|
|
|
return selected
|
|
|
|
def get_quality_level(score):
|
|
"""
|
|
Determine the quality level of an essence based on its score
|
|
"""
|
|
for level in reversed(list(ESSENCE_QUALITY_LEVELS.keys())):
|
|
if score >= ESSENCE_QUALITY_LEVELS[level]["threshold"]:
|
|
return level
|
|
return "ZAYIN"
|
|
|
|
def get_essence_cost(rarity):
|
|
"""
|
|
Calculate the cost to generate an essence image based on tag rarity
|
|
"""
|
|
return ESSENCE_COSTS.get(rarity, 100)
|
|
|
|
|
|
|
|
def initialize_essence_settings():
|
|
"""Initialize essence generator settings if not already present"""
|
|
if 'essence_custom_settings' not in st.session_state:
|
|
st.session_state.essence_custom_settings = DEFAULT_ESSENCE_SETTINGS.copy()
|
|
|
|
def save_essence_to_game_folder(image, tag, score, quality_level):
|
|
"""
|
|
Save the generated essence image to a persistent game folder
|
|
|
|
Args:
|
|
image: PIL Image of the essence
|
|
tag: The tag name
|
|
score: The generation score
|
|
quality_level: The quality classification (ZAYIN, TETH, etc.)
|
|
|
|
Returns:
|
|
Path to the saved image
|
|
"""
|
|
|
|
game_folder = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "game_data")
|
|
essence_folder = os.path.join(game_folder, "essences")
|
|
|
|
|
|
os.makedirs(essence_folder, exist_ok=True)
|
|
|
|
|
|
safe_tag = tag.replace('/', '_').replace('\\', '_').replace(' ', '_')
|
|
timestamp = time.strftime("%Y%m%d_%H%M%S")
|
|
filename = f"{safe_tag}_{quality_level}_{score:.2f}_{timestamp}.png"
|
|
filepath = os.path.join(essence_folder, filename)
|
|
|
|
|
|
image.save(filepath)
|
|
|
|
return filepath
|
|
|
|
def generate_essence_for_tag(tag, model, dataset, custom_settings=None):
|
|
"""
|
|
Generate an essence image for a specific tag using the improved generator
|
|
|
|
Args:
|
|
tag: The tag name or index
|
|
model: The model to use
|
|
dataset: The dataset containing tag information
|
|
custom_settings: Optional dictionary with custom generation settings
|
|
|
|
Returns:
|
|
PIL Image of the generated essence, score, quality level
|
|
"""
|
|
|
|
print(f"\n=== Starting essence generation for tag '{tag}' ===")
|
|
|
|
|
|
is_manual_tag = hasattr(st.session_state, 'manual_tags') and tag in st.session_state.manual_tags
|
|
is_discovered = hasattr(st.session_state, 'discovered_tags') and tag in st.session_state.discovered_tags
|
|
|
|
if not is_discovered and not is_manual_tag:
|
|
st.error(f"Tag '{tag}' has not been discovered yet.")
|
|
return None, 0, None
|
|
|
|
|
|
if is_discovered:
|
|
rarity = st.session_state.discovered_tags[tag].get("rarity", "Canard")
|
|
elif is_manual_tag:
|
|
rarity = st.session_state.manual_tags[tag].get("rarity", "Canard")
|
|
else:
|
|
rarity = "Canard"
|
|
|
|
|
|
cost = get_essence_cost(rarity)
|
|
|
|
|
|
if st.session_state.enkephalin < cost:
|
|
st.error(f"Not enough {ENKEPHALIN_CURRENCY_NAME} to generate this essence. You need {cost} {ENKEPHALIN_ICON} but have {st.session_state.enkephalin} {ENKEPHALIN_ICON}.")
|
|
return None, 0, None
|
|
|
|
|
|
settings = custom_settings or DEFAULT_ESSENCE_SETTINGS.copy()
|
|
print(f"Using settings: {settings}")
|
|
|
|
|
|
iterations = settings.get("iterations", 256)
|
|
scales = settings.get("scales", 5)
|
|
layer_emphasis = settings.get("layer_emphasis", "auto")
|
|
|
|
|
|
preview_container = st.empty()
|
|
progress_container = st.empty()
|
|
message_container = st.empty()
|
|
|
|
|
|
if layer_emphasis == "compare":
|
|
message_container.info("Generating essences with different layer emphasis types...")
|
|
tabs_container = st.empty()
|
|
|
|
tabs = None
|
|
tab_images = {}
|
|
best_score = -float('inf')
|
|
best_image = None
|
|
best_emphasis = None
|
|
|
|
try:
|
|
|
|
if layer_emphasis != "compare":
|
|
message_container.info(f"Generating essence for '{tag}' with {layer_emphasis} layer emphasis...")
|
|
|
|
|
|
def progress_callback(scale_idx, scale_count, iter_idx, iter_count, score):
|
|
|
|
progress = ((scale_idx * iter_count) + iter_idx) / (scale_count * iter_count)
|
|
progress_container.progress(progress, f"Scale {scale_idx+1}/{scale_count}, Iteration {iter_idx}/{iter_count}")
|
|
message_container.info(f"Current score: {score:.4f}")
|
|
|
|
|
|
if iter_idx % 20 == 0:
|
|
print(f"Progress: Scale {scale_idx+1}/{scale_count}, Iteration {iter_idx}/{iter_count}, Score: {score:.4f}")
|
|
|
|
|
|
tag_idx = None
|
|
|
|
|
|
if isinstance(tag, str):
|
|
print(f"Converting tag name '{tag}' to index...")
|
|
|
|
if hasattr(dataset, 'tag_to_idx') and tag in dataset.tag_to_idx:
|
|
tag_idx = dataset.tag_to_idx[tag]
|
|
print(f"Found tag index from dataset.tag_to_idx: {tag_idx}")
|
|
|
|
|
|
if tag_idx is None and hasattr(st.session_state, 'metadata') and 'tag_to_idx' in st.session_state.metadata:
|
|
tag_idx = st.session_state.metadata['tag_to_idx'].get(tag)
|
|
if tag_idx is not None:
|
|
print(f"Found tag index from session_state.metadata['tag_to_idx']: {tag_idx}")
|
|
|
|
|
|
if tag_idx is None and hasattr(st.session_state, 'metadata') and 'idx_to_tag' in st.session_state.metadata:
|
|
idx_to_tag = st.session_state.metadata['idx_to_tag']
|
|
tag_to_idx = {v: int(k) for k, v in idx_to_tag.items()}
|
|
tag_idx = tag_to_idx.get(tag)
|
|
if tag_idx is not None:
|
|
print(f"Found tag index from inverted idx_to_tag: {tag_idx}")
|
|
|
|
|
|
if tag_idx is None:
|
|
tag_lower = tag.lower()
|
|
for t, idx in tag_to_idx.items():
|
|
if t.lower() == tag_lower:
|
|
tag_idx = idx
|
|
print(f"Found tag index using case-insensitive match: {tag_idx}")
|
|
break
|
|
|
|
|
|
|
|
if tag_idx is None and is_manual_tag:
|
|
|
|
manual_tag_mapping = {
|
|
"hatsune_miku": "hatsune_miku",
|
|
"lamp": "lamp",
|
|
"blue_gloves": "gloves",
|
|
}
|
|
|
|
fallback_tag = manual_tag_mapping.get(tag)
|
|
if fallback_tag:
|
|
|
|
if hasattr(dataset, 'tag_to_idx') and fallback_tag in dataset.tag_to_idx:
|
|
tag_idx = dataset.tag_to_idx[fallback_tag]
|
|
print(f"Using fallback tag '{fallback_tag}' with index: {tag_idx}")
|
|
|
|
|
|
if tag_idx is None and hasattr(st.session_state, 'metadata') and 'tag_to_idx' in st.session_state.metadata:
|
|
tag_idx = st.session_state.metadata['tag_to_idx'].get(fallback_tag)
|
|
if tag_idx is not None:
|
|
print(f"Using fallback tag '{fallback_tag}' with index: {tag_idx}")
|
|
|
|
|
|
if tag_idx is None:
|
|
|
|
if "hair" in tag.lower():
|
|
generic_tag = "blue_hair"
|
|
elif "gloves" in tag.lower():
|
|
generic_tag = "gloves"
|
|
elif "miku" in tag.lower():
|
|
generic_tag = "twintails"
|
|
else:
|
|
generic_tag = "1girl"
|
|
|
|
|
|
if hasattr(dataset, 'tag_to_idx') and generic_tag in dataset.tag_to_idx:
|
|
tag_idx = dataset.tag_to_idx[generic_tag]
|
|
print(f"Using generic tag '{generic_tag}' with index: {tag_idx}")
|
|
elif hasattr(st.session_state, 'metadata') and 'tag_to_idx' in st.session_state.metadata:
|
|
tag_idx = st.session_state.metadata['tag_to_idx'].get(generic_tag)
|
|
if tag_idx is not None:
|
|
print(f"Using generic tag '{generic_tag}' with index: {tag_idx}")
|
|
|
|
|
|
if tag_idx is None:
|
|
st.error(f"Tag '{tag}' index not found. Cannot generate essence.")
|
|
print(f"ERROR: Tag '{tag}' index not found")
|
|
return None, 0, None
|
|
else:
|
|
|
|
tag_idx = tag
|
|
print(f"Using provided tag index: {tag_idx}")
|
|
|
|
|
|
if layer_emphasis == "compare":
|
|
|
|
results = try_different_layer_emphasis(
|
|
model=model,
|
|
tag_idx=tag_idx,
|
|
tag_name=tag,
|
|
image_size=512,
|
|
iterations=iterations,
|
|
scales=scales,
|
|
progress_callback=progress_callback
|
|
)
|
|
|
|
|
|
tab_names = []
|
|
tab_contents = []
|
|
|
|
for emphasis_type, result in results.items():
|
|
image = result["image"]
|
|
score = result["score"]
|
|
|
|
|
|
tab_images[emphasis_type] = image
|
|
|
|
|
|
if score > best_score:
|
|
best_score = score
|
|
best_image = image
|
|
best_emphasis = emphasis_type
|
|
|
|
|
|
tab_names.append(f"{emphasis_type.capitalize()} ({score:.2f})")
|
|
tab_contents.append(image)
|
|
|
|
|
|
tabs = tabs_container.tabs(tab_names)
|
|
for i, tab in enumerate(tabs):
|
|
with tab:
|
|
st.image(tab_contents[i], caption=f"Essence with {list(results.keys())[i]} layer emphasis", use_container_width=True)
|
|
|
|
|
|
image = best_image
|
|
score = best_score
|
|
|
|
|
|
st.success(f"Best results achieved with {best_emphasis} layer emphasis (score: {best_score:.2f})")
|
|
|
|
else:
|
|
|
|
color_boost = 1.5
|
|
tv_weight = 5e-4
|
|
|
|
|
|
if layer_emphasis == "low":
|
|
color_boost = 1.3
|
|
tv_weight = 2e-4
|
|
elif layer_emphasis == "high":
|
|
color_boost = 1.7
|
|
tv_weight = 8e-4
|
|
|
|
image, score = generate_essence_with_emphasis(
|
|
model=model,
|
|
tag_idx=tag_idx,
|
|
tag_name=tag,
|
|
image_size=512,
|
|
iterations=iterations,
|
|
scales=scales,
|
|
progress_callback=progress_callback,
|
|
layer_emphasis=layer_emphasis,
|
|
color_boost=color_boost,
|
|
tv_weight=tv_weight
|
|
)
|
|
|
|
|
|
quality_level = get_quality_level(score)
|
|
|
|
|
|
st.session_state.enkephalin -= cost
|
|
st.session_state.game_stats["enkephalin_spent"] = st.session_state.game_stats.get("enkephalin_spent", 0) + cost
|
|
|
|
|
|
st.session_state.game_stats["essences_generated"] = st.session_state.game_stats.get("essences_generated", 0) + 1
|
|
|
|
|
|
filepath = save_essence_to_game_folder(image, tag, score, quality_level)
|
|
print(f"Saved essence to: {filepath}")
|
|
|
|
|
|
if layer_emphasis != "compare":
|
|
preview_container.image(image, caption=f"Essence of '{tag}' - Quality: {quality_level}", width=512)
|
|
|
|
|
|
progress_container.empty()
|
|
message_container.empty()
|
|
|
|
|
|
if 'generated_essences' not in st.session_state:
|
|
st.session_state.generated_essences = {}
|
|
|
|
st.session_state.generated_essences[tag] = {
|
|
"path": filepath,
|
|
"score": score,
|
|
"quality": quality_level,
|
|
"rarity": rarity,
|
|
"settings": settings,
|
|
"generated_time": time.strftime("%Y-%m-%d %H:%M:%S")
|
|
}
|
|
|
|
|
|
st.success(f"Successfully generated {quality_level} essence for '{tag}' with score {score:.4f}! Spent {cost} {ENKEPHALIN_ICON}")
|
|
print(f"=== Essence generation complete for '{tag}' ===\n")
|
|
|
|
|
|
tag_storage.save_essence_state(session_state=st.session_state)
|
|
|
|
return image, score, quality_level
|
|
|
|
except Exception as e:
|
|
st.error(f"Error generating essence: {str(e)}")
|
|
print(f"EXCEPTION in generate_essence_for_tag: {str(e)}")
|
|
import traceback
|
|
err_traceback = traceback.format_exc()
|
|
print(err_traceback)
|
|
st.code(err_traceback)
|
|
return None, 0, None
|
|
|
|
def display_essence_generator():
|
|
"""
|
|
Display the essence generator interface
|
|
"""
|
|
|
|
initialize_essence_settings()
|
|
|
|
st.title("🎨 Tag Essence Generator")
|
|
st.write("Generate visual representations of what the AI model recognizes for specific tags.")
|
|
|
|
|
|
with st.expander("What are Tag Essences & How to Use Them", expanded=True):
|
|
st.markdown("""
|
|
### 💡 Understanding Tag Essences
|
|
|
|
Tag Essences are visual representations of what the AI model recognizes for specific tags. They can be extremely valuable for your tag collection strategy!
|
|
|
|
**How to use Tag Essences:**
|
|
1. **Generate a high-quality essence** for a tag you want to collect more of (only available on tags discovered in the library)
|
|
2. **Save the essence image** to your computer
|
|
3. **Upload the essence image** back into the tagger
|
|
4. The tagger will **almost always detect the original tag**
|
|
5. It will often also **detect related rare tags** from the same category
|
|
|
|
**Strategic Value:**
|
|
- Character essences can help unlock other tags associated with that character
|
|
- Category essences can help discover rare tags within that category
|
|
- High-quality essences (WAW, ALEPH) have the strongest effect
|
|
|
|
**This is why Enkephalin costs are high** - essences are powerful tools that can help you discover rare tags much more efficiently than random image scanning!
|
|
""")
|
|
|
|
|
|
|
|
model_available = hasattr(st.session_state, 'model')
|
|
if not model_available:
|
|
st.warning("Model not available. You can browse your tags but cannot generate essences.")
|
|
|
|
|
|
tabs = st.tabs(["Generate Essence", "My Essences"])
|
|
|
|
with tabs[0]:
|
|
|
|
if hasattr(st.session_state, 'selected_tag') and st.session_state.selected_tag:
|
|
tag = st.session_state.selected_tag
|
|
|
|
st.subheader(f"Generating Essence for '{tag}'")
|
|
|
|
|
|
image, score, quality = generate_essence_for_tag(
|
|
tag,
|
|
st.session_state.model,
|
|
st.session_state.model.dataset,
|
|
st.session_state.essence_custom_settings
|
|
)
|
|
|
|
|
|
if image is not None:
|
|
with st.expander("Essence Usage", expanded=True):
|
|
st.markdown("""
|
|
💡 **Tag Essence Usage Tips:**
|
|
1. Look for similar patterns, colors, and elements in real images
|
|
2. The essence reveals what features the AI model recognizes for this tag
|
|
3. Use this as inspiration when creating or finding images to get this tag
|
|
""")
|
|
else:
|
|
st.error("Essence generation failed. Please check the error messages above and try again with different settings.")
|
|
|
|
|
|
st.session_state.selected_tag = None
|
|
else:
|
|
|
|
selected_tag = display_essence_generation_interface(model_available)
|
|
|
|
|
|
if selected_tag:
|
|
st.session_state.selected_tag = selected_tag
|
|
st.rerun()
|
|
|
|
with tabs[1]:
|
|
display_saved_essences()
|
|
|
|
def save_essence_to_game_folder(image, tag, score, quality_level):
|
|
"""
|
|
Save the generated essence image to a persistent game folder
|
|
|
|
Args:
|
|
image: PIL Image of the essence
|
|
tag: The tag name
|
|
score: The generation score
|
|
quality_level: The quality classification (ZAYIN, TETH, etc.)
|
|
|
|
Returns:
|
|
Path to the saved image
|
|
"""
|
|
|
|
base_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
|
game_data_dir = os.path.join(base_dir, "game_data")
|
|
essence_folder = os.path.join(game_data_dir, "essences")
|
|
|
|
|
|
os.makedirs(game_data_dir, exist_ok=True)
|
|
os.makedirs(essence_folder, exist_ok=True)
|
|
|
|
|
|
quality_folder = os.path.join(essence_folder, quality_level)
|
|
os.makedirs(quality_folder, exist_ok=True)
|
|
|
|
|
|
safe_tag = tag.replace('/', '_').replace('\\', '_').replace(' ', '_')
|
|
timestamp = time.strftime("%Y%m%d_%H%M%S")
|
|
filename = f"{safe_tag}_{score:.2f}_{timestamp}.png"
|
|
filepath = os.path.join(quality_folder, filename)
|
|
|
|
|
|
image.save(filepath)
|
|
|
|
print(f"Saved essence to: {filepath}")
|
|
return filepath
|
|
|
|
def essence_folder_path():
|
|
"""Get the path to the essence folder, creating it if necessary"""
|
|
base_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
|
game_data_dir = os.path.join(base_dir, "game_data")
|
|
essence_folder = os.path.join(game_data_dir, "essences")
|
|
|
|
|
|
os.makedirs(game_data_dir, exist_ok=True)
|
|
os.makedirs(essence_folder, exist_ok=True)
|
|
|
|
return essence_folder
|
|
|
|
def display_saved_essences():
|
|
"""Display the user's saved essence images"""
|
|
st.subheader("My Generated Essences")
|
|
|
|
if not hasattr(st.session_state, 'generated_essences') or not st.session_state.generated_essences:
|
|
st.info("You haven't generated any essences yet. Go to the Generate tab to create some!")
|
|
return
|
|
|
|
|
|
st.markdown("""
|
|
### How to Use Your Essences
|
|
|
|
1. **Click on any essence image** to open it in full size
|
|
2. **Save the image** to your computer (right-click → Save image)
|
|
3. **Go to the Scan Images tab** and upload the saved essence
|
|
4. The tagger will likely detect the original tag and potentially related rare tags!
|
|
|
|
Higher quality essences (WAW, ALEPH) generally produce the best results.
|
|
""")
|
|
|
|
|
|
essence_dir = essence_folder_path()
|
|
|
|
|
|
for tag, info in st.session_state.generated_essences.items():
|
|
if "path" in info and not os.path.exists(info["path"]):
|
|
|
|
quality = info.get("quality", "ZAYIN")
|
|
quality_dir = os.path.join(essence_dir, quality)
|
|
|
|
if os.path.exists(quality_dir):
|
|
|
|
safe_tag = tag.replace('/', '_').replace('\\', '_').replace(' ', '_')
|
|
matching_files = [f for f in os.listdir(quality_dir) if f.startswith(safe_tag)]
|
|
|
|
if matching_files:
|
|
|
|
matching_files.sort(reverse=True)
|
|
info["path"] = os.path.join(quality_dir, matching_files[0])
|
|
print(f"Reconnected essence for {tag} to {info['path']}")
|
|
|
|
|
|
essences_by_quality = {}
|
|
for tag, info in st.session_state.generated_essences.items():
|
|
quality = info.get("quality", "ZAYIN")
|
|
if quality not in essences_by_quality:
|
|
essences_by_quality[quality] = []
|
|
essences_by_quality[quality].append((tag, info))
|
|
|
|
|
|
try:
|
|
untracked_essences = {}
|
|
|
|
for quality in ESSENCE_QUALITY_LEVELS.keys():
|
|
quality_dir = os.path.join(essence_dir, quality)
|
|
if os.path.exists(quality_dir):
|
|
essence_files = os.listdir(quality_dir)
|
|
|
|
|
|
essence_files = [f for f in essence_files if f.lower().endswith('.png')]
|
|
|
|
if essence_files:
|
|
|
|
for filename in essence_files:
|
|
|
|
parts = filename.split('_')
|
|
if len(parts) >= 2:
|
|
tag = parts[0].replace('_', ' ')
|
|
|
|
|
|
is_tracked = False
|
|
for tracked_tag, tracked_info in st.session_state.generated_essences.items():
|
|
if "path" in tracked_info and os.path.basename(tracked_info["path"]) == filename:
|
|
is_tracked = True
|
|
break
|
|
|
|
if not is_tracked:
|
|
if quality not in untracked_essences:
|
|
untracked_essences[quality] = []
|
|
untracked_essences[quality].append((tag, {
|
|
"path": os.path.join(quality_dir, filename),
|
|
"quality": quality,
|
|
"discovered_on_disk": True
|
|
}))
|
|
except Exception as e:
|
|
print(f"Error checking for untracked essences: {e}")
|
|
|
|
|
|
for quality, essences in untracked_essences.items():
|
|
if quality not in essences_by_quality:
|
|
essences_by_quality[quality] = []
|
|
for tag, info in essences:
|
|
|
|
if not any(tracked_tag == tag for tracked_tag, _ in essences_by_quality[quality]):
|
|
essences_by_quality[quality].append((tag, info))
|
|
|
|
|
|
for quality in list(ESSENCE_QUALITY_LEVELS.keys())[::-1]:
|
|
if quality in essences_by_quality:
|
|
essences = essences_by_quality[quality]
|
|
color = ESSENCE_QUALITY_LEVELS[quality]["color"]
|
|
|
|
with st.expander(f"{quality} Essences ({len(essences)})", expanded=quality in ["ALEPH", "WAW"]):
|
|
|
|
cols = st.columns(3)
|
|
for i, (tag, info) in enumerate(sorted(essences, key=lambda x: x[1].get("score", 0), reverse=True)):
|
|
col_idx = i % 3
|
|
with cols[col_idx]:
|
|
try:
|
|
|
|
if "path" in info and os.path.exists(info["path"]):
|
|
image = Image.open(info["path"])
|
|
rarity = info.get("rarity", "Canard")
|
|
score = info.get("score", 0)
|
|
|
|
|
|
rarity_color = RARITY_LEVELS.get(rarity, {}).get("color", "#AAAAAA")
|
|
|
|
|
|
st.image(image, caption=tag, use_container_width=True)
|
|
|
|
|
|
if rarity == "Impuritas Civitas":
|
|
st.markdown(f"""
|
|
<span style='color:{color};font-weight:bold;'>{quality}</span> |
|
|
<span style='animation: rainbow-text 4s linear infinite;font-weight:bold;'>{rarity}</span> |
|
|
Score: {score:.2f}
|
|
""", unsafe_allow_html=True)
|
|
elif rarity == "Star of the City":
|
|
st.markdown(f"""
|
|
<span style='color:{color};font-weight:bold;'>{quality}</span> |
|
|
<span style='color:{rarity_color};text-shadow:0 0 3px gold;font-weight:bold;'>{rarity}</span> |
|
|
Score: {score:.2f}
|
|
""", unsafe_allow_html=True)
|
|
elif rarity == "Urban Nightmare":
|
|
st.markdown(f"""
|
|
<span style='color:{color};font-weight:bold;'>{quality}</span> |
|
|
<span style='color:{rarity_color};text-shadow:0 0 1px #FF5722;font-weight:bold;'>{rarity}</span> |
|
|
Score: {score:.2f}
|
|
""", unsafe_allow_html=True)
|
|
elif rarity == "Urban Plague":
|
|
st.markdown(f"""
|
|
<span style='color:{color};font-weight:bold;'>{quality}</span> |
|
|
<span style='color:{rarity_color};text-shadow:0 0 1px #9C27B0;font-weight:bold;'>{rarity}</span> |
|
|
Score: {score:.2f}
|
|
""", unsafe_allow_html=True)
|
|
else:
|
|
st.markdown(f"""
|
|
<span style='color:{color};font-weight:bold;'>{quality}</span> |
|
|
<span style='color:{rarity_color};font-weight:bold;'>{rarity}</span> |
|
|
Score: {score:.2f}
|
|
""", unsafe_allow_html=True)
|
|
|
|
|
|
if "discovered_on_disk" in info and info["discovered_on_disk"]:
|
|
st.info("Found on disk (not in session state)")
|
|
|
|
|
|
if st.button(f"Open Folder", key=f"open_folder_{tag}_{quality}"):
|
|
folder_path = os.path.dirname(info["path"])
|
|
try:
|
|
|
|
if os.name == 'nt':
|
|
os.startfile(folder_path)
|
|
elif os.name == 'posix':
|
|
import subprocess
|
|
if 'darwin' in os.sys.platform:
|
|
subprocess.call(['open', folder_path])
|
|
else:
|
|
subprocess.call(['xdg-open', folder_path])
|
|
st.success(f"Opened folder: {folder_path}")
|
|
except Exception as e:
|
|
st.error(f"Could not open folder: {str(e)}")
|
|
|
|
st.code(folder_path)
|
|
else:
|
|
|
|
st.warning(f"Image file not found: {info.get('path', 'No path available')}")
|
|
|
|
|
|
st.markdown(f"""
|
|
<span style='color:{color};font-weight:bold;'>{quality}</span> | {tag}
|
|
""", unsafe_allow_html=True)
|
|
|
|
|
|
if "rarity" in info and "score" in info:
|
|
if st.button(f"Reconnect File", key=f"reconnect_{tag}_{quality}"):
|
|
|
|
safe_tag = tag.replace('/', '_').replace('\\', '_').replace(' ', '_')
|
|
score = info.get("score", 0)
|
|
quality_dir = os.path.join(essence_dir, quality)
|
|
|
|
|
|
os.makedirs(quality_dir, exist_ok=True)
|
|
|
|
|
|
timestamp = time.strftime("%Y%m%d_%H%M%S")
|
|
filename = f"{safe_tag}_{score:.2f}_{timestamp}.png"
|
|
info["path"] = os.path.join(quality_dir, filename)
|
|
|
|
st.info(f"Please save your image to this location: {info['path']}")
|
|
st.session_state.generated_essences[tag] = info
|
|
tag_storage.save_essence_state(session_state=st.session_state)
|
|
st.rerun()
|
|
|
|
except Exception as e:
|
|
st.write(f"Error loading {tag}: {str(e)}")
|
|
|
|
|
|
st.divider()
|
|
if st.button("Clean Up Missing Files", help="Remove entries for essences where the file no longer exists"):
|
|
|
|
to_remove = []
|
|
for tag, info in st.session_state.generated_essences.items():
|
|
if "path" in info and not os.path.exists(info["path"]):
|
|
to_remove.append(tag)
|
|
|
|
|
|
for tag in to_remove:
|
|
del st.session_state.generated_essences[tag]
|
|
|
|
|
|
tag_storage.save_essence_state(session_state=st.session_state)
|
|
|
|
if to_remove:
|
|
st.success(f"Removed {len(to_remove)} entries with missing files")
|
|
else:
|
|
st.success("No missing files found")
|
|
|
|
st.rerun()
|
|
|
|
def display_essence_generation_interface(model_available):
|
|
"""Display the interface for generating new essences"""
|
|
|
|
initialize_manual_tags()
|
|
|
|
st.subheader("Generate Tag Essence")
|
|
st.write("Select a tag to generate its essence. Higher quality essences can help unlock rare related tags when uploaded back into the tagger.")
|
|
|
|
|
|
col1, col2 = st.columns(2)
|
|
|
|
with col1:
|
|
|
|
st.write("Generation Settings:")
|
|
|
|
|
|
scales = st.slider("Scales", 1, 5, DEFAULT_ESSENCE_SETTINGS["scales"],
|
|
help="More scales produce more detailed essences")
|
|
|
|
iterations = st.slider("Iterations", 64, 2048, DEFAULT_ESSENCE_SETTINGS["iterations"], 64,
|
|
help="More iterations improve quality")
|
|
|
|
|
|
layer_emphasis = st.selectbox(
|
|
"Feature Targeting",
|
|
options=["auto", "balanced", "high", "mid", "low", "compare", "custom"],
|
|
index=0,
|
|
format_func=lambda x: {
|
|
"auto": "Auto-detect (best for each tag)",
|
|
"balanced": "Balanced (mix of features)",
|
|
"high": "High-level (characters, objects)",
|
|
"mid": "Mid-level (parts, components)",
|
|
"low": "Low-level (textures, patterns)",
|
|
"compare": "Compare different approaches",
|
|
"custom": "Custom layer selection"
|
|
}.get(x, x),
|
|
help="Controls which model features to emphasize in the essence"
|
|
)
|
|
|
|
|
|
custom_layers = []
|
|
if layer_emphasis == "custom" and model_available:
|
|
st.write("Select Custom Layers:")
|
|
|
|
|
|
key_layers = get_key_layers(st.session_state.model, max_layers=15)
|
|
|
|
|
|
for category, layers in key_layers.items():
|
|
if layers:
|
|
category_name = {
|
|
"early": "Early Layers (textures, colors)",
|
|
"middle": "Middle Layers (parts, components)",
|
|
"late": "Late Layers (objects, characters)",
|
|
"classifier": "Classifier (final recognition)"
|
|
}.get(category, category.capitalize())
|
|
|
|
with st.expander(f"{category_name}", expanded=category in ["late", "classifier"]):
|
|
select_all = st.checkbox(f"Select all {category} layers",
|
|
key=f"select_all_{category}")
|
|
|
|
for layer in layers:
|
|
|
|
parts = layer.split(".")
|
|
display_name = f"...{parts[-2]}.{parts[-1]}" if len(parts) > 3 else layer
|
|
|
|
if select_all or st.checkbox(display_name, key=f"layer_{layer}"):
|
|
custom_layers.append(layer)
|
|
|
|
|
|
if custom_layers:
|
|
st.success(f"Selected {len(custom_layers)} layers")
|
|
else:
|
|
st.warning("Please select at least one layer")
|
|
|
|
|
|
st.session_state.essence_custom_settings = {
|
|
"scales": scales,
|
|
"iterations": iterations,
|
|
"image_size": 512,
|
|
"lr": 0.03,
|
|
"layer_emphasis": layer_emphasis,
|
|
"custom_layers": custom_layers
|
|
}
|
|
|
|
with col2:
|
|
|
|
st.write("Quality Levels:")
|
|
for level, info in ESSENCE_QUALITY_LEVELS.items():
|
|
st.markdown(f"""
|
|
<div style="padding:5px;margin-bottom:5px;border-radius:4px;background-color:rgba({int(info['color'][1:3], 16)},{int(info['color'][3:5], 16)},{int(info['color'][5:7], 16)},0.1);border-left:3px solid {info['color']}">
|
|
<span style="color:{info['color']};font-weight:bold;">{level}</span> ({info['threshold']:.0f} Score+): {info['description']}
|
|
</div>
|
|
""", unsafe_allow_html=True)
|
|
|
|
|
|
st.write("Feature Targeting Explanation:")
|
|
st.markdown("""
|
|
ℹ️ **Feature targeting affects what the visualization emphasizes:**
|
|
""")
|
|
|
|
|
|
st.markdown(f"### Your {ENKEPHALIN_CURRENCY_NAME}: **{st.session_state.enkephalin}** {ENKEPHALIN_ICON}")
|
|
st.divider()
|
|
|
|
|
|
st.markdown("""
|
|
<style>
|
|
@keyframes rainbow-text {
|
|
0% { color: red; }
|
|
14% { color: orange; }
|
|
28% { color: yellow; }
|
|
42% { color: green; }
|
|
57% { color: blue; }
|
|
71% { color: indigo; }
|
|
85% { color: violet; }
|
|
100% { color: red; }
|
|
}
|
|
|
|
.impuritas-text {
|
|
font-weight: bold;
|
|
animation: rainbow-text 4s linear infinite;
|
|
}
|
|
|
|
@keyframes glow-text {
|
|
0% { text-shadow: 0 0 2px gold; }
|
|
50% { text-shadow: 0 0 6px gold; }
|
|
100% { text-shadow: 0 0 2px gold; }
|
|
}
|
|
|
|
.star-text {
|
|
color: #FFEB3B;
|
|
text-shadow: 0 0 3px gold;
|
|
animation: glow-text 2s infinite;
|
|
font-weight: bold;
|
|
}
|
|
|
|
@keyframes pulse-text {
|
|
0% { opacity: 0.8; }
|
|
50% { opacity: 1; }
|
|
100% { opacity: 0.8; }
|
|
}
|
|
|
|
.nightmare-text {
|
|
color: #FF9800;
|
|
text-shadow: 0 0 1px #FF5722;
|
|
animation: pulse-text 3s infinite;
|
|
font-weight: bold;
|
|
}
|
|
|
|
.plague-text {
|
|
color: #9C27B0;
|
|
text-shadow: 0 0 1px #9C27B0;
|
|
font-weight: bold;
|
|
}
|
|
|
|
.category-section {
|
|
margin-top: 20px;
|
|
margin-bottom: 30px;
|
|
padding: 10px;
|
|
border-radius: 5px;
|
|
border-left: 5px solid;
|
|
}
|
|
</style>
|
|
""", unsafe_allow_html=True)
|
|
|
|
|
|
|
|
|
|
all_tags = []
|
|
|
|
|
|
if hasattr(st.session_state, 'discovered_tags'):
|
|
for tag, info in st.session_state.discovered_tags.items():
|
|
tag_info = {
|
|
"tag": tag,
|
|
"rarity": info.get("rarity", "Unknown"),
|
|
"category": info.get("category", "unknown"),
|
|
"source": "discovered",
|
|
"library_floor": info.get("library_floor", ""),
|
|
"discovery_time": info.get("discovery_time", "")
|
|
}
|
|
all_tags.append(tag_info)
|
|
|
|
|
|
if hasattr(st.session_state, 'manual_tags'):
|
|
for tag, info in st.session_state.manual_tags.items():
|
|
tag_info = {
|
|
"tag": tag,
|
|
"rarity": info.get("rarity", "Special"),
|
|
"category": info.get("category", "special"),
|
|
"source": "manual",
|
|
"description": info.get("description", "")
|
|
}
|
|
all_tags.append(tag_info)
|
|
|
|
|
|
rarity_counts = {}
|
|
for info in all_tags:
|
|
rarity = info["rarity"]
|
|
if rarity not in rarity_counts:
|
|
rarity_counts[rarity] = 0
|
|
rarity_counts[rarity] += 1
|
|
|
|
|
|
st.subheader("Available Tags for Essence Generation")
|
|
st.write(f"You have {len(all_tags)} tags available for essence generation. Collect more from the library!")
|
|
|
|
|
|
rarity_cols = st.columns(len(rarity_counts))
|
|
for i, (rarity, count) in enumerate(sorted(rarity_counts.items(),
|
|
key=lambda x: list(RARITY_LEVELS.keys()).index(x[0]) if x[0] in RARITY_LEVELS else 999)):
|
|
with rarity_cols[i]:
|
|
|
|
color = RARITY_LEVELS.get(rarity, {}).get("color", "#888888")
|
|
|
|
|
|
style = f"color:{color};font-weight:bold;"
|
|
class_name = ""
|
|
|
|
if rarity == "Impuritas Civitas":
|
|
class_name = "grid-impuritas"
|
|
elif rarity == "Star of the City":
|
|
class_name = "grid-star"
|
|
elif rarity == "Urban Nightmare":
|
|
class_name = "grid-nightmare"
|
|
elif rarity == "Urban Plague":
|
|
class_name = "grid-plague"
|
|
|
|
if class_name:
|
|
st.markdown(
|
|
f"<div style='text-align:center;'><span class='{class_name}' style='font-weight:bold;'>{rarity.capitalize()}</span><br>{count}</div>",
|
|
unsafe_allow_html=True
|
|
)
|
|
else:
|
|
st.markdown(
|
|
f"<div style='text-align:center;'><span style='{style}'>{rarity.capitalize()}</span><br>{count}</div>",
|
|
unsafe_allow_html=True
|
|
)
|
|
|
|
|
|
search_term = st.text_input("Search tags", "", key="essence_search_tags")
|
|
|
|
|
|
sort_options = ["Category (rarest first)", "Rarity", "Discovery Time"]
|
|
selected_sort = st.selectbox("Sort tags by:", sort_options, key="essence_tags_sort")
|
|
|
|
|
|
if search_term:
|
|
all_tags = [info for info in all_tags if search_term.lower() in info["tag"].lower()]
|
|
|
|
selected_tag = None
|
|
|
|
|
|
if selected_sort == "Category (rarest first)":
|
|
|
|
categories = {}
|
|
for info in all_tags:
|
|
category = info["category"]
|
|
if category not in categories:
|
|
categories[category] = []
|
|
categories[category].append(info)
|
|
|
|
|
|
for category, tags in sorted(categories.items()):
|
|
|
|
rarity_order = list(reversed(RARITY_LEVELS.keys()))
|
|
|
|
|
|
def get_rarity_index(info):
|
|
rarity = info["rarity"]
|
|
if rarity in rarity_order:
|
|
return len(rarity_order) - rarity_order.index(rarity)
|
|
return 0
|
|
|
|
sorted_tags = sorted(tags, key=get_rarity_index, reverse=True)
|
|
|
|
|
|
has_rare_tags = any(info["rarity"] in ["Impuritas Civitas", "Star of the City"]
|
|
for info in sorted_tags)
|
|
|
|
|
|
category_display = category.capitalize()
|
|
if category in TAG_CATEGORIES:
|
|
category_info = TAG_CATEGORIES[category]
|
|
icon = category_info.get("icon", "")
|
|
color = category_info.get("color", "#888888")
|
|
category_display = f"<span style='color:{color};'>{icon} {category.capitalize()}</span>"
|
|
|
|
|
|
header = f"{category_display} ({len(tags)} tags)"
|
|
if has_rare_tags:
|
|
header += " ✨ Contains rare tags!"
|
|
|
|
|
|
st.markdown(header, unsafe_allow_html=True)
|
|
with st.expander("Show/Hide", expanded=has_rare_tags):
|
|
|
|
cols = st.columns(3)
|
|
for i, info in enumerate(sorted_tags):
|
|
with cols[i % 3]:
|
|
tag = info["tag"]
|
|
rarity = info["rarity"]
|
|
source = info["source"]
|
|
|
|
|
|
rarity_color = RARITY_LEVELS.get(rarity, {}).get("color", "#AAAAAA")
|
|
|
|
|
|
has_essence = hasattr(st.session_state, 'generated_essences') and tag in st.session_state.generated_essences
|
|
|
|
|
|
cost = get_essence_cost(rarity)
|
|
can_afford = st.session_state.enkephalin >= cost
|
|
|
|
|
|
if rarity == "Impuritas Civitas":
|
|
tag_display = f'<span class="impuritas-text">{tag}</span>'
|
|
elif rarity == "Star of the City":
|
|
tag_display = f'<span class="star-text">{tag}</span>'
|
|
elif rarity == "Urban Nightmare":
|
|
tag_display = f'<span class="nightmare-text">{tag}</span>'
|
|
elif rarity == "Urban Plague":
|
|
tag_display = f'<span class="plague-text">{tag}</span>'
|
|
else:
|
|
tag_display = f'<span style="color:{rarity_color};font-weight:bold;">{tag}</span>'
|
|
|
|
|
|
st.markdown(
|
|
f'{tag_display} <span style="background-color:{rarity_color};color:white;padding:2px 6px;border-radius:10px;font-size:0.8em;">{rarity.capitalize()}</span> ({cost} {ENKEPHALIN_ICON})',
|
|
unsafe_allow_html=True
|
|
)
|
|
|
|
|
|
if source == "discovered" and "library_floor" in info and info["library_floor"]:
|
|
st.markdown(f'<span style="font-size:0.85em;">Found in: {info["library_floor"]}</span>',
|
|
unsafe_allow_html=True)
|
|
elif source == "manual" and "description" in info and info["description"]:
|
|
st.markdown(f'<span style="font-size:0.85em;font-style:italic;">{info["description"]}</span>',
|
|
unsafe_allow_html=True)
|
|
|
|
|
|
button_label = "Generate" if not has_essence else "Regenerate ✓"
|
|
if st.button(button_label, key=f"gen_{tag}_{source}", disabled=not model_available or not can_afford):
|
|
selected_tag = tag
|
|
|
|
elif selected_sort == "Rarity":
|
|
|
|
rarity_groups = {}
|
|
for info in all_tags:
|
|
rarity = info["rarity"]
|
|
if rarity not in rarity_groups:
|
|
rarity_groups[rarity] = []
|
|
rarity_groups[rarity].append(info)
|
|
|
|
|
|
ordered_rarities = list(RARITY_LEVELS.keys())
|
|
ordered_rarities.reverse()
|
|
|
|
|
|
for rarity in rarity_groups.keys():
|
|
if rarity not in ordered_rarities:
|
|
ordered_rarities.append(rarity)
|
|
|
|
|
|
for rarity in ordered_rarities:
|
|
if rarity in rarity_groups:
|
|
tags = rarity_groups[rarity]
|
|
color = RARITY_LEVELS.get(rarity, {}).get("color", "#AAAAAA")
|
|
|
|
|
|
rarity_html = f"<span style='color:{color};font-weight:bold;'>{rarity.capitalize()}</span>"
|
|
if rarity == "Impuritas Civitas":
|
|
rarity_html = f"<span style='animation:rainbow-text 4s linear infinite;font-weight:bold;'>{rarity.capitalize()}</span>"
|
|
elif rarity == "Star of the City":
|
|
rarity_html = f"<span style='color:{color};text-shadow:0 0 3px gold;font-weight:bold;'>{rarity.capitalize()}</span>"
|
|
elif rarity == "Urban Nightmare":
|
|
rarity_html = f"<span style='color:{color};text-shadow:0 0 1px #FF5722;font-weight:bold;'>{rarity.capitalize()}</span>"
|
|
|
|
|
|
st.markdown(f"### {rarity_html} ({len(tags)} tags)", unsafe_allow_html=True)
|
|
with st.expander("Show/Hide", expanded=rarity in ["Impuritas Civitas", "Star of the City"]):
|
|
|
|
cols = st.columns(3)
|
|
for i, info in enumerate(sorted(tags, key=lambda x: x["tag"])):
|
|
with cols[i % 3]:
|
|
tag = info["tag"]
|
|
source = info["source"]
|
|
|
|
|
|
has_essence = hasattr(st.session_state, 'generated_essences') and tag in st.session_state.generated_essences
|
|
|
|
|
|
cost = get_essence_cost(rarity)
|
|
can_afford = st.session_state.enkephalin >= cost
|
|
|
|
|
|
st.markdown(f"**{tag}** ({cost} {ENKEPHALIN_ICON})")
|
|
|
|
|
|
if source == "discovered" and "library_floor" in info and info["library_floor"]:
|
|
st.markdown(f'<span style="font-size:0.85em;">Found in: {info["library_floor"]}</span>',
|
|
unsafe_allow_html=True)
|
|
elif source == "manual" and "description" in info and info["description"]:
|
|
st.markdown(f'<span style="font-size:0.85em;font-style:italic;">{info["description"]}</span>',
|
|
unsafe_allow_html=True)
|
|
|
|
|
|
button_label = "Generate" if not has_essence else "Regenerate ✓"
|
|
if st.button(button_label, key=f"gen_{tag}_{source}", disabled=not model_available or not can_afford):
|
|
selected_tag = tag
|
|
|
|
elif selected_sort == "Discovery Time":
|
|
|
|
discovered_tags = [info for info in all_tags if info["source"] == "discovered" and "discovery_time" in info]
|
|
|
|
|
|
sorted_tags = sorted(discovered_tags, key=lambda x: x["discovery_time"], reverse=True)
|
|
|
|
|
|
date_groups = {}
|
|
for info in sorted_tags:
|
|
time_str = info["discovery_time"]
|
|
|
|
date = time_str.split()[0] if " " in time_str else time_str
|
|
|
|
if date not in date_groups:
|
|
date_groups[date] = []
|
|
date_groups[date].append(info)
|
|
|
|
|
|
for date, tags in date_groups.items():
|
|
date_display = date if date else "Unknown date"
|
|
st.markdown(f"### Discovered on {date_display} ({len(tags)} tags)")
|
|
|
|
with st.expander("Show/Hide", expanded=date == list(date_groups.keys())[0]):
|
|
|
|
cols = st.columns(3)
|
|
for i, info in enumerate(tags):
|
|
with cols[i % 3]:
|
|
tag = info["tag"]
|
|
rarity = info["rarity"]
|
|
|
|
|
|
rarity_color = RARITY_LEVELS.get(rarity, {}).get("color", "#AAAAAA")
|
|
|
|
|
|
has_essence = hasattr(st.session_state, 'generated_essences') and tag in st.session_state.generated_essences
|
|
|
|
|
|
cost = get_essence_cost(rarity)
|
|
can_afford = st.session_state.enkephalin >= cost
|
|
|
|
|
|
if rarity == "Impuritas Civitas":
|
|
tag_display = f'<span class="impuritas-text">{tag}</span>'
|
|
elif rarity == "Star of the City":
|
|
tag_display = f'<span class="star-text">{tag}</span>'
|
|
elif rarity == "Urban Nightmare":
|
|
tag_display = f'<span class="nightmare-text">{tag}</span>'
|
|
elif rarity == "Urban Plague":
|
|
tag_display = f'<span class="plague-text">{tag}</span>'
|
|
else:
|
|
tag_display = f'<span style="color:{rarity_color};font-weight:bold;">{tag}</span>'
|
|
|
|
|
|
st.markdown(
|
|
f'{tag_display} <span style="background-color:{rarity_color};color:white;padding:2px 6px;border-radius:10px;font-size:0.8em;">{rarity.capitalize()}</span> ({cost} {ENKEPHALIN_ICON})',
|
|
unsafe_allow_html=True
|
|
)
|
|
|
|
|
|
if "library_floor" in info and info["library_floor"]:
|
|
st.markdown(f'<span style="font-size:0.85em;">Found in: {info["library_floor"]}</span>',
|
|
unsafe_allow_html=True)
|
|
|
|
|
|
button_label = "Generate" if not has_essence else "Regenerate ✓"
|
|
if st.button(button_label, key=f"gen_{tag}_disc", disabled=not model_available or not can_afford):
|
|
selected_tag = tag
|
|
|
|
|
|
manual_tags = [info for info in all_tags if info["source"] == "manual"]
|
|
if manual_tags:
|
|
st.markdown("### Manual Tags")
|
|
with st.expander("Show/Hide"):
|
|
|
|
cols = st.columns(3)
|
|
for i, info in enumerate(manual_tags):
|
|
with cols[i % 3]:
|
|
tag = info["tag"]
|
|
rarity = info["rarity"]
|
|
|
|
|
|
rarity_color = RARITY_LEVELS.get(rarity, {}).get("color", "#AAAAAA")
|
|
|
|
|
|
has_essence = hasattr(st.session_state, 'generated_essences') and tag in st.session_state.generated_essences
|
|
|
|
|
|
cost = get_essence_cost(rarity)
|
|
can_afford = st.session_state.enkephalin >= cost
|
|
|
|
|
|
st.markdown(f"**{tag}** ({cost} {ENKEPHALIN_ICON})")
|
|
|
|
|
|
if "description" in info and info["description"]:
|
|
st.markdown(f'<span style="font-size:0.85em;font-style:italic;">{info["description"]}</span>',
|
|
unsafe_allow_html=True)
|
|
|
|
|
|
button_label = "Generate" if not has_essence else "Regenerate ✓"
|
|
if st.button(button_label, key=f"gen_{tag}_manual", disabled=not model_available or not can_afford):
|
|
selected_tag = tag
|
|
|
|
return selected_tag
|
|
|
|
def generate_essence_with_emphasis(model, tag_idx, tag_name=None, image_size=512,
|
|
iterations=256, scales=3, progress_callback=None,
|
|
layer_emphasis="mid", color_boost=1.5, tv_weight=5e-4):
|
|
"""
|
|
Generate an essence visualization with specific layer emphasis and enhancements.
|
|
|
|
Args:
|
|
model: Neural network model to visualize
|
|
tag_idx: Index of the tag to visualize
|
|
tag_name: Optional name of the tag (for logging)
|
|
image_size: Size of output image (default: 512)
|
|
iterations: Number of iterations per scale (default: 256)
|
|
scales: Number of scales to use (default: 5)
|
|
progress_callback: Optional callback for progress updates
|
|
layer_emphasis: Type of layers to use ("auto", "balanced", "high", "mid", "low")
|
|
color_boost: Factor for boosting color saturation (default: 1.5)
|
|
tv_weight: Total variation weight (default: 5e-4)
|
|
|
|
Returns:
|
|
PIL Image of the generated essence and the activation score
|
|
"""
|
|
|
|
tag_to_name = {tag_idx: tag_name} if tag_name else None
|
|
|
|
|
|
layers_to_hook = None
|
|
layer_weights = None
|
|
|
|
if layer_emphasis != "auto":
|
|
|
|
layers_to_hook = get_suggested_layers(model, layer_emphasis)
|
|
|
|
|
|
layer_weights = {}
|
|
for i, layer in enumerate(layers_to_hook):
|
|
|
|
weight = 0.5 + 0.5 * (i / max(1, len(layers_to_hook) - 1))
|
|
|
|
|
|
if any(x in layer.lower() for x in ["classifier", "fc", "linear", "output", "logits"]):
|
|
weight *= 1.5
|
|
|
|
layer_weights[layer] = weight
|
|
|
|
print(f"Using {len(layers_to_hook)} {layer_emphasis}-level layers")
|
|
else:
|
|
print("Using auto layer detection")
|
|
|
|
|
|
generator = EssenceGenerator(
|
|
model=model,
|
|
tag_to_name=tag_to_name,
|
|
iterations=iterations,
|
|
scales=scales,
|
|
learning_rate=0.05,
|
|
decay_power=1.0,
|
|
tv_weight=tv_weight,
|
|
layers_to_hook=layers_to_hook,
|
|
layer_weights=layer_weights,
|
|
color_boost=color_boost
|
|
)
|
|
|
|
|
|
print(f"Generating essence for tag {tag_name or tag_idx} with {layer_emphasis} emphasis...")
|
|
image, score = generator.generate_essence(
|
|
tag_idx=tag_idx,
|
|
image_size=image_size,
|
|
return_score=True,
|
|
progress_callback=progress_callback
|
|
)
|
|
|
|
print(f"Essence generation complete. Score: {score:.4f}")
|
|
return image, score
|
|
|
|
def try_different_layer_emphasis(model, tag_idx, tag_name=None, image_size=512,
|
|
iterations=256, scales=4, progress_callback=None):
|
|
"""
|
|
Generate multiple essences with different layer emphasis types and return them all.
|
|
|
|
Args:
|
|
model: Neural network model to visualize
|
|
tag_idx: Index of the tag to visualize
|
|
tag_name: Optional name of the tag (for logging)
|
|
image_size: Size of output image (default: 512)
|
|
iterations: Number of iterations per scale (default: 256)
|
|
scales: Number of scales to use (default: 4)
|
|
progress_callback: Optional callback for progress updates
|
|
|
|
Returns:
|
|
Dictionary of PIL Images and scores for each layer emphasis type
|
|
"""
|
|
emphasis_types = [
|
|
{"name": "low", "color_boost": 1.3, "tv_weight": 2e-4},
|
|
{"name": "mid", "color_boost": 1.5, "tv_weight": 5e-4},
|
|
{"name": "high", "color_boost": 1.7, "tv_weight": 8e-4},
|
|
]
|
|
|
|
results = {}
|
|
|
|
for emphasis in emphasis_types:
|
|
print(f"\n=== Trying {emphasis['name']} layer emphasis ===")
|
|
|
|
image, score = generate_essence_with_emphasis(
|
|
model=model,
|
|
tag_idx=tag_idx,
|
|
tag_name=tag_name,
|
|
image_size=image_size,
|
|
iterations=iterations,
|
|
scales=scales,
|
|
progress_callback=progress_callback,
|
|
layer_emphasis=emphasis["name"],
|
|
color_boost=emphasis["color_boost"],
|
|
tv_weight=emphasis["tv_weight"]
|
|
)
|
|
|
|
results[emphasis["name"]] = {
|
|
"image": image,
|
|
"score": score
|
|
}
|
|
|
|
print(f"=== Completed {emphasis['name']} layer emphasis with score {score:.4f} ===")
|
|
|
|
return results |