File size: 104,370 Bytes
29b445b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106 1107 1108 1109 1110 1111 1112 1113 1114 1115 1116 1117 1118 1119 1120 1121 1122 1123 1124 1125 1126 1127 1128 1129 1130 1131 1132 1133 1134 1135 1136 1137 1138 1139 1140 1141 1142 1143 1144 1145 1146 1147 1148 1149 1150 1151 1152 1153 1154 1155 1156 1157 1158 1159 1160 1161 1162 1163 1164 1165 1166 1167 1168 1169 1170 1171 1172 1173 1174 1175 1176 1177 1178 1179 1180 1181 1182 1183 1184 1185 1186 1187 1188 1189 1190 1191 1192 1193 1194 1195 1196 1197 1198 1199 1200 1201 1202 1203 1204 1205 1206 1207 1208 1209 1210 1211 1212 1213 1214 1215 1216 1217 1218 1219 1220 1221 1222 1223 1224 1225 1226 1227 1228 1229 1230 1231 1232 1233 1234 1235 1236 1237 1238 1239 1240 1241 1242 1243 1244 1245 1246 1247 1248 1249 1250 1251 1252 1253 1254 1255 1256 1257 1258 1259 1260 1261 1262 1263 1264 1265 1266 1267 1268 1269 1270 1271 1272 1273 1274 1275 1276 1277 1278 1279 1280 1281 1282 1283 1284 1285 1286 1287 1288 1289 1290 1291 1292 1293 1294 1295 1296 1297 1298 1299 1300 1301 1302 1303 1304 1305 1306 1307 1308 1309 1310 1311 1312 1313 1314 1315 1316 1317 1318 1319 1320 1321 1322 1323 1324 1325 1326 1327 1328 1329 1330 1331 1332 1333 1334 1335 1336 1337 1338 1339 1340 1341 1342 1343 1344 1345 1346 1347 1348 1349 1350 1351 1352 1353 1354 1355 1356 1357 1358 1359 1360 1361 1362 1363 1364 1365 1366 1367 1368 1369 1370 1371 1372 1373 1374 1375 1376 1377 1378 1379 1380 1381 1382 1383 1384 1385 1386 1387 1388 1389 1390 1391 1392 1393 1394 1395 1396 1397 1398 1399 1400 1401 1402 1403 1404 1405 1406 1407 1408 1409 1410 1411 1412 1413 1414 1415 1416 1417 1418 1419 1420 1421 1422 1423 1424 1425 1426 1427 1428 1429 1430 1431 1432 1433 1434 1435 1436 1437 1438 1439 1440 1441 1442 1443 1444 1445 1446 1447 1448 1449 1450 1451 1452 1453 1454 1455 1456 1457 1458 1459 1460 1461 1462 1463 1464 1465 1466 1467 1468 1469 1470 1471 1472 1473 1474 1475 1476 1477 1478 1479 1480 1481 1482 1483 1484 1485 1486 1487 1488 1489 1490 1491 1492 1493 1494 1495 1496 1497 1498 1499 1500 1501 1502 1503 1504 1505 1506 1507 1508 1509 1510 1511 1512 1513 1514 1515 1516 1517 1518 1519 1520 1521 1522 1523 1524 1525 1526 1527 1528 1529 1530 1531 1532 1533 1534 1535 1536 1537 1538 1539 1540 1541 1542 1543 1544 1545 1546 1547 1548 1549 1550 1551 1552 1553 1554 1555 1556 1557 1558 1559 1560 1561 1562 1563 1564 1565 1566 1567 1568 1569 1570 1571 1572 1573 1574 1575 1576 1577 1578 1579 1580 1581 1582 1583 1584 1585 1586 1587 1588 1589 1590 1591 1592 1593 1594 1595 1596 1597 1598 1599 1600 1601 1602 1603 1604 1605 1606 1607 1608 1609 1610 1611 1612 1613 1614 1615 1616 1617 1618 1619 1620 1621 1622 1623 1624 1625 1626 1627 1628 1629 1630 1631 1632 1633 1634 1635 1636 1637 1638 1639 1640 1641 1642 1643 1644 1645 1646 1647 1648 1649 1650 1651 1652 1653 1654 1655 1656 1657 1658 1659 1660 1661 1662 1663 1664 1665 1666 1667 1668 1669 1670 1671 1672 1673 1674 1675 1676 1677 1678 1679 1680 1681 1682 1683 1684 1685 1686 1687 1688 1689 1690 1691 1692 1693 1694 1695 1696 1697 1698 1699 1700 1701 1702 1703 1704 1705 1706 1707 1708 1709 1710 1711 1712 1713 1714 1715 1716 1717 1718 1719 1720 1721 1722 1723 1724 1725 1726 1727 1728 1729 1730 1731 1732 1733 1734 1735 1736 1737 1738 1739 1740 1741 1742 1743 1744 1745 1746 1747 1748 1749 1750 1751 1752 1753 1754 1755 1756 1757 1758 1759 1760 1761 1762 1763 1764 1765 1766 1767 1768 1769 1770 1771 1772 1773 1774 1775 1776 1777 1778 1779 1780 1781 1782 1783 1784 1785 1786 1787 1788 1789 1790 1791 1792 1793 1794 1795 1796 1797 1798 1799 1800 1801 1802 1803 1804 1805 1806 1807 1808 1809 1810 1811 1812 1813 1814 1815 1816 1817 1818 1819 1820 1821 1822 1823 1824 1825 1826 1827 1828 1829 1830 1831 1832 1833 1834 1835 1836 1837 1838 1839 1840 1841 1842 1843 1844 1845 1846 1847 1848 1849 1850 1851 1852 1853 1854 1855 1856 1857 1858 1859 1860 1861 1862 1863 1864 1865 1866 1867 1868 1869 1870 1871 1872 1873 1874 1875 1876 1877 1878 1879 1880 1881 1882 1883 1884 1885 1886 1887 1888 1889 1890 1891 1892 1893 1894 1895 1896 1897 1898 1899 1900 1901 1902 1903 1904 1905 1906 1907 1908 1909 1910 1911 1912 1913 1914 1915 1916 1917 1918 1919 1920 1921 1922 1923 1924 1925 1926 1927 1928 1929 1930 1931 1932 1933 1934 1935 1936 1937 1938 1939 1940 1941 1942 1943 1944 1945 1946 1947 1948 1949 1950 1951 1952 1953 1954 1955 1956 1957 1958 1959 1960 1961 1962 1963 1964 1965 1966 1967 1968 1969 1970 1971 1972 1973 1974 1975 1976 1977 1978 1979 1980 1981 1982 1983 1984 1985 1986 1987 1988 1989 1990 1991 1992 1993 1994 1995 1996 1997 1998 1999 2000 2001 2002 2003 2004 2005 2006 2007 2008 2009 2010 2011 2012 2013 2014 2015 2016 2017 2018 2019 2020 2021 2022 2023 2024 2025 2026 2027 2028 2029 2030 2031 2032 2033 2034 2035 2036 2037 2038 2039 2040 2041 2042 2043 2044 2045 2046 2047 2048 2049 2050 2051 2052 2053 2054 2055 2056 2057 2058 2059 2060 2061 2062 2063 2064 2065 2066 2067 2068 2069 2070 2071 2072 2073 2074 2075 2076 2077 2078 2079 2080 2081 2082 2083 2084 2085 2086 2087 2088 2089 2090 2091 2092 2093 2094 2095 2096 2097 2098 2099 2100 2101 2102 2103 2104 2105 2106 2107 2108 2109 2110 2111 2112 2113 2114 2115 2116 2117 2118 2119 2120 2121 2122 2123 2124 2125 2126 2127 2128 2129 2130 2131 2132 2133 2134 2135 2136 2137 2138 2139 2140 2141 2142 2143 2144 2145 2146 2147 2148 2149 2150 2151 2152 2153 2154 2155 2156 2157 2158 2159 2160 2161 2162 2163 2164 2165 2166 2167 2168 2169 2170 2171 2172 2173 2174 2175 2176 2177 2178 2179 2180 2181 2182 2183 2184 2185 2186 2187 2188 2189 2190 2191 2192 2193 2194 2195 2196 2197 2198 2199 2200 2201 2202 2203 2204 2205 2206 2207 2208 2209 2210 2211 2212 2213 2214 2215 2216 2217 2218 2219 2220 2221 2222 2223 2224 2225 2226 2227 2228 2229 2230 2231 2232 |
#!/usr/bin/env python3
"""
Enhanced Essence Generator for Tag Collector Game
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.transforms.functional import to_pil_image
from PIL import Image
import numpy as np
import os
import re
import math
import json
import streamlit as st
from tqdm import tqdm
from scipy.ndimage import gaussian_filter
from functools import wraps
import time
import tag_storage # Import for saving game state
from game_constants import RARITY_LEVELS, ENKEPHALIN_CURRENCY_NAME, ENKEPHALIN_ICON
from tag_categories import TAG_CATEGORIES
# Define essence quality levels with thresholds and styles
ESSENCE_QUALITY_LEVELS = {
"ZAYIN": {"threshold": 0.0, "color": "#1CFC00", "description": "Basic representation with minimal details."},
"TETH": {"threshold": 3.0, "color": "#389DDF", "description": "Clear representation with recognizable features."},
"HE": {"threshold": 5.0, "color": "#FEF900", "description": "Refined representation with distinctive elements."},
"WAW": {"threshold": 10.0, "color": "#7930F1", "description": "Advanced representation with precise details."},
"ALEPH": {"threshold": 12.0, "color": "#FF0000", "description": "Perfect representation with extraordinary precision."}
}
# Essence generation costs in enkephalin based on tag rarity
ESSENCE_COSTS = {
"Special": 0,
"Canard": 100, # Common tags
"Urban Myth": 125, # Uncommon tags
"Urban Legend": 150, # Rare tags
"Urban Plague": 200, # Very rare tags
"Urban Nightmare": 250, # Extremely rare tags
"Star of the City": 300, # Nearly mythical tags
"Impuritas Civitas": 400 # Legendary tags
}
# Default essence generation settings
DEFAULT_ESSENCE_SETTINGS = {
"scales": 1, # Number of scales for multiscale optimization
"iterations": 256, # Iterations per scale
"image_size": 512, # Always use 512x512 resolution
"lr": 0.1, # Learning rate
"layer_emphasis": "auto" # Default to auto-detection
}
def initialize_essence_settings():
"""Initialize essence generator settings if not already present"""
if 'essence_custom_settings' not in st.session_state:
# Try to load from storage first
loaded_state = tag_storage.load_essence_state()
if loaded_state and 'essence_custom_settings' in loaded_state:
st.session_state.essence_custom_settings = loaded_state['essence_custom_settings']
else:
st.session_state.essence_custom_settings = DEFAULT_ESSENCE_SETTINGS.copy()
# Replace initialize_manual_tags with:
def initialize_manual_tags():
"""Initialize manual tags if not already present"""
if 'manual_tags' not in st.session_state:
# Try to load from storage first
loaded_state = tag_storage.load_essence_state()
if loaded_state and 'manual_tags' in loaded_state:
st.session_state.manual_tags = loaded_state['manual_tags']
else:
st.session_state.manual_tags = {
"hatsune_miku": {"rarity": "Special", "description": "Popular virtual singer with long teal twin-tails"},
}
def timeout(seconds, fallback_value=None):
"""
Simple timeout utility for functions.
Warns if a function takes longer than expected but doesn't interrupt it.
Args:
seconds: Expected maximum seconds the function should take
fallback_value: Not used, just for API compatibility
"""
def decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
start_time = time.time()
result = func(*args, **kwargs)
elapsed = time.time() - start_time
if elapsed > seconds:
print(f"WARNING: Function {func.__name__} took {elapsed:.2f} seconds (expected max {seconds}s)")
return result
return wrapper
return decorator
# Core Classes for Essence Generation
class LayerHook:
"""Helper class to store the outputs of a layer via forward hook."""
def __init__(self, layer):
self.layer = layer
self.features = None
self.hook = layer.register_forward_hook(self.hook_fn)
def hook_fn(self, module, input, output):
self.features = output
def close(self):
self.hook.remove()
class FullModelHook:
"""Hook all layers in a model and track their responses to inputs."""
def __init__(self, model):
self.model = model
self.hooks = {}
self.activations = {}
self.layer_scores = {}
# Recursively register hooks for all eligible layers
self._register_hooks(model)
print(f"FullModelHook initialized with {len(self.hooks)} hooks")
def _register_hooks(self, module, prefix=''):
"""Recursively register hooks on all suitable layers."""
for name, child in module.named_children():
layer_name = f"{prefix}.{name}" if prefix else name
# Only hook layers that produce activations
# Avoid hooking containers like Sequential
if isinstance(child, (torch.nn.Conv2d, torch.nn.Linear, torch.nn.BatchNorm2d, torch.nn.LayerNorm)):
self.hooks[layer_name] = child.register_forward_hook(
lambda m, inp, out, layer=layer_name: self._hook_fn(layer, out)
)
# Recurse into children
self._register_hooks(child, layer_name)
def _hook_fn(self, layer_name, output):
"""Store activations for each layer."""
# For convolutional layers, compute channel-wise mean activations
if len(output.shape) == 4: # [batch, channels, height, width]
# Store mean activation per channel
self.activations[layer_name] = output.mean(dim=[2, 3]).detach()
else:
# For other layers, store as is
self.activations[layer_name] = output.detach()
class EssenceGenerator:
"""
Enhanced Essence Generator optimized for anime characters.
Includes improvements for more vibrant colors and recognizable features.
"""
def __init__(
self,
model,
tag_to_name=None,
iterations=256,
scales=3,
learning_rate=0.03, # Lower learning rate for better convergence
decay_power=1.5, # Stronger emphasis on low frequencies
tv_weight=5e-4, # Stronger total variation for clearer structures
layers_to_hook=None,
layer_weights=None,
color_boost=1.5 # Color boosting factor
):
"""Initialize the Enhanced Essence Generator"""
self.model = model
self.tag_to_name = tag_to_name
self.iterations = iterations
self.scales = scales
self.lr = learning_rate
self.decay_power = decay_power
self.tv_weight = tv_weight
self.layers_to_hook = layers_to_hook
self.layer_weights = layer_weights
self.color_boost = color_boost
# Set device
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model.eval().to(self.device)
# Initialize hooks
self.hooks = {}
# Enhanced color correlation matrix for anime-style colors
# More saturated colors with stronger correlations
self.color_correlation_matrix = torch.tensor([
[1.0000, 0.9522, 0.9156],
[0.9522, 1.0000, 0.9708],
[0.9156, 0.9708, 1.0000]], device=self.device)
# Setup hooks if specified
if self.layers_to_hook:
self.setup_hooks(self.layers_to_hook)
def setup_hooks(self, layers_to_hook):
"""Setup hooks on the specified layers."""
# Close any existing hooks
self.close_hooks()
# Create new hooks
for layer_name in layers_to_hook:
try:
# Try to get layer by navigating the model hierarchy
parts = layer_name.split('.')
layer = self.model
for part in parts:
layer = getattr(layer, part)
self.hooks[layer_name] = LayerHook(layer)
print(f"Setup hook for layer: {layer_name}")
except Exception as e:
print(f"Failed to setup hook for {layer_name}: {e}")
def setup_auto_hooks(self, tag_idx):
"""
Automatically detect the most responsive layers for a specific tag.
This simplified version selects a few key layers based on model architecture.
"""
# Close any existing hooks
self.close_hooks()
# If we already have layer weights from initialization, use those
if self.layers_to_hook and self.layer_weights:
for layer_name in self.layers_to_hook:
try:
# Try to get layer by navigating the model hierarchy
parts = layer_name.split('.')
layer = self.model
for part in parts:
layer = getattr(layer, part)
self.hooks[layer_name] = LayerHook(layer)
print(f"Setup hook for layer: {layer_name}")
except Exception as e:
print(f"Failed to setup hook for {layer_name}: {e}")
return self.layer_weights
# Otherwise, detect layers automatically
# Get all named modules
all_layers = []
for name, module in self.model.named_modules():
if not name: # Skip empty name (the model itself)
continue
# Only consider certain layer types that typically have meaningful features
if isinstance(module, (nn.Conv2d, nn.Linear, nn.BatchNorm2d)):
all_layers.append((name, module))
# If the model is too large, select strategic layers
selected_layers = []
layer_weights = {}
if len(all_layers) > 30:
# For large models, select a subset of layers
# 1. Try to find classifier/final layer
classifier_layers = [(name, module) for name, module in all_layers
if any(x in name.lower() for x in ["classifier", "fc", "linear", "output", "logits"])]
if classifier_layers:
selected_layers.append(classifier_layers[-1])
layer_weights[classifier_layers[-1][0]] = 1.0 # Highest weight
# 2. Find some mid to late convolutional layers
conv_layers = [(name, module) for name, module in all_layers if isinstance(module, nn.Conv2d)]
if conv_layers:
# Take some layers from the second half
half_idx = len(conv_layers) // 2
selected_idx = [half_idx, 3*len(conv_layers)//4, -1] # middle, 3/4, and last
for idx in selected_idx:
if idx < len(conv_layers) and conv_layers[idx] not in selected_layers:
selected_layers.append(conv_layers[idx])
# Later layers get higher weights
pos = selected_idx.index(idx)
layer_weights[conv_layers[idx][0]] = 0.5 + 0.5 * (pos / max(1, len(selected_idx) - 1))
else:
# For smaller models, use more layers
# Take a sample across the network depth
step = max(1, len(all_layers) // 5)
indices = list(range(0, len(all_layers), step))
if len(all_layers) - 1 not in indices:
indices.append(len(all_layers) - 1) # Always include the last layer
for idx in indices:
selected_layers.append(all_layers[idx])
# Later layers get higher weights
layer_weights[all_layers[idx][0]] = 0.5 + 0.5 * (idx / max(1, len(all_layers) - 1))
# Create hooks for selected layers
print(f"Setting up {len(selected_layers)} auto-detected layers:")
for name, module in selected_layers:
self.hooks[name] = LayerHook(module)
print(f" - {name} (weight: {layer_weights.get(name, 0.5):.2f})")
return layer_weights
def close_hooks(self):
"""Clean up hooks to avoid memory leaks."""
for hook in self.hooks.values():
hook.close()
self.hooks.clear()
def total_variation_loss(self, img):
"""
Total variation loss for smoother images but preserving edges.
Modified version that better preserves strong edges.
"""
diff_y = torch.abs(img[:, :, 1:, :] - img[:, :, :-1, :])
diff_x = torch.abs(img[:, :, :, 1:] - img[:, :, :, :-1])
# Use a more gentle version of total variation that allows for edges
# but still penalizes noise (using square root reduces the penalty on large differences)
tv = torch.mean(torch.sqrt(diff_y + 1e-8)) + torch.mean(torch.sqrt(diff_x + 1e-8))
return tv
def create_fft_spectrum_initializer(self, size, batch_size=1):
"""Enhanced frequency domain initialization for better essence generation with more color and coverage"""
fft_size = size // 2 + 1
# Initialize frequency components with a natural image prior
# This biases toward more natural-looking essences
spectrum_scale = torch.zeros(batch_size, 3, size, fft_size, 2, device=self.device)
# Use 1/f spectrum characteristic of natural images (pink noise)
for h in range(size):
for w in range(fft_size):
# Calculate distance from DC component
dist = np.sqrt((h/size)**2 + (w/fft_size)**2) + 1e-5
# Pink noise falls off as 1/f, but with higher amplitude for better initial coverage
weight = 1.0 / dist
# Add random phase but weighted amplitude - increased amplitude for better initial values
spectrum_scale[:, :, h, w, 0] = torch.randn(batch_size, 3, device=self.device) * weight * 0.15 # Increased from 0.05
spectrum_scale[:, :, h, w, 1] = torch.randn(batch_size, 3, device=self.device) * weight * 0.15 # Increased from 0.05
# Initialize DC component (average color) with higher values for better color saturation
# Give distinct colors to each channel for more vibrant initialization
spectrum_scale[:, 0, 0, 0, 0] = 0.5 # Red channel
spectrum_scale[:, 1, 0, 0, 0] = 0.4 # Green channel
spectrum_scale[:, 2, 0, 0, 0] = 0.6 # Blue channel
spectrum_scale[:, :, 0, 0, 1] = 0
spectrum_scale.requires_grad_(True)
# Phase component - increased for more variation
spectrum_shift = torch.randn(batch_size, 3, size, fft_size, 2, device=self.device) * 0.05 # Increased from 0.02
spectrum_shift.requires_grad_(True)
return spectrum_scale, spectrum_shift
def create_spectrum_weights(self, size, decay_power=1.0):
"""Create weights for the spectrum that emphasize lower frequencies but preserve more details."""
freqs_x = torch.fft.rfftfreq(size).view(1, -1).to(self.device)
freqs_y = torch.fft.fftfreq(size).view(-1, 1).to(self.device)
dist_from_center = torch.sqrt(freqs_x**2 + freqs_y**2)
# Modified weight calculation that allows for more mid-frequency details
# and preserves more high frequencies for better detail
weights = 1.0 / (dist_from_center + 1e-8) ** decay_power
# Significantly increase mid-frequency weights for more details and coverage
mid_freq_mask = (dist_from_center > 0.05) & (dist_from_center < 0.3)
weights = weights * (1.0 + 1.0 * mid_freq_mask.float()) # Increased from 0.5
# Add some weight to high frequencies for texture details
high_freq_mask = (dist_from_center >= 0.3) & (dist_from_center < 0.7)
weights = weights * (1.0 + 0.3 * high_freq_mask.float()) # New addition
weights = weights / weights.max()
weights[0, 0] = 0.8 # Higher DC component for better color coherence (increased from 0.7)
return weights
def fft_to_rgb(self, spectrum_scale, spectrum_shift, size, spectrum_weight=None):
"""Convert FFT spectrum parameters to an RGB image."""
batch_size = spectrum_scale.shape[0]
if spectrum_weight is not None:
spectrum_scale = spectrum_scale * spectrum_weight.unsqueeze(0).unsqueeze(0).unsqueeze(-1)
image = torch.zeros(batch_size, 3, size, size, device=self.device)
spectrum_complex = torch.complex(
spectrum_scale[..., 0],
spectrum_scale[..., 1]
)
phase_shift = torch.complex(
torch.cos(spectrum_shift[..., 0]),
torch.sin(spectrum_shift[..., 1])
)
spectrum_complex = spectrum_complex * phase_shift
for b in range(batch_size):
for c in range(3):
channel_spectrum = spectrum_complex[b, c]
channel_image = torch.fft.irfft2(channel_spectrum, s=(size, size))
channel_min = channel_image.min()
channel_max = channel_image.max()
if channel_max > channel_min:
channel_image = (channel_image - channel_min) / (channel_max - channel_min)
else:
channel_image = torch.zeros_like(channel_image)
image[b, c] = channel_image
return image
def apply_color_correlation(self, image):
"""
Apply color correlation to produce more vibrant, colorful images.
"""
batch_size, _, height, width = image.shape
# 1. Apply basic color correlation
flat_image = image.view(batch_size, 3, -1)
correlated = torch.matmul(self.color_correlation_matrix, flat_image)
correlated_image = correlated.view(batch_size, 3, height, width)
# 2. Apply stronger color boost (increase saturation)
# Calculate luminance (0.3R + 0.59G + 0.11B)
luminance = 0.3 * image[:, 0:1] + 0.59 * image[:, 1:2] + 0.11 * image[:, 2:3]
# Boost colors away from luminance (enhances saturation)
# Increase the boost factor for more vibrant colors
boosted_image = luminance + (correlated_image - luminance) * self.color_boost * 1.5 # Apply additional 1.5x boost
# 3. Apply a gentle S-curve for better contrast
# This helps make the colors "pop" more
boosted_image = 0.5 + torch.tanh((boosted_image - 0.5) * 2) * 0.5
# Ensure values are in [0, 1] range
boosted_image = torch.clamp(boosted_image, 0, 1)
return boosted_image
def apply_transforms(self, img):
"""Apply random transformations to the image for robustness."""
batch_size, c, h, w = img.shape
# 1. Padding
pad = 16
padded = F.pad(img, (pad, pad, pad, pad), mode='reflect')
# 2. Random jitter
jitter = 16
h_jitter = torch.randint(-jitter, jitter + 1, (batch_size,), device=self.device)
w_jitter = torch.randint(-jitter, jitter + 1, (batch_size,), device=self.device)
# Create sampling grid
rows = torch.arange(h, device=self.device).view(1, 1, -1, 1).repeat(batch_size, 1, 1, w)
cols = torch.arange(w, device=self.device).view(1, 1, 1, -1).repeat(batch_size, 1, h, 1)
rows = rows + h_jitter.view(-1, 1, 1, 1) + pad
cols = cols + w_jitter.view(-1, 1, 1, 1) + pad
# Get transformed image (simplified implementation)
grid_h = torch.clamp(rows, 0, padded.shape[2] - 1).long()
grid_w = torch.clamp(cols, 0, padded.shape[3] - 1).long()
# Apply second jitter for more randomness
jitter2 = 8
h_jitter2 = torch.randint(-jitter2, jitter2 + 1, (batch_size,), device=self.device)
w_jitter2 = torch.randint(-jitter2, jitter2 + 1, (batch_size,), device=self.device)
grid_h = torch.clamp(grid_h + h_jitter2.view(-1, 1, 1, 1), 0, padded.shape[2] - 1).long()
grid_w = torch.clamp(grid_w + w_jitter2.view(-1, 1, 1, 1), 0, padded.shape[3] - 1).long()
# Gather values
transformed = torch.zeros_like(img)
for b in range(batch_size):
transformed[b] = padded[b, :, grid_h[b, 0], grid_w[b, 0]]
return transformed
def add_spatial_prior(self, img, strength=0.05):
"""
Add a spatial prior to encourage character-like structures with better composition
and fuller image coverage.
"""
batch_size, c, h, w = img.shape
# Create normalized coordinate grids
y_indices = torch.arange(h, device=self.device).float()
x_indices = torch.arange(w, device=self.device).float()
y = (2.0 * y_indices / h) - 1.0 # Normalize to [-1, 1]
x = (2.0 * x_indices / w) - 1.0 # Normalize to [-1, 1]
# Expand to 2D grid
y_grid = y.view(-1, 1).repeat(1, w)
x_grid = x.view(1, -1).repeat(h, 1)
# Center bias (gentler in the middle to allow more coverage)
center_dist = torch.sqrt(x_grid.pow(2) + y_grid.pow(2))
# Wider center bias for better coverage
center_value = torch.exp(-0.8 * center_dist) # Reduced from -1.5 for wider coverage
# Full-image utilization bias (higher values further from edge)
edge_dist_x = torch.min(torch.abs(x_grid - 1.0), torch.abs(x_grid + 1.0))
edge_dist_y = torch.min(torch.abs(y_grid - 1.0), torch.abs(y_grid + 1.0))
edge_dist = torch.min(edge_dist_x, edge_dist_y)
edge_value = torch.clamp(edge_dist * 5.0, 0.2, 1.0) # Higher values away from edges
# Rule of thirds with wider peaks (subtle enhancement at thirds points)
thirds_x = torch.exp(-10 * (x_grid - 1/3).pow(2)) + torch.exp(-10 * (x_grid + 1/3).pow(2)) # Reduced from -30
thirds_y = torch.exp(-10 * (y_grid - 1/3).pow(2)) + torch.exp(-10 * (y_grid + 1/3).pow(2)) # Reduced from -30
thirds_value = (thirds_x + thirds_y) / 2
# Combine the different priors with rebalanced weights to favor coverage
prior = 0.4 * center_value + 0.4 * edge_value + 0.2 * thirds_value # More weight on edge_value for better coverage
# Normalize the prior
prior = prior / prior.max()
# Expand to match the input dimensions
prior = prior.unsqueeze(0).unsqueeze(0)
prior = prior.repeat(batch_size, c, 1, 1)
# Apply the prior with increased strength
result = img * (1.0 - strength*1.5) + prior * strength*1.5 # Increased strength by 50%
return result
def get_layer_activations(self, tag_idx, layer_weights):
"""
Get activations from all hooked layers for the target tag.
Returns a weighted sum of activations based on layer weights.
"""
activation_sum = 0.0
for layer_name, hook in self.hooks.items():
if hook.features is None:
continue
# Get weight for this layer
weight = layer_weights.get(layer_name, 0.5)
# Handle different layer types
if len(hook.features.shape) <= 2:
# For fully connected layers, focus on the target class logit
if hook.features.size(1) > tag_idx:
activation = hook.features[0, tag_idx].item()
activation_sum += weight * activation
else:
# For convolutional layers, focus on overall activation strength
channel_means = hook.features.mean(dim=[2, 3])
# Make sure we don't request more channels than exist
num_channels = min(5, channel_means.size(1))
_, top_indices = torch.topk(channel_means, num_channels)
# Process each top channel individually
for idx in range(min(3, len(top_indices[0]))): # Use up to 3 top channels
channel_idx = top_indices[0, idx]
channel_activation = hook.features[:, channel_idx].mean().item()
# Weight decreases for less important channels
channel_weight = 1.0 if idx == 0 else (0.5 if idx == 1 else 0.25)
activation_sum += weight * channel_activation * channel_weight
return activation_sum
def generate_essence(self, tag_idx, image_size=512, return_score=True, progress_callback=None):
"""Generate an essence visualization using enhanced techniques."""
# Get tag name for logging (if available)
tag_name = self.tag_to_name.get(tag_idx, f"Tag {tag_idx}") if self.tag_to_name else f"Tag {tag_idx}"
print(f"Generating enhanced essence for '{tag_name}'...")
# Auto-detect and set up hooks for responsive layers if not already set up
layer_weights = self.layer_weights or self.setup_auto_hooks(tag_idx)
# Determine scale sizes
scale_sizes = []
for s in range(self.scales):
# Start small and progressively increase size
scale_size = max(32, image_size // (2 ** (self.scales - s - 1)))
scale_sizes.append(scale_size)
print(f"Processing scales: {scale_sizes}")
# Create frequency spectrum weights
spectrum_weights = {}
for size in scale_sizes:
spectrum_weights[size] = self.create_spectrum_weights(size, decay_power=self.decay_power)
# Track best result
best_score = -float('inf')
best_img = None
# Process each scale independently
for scale_idx, size in enumerate(scale_sizes):
# Initialize parameters for this scale
spectrum_scale, spectrum_shift = self.create_fft_spectrum_initializer(size)
# Create optimizer
optimizer = torch.optim.Adam([spectrum_scale, spectrum_shift], lr=self.lr)
# Use learning rate scheduler for better convergence
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer,
T_max=self.iterations,
eta_min=self.lr * 0.1
)
# Current scale's spectrum weights
current_weights = spectrum_weights[size]
# Iterations for this scale
iterations = self.iterations
# Epoch tracking for early stopping
no_improvement_streak = 0
plateau_threshold = 128 # Increased from 64 - give more iterations before stopping
scale_best_score = -float('inf')
scale_best_img = None
for i in range(iterations):
# Clear gradients
optimizer.zero_grad()
# Convert FFT parameters to RGB image
img = self.fft_to_rgb(spectrum_scale, spectrum_shift, size, current_weights)
# # Apply color correlation with boosted colors for anime-style
# img = self.apply_color_correlation(img)
# Add spatial prior to encourage character-like patterns with better coverage
if size >= 32:
# Apply stronger spatial prior for better coverage
img = self.add_spatial_prior(img, strength=0.25) # Increased from 0.15
# Apply transformations for robustness
if size >= 32:
img = self.apply_transforms(img)
# Reset hooks
for hook in self.hooks.values():
hook.features = None
# Forward pass
outputs = self.model(img)
# Get target tag activation in final layer
if isinstance(outputs, (list, tuple)):
predictions = outputs[0]
else:
predictions = outputs
# Get tag activation from final layer
tag_activation = predictions[0, tag_idx]
# Get activations from earlier layers
layer_activation = self.get_layer_activations(tag_idx, layer_weights)
# Combined loss to maximize activations
# Increase weight on layer_activation for better feature emphasis
activation_loss = -(tag_activation + 1.5 + layer_activation * 2.0) # Double layer activation weight
# Add regularization term - total variation for smoothness
# Reduce TV weight slightly to allow more details and coverage
tv_loss = self.total_variation_loss(img) * (self.tv_weight * 0.7) # Reduce to 70% of original
# Total loss
total_loss = activation_loss + tv_loss
# Backpropagation
total_loss.backward()
# Update parameters
optimizer.step()
scheduler.step()
# Track best result for this scale
current_score = tag_activation.item()
if current_score > scale_best_score + 1e-4:
scale_best_score = current_score
scale_best_img = img.detach().clone()
no_improvement_streak = 0
else:
no_improvement_streak += 1
# More patient early stopping
if no_improvement_streak >= plateau_threshold:
print(f"Early stopping at iteration {i}/{iterations} due to plateau")
break
# Report progress
if progress_callback and i % max(1, iterations // 10) == 0:
progress_callback(
scale_idx=scale_idx,
scale_count=len(scale_sizes),
iter_idx=i,
iter_count=iterations,
score=current_score
)
print(f"Scale {scale_idx+1}/{len(scale_sizes)} completed. Score: {scale_best_score:.4f}")
# Update overall best if this scale improved the score
if scale_best_score > best_score:
best_score = scale_best_score
# For the final scale, keep the full resolution image
if scale_idx == len(scale_sizes) - 1:
best_img = scale_best_img
# Otherwise, upscale to the final size for the return value
else:
with torch.no_grad():
best_img = F.interpolate(scale_best_img, size=(image_size, image_size),
mode='bilinear', align_corners=False)
# In case all scales failed, create an empty image
if best_img is None:
final_img = torch.zeros((1, 3, image_size, image_size), device=self.device)
else:
final_img = best_img
# Convert to PIL image
pil_img = to_pil_image(final_img[0].cpu())
# Clean up hooks
self.close_hooks()
if return_score:
return pil_img, best_score
else:
return pil_img
# Utility Functions for Model Analysis and Layer Selection
def get_model_layers(model):
"""Utility function to get all available layers in a model."""
layers = []
for name, _ in model.named_modules():
if name: # Skip empty name (the model itself)
layers.append(name)
return layers
def get_key_layers(model, max_layers=15):
"""
Get a curated list of the most relevant layers for visualization.
"""
all_layers = get_model_layers(model)
# For models with hundreds of layers, we need to be selective
if len(all_layers) > 30:
# Extract patterns to identify layer types
block_patterns = {}
# Find common patterns in layer names
for layer in all_layers:
# Extract the main component (e.g., "backbone.features")
parts = layer.split(".")
if len(parts) >= 2:
prefix = ".".join(parts[:2])
if prefix not in block_patterns:
block_patterns[prefix] = []
block_patterns[prefix].append(layer)
# Now select representative layers from each major block
key_layers = {
"early": [],
"middle": [],
"late": []
}
# For each major block, select layers at strategic positions
for prefix, layers in block_patterns.items():
if len(layers) > 3: # Only process significant blocks
# Sort by natural depth (assuming numerical components indicate depth)
layers.sort(key=lambda x: [int(s) if s.isdigit() else s for s in re.findall(r'\d+|\D+', x)])
# Get layers at strategic positions
early = layers[0]
middle = layers[len(layers) // 2]
late = layers[-1]
key_layers["early"].append(early)
key_layers["middle"].append(middle)
key_layers["late"].append(late)
# Ensure we don't have too many layers
# If we need to reduce further, prioritize middle and late layers
flattened = []
for _, group_layers in key_layers.items():
flattened.extend(group_layers)
if len(flattened) > max_layers:
# Calculate how many to keep from each group
total = len(flattened)
# Prioritize keeping late layers (for character recognition)
late_count = min(len(key_layers["late"]), max_layers // 3)
# Allocate remaining slots between early and middle
remaining = max_layers - late_count
middle_count = min(len(key_layers["middle"]), remaining // 2)
early_count = min(len(key_layers["early"]), remaining - middle_count)
# Take only the needed number from each category
key_layers["early"] = key_layers["early"][:early_count]
key_layers["middle"] = key_layers["middle"][:middle_count]
key_layers["late"] = key_layers["late"][:late_count]
else:
# For simpler models, use standard distribution
n = len(all_layers)
key_layers = {
"early": all_layers[:n//3][:3], # First few layers
"middle": all_layers[n//3:2*n//3][:4], # Middle layers
"late": all_layers[2*n//3:][:3] # Last few layers
}
# Try to identify the classifier/final layer
classifier_layers = [layer for layer in all_layers if any(x in layer.lower()
for x in ["classifier", "fc", "linear", "output", "logits", "head"])]
if classifier_layers:
key_layers["classifier"] = [classifier_layers[-1]]
return key_layers
def get_suggested_layers(model, layer_type="balanced"):
"""
Get suggested layers based on the desired feature type.
"""
key_layers = get_key_layers(model)
# Flatten all layers for reference
all_key_layers = []
for layers in key_layers.values():
all_key_layers.extend(layers)
# Choose layers based on the requested emphasis
if layer_type == "low":
# Focus on early visual features (textures, patterns, colors)
selected = key_layers.get("early", [])
# Add one middle layer for stability
if "middle" in key_layers and key_layers["middle"]:
selected.append(key_layers["middle"][0])
elif layer_type == "mid":
# Focus on mid-level features (parts, components)
selected = key_layers.get("middle", [])
# Add one early layer for context
if "early" in key_layers and key_layers["early"]:
selected.append(key_layers["early"][-1])
elif layer_type == "high":
# Focus on high-level semantic features (objects, characters)
selected = key_layers.get("late", [])
selected.extend(key_layers.get("classifier", []))
# Add one middle layer for context
if "middle" in key_layers and key_layers["middle"]:
selected.append(key_layers["middle"][-1])
else: # balanced
# Use a mix of early, middle and late layers
selected = []
for category in ["early", "middle", "late", "classifier"]:
if category in key_layers and key_layers[category]:
# Take one from each category
selected.append(key_layers[category][0])
# For middle and late, also take the last one if different
if category in ["middle", "late"] and len(key_layers[category]) > 1:
selected.append(key_layers[category][-1])
# Ensure we have at least one layer
if not selected and all_key_layers:
selected = [all_key_layers[-1]] # Use the last layer as fallback
return selected
def get_quality_level(score):
"""
Determine the quality level of an essence based on its score
"""
for level in reversed(list(ESSENCE_QUALITY_LEVELS.keys())):
if score >= ESSENCE_QUALITY_LEVELS[level]["threshold"]:
return level
return "ZAYIN" # Default to lowest level
def get_essence_cost(rarity):
"""
Calculate the cost to generate an essence image based on tag rarity
"""
return ESSENCE_COSTS.get(rarity, 100) # Default to 100 if rarity unknown
# Game UI and Integration Functions
def initialize_essence_settings():
"""Initialize essence generator settings if not already present"""
if 'essence_custom_settings' not in st.session_state:
st.session_state.essence_custom_settings = DEFAULT_ESSENCE_SETTINGS.copy()
def save_essence_to_game_folder(image, tag, score, quality_level):
"""
Save the generated essence image to a persistent game folder
Args:
image: PIL Image of the essence
tag: The tag name
score: The generation score
quality_level: The quality classification (ZAYIN, TETH, etc.)
Returns:
Path to the saved image
"""
# Create game folder if it doesn't exist
game_folder = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "game_data")
essence_folder = os.path.join(game_folder, "essences")
# Make directories if they don't exist
os.makedirs(essence_folder, exist_ok=True)
# Create filename with quality level and score
safe_tag = tag.replace('/', '_').replace('\\', '_').replace(' ', '_')
timestamp = time.strftime("%Y%m%d_%H%M%S")
filename = f"{safe_tag}_{quality_level}_{score:.2f}_{timestamp}.png"
filepath = os.path.join(essence_folder, filename)
# Save the image
image.save(filepath)
return filepath
def generate_essence_for_tag(tag, model, dataset, custom_settings=None):
"""
Generate an essence image for a specific tag using the improved generator
Args:
tag: The tag name or index
model: The model to use
dataset: The dataset containing tag information
custom_settings: Optional dictionary with custom generation settings
Returns:
PIL Image of the generated essence, score, quality level
"""
print(f"\n=== Starting essence generation for tag '{tag}' ===")
# Check if tag is discovered or a manual tag
is_manual_tag = hasattr(st.session_state, 'manual_tags') and tag in st.session_state.manual_tags
is_discovered = hasattr(st.session_state, 'discovered_tags') and tag in st.session_state.discovered_tags
if not is_discovered and not is_manual_tag:
st.error(f"Tag '{tag}' has not been discovered yet.")
return None, 0, None
# Get tag rarity and calculate cost
if is_discovered:
rarity = st.session_state.discovered_tags[tag].get("rarity", "Canard")
elif is_manual_tag:
rarity = st.session_state.manual_tags[tag].get("rarity", "Canard")
else:
rarity = "Canard"
# Calculate cost based on rarity
cost = get_essence_cost(rarity)
# Check if player has enough Enkephalin
if st.session_state.enkephalin < cost:
st.error(f"Not enough {ENKEPHALIN_CURRENCY_NAME} to generate this essence. You need {cost} {ENKEPHALIN_ICON} but have {st.session_state.enkephalin} {ENKEPHALIN_ICON}.")
return None, 0, None
# Use provided settings or defaults
settings = custom_settings or DEFAULT_ESSENCE_SETTINGS.copy()
print(f"Using settings: {settings}")
# Extract settings
iterations = settings.get("iterations", 256)
scales = settings.get("scales", 5)
layer_emphasis = settings.get("layer_emphasis", "auto")
# UI containers for progress
preview_container = st.empty()
progress_container = st.empty()
message_container = st.empty()
# If multiple layer emphasis types are requested, show tabs
if layer_emphasis == "compare":
message_container.info("Generating essences with different layer emphasis types...")
tabs_container = st.empty()
tabs = None
tab_images = {}
best_score = -float('inf')
best_image = None
best_emphasis = None
try:
# Show generation information
if layer_emphasis != "compare":
message_container.info(f"Generating essence for '{tag}' with {layer_emphasis} layer emphasis...")
# Progress callback function for essence generation
def progress_callback(scale_idx, scale_count, iter_idx, iter_count, score):
# Update progress bar
progress = ((scale_idx * iter_count) + iter_idx) / (scale_count * iter_count)
progress_container.progress(progress, f"Scale {scale_idx+1}/{scale_count}, Iteration {iter_idx}/{iter_count}")
message_container.info(f"Current score: {score:.4f}")
# Print status to console too
if iter_idx % 20 == 0:
print(f"Progress: Scale {scale_idx+1}/{scale_count}, Iteration {iter_idx}/{iter_count}, Score: {score:.4f}")
# Convert tag name to index if needed
tag_idx = None
# Try to find tag in various places
if isinstance(tag, str):
print(f"Converting tag name '{tag}' to index...")
# Standard lookup methods
if hasattr(dataset, 'tag_to_idx') and tag in dataset.tag_to_idx:
tag_idx = dataset.tag_to_idx[tag]
print(f"Found tag index from dataset.tag_to_idx: {tag_idx}")
# Session state metadata lookup
if tag_idx is None and hasattr(st.session_state, 'metadata') and 'tag_to_idx' in st.session_state.metadata:
tag_idx = st.session_state.metadata['tag_to_idx'].get(tag)
if tag_idx is not None:
print(f"Found tag index from session_state.metadata['tag_to_idx']: {tag_idx}")
# Lookup from idx_to_tag
if tag_idx is None and hasattr(st.session_state, 'metadata') and 'idx_to_tag' in st.session_state.metadata:
idx_to_tag = st.session_state.metadata['idx_to_tag']
tag_to_idx = {v: int(k) for k, v in idx_to_tag.items()}
tag_idx = tag_to_idx.get(tag)
if tag_idx is not None:
print(f"Found tag index from inverted idx_to_tag: {tag_idx}")
# Try case-insensitive
if tag_idx is None:
tag_lower = tag.lower()
for t, idx in tag_to_idx.items():
if t.lower() == tag_lower:
tag_idx = idx
print(f"Found tag index using case-insensitive match: {tag_idx}")
break
# For manual tags that aren't in the model's tag list,
# we might need to find a semantically similar tag or use a generic index
if tag_idx is None and is_manual_tag:
# For demonstration, we could map manual tags to known similar tags
manual_tag_mapping = {
"hatsune_miku": "hatsune_miku", # Try to find this in the dataset
"lamp": "lamp", # Try to find this in the dataset
"blue_gloves": "gloves", # Fallback to a more generic tag
}
fallback_tag = manual_tag_mapping.get(tag)
if fallback_tag:
# Try to find the fallback tag
if hasattr(dataset, 'tag_to_idx') and fallback_tag in dataset.tag_to_idx:
tag_idx = dataset.tag_to_idx[fallback_tag]
print(f"Using fallback tag '{fallback_tag}' with index: {tag_idx}")
# Try session state metadata
if tag_idx is None and hasattr(st.session_state, 'metadata') and 'tag_to_idx' in st.session_state.metadata:
tag_idx = st.session_state.metadata['tag_to_idx'].get(fallback_tag)
if tag_idx is not None:
print(f"Using fallback tag '{fallback_tag}' with index: {tag_idx}")
# If still not found, use a generic index (this is a last resort)
if tag_idx is None:
# Try to use a category-specific generic tag
if "hair" in tag.lower():
generic_tag = "blue_hair" # A common tag that might be in the model
elif "gloves" in tag.lower():
generic_tag = "gloves"
elif "miku" in tag.lower():
generic_tag = "twintails" # A feature of Hatsune Miku
else:
generic_tag = "1girl" # A very common tag
# Try to find this generic tag
if hasattr(dataset, 'tag_to_idx') and generic_tag in dataset.tag_to_idx:
tag_idx = dataset.tag_to_idx[generic_tag]
print(f"Using generic tag '{generic_tag}' with index: {tag_idx}")
elif hasattr(st.session_state, 'metadata') and 'tag_to_idx' in st.session_state.metadata:
tag_idx = st.session_state.metadata['tag_to_idx'].get(generic_tag)
if tag_idx is not None:
print(f"Using generic tag '{generic_tag}' with index: {tag_idx}")
# If still not found, show error
if tag_idx is None:
st.error(f"Tag '{tag}' index not found. Cannot generate essence.")
print(f"ERROR: Tag '{tag}' index not found")
return None, 0, None
else:
# Tag is already an index
tag_idx = tag
print(f"Using provided tag index: {tag_idx}")
# Generate the essence - either one or multiple depending on settings
if layer_emphasis == "compare":
# Generate essences with different layer emphasis types
results = try_different_layer_emphasis(
model=model,
tag_idx=tag_idx,
tag_name=tag,
image_size=512, # Always use 512x512
iterations=iterations,
scales=scales,
progress_callback=progress_callback
)
# Create tabs to display the results
tab_names = []
tab_contents = []
for emphasis_type, result in results.items():
image = result["image"]
score = result["score"]
# Store results
tab_images[emphasis_type] = image
# Track best score
if score > best_score:
best_score = score
best_image = image
best_emphasis = emphasis_type
# Add to tabs
tab_names.append(f"{emphasis_type.capitalize()} ({score:.2f})")
tab_contents.append(image)
# Show tabs with results
tabs = tabs_container.tabs(tab_names)
for i, tab in enumerate(tabs):
with tab:
st.image(tab_contents[i], caption=f"Essence with {list(results.keys())[i]} layer emphasis", use_container_width=True)
# Use the best-scored image as the final result
image = best_image
score = best_score
# Show which emphasis type worked best
st.success(f"Best results achieved with {best_emphasis} layer emphasis (score: {best_score:.2f})")
else:
# Generate single essence with specified layer emphasis
color_boost = 1.5 # Default color boost
tv_weight = 5e-4 # Default TV weight
# Adjust parameters based on layer emphasis
if layer_emphasis == "low":
color_boost = 1.3
tv_weight = 2e-4
elif layer_emphasis == "high":
color_boost = 1.7
tv_weight = 8e-4
image, score = generate_essence_with_emphasis(
model=model,
tag_idx=tag_idx,
tag_name=tag,
image_size=512, # Always use 512x512
iterations=iterations,
scales=scales,
progress_callback=progress_callback,
layer_emphasis=layer_emphasis,
color_boost=color_boost,
tv_weight=tv_weight
)
# Determine quality level
quality_level = get_quality_level(score)
# Deduct enkephalin cost
st.session_state.enkephalin -= cost
st.session_state.game_stats["enkephalin_spent"] = st.session_state.game_stats.get("enkephalin_spent", 0) + cost
# Increment essence counter
st.session_state.game_stats["essences_generated"] = st.session_state.game_stats.get("essences_generated", 0) + 1
# Save to persistent location
filepath = save_essence_to_game_folder(image, tag, score, quality_level)
print(f"Saved essence to: {filepath}")
# Update UI with final result if not showing comparison tabs
if layer_emphasis != "compare":
preview_container.image(image, caption=f"Essence of '{tag}' - Quality: {quality_level}", width=512)
# Clear progress elements
progress_container.empty()
message_container.empty()
# Store in session state
if 'generated_essences' not in st.session_state:
st.session_state.generated_essences = {}
st.session_state.generated_essences[tag] = {
"path": filepath,
"score": score,
"quality": quality_level,
"rarity": rarity,
"settings": settings,
"generated_time": time.strftime("%Y-%m-%d %H:%M:%S")
}
# Show success message
st.success(f"Successfully generated {quality_level} essence for '{tag}' with score {score:.4f}! Spent {cost} {ENKEPHALIN_ICON}")
print(f"=== Essence generation complete for '{tag}' ===\n")
# Add at the end of generate_essence_for_tag function, just before returning:
tag_storage.save_essence_state(session_state=st.session_state)
return image, score, quality_level
except Exception as e:
st.error(f"Error generating essence: {str(e)}")
print(f"EXCEPTION in generate_essence_for_tag: {str(e)}")
import traceback
err_traceback = traceback.format_exc()
print(err_traceback)
st.code(err_traceback)
return None, 0, None
def display_essence_generator():
"""
Display the essence generator interface
"""
# Initialize settings
initialize_essence_settings()
st.title("🎨 Tag Essence Generator")
st.write("Generate visual representations of what the AI model recognizes for specific tags.")
# Add detailed explanation of what essences are for
with st.expander("What are Tag Essences & How to Use Them", expanded=True):
st.markdown("""
### 💡 Understanding Tag Essences
Tag Essences are visual representations of what the AI model recognizes for specific tags. They can be extremely valuable for your tag collection strategy!
**How to use Tag Essences:**
1. **Generate a high-quality essence** for a tag you want to collect more of (only available on tags discovered in the library)
2. **Save the essence image** to your computer
3. **Upload the essence image** back into the tagger
4. The tagger will **almost always detect the original tag**
5. It will often also **detect related rare tags** from the same category
**Strategic Value:**
- Character essences can help unlock other tags associated with that character
- Category essences can help discover rare tags within that category
- High-quality essences (WAW, ALEPH) have the strongest effect
**This is why Enkephalin costs are high** - essences are powerful tools that can help you discover rare tags much more efficiently than random image scanning!
""")
# Check for model availability
model_available = hasattr(st.session_state, 'model')
if not model_available:
st.warning("Model not available. You can browse your tags but cannot generate essences.")
# Create tabs for the different sections
tabs = st.tabs(["Generate Essence", "My Essences"])
with tabs[0]:
# Check for pending generation from previous interaction
if hasattr(st.session_state, 'selected_tag') and st.session_state.selected_tag:
tag = st.session_state.selected_tag
st.subheader(f"Generating Essence for '{tag}'")
# Generate the essence
image, score, quality = generate_essence_for_tag(
tag,
st.session_state.model,
st.session_state.model.dataset,
st.session_state.essence_custom_settings
)
# Show usage tips if successful
if image is not None:
with st.expander("Essence Usage", expanded=True):
st.markdown("""
💡 **Tag Essence Usage Tips:**
1. Look for similar patterns, colors, and elements in real images
2. The essence reveals what features the AI model recognizes for this tag
3. Use this as inspiration when creating or finding images to get this tag
""")
else:
st.error("Essence generation failed. Please check the error messages above and try again with different settings.")
# Clear selected tag
st.session_state.selected_tag = None
else:
# Show the interface to select a tag
selected_tag = display_essence_generation_interface(model_available)
# If a tag was selected, store it for the next run and rerun
if selected_tag:
st.session_state.selected_tag = selected_tag
st.rerun()
with tabs[1]:
display_saved_essences()
def save_essence_to_game_folder(image, tag, score, quality_level):
"""
Save the generated essence image to a persistent game folder
Args:
image: PIL Image of the essence
tag: The tag name
score: The generation score
quality_level: The quality classification (ZAYIN, TETH, etc.)
Returns:
Path to the saved image
"""
# Create game folder paths with better structure
base_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
game_data_dir = os.path.join(base_dir, "game_data")
essence_folder = os.path.join(game_data_dir, "essences")
# Make sure all parent directories exist
os.makedirs(game_data_dir, exist_ok=True)
os.makedirs(essence_folder, exist_ok=True)
# Organize essences by quality level for easier browsing
quality_folder = os.path.join(essence_folder, quality_level)
os.makedirs(quality_folder, exist_ok=True)
# Create filename with more details and better organization
safe_tag = tag.replace('/', '_').replace('\\', '_').replace(' ', '_')
timestamp = time.strftime("%Y%m%d_%H%M%S")
filename = f"{safe_tag}_{score:.2f}_{timestamp}.png"
filepath = os.path.join(quality_folder, filename)
# Save the image
image.save(filepath)
print(f"Saved essence to: {filepath}")
return filepath
def essence_folder_path():
"""Get the path to the essence folder, creating it if necessary"""
base_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
game_data_dir = os.path.join(base_dir, "game_data")
essence_folder = os.path.join(game_data_dir, "essences")
# Make sure all directories exist
os.makedirs(game_data_dir, exist_ok=True)
os.makedirs(essence_folder, exist_ok=True)
return essence_folder
def display_saved_essences():
"""Display the user's saved essence images"""
st.subheader("My Generated Essences")
if not hasattr(st.session_state, 'generated_essences') or not st.session_state.generated_essences:
st.info("You haven't generated any essences yet. Go to the Generate tab to create some!")
return
# Add usage instructions at the top
st.markdown("""
### How to Use Your Essences
1. **Click on any essence image** to open it in full size
2. **Save the image** to your computer (right-click → Save image)
3. **Go to the Scan Images tab** and upload the saved essence
4. The tagger will likely detect the original tag and potentially related rare tags!
Higher quality essences (WAW, ALEPH) generally produce the best results.
""")
# Get the essence folder path
essence_dir = essence_folder_path()
# Try to locate any missing files
for tag, info in st.session_state.generated_essences.items():
if "path" in info and not os.path.exists(info["path"]):
# Try to find the file in the essence directory
quality = info.get("quality", "ZAYIN")
quality_dir = os.path.join(essence_dir, quality)
if os.path.exists(quality_dir):
# Check for files with this tag name
safe_tag = tag.replace('/', '_').replace('\\', '_').replace(' ', '_')
matching_files = [f for f in os.listdir(quality_dir) if f.startswith(safe_tag)]
if matching_files:
# Use the most recent file if there are multiple
matching_files.sort(reverse=True)
info["path"] = os.path.join(quality_dir, matching_files[0])
print(f"Reconnected essence for {tag} to {info['path']}")
# List essences by quality level
essences_by_quality = {}
for tag, info in st.session_state.generated_essences.items():
quality = info.get("quality", "ZAYIN") # Default to lowest if not set
if quality not in essences_by_quality:
essences_by_quality[quality] = []
essences_by_quality[quality].append((tag, info))
# Check if any essences exist on disk but are not tracked in session state
try:
untracked_essences = {}
for quality in ESSENCE_QUALITY_LEVELS.keys():
quality_dir = os.path.join(essence_dir, quality)
if os.path.exists(quality_dir):
essence_files = os.listdir(quality_dir)
# Filter to only show PNG files
essence_files = [f for f in essence_files if f.lower().endswith('.png')]
if essence_files:
# Check if any of these files aren't in our tracked essences
for filename in essence_files:
# Extract tag name from filename
parts = filename.split('_')
if len(parts) >= 2:
tag = parts[0].replace('_', ' ')
# Check if file is already tracked
is_tracked = False
for tracked_tag, tracked_info in st.session_state.generated_essences.items():
if "path" in tracked_info and os.path.basename(tracked_info["path"]) == filename:
is_tracked = True
break
if not is_tracked:
if quality not in untracked_essences:
untracked_essences[quality] = []
untracked_essences[quality].append((tag, {
"path": os.path.join(quality_dir, filename),
"quality": quality,
"discovered_on_disk": True
}))
except Exception as e:
print(f"Error checking for untracked essences: {e}")
# Combine tracked and untracked essences
for quality, essences in untracked_essences.items():
if quality not in essences_by_quality:
essences_by_quality[quality] = []
for tag, info in essences:
# Only add if we don't already have this tag in this quality level
if not any(tracked_tag == tag for tracked_tag, _ in essences_by_quality[quality]):
essences_by_quality[quality].append((tag, info))
# Show essences from highest to lowest quality
for quality in list(ESSENCE_QUALITY_LEVELS.keys())[::-1]:
if quality in essences_by_quality:
essences = essences_by_quality[quality]
color = ESSENCE_QUALITY_LEVELS[quality]["color"]
with st.expander(f"{quality} Essences ({len(essences)})", expanded=quality in ["ALEPH", "WAW"]):
# Create grid layout
cols = st.columns(3)
for i, (tag, info) in enumerate(sorted(essences, key=lambda x: x[1].get("score", 0), reverse=True)):
col_idx = i % 3
with cols[col_idx]:
try:
# Try to load the image from path
if "path" in info and os.path.exists(info["path"]):
image = Image.open(info["path"])
rarity = info.get("rarity", "Canard")
score = info.get("score", 0)
# Get color for rarity
rarity_color = RARITY_LEVELS.get(rarity, {}).get("color", "#AAAAAA")
# Display the image with metadata
st.image(image, caption=tag, use_container_width=True)
# Use special styling for rare tags
if rarity == "Impuritas Civitas":
st.markdown(f"""
<span style='color:{color};font-weight:bold;'>{quality}</span> |
<span style='animation: rainbow-text 4s linear infinite;font-weight:bold;'>{rarity}</span> |
Score: {score:.2f}
""", unsafe_allow_html=True)
elif rarity == "Star of the City":
st.markdown(f"""
<span style='color:{color};font-weight:bold;'>{quality}</span> |
<span style='color:{rarity_color};text-shadow:0 0 3px gold;font-weight:bold;'>{rarity}</span> |
Score: {score:.2f}
""", unsafe_allow_html=True)
elif rarity == "Urban Nightmare":
st.markdown(f"""
<span style='color:{color};font-weight:bold;'>{quality}</span> |
<span style='color:{rarity_color};text-shadow:0 0 1px #FF5722;font-weight:bold;'>{rarity}</span> |
Score: {score:.2f}
""", unsafe_allow_html=True)
elif rarity == "Urban Plague":
st.markdown(f"""
<span style='color:{color};font-weight:bold;'>{quality}</span> |
<span style='color:{rarity_color};text-shadow:0 0 1px #9C27B0;font-weight:bold;'>{rarity}</span> |
Score: {score:.2f}
""", unsafe_allow_html=True)
else:
st.markdown(f"""
<span style='color:{color};font-weight:bold;'>{quality}</span> |
<span style='color:{rarity_color};font-weight:bold;'>{rarity}</span> |
Score: {score:.2f}
""", unsafe_allow_html=True)
# Add file info
if "discovered_on_disk" in info and info["discovered_on_disk"]:
st.info("Found on disk (not in session state)")
# Add button to open folder
if st.button(f"Open Folder", key=f"open_folder_{tag}_{quality}"):
folder_path = os.path.dirname(info["path"])
try:
# Try different methods to open folder based on platform
if os.name == 'nt': # Windows
os.startfile(folder_path)
elif os.name == 'posix': # macOS or Linux
import subprocess
if 'darwin' in os.sys.platform: # macOS
subprocess.call(['open', folder_path])
else: # Linux
subprocess.call(['xdg-open', folder_path])
st.success(f"Opened folder: {folder_path}")
except Exception as e:
st.error(f"Could not open folder: {str(e)}")
# Provide the path for manual navigation
st.code(folder_path)
else:
# Could not find image
st.warning(f"Image file not found: {info.get('path', 'No path available')}")
# Show quality and tag name
st.markdown(f"""
<span style='color:{color};font-weight:bold;'>{quality}</span> | {tag}
""", unsafe_allow_html=True)
# Only add reconnect button if we have some metadata
if "rarity" in info and "score" in info:
if st.button(f"Reconnect File", key=f"reconnect_{tag}_{quality}"):
# Update path in session state
safe_tag = tag.replace('/', '_').replace('\\', '_').replace(' ', '_')
score = info.get("score", 0)
quality_dir = os.path.join(essence_dir, quality)
# Create directory if it doesn't exist
os.makedirs(quality_dir, exist_ok=True)
# Set a path - user will need to manually add the image
timestamp = time.strftime("%Y%m%d_%H%M%S")
filename = f"{safe_tag}_{score:.2f}_{timestamp}.png"
info["path"] = os.path.join(quality_dir, filename)
st.info(f"Please save your image to this location: {info['path']}")
st.session_state.generated_essences[tag] = info
tag_storage.save_essence_state(session_state=st.session_state)
st.rerun()
except Exception as e:
st.write(f"Error loading {tag}: {str(e)}")
# Add option to clean up missing files
st.divider()
if st.button("Clean Up Missing Files", help="Remove entries for essences where the file no longer exists"):
# Find all entries with missing files
to_remove = []
for tag, info in st.session_state.generated_essences.items():
if "path" in info and not os.path.exists(info["path"]):
to_remove.append(tag)
# Remove them
for tag in to_remove:
del st.session_state.generated_essences[tag]
# Save state
tag_storage.save_essence_state(session_state=st.session_state)
if to_remove:
st.success(f"Removed {len(to_remove)} entries with missing files")
else:
st.success("No missing files found")
st.rerun()
def display_essence_generation_interface(model_available):
"""Display the interface for generating new essences"""
# Initialize manual tags
initialize_manual_tags()
st.subheader("Generate Tag Essence")
st.write("Select a tag to generate its essence. Higher quality essences can help unlock rare related tags when uploaded back into the tagger.")
# Settings column
col1, col2 = st.columns(2)
with col1:
# Simple settings
st.write("Generation Settings:")
# Basic settings
scales = st.slider("Scales", 1, 5, DEFAULT_ESSENCE_SETTINGS["scales"],
help="More scales produce more detailed essences")
iterations = st.slider("Iterations", 64, 2048, DEFAULT_ESSENCE_SETTINGS["iterations"], 64,
help="More iterations improve quality")
# Layer emphasis selection - all options available including comparison
layer_emphasis = st.selectbox(
"Feature Targeting",
options=["auto", "balanced", "high", "mid", "low", "compare", "custom"],
index=0, # Default to auto
format_func=lambda x: {
"auto": "Auto-detect (best for each tag)",
"balanced": "Balanced (mix of features)",
"high": "High-level (characters, objects)",
"mid": "Mid-level (parts, components)",
"low": "Low-level (textures, patterns)",
"compare": "Compare different approaches",
"custom": "Custom layer selection"
}.get(x, x),
help="Controls which model features to emphasize in the essence"
)
# Custom layer selection if needed
custom_layers = []
if layer_emphasis == "custom" and model_available:
st.write("Select Custom Layers:")
# Get key layers (simplified approach)
key_layers = get_key_layers(st.session_state.model, max_layers=15)
# Show categories (early, middle, late, classifier)
for category, layers in key_layers.items():
if layers:
category_name = {
"early": "Early Layers (textures, colors)",
"middle": "Middle Layers (parts, components)",
"late": "Late Layers (objects, characters)",
"classifier": "Classifier (final recognition)"
}.get(category, category.capitalize())
with st.expander(f"{category_name}", expanded=category in ["late", "classifier"]):
select_all = st.checkbox(f"Select all {category} layers",
key=f"select_all_{category}")
for layer in layers:
# Create a shortened display name
parts = layer.split(".")
display_name = f"...{parts[-2]}.{parts[-1]}" if len(parts) > 3 else layer
if select_all or st.checkbox(display_name, key=f"layer_{layer}"):
custom_layers.append(layer)
# Show selected layers
if custom_layers:
st.success(f"Selected {len(custom_layers)} layers")
else:
st.warning("Please select at least one layer")
# Save settings
st.session_state.essence_custom_settings = {
"scales": scales,
"iterations": iterations,
"image_size": 512, # Fixed
"lr": 0.03, # Lower learning rate for better results
"layer_emphasis": layer_emphasis,
"custom_layers": custom_layers
}
with col2:
# Show quality level descriptions
st.write("Quality Levels:")
for level, info in ESSENCE_QUALITY_LEVELS.items():
st.markdown(f"""
<div style="padding:5px;margin-bottom:5px;border-radius:4px;background-color:rgba({int(info['color'][1:3], 16)},{int(info['color'][3:5], 16)},{int(info['color'][5:7], 16)},0.1);border-left:3px solid {info['color']}">
<span style="color:{info['color']};font-weight:bold;">{level}</span> ({info['threshold']:.0f} Score+): {info['description']}
</div>
""", unsafe_allow_html=True)
# Feature targeting explanation
st.write("Feature Targeting Explanation:")
st.markdown("""
ℹ️ **Feature targeting affects what the visualization emphasizes:**
""")
# Show current Enkephalin
st.markdown(f"### Your {ENKEPHALIN_CURRENCY_NAME}: **{st.session_state.enkephalin}** {ENKEPHALIN_ICON}")
st.divider()
# Add CSS for animations matching tag collection display
st.markdown("""
<style>
@keyframes rainbow-text {
0% { color: red; }
14% { color: orange; }
28% { color: yellow; }
42% { color: green; }
57% { color: blue; }
71% { color: indigo; }
85% { color: violet; }
100% { color: red; }
}
.impuritas-text {
font-weight: bold;
animation: rainbow-text 4s linear infinite;
}
@keyframes glow-text {
0% { text-shadow: 0 0 2px gold; }
50% { text-shadow: 0 0 6px gold; }
100% { text-shadow: 0 0 2px gold; }
}
.star-text {
color: #FFEB3B;
text-shadow: 0 0 3px gold;
animation: glow-text 2s infinite;
font-weight: bold;
}
@keyframes pulse-text {
0% { opacity: 0.8; }
50% { opacity: 1; }
100% { opacity: 0.8; }
}
.nightmare-text {
color: #FF9800;
text-shadow: 0 0 1px #FF5722;
animation: pulse-text 3s infinite;
font-weight: bold;
}
.plague-text {
color: #9C27B0;
text-shadow: 0 0 1px #9C27B0;
font-weight: bold;
}
.category-section {
margin-top: 20px;
margin-bottom: 30px;
padding: 10px;
border-radius: 5px;
border-left: 5px solid;
}
</style>
""", unsafe_allow_html=True)
# ----- NEW TAG COLLECTION DISPLAY -----
# Gather all tags for essence generation
all_tags = []
# Process discovered tags
if hasattr(st.session_state, 'discovered_tags'):
for tag, info in st.session_state.discovered_tags.items():
tag_info = {
"tag": tag,
"rarity": info.get("rarity", "Unknown"),
"category": info.get("category", "unknown"),
"source": "discovered",
"library_floor": info.get("library_floor", ""),
"discovery_time": info.get("discovery_time", "")
}
all_tags.append(tag_info)
# Process manual tags
if hasattr(st.session_state, 'manual_tags'):
for tag, info in st.session_state.manual_tags.items():
tag_info = {
"tag": tag,
"rarity": info.get("rarity", "Special"),
"category": info.get("category", "special"),
"source": "manual",
"description": info.get("description", "")
}
all_tags.append(tag_info)
# Count tags by rarity
rarity_counts = {}
for info in all_tags:
rarity = info["rarity"]
if rarity not in rarity_counts:
rarity_counts[rarity] = 0
rarity_counts[rarity] += 1
# Display rarity counts at the top
st.subheader("Available Tags for Essence Generation")
st.write(f"You have {len(all_tags)} tags available for essence generation. Collect more from the library!")
# Display rarity distribution
rarity_cols = st.columns(len(rarity_counts))
for i, (rarity, count) in enumerate(sorted(rarity_counts.items(),
key=lambda x: list(RARITY_LEVELS.keys()).index(x[0]) if x[0] in RARITY_LEVELS else 999)):
with rarity_cols[i]:
# Get color with fallback
color = RARITY_LEVELS.get(rarity, {}).get("color", "#888888")
# Apply special styling based on rarity
style = f"color:{color};font-weight:bold;"
class_name = ""
if rarity == "Impuritas Civitas":
class_name = "grid-impuritas"
elif rarity == "Star of the City":
class_name = "grid-star"
elif rarity == "Urban Nightmare":
class_name = "grid-nightmare"
elif rarity == "Urban Plague":
class_name = "grid-plague"
if class_name:
st.markdown(
f"<div style='text-align:center;'><span class='{class_name}' style='font-weight:bold;'>{rarity.capitalize()}</span><br>{count}</div>",
unsafe_allow_html=True
)
else:
st.markdown(
f"<div style='text-align:center;'><span style='{style}'>{rarity.capitalize()}</span><br>{count}</div>",
unsafe_allow_html=True
)
# Search box for all tags
search_term = st.text_input("Search tags", "", key="essence_search_tags")
# Sort options
sort_options = ["Category (rarest first)", "Rarity", "Discovery Time"]
selected_sort = st.selectbox("Sort tags by:", sort_options, key="essence_tags_sort")
# Filter tags by search term if provided
if search_term:
all_tags = [info for info in all_tags if search_term.lower() in info["tag"].lower()]
selected_tag = None
# Sort and group tags based on selection
if selected_sort == "Category (rarest first)":
# Group tags by category
categories = {}
for info in all_tags:
category = info["category"]
if category not in categories:
categories[category] = []
categories[category].append(info)
# Display tags by category in expanders
for category, tags in sorted(categories.items()):
# Get rarity order for sorting
rarity_order = list(reversed(RARITY_LEVELS.keys()))
# Sort tags by rarity (rarest first)
def get_rarity_index(info):
rarity = info["rarity"]
if rarity in rarity_order:
return len(rarity_order) - rarity_order.index(rarity)
return 0
sorted_tags = sorted(tags, key=get_rarity_index, reverse=True)
# Check if category has any rare tags
has_rare_tags = any(info["rarity"] in ["Impuritas Civitas", "Star of the City"]
for info in sorted_tags)
# Get category info if available
category_display = category.capitalize()
if category in TAG_CATEGORIES:
category_info = TAG_CATEGORIES[category]
icon = category_info.get("icon", "")
color = category_info.get("color", "#888888")
category_display = f"<span style='color:{color};'>{icon} {category.capitalize()}</span>"
# Create header with information about rare tags if present
header = f"{category_display} ({len(tags)} tags)"
if has_rare_tags:
header += " ✨ Contains rare tags!"
# Display category header and expander
st.markdown(header, unsafe_allow_html=True)
with st.expander("Show/Hide", expanded=has_rare_tags):
# Create grid layout for tags
cols = st.columns(3)
for i, info in enumerate(sorted_tags):
with cols[i % 3]:
tag = info["tag"]
rarity = info["rarity"]
source = info["source"]
# Get rarity color
rarity_color = RARITY_LEVELS.get(rarity, {}).get("color", "#AAAAAA")
# Check if this tag has an essence already
has_essence = hasattr(st.session_state, 'generated_essences') and tag in st.session_state.generated_essences
# Get cost for this tag
cost = get_essence_cost(rarity)
can_afford = st.session_state.enkephalin >= cost
# Format tag display with special styling
if rarity == "Impuritas Civitas":
tag_display = f'<span class="impuritas-text">{tag}</span>'
elif rarity == "Star of the City":
tag_display = f'<span class="star-text">{tag}</span>'
elif rarity == "Urban Nightmare":
tag_display = f'<span class="nightmare-text">{tag}</span>'
elif rarity == "Urban Plague":
tag_display = f'<span class="plague-text">{tag}</span>'
else:
tag_display = f'<span style="color:{rarity_color};font-weight:bold;">{tag}</span>'
# Show tag with rarity badge and cost
st.markdown(
f'{tag_display} <span style="background-color:{rarity_color};color:white;padding:2px 6px;border-radius:10px;font-size:0.8em;">{rarity.capitalize()}</span> ({cost} {ENKEPHALIN_ICON})',
unsafe_allow_html=True
)
# Show discovery details if available
if source == "discovered" and "library_floor" in info and info["library_floor"]:
st.markdown(f'<span style="font-size:0.85em;">Found in: {info["library_floor"]}</span>',
unsafe_allow_html=True)
elif source == "manual" and "description" in info and info["description"]:
st.markdown(f'<span style="font-size:0.85em;font-style:italic;">{info["description"]}</span>',
unsafe_allow_html=True)
# Add generation button
button_label = "Generate" if not has_essence else "Regenerate ✓"
if st.button(button_label, key=f"gen_{tag}_{source}", disabled=not model_available or not can_afford):
selected_tag = tag
elif selected_sort == "Rarity":
# Group tags by rarity
rarity_groups = {}
for info in all_tags:
rarity = info["rarity"]
if rarity not in rarity_groups:
rarity_groups[rarity] = []
rarity_groups[rarity].append(info)
# Get ordered rarities (rarest first)
ordered_rarities = list(RARITY_LEVELS.keys())
ordered_rarities.reverse() # Reverse to show rarest first
# Add any rarities not in RARITY_LEVELS
for rarity in rarity_groups.keys():
if rarity not in ordered_rarities:
ordered_rarities.append(rarity)
# Display tags by rarity
for rarity in ordered_rarities:
if rarity in rarity_groups:
tags = rarity_groups[rarity]
color = RARITY_LEVELS.get(rarity, {}).get("color", "#AAAAAA")
# Add special styling for rare rarities
rarity_html = f"<span style='color:{color};font-weight:bold;'>{rarity.capitalize()}</span>"
if rarity == "Impuritas Civitas":
rarity_html = f"<span style='animation:rainbow-text 4s linear infinite;font-weight:bold;'>{rarity.capitalize()}</span>"
elif rarity == "Star of the City":
rarity_html = f"<span style='color:{color};text-shadow:0 0 3px gold;font-weight:bold;'>{rarity.capitalize()}</span>"
elif rarity == "Urban Nightmare":
rarity_html = f"<span style='color:{color};text-shadow:0 0 1px #FF5722;font-weight:bold;'>{rarity.capitalize()}</span>"
# First create the title with HTML, then use it in the expander
st.markdown(f"### {rarity_html} ({len(tags)} tags)", unsafe_allow_html=True)
with st.expander("Show/Hide", expanded=rarity in ["Impuritas Civitas", "Star of the City"]):
# Create grid layout for tags
cols = st.columns(3)
for i, info in enumerate(sorted(tags, key=lambda x: x["tag"])):
with cols[i % 3]:
tag = info["tag"]
source = info["source"]
# Check if this tag has an essence already
has_essence = hasattr(st.session_state, 'generated_essences') and tag in st.session_state.generated_essences
# Get cost for this tag
cost = get_essence_cost(rarity)
can_afford = st.session_state.enkephalin >= cost
# Show tag with cost
st.markdown(f"**{tag}** ({cost} {ENKEPHALIN_ICON})")
# Show discovery details if available
if source == "discovered" and "library_floor" in info and info["library_floor"]:
st.markdown(f'<span style="font-size:0.85em;">Found in: {info["library_floor"]}</span>',
unsafe_allow_html=True)
elif source == "manual" and "description" in info and info["description"]:
st.markdown(f'<span style="font-size:0.85em;font-style:italic;">{info["description"]}</span>',
unsafe_allow_html=True)
# Add generation button
button_label = "Generate" if not has_essence else "Regenerate ✓"
if st.button(button_label, key=f"gen_{tag}_{source}", disabled=not model_available or not can_afford):
selected_tag = tag
elif selected_sort == "Discovery Time":
# Filter to just discovered tags (manual tags don't have discovery time)
discovered_tags = [info for info in all_tags if info["source"] == "discovered" and "discovery_time" in info]
# Sort all tags by discovery time (newest first)
sorted_tags = sorted(discovered_tags, key=lambda x: x["discovery_time"], reverse=True)
# Group by date
date_groups = {}
for info in sorted_tags:
time_str = info["discovery_time"]
# Extract just the date part if timestamp has date and time
date = time_str.split()[0] if " " in time_str else time_str
if date not in date_groups:
date_groups[date] = []
date_groups[date].append(info)
# Display tags grouped by discovery date
for date, tags in date_groups.items():
date_display = date if date else "Unknown date"
st.markdown(f"### Discovered on {date_display} ({len(tags)} tags)")
with st.expander("Show/Hide", expanded=date == list(date_groups.keys())[0]): # Expand most recent by default
# Create grid layout for tags
cols = st.columns(3)
for i, info in enumerate(tags):
with cols[i % 3]:
tag = info["tag"]
rarity = info["rarity"]
# Get rarity color
rarity_color = RARITY_LEVELS.get(rarity, {}).get("color", "#AAAAAA")
# Check if this tag has an essence already
has_essence = hasattr(st.session_state, 'generated_essences') and tag in st.session_state.generated_essences
# Get cost for this tag
cost = get_essence_cost(rarity)
can_afford = st.session_state.enkephalin >= cost
# Format tag display with special styling
if rarity == "Impuritas Civitas":
tag_display = f'<span class="impuritas-text">{tag}</span>'
elif rarity == "Star of the City":
tag_display = f'<span class="star-text">{tag}</span>'
elif rarity == "Urban Nightmare":
tag_display = f'<span class="nightmare-text">{tag}</span>'
elif rarity == "Urban Plague":
tag_display = f'<span class="plague-text">{tag}</span>'
else:
tag_display = f'<span style="color:{rarity_color};font-weight:bold;">{tag}</span>'
# Show tag with rarity badge and cost
st.markdown(
f'{tag_display} <span style="background-color:{rarity_color};color:white;padding:2px 6px;border-radius:10px;font-size:0.8em;">{rarity.capitalize()}</span> ({cost} {ENKEPHALIN_ICON})',
unsafe_allow_html=True
)
# Show discovery details
if "library_floor" in info and info["library_floor"]:
st.markdown(f'<span style="font-size:0.85em;">Found in: {info["library_floor"]}</span>',
unsafe_allow_html=True)
# Add generation button
button_label = "Generate" if not has_essence else "Regenerate ✓"
if st.button(button_label, key=f"gen_{tag}_disc", disabled=not model_available or not can_afford):
selected_tag = tag
# Show manual tags separately if we have any
manual_tags = [info for info in all_tags if info["source"] == "manual"]
if manual_tags:
st.markdown("### Manual Tags")
with st.expander("Show/Hide"):
# Create grid layout for tags
cols = st.columns(3)
for i, info in enumerate(manual_tags):
with cols[i % 3]:
tag = info["tag"]
rarity = info["rarity"]
# Get rarity color
rarity_color = RARITY_LEVELS.get(rarity, {}).get("color", "#AAAAAA")
# Check if this tag has an essence already
has_essence = hasattr(st.session_state, 'generated_essences') and tag in st.session_state.generated_essences
# Get cost for this tag
cost = get_essence_cost(rarity)
can_afford = st.session_state.enkephalin >= cost
# Show tag with rarity badge and cost
st.markdown(f"**{tag}** ({cost} {ENKEPHALIN_ICON})")
# Show description if available
if "description" in info and info["description"]:
st.markdown(f'<span style="font-size:0.85em;font-style:italic;">{info["description"]}</span>',
unsafe_allow_html=True)
# Add generation button
button_label = "Generate" if not has_essence else "Regenerate ✓"
if st.button(button_label, key=f"gen_{tag}_manual", disabled=not model_available or not can_afford):
selected_tag = tag
return selected_tag
def generate_essence_with_emphasis(model, tag_idx, tag_name=None, image_size=512,
iterations=256, scales=3, progress_callback=None,
layer_emphasis="mid", color_boost=1.5, tv_weight=5e-4):
"""
Generate an essence visualization with specific layer emphasis and enhancements.
Args:
model: Neural network model to visualize
tag_idx: Index of the tag to visualize
tag_name: Optional name of the tag (for logging)
image_size: Size of output image (default: 512)
iterations: Number of iterations per scale (default: 256)
scales: Number of scales to use (default: 5)
progress_callback: Optional callback for progress updates
layer_emphasis: Type of layers to use ("auto", "balanced", "high", "mid", "low")
color_boost: Factor for boosting color saturation (default: 1.5)
tv_weight: Total variation weight (default: 5e-4)
Returns:
PIL Image of the generated essence and the activation score
"""
# Create a tag-to-name mapping with the provided name
tag_to_name = {tag_idx: tag_name} if tag_name else None
# Determine layers to use if not auto
layers_to_hook = None
layer_weights = None
if layer_emphasis != "auto":
# Get layers based on the emphasis type
layers_to_hook = get_suggested_layers(model, layer_emphasis)
# Set layer weights based on position
layer_weights = {}
for i, layer in enumerate(layers_to_hook):
# Base weight from position
weight = 0.5 + 0.5 * (i / max(1, len(layers_to_hook) - 1))
# Boost weight for classifier layers
if any(x in layer.lower() for x in ["classifier", "fc", "linear", "output", "logits"]):
weight *= 1.5
layer_weights[layer] = weight
print(f"Using {len(layers_to_hook)} {layer_emphasis}-level layers")
else:
print("Using auto layer detection")
# Create instance of the improved generator
generator = EssenceGenerator(
model=model,
tag_to_name=tag_to_name,
iterations=iterations,
scales=scales,
learning_rate=0.05, # Lower for better convergence
decay_power=1.0, # Stronger decay power for cleaner images
tv_weight=tv_weight, # Customizable TV weight
layers_to_hook=layers_to_hook,
layer_weights=layer_weights,
color_boost=color_boost # Customizable color boost
)
# Generate the essence
print(f"Generating essence for tag {tag_name or tag_idx} with {layer_emphasis} emphasis...")
image, score = generator.generate_essence(
tag_idx=tag_idx,
image_size=image_size,
return_score=True,
progress_callback=progress_callback
)
print(f"Essence generation complete. Score: {score:.4f}")
return image, score
def try_different_layer_emphasis(model, tag_idx, tag_name=None, image_size=512,
iterations=256, scales=4, progress_callback=None):
"""
Generate multiple essences with different layer emphasis types and return them all.
Args:
model: Neural network model to visualize
tag_idx: Index of the tag to visualize
tag_name: Optional name of the tag (for logging)
image_size: Size of output image (default: 512)
iterations: Number of iterations per scale (default: 256)
scales: Number of scales to use (default: 4)
progress_callback: Optional callback for progress updates
Returns:
Dictionary of PIL Images and scores for each layer emphasis type
"""
emphasis_types = [
{"name": "low", "color_boost": 1.3, "tv_weight": 2e-4}, # Low-level features (textures, colors)
{"name": "mid", "color_boost": 1.5, "tv_weight": 5e-4}, # Mid-level features (parts, components)
{"name": "high", "color_boost": 1.7, "tv_weight": 8e-4}, # High-level features (characters, objects)
]
results = {}
for emphasis in emphasis_types:
print(f"\n=== Trying {emphasis['name']} layer emphasis ===")
image, score = generate_essence_with_emphasis(
model=model,
tag_idx=tag_idx,
tag_name=tag_name,
image_size=image_size,
iterations=iterations,
scales=scales,
progress_callback=progress_callback,
layer_emphasis=emphasis["name"],
color_boost=emphasis["color_boost"],
tv_weight=emphasis["tv_weight"]
)
results[emphasis["name"]] = {
"image": image,
"score": score
}
print(f"=== Completed {emphasis['name']} layer emphasis with score {score:.4f} ===")
return results |