File size: 104,370 Bytes
29b445b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
1619
1620
1621
1622
1623
1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
1637
1638
1639
1640
1641
1642
1643
1644
1645
1646
1647
1648
1649
1650
1651
1652
1653
1654
1655
1656
1657
1658
1659
1660
1661
1662
1663
1664
1665
1666
1667
1668
1669
1670
1671
1672
1673
1674
1675
1676
1677
1678
1679
1680
1681
1682
1683
1684
1685
1686
1687
1688
1689
1690
1691
1692
1693
1694
1695
1696
1697
1698
1699
1700
1701
1702
1703
1704
1705
1706
1707
1708
1709
1710
1711
1712
1713
1714
1715
1716
1717
1718
1719
1720
1721
1722
1723
1724
1725
1726
1727
1728
1729
1730
1731
1732
1733
1734
1735
1736
1737
1738
1739
1740
1741
1742
1743
1744
1745
1746
1747
1748
1749
1750
1751
1752
1753
1754
1755
1756
1757
1758
1759
1760
1761
1762
1763
1764
1765
1766
1767
1768
1769
1770
1771
1772
1773
1774
1775
1776
1777
1778
1779
1780
1781
1782
1783
1784
1785
1786
1787
1788
1789
1790
1791
1792
1793
1794
1795
1796
1797
1798
1799
1800
1801
1802
1803
1804
1805
1806
1807
1808
1809
1810
1811
1812
1813
1814
1815
1816
1817
1818
1819
1820
1821
1822
1823
1824
1825
1826
1827
1828
1829
1830
1831
1832
1833
1834
1835
1836
1837
1838
1839
1840
1841
1842
1843
1844
1845
1846
1847
1848
1849
1850
1851
1852
1853
1854
1855
1856
1857
1858
1859
1860
1861
1862
1863
1864
1865
1866
1867
1868
1869
1870
1871
1872
1873
1874
1875
1876
1877
1878
1879
1880
1881
1882
1883
1884
1885
1886
1887
1888
1889
1890
1891
1892
1893
1894
1895
1896
1897
1898
1899
1900
1901
1902
1903
1904
1905
1906
1907
1908
1909
1910
1911
1912
1913
1914
1915
1916
1917
1918
1919
1920
1921
1922
1923
1924
1925
1926
1927
1928
1929
1930
1931
1932
1933
1934
1935
1936
1937
1938
1939
1940
1941
1942
1943
1944
1945
1946
1947
1948
1949
1950
1951
1952
1953
1954
1955
1956
1957
1958
1959
1960
1961
1962
1963
1964
1965
1966
1967
1968
1969
1970
1971
1972
1973
1974
1975
1976
1977
1978
1979
1980
1981
1982
1983
1984
1985
1986
1987
1988
1989
1990
1991
1992
1993
1994
1995
1996
1997
1998
1999
2000
2001
2002
2003
2004
2005
2006
2007
2008
2009
2010
2011
2012
2013
2014
2015
2016
2017
2018
2019
2020
2021
2022
2023
2024
2025
2026
2027
2028
2029
2030
2031
2032
2033
2034
2035
2036
2037
2038
2039
2040
2041
2042
2043
2044
2045
2046
2047
2048
2049
2050
2051
2052
2053
2054
2055
2056
2057
2058
2059
2060
2061
2062
2063
2064
2065
2066
2067
2068
2069
2070
2071
2072
2073
2074
2075
2076
2077
2078
2079
2080
2081
2082
2083
2084
2085
2086
2087
2088
2089
2090
2091
2092
2093
2094
2095
2096
2097
2098
2099
2100
2101
2102
2103
2104
2105
2106
2107
2108
2109
2110
2111
2112
2113
2114
2115
2116
2117
2118
2119
2120
2121
2122
2123
2124
2125
2126
2127
2128
2129
2130
2131
2132
2133
2134
2135
2136
2137
2138
2139
2140
2141
2142
2143
2144
2145
2146
2147
2148
2149
2150
2151
2152
2153
2154
2155
2156
2157
2158
2159
2160
2161
2162
2163
2164
2165
2166
2167
2168
2169
2170
2171
2172
2173
2174
2175
2176
2177
2178
2179
2180
2181
2182
2183
2184
2185
2186
2187
2188
2189
2190
2191
2192
2193
2194
2195
2196
2197
2198
2199
2200
2201
2202
2203
2204
2205
2206
2207
2208
2209
2210
2211
2212
2213
2214
2215
2216
2217
2218
2219
2220
2221
2222
2223
2224
2225
2226
2227
2228
2229
2230
2231
2232
#!/usr/bin/env python3
"""

Enhanced Essence Generator for Tag Collector Game

"""

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.transforms.functional import to_pil_image
from PIL import Image
import numpy as np
import os
import re
import math
import json
import streamlit as st
from tqdm import tqdm
from scipy.ndimage import gaussian_filter
from functools import wraps
import time
import tag_storage  # Import for saving game state

from game_constants import RARITY_LEVELS, ENKEPHALIN_CURRENCY_NAME, ENKEPHALIN_ICON
from tag_categories import TAG_CATEGORIES

# Define essence quality levels with thresholds and styles
ESSENCE_QUALITY_LEVELS = {
    "ZAYIN": {"threshold": 0.0, "color": "#1CFC00", "description": "Basic representation with minimal details."},
    "TETH": {"threshold": 3.0, "color": "#389DDF", "description": "Clear representation with recognizable features."},
    "HE": {"threshold": 5.0, "color": "#FEF900", "description": "Refined representation with distinctive elements."},
    "WAW": {"threshold": 10.0, "color": "#7930F1", "description": "Advanced representation with precise details."},
    "ALEPH": {"threshold": 12.0, "color": "#FF0000", "description": "Perfect representation with extraordinary precision."}
}

# Essence generation costs in enkephalin based on tag rarity
ESSENCE_COSTS = {
    "Special": 0,
    "Canard": 100,                 # Common tags
    "Urban Myth": 125,             # Uncommon tags
    "Urban Legend": 150,           # Rare tags
    "Urban Plague": 200,          # Very rare tags
    "Urban Nightmare": 250,       # Extremely rare tags
    "Star of the City": 300,      # Nearly mythical tags
    "Impuritas Civitas": 400     # Legendary tags
}

# Default essence generation settings
DEFAULT_ESSENCE_SETTINGS = {
    "scales": 1,            # Number of scales for multiscale optimization
    "iterations": 256,      # Iterations per scale
    "image_size": 512,      # Always use 512x512 resolution
    "lr": 0.1,              # Learning rate
    "layer_emphasis": "auto"  # Default to auto-detection
}

def initialize_essence_settings():
    """Initialize essence generator settings if not already present"""
    if 'essence_custom_settings' not in st.session_state:
        # Try to load from storage first
        loaded_state = tag_storage.load_essence_state()
        
        if loaded_state and 'essence_custom_settings' in loaded_state:
            st.session_state.essence_custom_settings = loaded_state['essence_custom_settings']
        else:
            st.session_state.essence_custom_settings = DEFAULT_ESSENCE_SETTINGS.copy()

# Replace initialize_manual_tags with:
def initialize_manual_tags():
    """Initialize manual tags if not already present"""
    if 'manual_tags' not in st.session_state:
        # Try to load from storage first
        loaded_state = tag_storage.load_essence_state()
        
        if loaded_state and 'manual_tags' in loaded_state:
            st.session_state.manual_tags = loaded_state['manual_tags']
        else:
            st.session_state.manual_tags = {
                "hatsune_miku": {"rarity": "Special", "description": "Popular virtual singer with long teal twin-tails"},
            }


def timeout(seconds, fallback_value=None):
    """

    Simple timeout utility for functions.

    Warns if a function takes longer than expected but doesn't interrupt it.

    

    Args:

        seconds: Expected maximum seconds the function should take

        fallback_value: Not used, just for API compatibility

    """
    def decorator(func):
        @wraps(func)
        def wrapper(*args, **kwargs):
            start_time = time.time()
            result = func(*args, **kwargs)
            elapsed = time.time() - start_time
            
            if elapsed > seconds:
                print(f"WARNING: Function {func.__name__} took {elapsed:.2f} seconds (expected max {seconds}s)")
                
            return result
        return wrapper
    return decorator

# Core Classes for Essence Generation

class LayerHook:
    """Helper class to store the outputs of a layer via forward hook."""
    def __init__(self, layer):
        self.layer = layer
        self.features = None
        self.hook = layer.register_forward_hook(self.hook_fn)
    
    def hook_fn(self, module, input, output):
        self.features = output
    
    def close(self):
        self.hook.remove()

class FullModelHook:
    """Hook all layers in a model and track their responses to inputs."""
    
    def __init__(self, model):
        self.model = model
        self.hooks = {}
        self.activations = {}
        self.layer_scores = {}
        
        # Recursively register hooks for all eligible layers
        self._register_hooks(model)
        print(f"FullModelHook initialized with {len(self.hooks)} hooks")
    
    def _register_hooks(self, module, prefix=''):
        """Recursively register hooks on all suitable layers."""
        for name, child in module.named_children():
            layer_name = f"{prefix}.{name}" if prefix else name
            
            # Only hook layers that produce activations
            # Avoid hooking containers like Sequential
            if isinstance(child, (torch.nn.Conv2d, torch.nn.Linear, torch.nn.BatchNorm2d, torch.nn.LayerNorm)):
                self.hooks[layer_name] = child.register_forward_hook(
                    lambda m, inp, out, layer=layer_name: self._hook_fn(layer, out)
                )
            
            # Recurse into children
            self._register_hooks(child, layer_name)
    
    def _hook_fn(self, layer_name, output):
        """Store activations for each layer."""
        # For convolutional layers, compute channel-wise mean activations
        if len(output.shape) == 4:  # [batch, channels, height, width]
            # Store mean activation per channel
            self.activations[layer_name] = output.mean(dim=[2, 3]).detach()
        else:
            # For other layers, store as is
            self.activations[layer_name] = output.detach()

class EssenceGenerator:
    """

    Enhanced Essence Generator optimized for anime characters.

    Includes improvements for more vibrant colors and recognizable features.

    """
    
    def __init__(

        self,

        model,

        tag_to_name=None,

        iterations=256,

        scales=3,

        learning_rate=0.03,  # Lower learning rate for better convergence

        decay_power=1.5,     # Stronger emphasis on low frequencies

        tv_weight=5e-4,      # Stronger total variation for clearer structures

        layers_to_hook=None,

        layer_weights=None,

        color_boost=1.5      # Color boosting factor

    ):
        """Initialize the Enhanced Essence Generator"""
        self.model = model
        self.tag_to_name = tag_to_name
        self.iterations = iterations
        self.scales = scales
        self.lr = learning_rate
        self.decay_power = decay_power
        self.tv_weight = tv_weight
        self.layers_to_hook = layers_to_hook
        self.layer_weights = layer_weights
        self.color_boost = color_boost
        
        # Set device
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model.eval().to(self.device)
        
        # Initialize hooks
        self.hooks = {}
        
        # Enhanced color correlation matrix for anime-style colors
        # More saturated colors with stronger correlations
        self.color_correlation_matrix = torch.tensor([
            [1.0000, 0.9522, 0.9156],
            [0.9522, 1.0000, 0.9708],
            [0.9156, 0.9708, 1.0000]], device=self.device)
        
        # Setup hooks if specified
        if self.layers_to_hook:
            self.setup_hooks(self.layers_to_hook)
    
    def setup_hooks(self, layers_to_hook):
        """Setup hooks on the specified layers."""
        # Close any existing hooks
        self.close_hooks()
        
        # Create new hooks
        for layer_name in layers_to_hook:
            try:
                # Try to get layer by navigating the model hierarchy
                parts = layer_name.split('.')
                layer = self.model
                for part in parts:
                    layer = getattr(layer, part)
                    
                self.hooks[layer_name] = LayerHook(layer)
                print(f"Setup hook for layer: {layer_name}")
            except Exception as e:
                print(f"Failed to setup hook for {layer_name}: {e}")
    
    def setup_auto_hooks(self, tag_idx):
        """

        Automatically detect the most responsive layers for a specific tag.

        This simplified version selects a few key layers based on model architecture.

        """
        # Close any existing hooks
        self.close_hooks()
        
        # If we already have layer weights from initialization, use those
        if self.layers_to_hook and self.layer_weights:
            for layer_name in self.layers_to_hook:
                try:
                    # Try to get layer by navigating the model hierarchy
                    parts = layer_name.split('.')
                    layer = self.model
                    for part in parts:
                        layer = getattr(layer, part)
                        
                    self.hooks[layer_name] = LayerHook(layer)
                    print(f"Setup hook for layer: {layer_name}")
                except Exception as e:
                    print(f"Failed to setup hook for {layer_name}: {e}")
            
            return self.layer_weights
        
        # Otherwise, detect layers automatically
        # Get all named modules
        all_layers = []
        for name, module in self.model.named_modules():
            if not name:  # Skip empty name (the model itself)
                continue
                
            # Only consider certain layer types that typically have meaningful features
            if isinstance(module, (nn.Conv2d, nn.Linear, nn.BatchNorm2d)):
                all_layers.append((name, module))
        
        # If the model is too large, select strategic layers
        selected_layers = []
        layer_weights = {}
        
        if len(all_layers) > 30:
            # For large models, select a subset of layers
            # 1. Try to find classifier/final layer
            classifier_layers = [(name, module) for name, module in all_layers 
                               if any(x in name.lower() for x in ["classifier", "fc", "linear", "output", "logits"])]
            if classifier_layers:
                selected_layers.append(classifier_layers[-1])
                layer_weights[classifier_layers[-1][0]] = 1.0  # Highest weight
            
            # 2. Find some mid to late convolutional layers
            conv_layers = [(name, module) for name, module in all_layers if isinstance(module, nn.Conv2d)]
            if conv_layers:
                # Take some layers from the second half
                half_idx = len(conv_layers) // 2
                selected_idx = [half_idx, 3*len(conv_layers)//4, -1]  # middle, 3/4, and last
                for idx in selected_idx:
                    if idx < len(conv_layers) and conv_layers[idx] not in selected_layers:
                        selected_layers.append(conv_layers[idx])
                        # Later layers get higher weights
                        pos = selected_idx.index(idx)
                        layer_weights[conv_layers[idx][0]] = 0.5 + 0.5 * (pos / max(1, len(selected_idx) - 1))
        else:
            # For smaller models, use more layers
            # Take a sample across the network depth
            step = max(1, len(all_layers) // 5)
            indices = list(range(0, len(all_layers), step))
            if len(all_layers) - 1 not in indices:
                indices.append(len(all_layers) - 1)  # Always include the last layer
                
            for idx in indices:
                selected_layers.append(all_layers[idx])
                # Later layers get higher weights
                layer_weights[all_layers[idx][0]] = 0.5 + 0.5 * (idx / max(1, len(all_layers) - 1))
        
        # Create hooks for selected layers
        print(f"Setting up {len(selected_layers)} auto-detected layers:")
        for name, module in selected_layers:
            self.hooks[name] = LayerHook(module)
            print(f"  - {name} (weight: {layer_weights.get(name, 0.5):.2f})")
        
        return layer_weights
    
    def close_hooks(self):
        """Clean up hooks to avoid memory leaks."""
        for hook in self.hooks.values():
            hook.close()
        self.hooks.clear()
    
    def total_variation_loss(self, img):
        """

        Total variation loss for smoother images but preserving edges.

        Modified version that better preserves strong edges.

        """
        diff_y = torch.abs(img[:, :, 1:, :] - img[:, :, :-1, :])
        diff_x = torch.abs(img[:, :, :, 1:] - img[:, :, :, :-1])
        
        # Use a more gentle version of total variation that allows for edges
        # but still penalizes noise (using square root reduces the penalty on large differences)
        tv = torch.mean(torch.sqrt(diff_y + 1e-8)) + torch.mean(torch.sqrt(diff_x + 1e-8))
        return tv
    
    def create_fft_spectrum_initializer(self, size, batch_size=1):
        """Enhanced frequency domain initialization for better essence generation with more color and coverage"""
        fft_size = size // 2 + 1
        
        # Initialize frequency components with a natural image prior
        # This biases toward more natural-looking essences
        spectrum_scale = torch.zeros(batch_size, 3, size, fft_size, 2, device=self.device)
        
        # Use 1/f spectrum characteristic of natural images (pink noise)
        for h in range(size):
            for w in range(fft_size):
                # Calculate distance from DC component
                dist = np.sqrt((h/size)**2 + (w/fft_size)**2) + 1e-5
                # Pink noise falls off as 1/f, but with higher amplitude for better initial coverage
                weight = 1.0 / dist
                # Add random phase but weighted amplitude - increased amplitude for better initial values
                spectrum_scale[:, :, h, w, 0] = torch.randn(batch_size, 3, device=self.device) * weight * 0.15  # Increased from 0.05
                spectrum_scale[:, :, h, w, 1] = torch.randn(batch_size, 3, device=self.device) * weight * 0.15  # Increased from 0.05
        
        # Initialize DC component (average color) with higher values for better color saturation
        # Give distinct colors to each channel for more vibrant initialization
        spectrum_scale[:, 0, 0, 0, 0] = 0.5  # Red channel
        spectrum_scale[:, 1, 0, 0, 0] = 0.4  # Green channel
        spectrum_scale[:, 2, 0, 0, 0] = 0.6  # Blue channel
        spectrum_scale[:, :, 0, 0, 1] = 0
        
        spectrum_scale.requires_grad_(True)
        
        # Phase component - increased for more variation
        spectrum_shift = torch.randn(batch_size, 3, size, fft_size, 2, device=self.device) * 0.05  # Increased from 0.02
        spectrum_shift.requires_grad_(True)
        
        return spectrum_scale, spectrum_shift
    
    def create_spectrum_weights(self, size, decay_power=1.0):
        """Create weights for the spectrum that emphasize lower frequencies but preserve more details."""
        freqs_x = torch.fft.rfftfreq(size).view(1, -1).to(self.device)
        freqs_y = torch.fft.fftfreq(size).view(-1, 1).to(self.device)
        
        dist_from_center = torch.sqrt(freqs_x**2 + freqs_y**2)
        
        # Modified weight calculation that allows for more mid-frequency details
        # and preserves more high frequencies for better detail
        weights = 1.0 / (dist_from_center + 1e-8) ** decay_power
        
        # Significantly increase mid-frequency weights for more details and coverage
        mid_freq_mask = (dist_from_center > 0.05) & (dist_from_center < 0.3)
        weights = weights * (1.0 + 1.0 * mid_freq_mask.float())  # Increased from 0.5
        
        # Add some weight to high frequencies for texture details
        high_freq_mask = (dist_from_center >= 0.3) & (dist_from_center < 0.7)
        weights = weights * (1.0 + 0.3 * high_freq_mask.float())  # New addition
        
        weights = weights / weights.max()
        weights[0, 0] = 0.8  # Higher DC component for better color coherence (increased from 0.7)
        
        return weights
    
    def fft_to_rgb(self, spectrum_scale, spectrum_shift, size, spectrum_weight=None):
        """Convert FFT spectrum parameters to an RGB image."""
        batch_size = spectrum_scale.shape[0]
        
        if spectrum_weight is not None:
            spectrum_scale = spectrum_scale * spectrum_weight.unsqueeze(0).unsqueeze(0).unsqueeze(-1)
        
        image = torch.zeros(batch_size, 3, size, size, device=self.device)
        
        spectrum_complex = torch.complex(
            spectrum_scale[..., 0], 
            spectrum_scale[..., 1]
        )
        
        phase_shift = torch.complex(
            torch.cos(spectrum_shift[..., 0]), 
            torch.sin(spectrum_shift[..., 1])
        )
        spectrum_complex = spectrum_complex * phase_shift
        
        for b in range(batch_size):
            for c in range(3):
                channel_spectrum = spectrum_complex[b, c]
                channel_image = torch.fft.irfft2(channel_spectrum, s=(size, size))
                
                channel_min = channel_image.min()
                channel_max = channel_image.max()
                if channel_max > channel_min:
                    channel_image = (channel_image - channel_min) / (channel_max - channel_min)
                else:
                    channel_image = torch.zeros_like(channel_image)
                
                image[b, c] = channel_image
        
        return image
    
    def apply_color_correlation(self, image):
        """

        Apply color correlation to produce more vibrant, colorful images.

        """
        batch_size, _, height, width = image.shape
        
        # 1. Apply basic color correlation
        flat_image = image.view(batch_size, 3, -1)
        correlated = torch.matmul(self.color_correlation_matrix, flat_image)
        correlated_image = correlated.view(batch_size, 3, height, width)
        
        # 2. Apply stronger color boost (increase saturation)
        # Calculate luminance (0.3R + 0.59G + 0.11B)
        luminance = 0.3 * image[:, 0:1] + 0.59 * image[:, 1:2] + 0.11 * image[:, 2:3]
        
        # Boost colors away from luminance (enhances saturation)
        # Increase the boost factor for more vibrant colors
        boosted_image = luminance + (correlated_image - luminance) * self.color_boost * 1.5  # Apply additional 1.5x boost
        
        # 3. Apply a gentle S-curve for better contrast
        # This helps make the colors "pop" more
        boosted_image = 0.5 + torch.tanh((boosted_image - 0.5) * 2) * 0.5
        
        # Ensure values are in [0, 1] range
        boosted_image = torch.clamp(boosted_image, 0, 1)
        
        return boosted_image
    
    def apply_transforms(self, img):
        """Apply random transformations to the image for robustness."""
        batch_size, c, h, w = img.shape
        
        # 1. Padding
        pad = 16
        padded = F.pad(img, (pad, pad, pad, pad), mode='reflect')
        
        # 2. Random jitter
        jitter = 16
        h_jitter = torch.randint(-jitter, jitter + 1, (batch_size,), device=self.device)
        w_jitter = torch.randint(-jitter, jitter + 1, (batch_size,), device=self.device)
        
        # Create sampling grid
        rows = torch.arange(h, device=self.device).view(1, 1, -1, 1).repeat(batch_size, 1, 1, w)
        cols = torch.arange(w, device=self.device).view(1, 1, 1, -1).repeat(batch_size, 1, h, 1)
        
        rows = rows + h_jitter.view(-1, 1, 1, 1) + pad
        cols = cols + w_jitter.view(-1, 1, 1, 1) + pad
        
        # Get transformed image (simplified implementation)
        grid_h = torch.clamp(rows, 0, padded.shape[2] - 1).long()
        grid_w = torch.clamp(cols, 0, padded.shape[3] - 1).long()
        
        # Apply second jitter for more randomness
        jitter2 = 8
        h_jitter2 = torch.randint(-jitter2, jitter2 + 1, (batch_size,), device=self.device)
        w_jitter2 = torch.randint(-jitter2, jitter2 + 1, (batch_size,), device=self.device)
        
        grid_h = torch.clamp(grid_h + h_jitter2.view(-1, 1, 1, 1), 0, padded.shape[2] - 1).long()
        grid_w = torch.clamp(grid_w + w_jitter2.view(-1, 1, 1, 1), 0, padded.shape[3] - 1).long()
        
        # Gather values
        transformed = torch.zeros_like(img)
        for b in range(batch_size):
            transformed[b] = padded[b, :, grid_h[b, 0], grid_w[b, 0]]
        
        return transformed
    
    def add_spatial_prior(self, img, strength=0.05):
        """

        Add a spatial prior to encourage character-like structures with better composition

        and fuller image coverage.

        """
        batch_size, c, h, w = img.shape
        
        # Create normalized coordinate grids
        y_indices = torch.arange(h, device=self.device).float()
        x_indices = torch.arange(w, device=self.device).float()
        
        y = (2.0 * y_indices / h) - 1.0  # Normalize to [-1, 1]
        x = (2.0 * x_indices / w) - 1.0  # Normalize to [-1, 1]
        
        # Expand to 2D grid
        y_grid = y.view(-1, 1).repeat(1, w)
        x_grid = x.view(1, -1).repeat(h, 1)
        
        # Center bias (gentler in the middle to allow more coverage)
        center_dist = torch.sqrt(x_grid.pow(2) + y_grid.pow(2))
        # Wider center bias for better coverage
        center_value = torch.exp(-0.8 * center_dist)  # Reduced from -1.5 for wider coverage
        
        # Full-image utilization bias (higher values further from edge)
        edge_dist_x = torch.min(torch.abs(x_grid - 1.0), torch.abs(x_grid + 1.0))
        edge_dist_y = torch.min(torch.abs(y_grid - 1.0), torch.abs(y_grid + 1.0))
        edge_dist = torch.min(edge_dist_x, edge_dist_y)
        edge_value = torch.clamp(edge_dist * 5.0, 0.2, 1.0)  # Higher values away from edges
        
        # Rule of thirds with wider peaks (subtle enhancement at thirds points)
        thirds_x = torch.exp(-10 * (x_grid - 1/3).pow(2)) + torch.exp(-10 * (x_grid + 1/3).pow(2))  # Reduced from -30
        thirds_y = torch.exp(-10 * (y_grid - 1/3).pow(2)) + torch.exp(-10 * (y_grid + 1/3).pow(2))  # Reduced from -30
        thirds_value = (thirds_x + thirds_y) / 2
        
        # Combine the different priors with rebalanced weights to favor coverage
        prior = 0.4 * center_value + 0.4 * edge_value + 0.2 * thirds_value  # More weight on edge_value for better coverage
        
        # Normalize the prior
        prior = prior / prior.max()
        
        # Expand to match the input dimensions
        prior = prior.unsqueeze(0).unsqueeze(0)
        prior = prior.repeat(batch_size, c, 1, 1)
        
        # Apply the prior with increased strength
        result = img * (1.0 - strength*1.5) + prior * strength*1.5  # Increased strength by 50%
        
        return result
 
    def get_layer_activations(self, tag_idx, layer_weights):
        """

        Get activations from all hooked layers for the target tag.

        Returns a weighted sum of activations based on layer weights.

        """
        activation_sum = 0.0
        
        for layer_name, hook in self.hooks.items():
            if hook.features is None:
                continue
                
            # Get weight for this layer
            weight = layer_weights.get(layer_name, 0.5)
            
            # Handle different layer types
            if len(hook.features.shape) <= 2:
                # For fully connected layers, focus on the target class logit
                if hook.features.size(1) > tag_idx:
                    activation = hook.features[0, tag_idx].item()
                    activation_sum += weight * activation
            else:
                # For convolutional layers, focus on overall activation strength
                channel_means = hook.features.mean(dim=[2, 3])
                
                # Make sure we don't request more channels than exist
                num_channels = min(5, channel_means.size(1))
                _, top_indices = torch.topk(channel_means, num_channels)
                
                # Process each top channel individually
                for idx in range(min(3, len(top_indices[0]))):  # Use up to 3 top channels
                    channel_idx = top_indices[0, idx]
                    channel_activation = hook.features[:, channel_idx].mean().item()
                    # Weight decreases for less important channels
                    channel_weight = 1.0 if idx == 0 else (0.5 if idx == 1 else 0.25)
                    activation_sum += weight * channel_activation * channel_weight
        
        return activation_sum
    
    def generate_essence(self, tag_idx, image_size=512, return_score=True, progress_callback=None):
        """Generate an essence visualization using enhanced techniques."""
        # Get tag name for logging (if available)
        tag_name = self.tag_to_name.get(tag_idx, f"Tag {tag_idx}") if self.tag_to_name else f"Tag {tag_idx}"
        print(f"Generating enhanced essence for '{tag_name}'...")
        
        # Auto-detect and set up hooks for responsive layers if not already set up
        layer_weights = self.layer_weights or self.setup_auto_hooks(tag_idx)
        
        # Determine scale sizes
        scale_sizes = []
        for s in range(self.scales):
            # Start small and progressively increase size
            scale_size = max(32, image_size // (2 ** (self.scales - s - 1)))
            scale_sizes.append(scale_size)
        
        print(f"Processing scales: {scale_sizes}")
        
        # Create frequency spectrum weights
        spectrum_weights = {}
        for size in scale_sizes:
            spectrum_weights[size] = self.create_spectrum_weights(size, decay_power=self.decay_power)
        
        # Track best result
        best_score = -float('inf')
        best_img = None
        
        # Process each scale independently
        for scale_idx, size in enumerate(scale_sizes):
            # Initialize parameters for this scale
            spectrum_scale, spectrum_shift = self.create_fft_spectrum_initializer(size)
            
            # Create optimizer
            optimizer = torch.optim.Adam([spectrum_scale, spectrum_shift], lr=self.lr)
            
            # Use learning rate scheduler for better convergence
            scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
                optimizer, 
                T_max=self.iterations,
                eta_min=self.lr * 0.1
            )
            
            # Current scale's spectrum weights
            current_weights = spectrum_weights[size]
            
            # Iterations for this scale 
            iterations = self.iterations
            
            # Epoch tracking for early stopping
            no_improvement_streak = 0
            plateau_threshold = 128  # Increased from 64 - give more iterations before stopping
            scale_best_score = -float('inf')
            scale_best_img = None
            
            for i in range(iterations):
                # Clear gradients
                optimizer.zero_grad()
                
                # Convert FFT parameters to RGB image
                img = self.fft_to_rgb(spectrum_scale, spectrum_shift, size, current_weights)
                
                # # Apply color correlation with boosted colors for anime-style
                # img = self.apply_color_correlation(img)
                
                # Add spatial prior to encourage character-like patterns with better coverage
                if size >= 32:
                    # Apply stronger spatial prior for better coverage
                    img = self.add_spatial_prior(img, strength=0.25)  # Increased from 0.15
                
                # Apply transformations for robustness
                if size >= 32:
                    img = self.apply_transforms(img)
                
                # Reset hooks
                for hook in self.hooks.values():
                    hook.features = None
                
                # Forward pass
                outputs = self.model(img)
                
                # Get target tag activation in final layer
                if isinstance(outputs, (list, tuple)):
                    predictions = outputs[0]
                else:
                    predictions = outputs
                
                # Get tag activation from final layer
                tag_activation = predictions[0, tag_idx]
                
                # Get activations from earlier layers
                layer_activation = self.get_layer_activations(tag_idx, layer_weights)
                
                # Combined loss to maximize activations
                # Increase weight on layer_activation for better feature emphasis
                activation_loss = -(tag_activation + 1.5 + layer_activation * 2.0)  # Double layer activation weight
                
                # Add regularization term - total variation for smoothness
                # Reduce TV weight slightly to allow more details and coverage
                tv_loss = self.total_variation_loss(img) * (self.tv_weight * 0.7)  # Reduce to 70% of original
                
                # Total loss
                total_loss = activation_loss + tv_loss
                
                # Backpropagation
                total_loss.backward()
                
                # Update parameters
                optimizer.step()
                scheduler.step()
                
                # Track best result for this scale
                current_score = tag_activation.item()
                if current_score > scale_best_score + 1e-4:
                    scale_best_score = current_score
                    scale_best_img = img.detach().clone()
                    no_improvement_streak = 0
                else:
                    no_improvement_streak += 1
                
                # More patient early stopping
                if no_improvement_streak >= plateau_threshold:
                    print(f"Early stopping at iteration {i}/{iterations} due to plateau")
                    break
                
                # Report progress
                if progress_callback and i % max(1, iterations // 10) == 0:
                    progress_callback(
                        scale_idx=scale_idx, 
                        scale_count=len(scale_sizes),
                        iter_idx=i,
                        iter_count=iterations,
                        score=current_score
                    )
            
            print(f"Scale {scale_idx+1}/{len(scale_sizes)} completed. Score: {scale_best_score:.4f}")
            
            # Update overall best if this scale improved the score
            if scale_best_score > best_score:
                best_score = scale_best_score
                # For the final scale, keep the full resolution image
                if scale_idx == len(scale_sizes) - 1:
                    best_img = scale_best_img
                # Otherwise, upscale to the final size for the return value
                else:
                    with torch.no_grad():
                        best_img = F.interpolate(scale_best_img, size=(image_size, image_size), 
                                            mode='bilinear', align_corners=False)
        
        # In case all scales failed, create an empty image
        if best_img is None:
            final_img = torch.zeros((1, 3, image_size, image_size), device=self.device)
        else:
            final_img = best_img
        
        # Convert to PIL image
        pil_img = to_pil_image(final_img[0].cpu())
        
        # Clean up hooks
        self.close_hooks()
        
        if return_score:
            return pil_img, best_score
        else:
            return pil_img

# Utility Functions for Model Analysis and Layer Selection

def get_model_layers(model):
    """Utility function to get all available layers in a model."""
    layers = []
    for name, _ in model.named_modules():
        if name:  # Skip empty name (the model itself)
            layers.append(name)
    return layers

def get_key_layers(model, max_layers=15):
    """

    Get a curated list of the most relevant layers for visualization.

    """
    all_layers = get_model_layers(model)
    
    # For models with hundreds of layers, we need to be selective
    if len(all_layers) > 30:
        # Extract patterns to identify layer types
        block_patterns = {}
        
        # Find common patterns in layer names
        for layer in all_layers:
            # Extract the main component (e.g., "backbone.features")
            parts = layer.split(".")
            if len(parts) >= 2:
                prefix = ".".join(parts[:2])
                if prefix not in block_patterns:
                    block_patterns[prefix] = []
                block_patterns[prefix].append(layer)
        
        # Now select representative layers from each major block
        key_layers = {
            "early": [],
            "middle": [],
            "late": []
        }
        
        # For each major block, select layers at strategic positions
        for prefix, layers in block_patterns.items():
            if len(layers) > 3:  # Only process significant blocks
                # Sort by natural depth (assuming numerical components indicate depth)
                layers.sort(key=lambda x: [int(s) if s.isdigit() else s for s in re.findall(r'\d+|\D+', x)])
                
                # Get layers at strategic positions
                early = layers[0]
                middle = layers[len(layers) // 2]
                late = layers[-1]
                
                key_layers["early"].append(early)
                key_layers["middle"].append(middle)
                key_layers["late"].append(late)
        
        # Ensure we don't have too many layers
        # If we need to reduce further, prioritize middle and late layers
        flattened = []
        for _, group_layers in key_layers.items():
            flattened.extend(group_layers)
        
        if len(flattened) > max_layers:
            # Calculate how many to keep from each group
            total = len(flattened)
            # Prioritize keeping late layers (for character recognition)
            late_count = min(len(key_layers["late"]), max_layers // 3)
            # Allocate remaining slots between early and middle
            remaining = max_layers - late_count
            middle_count = min(len(key_layers["middle"]), remaining // 2)
            early_count = min(len(key_layers["early"]), remaining - middle_count)
            
            # Take only the needed number from each category
            key_layers["early"] = key_layers["early"][:early_count]
            key_layers["middle"] = key_layers["middle"][:middle_count]
            key_layers["late"] = key_layers["late"][:late_count]
    else:
        # For simpler models, use standard distribution
        n = len(all_layers)
        key_layers = {
            "early": all_layers[:n//3][:3],  # First few layers
            "middle": all_layers[n//3:2*n//3][:4],  # Middle layers
            "late": all_layers[2*n//3:][:3]  # Last few layers
        }
    
    # Try to identify the classifier/final layer
    classifier_layers = [layer for layer in all_layers if any(x in layer.lower() 
                      for x in ["classifier", "fc", "linear", "output", "logits", "head"])]
    if classifier_layers:
        key_layers["classifier"] = [classifier_layers[-1]]
    
    return key_layers

def get_suggested_layers(model, layer_type="balanced"):
    """

    Get suggested layers based on the desired feature type.

    """
    key_layers = get_key_layers(model)
    
    # Flatten all layers for reference
    all_key_layers = []
    for layers in key_layers.values():
        all_key_layers.extend(layers)
    
    # Choose layers based on the requested emphasis
    if layer_type == "low":
        # Focus on early visual features (textures, patterns, colors)
        selected = key_layers.get("early", [])
        # Add one middle layer for stability
        if "middle" in key_layers and key_layers["middle"]:
            selected.append(key_layers["middle"][0])
    
    elif layer_type == "mid":
        # Focus on mid-level features (parts, components)
        selected = key_layers.get("middle", [])
        # Add one early layer for context
        if "early" in key_layers and key_layers["early"]:
            selected.append(key_layers["early"][-1])
    
    elif layer_type == "high":
        # Focus on high-level semantic features (objects, characters)
        selected = key_layers.get("late", [])
        selected.extend(key_layers.get("classifier", []))
        # Add one middle layer for context
        if "middle" in key_layers and key_layers["middle"]:
            selected.append(key_layers["middle"][-1])
    
    else:  # balanced
        # Use a mix of early, middle and late layers
        selected = []
        for category in ["early", "middle", "late", "classifier"]:
            if category in key_layers and key_layers[category]:
                # Take one from each category
                selected.append(key_layers[category][0])
                # For middle and late, also take the last one if different
                if category in ["middle", "late"] and len(key_layers[category]) > 1:
                    selected.append(key_layers[category][-1])
    
    # Ensure we have at least one layer
    if not selected and all_key_layers:
        selected = [all_key_layers[-1]]  # Use the last layer as fallback
    
    return selected

def get_quality_level(score):
    """

    Determine the quality level of an essence based on its score

    """
    for level in reversed(list(ESSENCE_QUALITY_LEVELS.keys())):
        if score >= ESSENCE_QUALITY_LEVELS[level]["threshold"]:
            return level
    return "ZAYIN"  # Default to lowest level

def get_essence_cost(rarity):
    """

    Calculate the cost to generate an essence image based on tag rarity

    """
    return ESSENCE_COSTS.get(rarity, 100)  # Default to 100 if rarity unknown

# Game UI and Integration Functions

def initialize_essence_settings():
    """Initialize essence generator settings if not already present"""
    if 'essence_custom_settings' not in st.session_state:
        st.session_state.essence_custom_settings = DEFAULT_ESSENCE_SETTINGS.copy()

def save_essence_to_game_folder(image, tag, score, quality_level):
    """

    Save the generated essence image to a persistent game folder

    

    Args:

        image: PIL Image of the essence

        tag: The tag name

        score: The generation score

        quality_level: The quality classification (ZAYIN, TETH, etc.)

        

    Returns:

        Path to the saved image

    """
    # Create game folder if it doesn't exist
    game_folder = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "game_data")
    essence_folder = os.path.join(game_folder, "essences")
    
    # Make directories if they don't exist
    os.makedirs(essence_folder, exist_ok=True)
    
    # Create filename with quality level and score
    safe_tag = tag.replace('/', '_').replace('\\', '_').replace(' ', '_')
    timestamp = time.strftime("%Y%m%d_%H%M%S")
    filename = f"{safe_tag}_{quality_level}_{score:.2f}_{timestamp}.png"
    filepath = os.path.join(essence_folder, filename)
    
    # Save the image
    image.save(filepath)
    
    return filepath

def generate_essence_for_tag(tag, model, dataset, custom_settings=None):
    """

    Generate an essence image for a specific tag using the improved generator

    

    Args:

        tag: The tag name or index

        model: The model to use

        dataset: The dataset containing tag information

        custom_settings: Optional dictionary with custom generation settings

        

    Returns:

        PIL Image of the generated essence, score, quality level

    """
    
    print(f"\n=== Starting essence generation for tag '{tag}' ===")
    
    # Check if tag is discovered or a manual tag
    is_manual_tag = hasattr(st.session_state, 'manual_tags') and tag in st.session_state.manual_tags
    is_discovered = hasattr(st.session_state, 'discovered_tags') and tag in st.session_state.discovered_tags
    
    if not is_discovered and not is_manual_tag:
        st.error(f"Tag '{tag}' has not been discovered yet.")
        return None, 0, None
    
    # Get tag rarity and calculate cost
    if is_discovered:
        rarity = st.session_state.discovered_tags[tag].get("rarity", "Canard")
    elif is_manual_tag:
        rarity = st.session_state.manual_tags[tag].get("rarity", "Canard")
    else:
        rarity = "Canard"
    
    # Calculate cost based on rarity
    cost = get_essence_cost(rarity)
    
    # Check if player has enough Enkephalin
    if st.session_state.enkephalin < cost:
        st.error(f"Not enough {ENKEPHALIN_CURRENCY_NAME} to generate this essence. You need {cost} {ENKEPHALIN_ICON} but have {st.session_state.enkephalin} {ENKEPHALIN_ICON}.")
        return None, 0, None
    
    # Use provided settings or defaults
    settings = custom_settings or DEFAULT_ESSENCE_SETTINGS.copy()
    print(f"Using settings: {settings}")
    
    # Extract settings
    iterations = settings.get("iterations", 256)
    scales = settings.get("scales", 5)
    layer_emphasis = settings.get("layer_emphasis", "auto")
    
    # UI containers for progress
    preview_container = st.empty()
    progress_container = st.empty()
    message_container = st.empty()
    
    # If multiple layer emphasis types are requested, show tabs
    if layer_emphasis == "compare":
        message_container.info("Generating essences with different layer emphasis types...")
        tabs_container = st.empty()
        
        tabs = None
        tab_images = {}
        best_score = -float('inf')
        best_image = None
        best_emphasis = None
    
    try:
        # Show generation information
        if layer_emphasis != "compare":
            message_container.info(f"Generating essence for '{tag}' with {layer_emphasis} layer emphasis...")
        
        # Progress callback function for essence generation
        def progress_callback(scale_idx, scale_count, iter_idx, iter_count, score):
            # Update progress bar
            progress = ((scale_idx * iter_count) + iter_idx) / (scale_count * iter_count)
            progress_container.progress(progress, f"Scale {scale_idx+1}/{scale_count}, Iteration {iter_idx}/{iter_count}")
            message_container.info(f"Current score: {score:.4f}")
            
            # Print status to console too
            if iter_idx % 20 == 0:
                print(f"Progress: Scale {scale_idx+1}/{scale_count}, Iteration {iter_idx}/{iter_count}, Score: {score:.4f}")
        
        # Convert tag name to index if needed
        tag_idx = None
        
        # Try to find tag in various places
        if isinstance(tag, str):
            print(f"Converting tag name '{tag}' to index...")
            # Standard lookup methods
            if hasattr(dataset, 'tag_to_idx') and tag in dataset.tag_to_idx:
                tag_idx = dataset.tag_to_idx[tag]
                print(f"Found tag index from dataset.tag_to_idx: {tag_idx}")
            
            # Session state metadata lookup
            if tag_idx is None and hasattr(st.session_state, 'metadata') and 'tag_to_idx' in st.session_state.metadata:
                tag_idx = st.session_state.metadata['tag_to_idx'].get(tag)
                if tag_idx is not None:
                    print(f"Found tag index from session_state.metadata['tag_to_idx']: {tag_idx}")
            
            # Lookup from idx_to_tag
            if tag_idx is None and hasattr(st.session_state, 'metadata') and 'idx_to_tag' in st.session_state.metadata:
                idx_to_tag = st.session_state.metadata['idx_to_tag']
                tag_to_idx = {v: int(k) for k, v in idx_to_tag.items()}
                tag_idx = tag_to_idx.get(tag)
                if tag_idx is not None:
                    print(f"Found tag index from inverted idx_to_tag: {tag_idx}")
                
                # Try case-insensitive
                if tag_idx is None:
                    tag_lower = tag.lower()
                    for t, idx in tag_to_idx.items():
                        if t.lower() == tag_lower:
                            tag_idx = idx
                            print(f"Found tag index using case-insensitive match: {tag_idx}")
                            break
            
            # For manual tags that aren't in the model's tag list, 
            # we might need to find a semantically similar tag or use a generic index
            if tag_idx is None and is_manual_tag:
                # For demonstration, we could map manual tags to known similar tags
                manual_tag_mapping = {
                    "hatsune_miku": "hatsune_miku", # Try to find this in the dataset
                    "lamp": "lamp",                 # Try to find this in the dataset
                    "blue_gloves": "gloves",        # Fallback to a more generic tag
                }
                
                fallback_tag = manual_tag_mapping.get(tag)
                if fallback_tag:
                    # Try to find the fallback tag
                    if hasattr(dataset, 'tag_to_idx') and fallback_tag in dataset.tag_to_idx:
                        tag_idx = dataset.tag_to_idx[fallback_tag]
                        print(f"Using fallback tag '{fallback_tag}' with index: {tag_idx}")
                    
                    # Try session state metadata
                    if tag_idx is None and hasattr(st.session_state, 'metadata') and 'tag_to_idx' in st.session_state.metadata:
                        tag_idx = st.session_state.metadata['tag_to_idx'].get(fallback_tag)
                        if tag_idx is not None:
                            print(f"Using fallback tag '{fallback_tag}' with index: {tag_idx}")
                
                # If still not found, use a generic index (this is a last resort)
                if tag_idx is None:
                    # Try to use a category-specific generic tag
                    if "hair" in tag.lower():
                        generic_tag = "blue_hair"  # A common tag that might be in the model
                    elif "gloves" in tag.lower():
                        generic_tag = "gloves"
                    elif "miku" in tag.lower():
                        generic_tag = "twintails"  # A feature of Hatsune Miku
                    else:
                        generic_tag = "1girl"  # A very common tag
                    
                    # Try to find this generic tag
                    if hasattr(dataset, 'tag_to_idx') and generic_tag in dataset.tag_to_idx:
                        tag_idx = dataset.tag_to_idx[generic_tag]
                        print(f"Using generic tag '{generic_tag}' with index: {tag_idx}")
                    elif hasattr(st.session_state, 'metadata') and 'tag_to_idx' in st.session_state.metadata:
                        tag_idx = st.session_state.metadata['tag_to_idx'].get(generic_tag)
                        if tag_idx is not None:
                            print(f"Using generic tag '{generic_tag}' with index: {tag_idx}")
            
            # If still not found, show error
            if tag_idx is None:
                st.error(f"Tag '{tag}' index not found. Cannot generate essence.")
                print(f"ERROR: Tag '{tag}' index not found")
                return None, 0, None
        else:
            # Tag is already an index
            tag_idx = tag
            print(f"Using provided tag index: {tag_idx}")
            
        # Generate the essence - either one or multiple depending on settings
        if layer_emphasis == "compare":
            # Generate essences with different layer emphasis types
            results = try_different_layer_emphasis(
                model=model,
                tag_idx=tag_idx,
                tag_name=tag,
                image_size=512,  # Always use 512x512
                iterations=iterations,
                scales=scales,
                progress_callback=progress_callback
            )
            
            # Create tabs to display the results
            tab_names = []
            tab_contents = []
            
            for emphasis_type, result in results.items():
                image = result["image"]
                score = result["score"]
                
                # Store results
                tab_images[emphasis_type] = image
                
                # Track best score
                if score > best_score:
                    best_score = score
                    best_image = image
                    best_emphasis = emphasis_type
                
                # Add to tabs
                tab_names.append(f"{emphasis_type.capitalize()} ({score:.2f})")
                tab_contents.append(image)
            
            # Show tabs with results
            tabs = tabs_container.tabs(tab_names)
            for i, tab in enumerate(tabs):
                with tab:
                    st.image(tab_contents[i], caption=f"Essence with {list(results.keys())[i]} layer emphasis", use_container_width=True)
            
            # Use the best-scored image as the final result
            image = best_image
            score = best_score
            
            # Show which emphasis type worked best
            st.success(f"Best results achieved with {best_emphasis} layer emphasis (score: {best_score:.2f})")
            
        else:
            # Generate single essence with specified layer emphasis
            color_boost = 1.5  # Default color boost
            tv_weight = 5e-4   # Default TV weight
            
            # Adjust parameters based on layer emphasis
            if layer_emphasis == "low":
                color_boost = 1.3
                tv_weight = 2e-4
            elif layer_emphasis == "high":
                color_boost = 1.7
                tv_weight = 8e-4
                
            image, score = generate_essence_with_emphasis(
                model=model,
                tag_idx=tag_idx,
                tag_name=tag,
                image_size=512,  # Always use 512x512
                iterations=iterations,
                scales=scales,
                progress_callback=progress_callback,
                layer_emphasis=layer_emphasis,
                color_boost=color_boost,
                tv_weight=tv_weight
            )
        
        # Determine quality level
        quality_level = get_quality_level(score)
        
        # Deduct enkephalin cost
        st.session_state.enkephalin -= cost
        st.session_state.game_stats["enkephalin_spent"] = st.session_state.game_stats.get("enkephalin_spent", 0) + cost
        
        # Increment essence counter
        st.session_state.game_stats["essences_generated"] = st.session_state.game_stats.get("essences_generated", 0) + 1
        
        # Save to persistent location
        filepath = save_essence_to_game_folder(image, tag, score, quality_level)
        print(f"Saved essence to: {filepath}")
        
        # Update UI with final result if not showing comparison tabs
        if layer_emphasis != "compare":
            preview_container.image(image, caption=f"Essence of '{tag}' - Quality: {quality_level}", width=512)
        
        # Clear progress elements
        progress_container.empty()
        message_container.empty()
        
        # Store in session state
        if 'generated_essences' not in st.session_state:
            st.session_state.generated_essences = {}
        
        st.session_state.generated_essences[tag] = {
            "path": filepath,
            "score": score,
            "quality": quality_level,
            "rarity": rarity,
            "settings": settings,
            "generated_time": time.strftime("%Y-%m-%d %H:%M:%S")
        }
        
        # Show success message
        st.success(f"Successfully generated {quality_level} essence for '{tag}' with score {score:.4f}! Spent {cost} {ENKEPHALIN_ICON}")
        print(f"=== Essence generation complete for '{tag}' ===\n")

        # Add at the end of generate_essence_for_tag function, just before returning:
        tag_storage.save_essence_state(session_state=st.session_state)
        
        return image, score, quality_level
    
    except Exception as e:
        st.error(f"Error generating essence: {str(e)}")
        print(f"EXCEPTION in generate_essence_for_tag: {str(e)}")
        import traceback
        err_traceback = traceback.format_exc()
        print(err_traceback)
        st.code(err_traceback)
        return None, 0, None

def display_essence_generator():
    """

    Display the essence generator interface

    """
    # Initialize settings
    initialize_essence_settings()
    
    st.title("🎨 Tag Essence Generator")
    st.write("Generate visual representations of what the AI model recognizes for specific tags.")
    
    # Add detailed explanation of what essences are for
    with st.expander("What are Tag Essences & How to Use Them", expanded=True):
        st.markdown("""

        ### 💡 Understanding Tag Essences

        

        Tag Essences are visual representations of what the AI model recognizes for specific tags. They can be extremely valuable for your tag collection strategy!

        

        **How to use Tag Essences:**

        1. **Generate a high-quality essence** for a tag you want to collect more of (only available on tags discovered in the library)

        2. **Save the essence image** to your computer

        3. **Upload the essence image** back into the tagger

        4. The tagger will **almost always detect the original tag**

        5. It will often also **detect related rare tags** from the same category

        

        **Strategic Value:**

        - Character essences can help unlock other tags associated with that character

        - Category essences can help discover rare tags within that category

        - High-quality essences (WAW, ALEPH) have the strongest effect

        

        **This is why Enkephalin costs are high** - essences are powerful tools that can help you discover rare tags much more efficiently than random image scanning!

        """)
    
    
    # Check for model availability
    model_available = hasattr(st.session_state, 'model')
    if not model_available:
        st.warning("Model not available. You can browse your tags but cannot generate essences.")
    
    # Create tabs for the different sections
    tabs = st.tabs(["Generate Essence", "My Essences"])
    
    with tabs[0]:
        # Check for pending generation from previous interaction
        if hasattr(st.session_state, 'selected_tag') and st.session_state.selected_tag:
            tag = st.session_state.selected_tag
            
            st.subheader(f"Generating Essence for '{tag}'")
            
            # Generate the essence
            image, score, quality = generate_essence_for_tag(
                tag, 
                st.session_state.model, 
                st.session_state.model.dataset,
                st.session_state.essence_custom_settings
            )
            
            # Show usage tips if successful
            if image is not None:
                with st.expander("Essence Usage", expanded=True):
                    st.markdown("""

                    💡 **Tag Essence Usage Tips:**

                    1. Look for similar patterns, colors, and elements in real images

                    2. The essence reveals what features the AI model recognizes for this tag

                    3. Use this as inspiration when creating or finding images to get this tag

                    """)
            else:
                st.error("Essence generation failed. Please check the error messages above and try again with different settings.")
            
            # Clear selected tag
            st.session_state.selected_tag = None
        else:
            # Show the interface to select a tag
            selected_tag = display_essence_generation_interface(model_available)
            
            # If a tag was selected, store it for the next run and rerun
            if selected_tag:
                st.session_state.selected_tag = selected_tag
                st.rerun()
    
    with tabs[1]:
        display_saved_essences()

def save_essence_to_game_folder(image, tag, score, quality_level):
    """

    Save the generated essence image to a persistent game folder

    

    Args:

        image: PIL Image of the essence

        tag: The tag name

        score: The generation score

        quality_level: The quality classification (ZAYIN, TETH, etc.)

        

    Returns:

        Path to the saved image

    """
    # Create game folder paths with better structure
    base_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
    game_data_dir = os.path.join(base_dir, "game_data")
    essence_folder = os.path.join(game_data_dir, "essences")
    
    # Make sure all parent directories exist
    os.makedirs(game_data_dir, exist_ok=True)
    os.makedirs(essence_folder, exist_ok=True)
    
    # Organize essences by quality level for easier browsing
    quality_folder = os.path.join(essence_folder, quality_level)
    os.makedirs(quality_folder, exist_ok=True)
    
    # Create filename with more details and better organization
    safe_tag = tag.replace('/', '_').replace('\\', '_').replace(' ', '_')
    timestamp = time.strftime("%Y%m%d_%H%M%S")
    filename = f"{safe_tag}_{score:.2f}_{timestamp}.png"
    filepath = os.path.join(quality_folder, filename)
    
    # Save the image
    image.save(filepath)
    
    print(f"Saved essence to: {filepath}")
    return filepath

def essence_folder_path():
    """Get the path to the essence folder, creating it if necessary"""
    base_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
    game_data_dir = os.path.join(base_dir, "game_data")
    essence_folder = os.path.join(game_data_dir, "essences")
    
    # Make sure all directories exist
    os.makedirs(game_data_dir, exist_ok=True)
    os.makedirs(essence_folder, exist_ok=True)
    
    return essence_folder

def display_saved_essences():
    """Display the user's saved essence images"""
    st.subheader("My Generated Essences")
    
    if not hasattr(st.session_state, 'generated_essences') or not st.session_state.generated_essences:
        st.info("You haven't generated any essences yet. Go to the Generate tab to create some!")
        return
        
    # Add usage instructions at the top
    st.markdown("""

    ### How to Use Your Essences

    

    1. **Click on any essence image** to open it in full size

    2. **Save the image** to your computer (right-click → Save image)

    3. **Go to the Scan Images tab** and upload the saved essence

    4. The tagger will likely detect the original tag and potentially related rare tags!

    

    Higher quality essences (WAW, ALEPH) generally produce the best results.

    """)
    
    # Get the essence folder path
    essence_dir = essence_folder_path()
    
    # Try to locate any missing files
    for tag, info in st.session_state.generated_essences.items():
        if "path" in info and not os.path.exists(info["path"]):
            # Try to find the file in the essence directory
            quality = info.get("quality", "ZAYIN")
            quality_dir = os.path.join(essence_dir, quality)
            
            if os.path.exists(quality_dir):
                # Check for files with this tag name
                safe_tag = tag.replace('/', '_').replace('\\', '_').replace(' ', '_')
                matching_files = [f for f in os.listdir(quality_dir) if f.startswith(safe_tag)]
                
                if matching_files:
                    # Use the most recent file if there are multiple
                    matching_files.sort(reverse=True)
                    info["path"] = os.path.join(quality_dir, matching_files[0])
                    print(f"Reconnected essence for {tag} to {info['path']}")
    
    # List essences by quality level
    essences_by_quality = {}
    for tag, info in st.session_state.generated_essences.items():
        quality = info.get("quality", "ZAYIN")  # Default to lowest if not set
        if quality not in essences_by_quality:
            essences_by_quality[quality] = []
        essences_by_quality[quality].append((tag, info))
    
    # Check if any essences exist on disk but are not tracked in session state
    try:
        untracked_essences = {}
        
        for quality in ESSENCE_QUALITY_LEVELS.keys():
            quality_dir = os.path.join(essence_dir, quality)
            if os.path.exists(quality_dir):
                essence_files = os.listdir(quality_dir)
                
                # Filter to only show PNG files
                essence_files = [f for f in essence_files if f.lower().endswith('.png')]
                
                if essence_files:
                    # Check if any of these files aren't in our tracked essences
                    for filename in essence_files:
                        # Extract tag name from filename
                        parts = filename.split('_')
                        if len(parts) >= 2:
                            tag = parts[0].replace('_', ' ')
                            
                            # Check if file is already tracked
                            is_tracked = False
                            for tracked_tag, tracked_info in st.session_state.generated_essences.items():
                                if "path" in tracked_info and os.path.basename(tracked_info["path"]) == filename:
                                    is_tracked = True
                                    break
                            
                            if not is_tracked:
                                if quality not in untracked_essences:
                                    untracked_essences[quality] = []
                                untracked_essences[quality].append((tag, {
                                    "path": os.path.join(quality_dir, filename),
                                    "quality": quality,
                                    "discovered_on_disk": True
                                }))
    except Exception as e:
        print(f"Error checking for untracked essences: {e}")
    
    # Combine tracked and untracked essences
    for quality, essences in untracked_essences.items():
        if quality not in essences_by_quality:
            essences_by_quality[quality] = []
        for tag, info in essences:
            # Only add if we don't already have this tag in this quality level
            if not any(tracked_tag == tag for tracked_tag, _ in essences_by_quality[quality]):
                essences_by_quality[quality].append((tag, info))
    
    # Show essences from highest to lowest quality
    for quality in list(ESSENCE_QUALITY_LEVELS.keys())[::-1]:
        if quality in essences_by_quality:
            essences = essences_by_quality[quality]
            color = ESSENCE_QUALITY_LEVELS[quality]["color"]
            
            with st.expander(f"{quality} Essences ({len(essences)})", expanded=quality in ["ALEPH", "WAW"]):
                # Create grid layout
                cols = st.columns(3)
                for i, (tag, info) in enumerate(sorted(essences, key=lambda x: x[1].get("score", 0), reverse=True)):
                    col_idx = i % 3
                    with cols[col_idx]:
                        try:
                            # Try to load the image from path
                            if "path" in info and os.path.exists(info["path"]):
                                image = Image.open(info["path"])
                                rarity = info.get("rarity", "Canard")
                                score = info.get("score", 0)
                                
                                # Get color for rarity
                                rarity_color = RARITY_LEVELS.get(rarity, {}).get("color", "#AAAAAA")
                                
                                # Display the image with metadata
                                st.image(image, caption=tag, use_container_width=True)
                                
                                # Use special styling for rare tags
                                if rarity == "Impuritas Civitas":
                                    st.markdown(f"""

                                    <span style='color:{color};font-weight:bold;'>{quality}</span> | 

                                    <span style='animation: rainbow-text 4s linear infinite;font-weight:bold;'>{rarity}</span> | 

                                    Score: {score:.2f}

                                    """, unsafe_allow_html=True)
                                elif rarity == "Star of the City":
                                    st.markdown(f"""

                                    <span style='color:{color};font-weight:bold;'>{quality}</span> | 

                                    <span style='color:{rarity_color};text-shadow:0 0 3px gold;font-weight:bold;'>{rarity}</span> | 

                                    Score: {score:.2f}

                                    """, unsafe_allow_html=True)
                                elif rarity == "Urban Nightmare":
                                    st.markdown(f"""

                                    <span style='color:{color};font-weight:bold;'>{quality}</span> | 

                                    <span style='color:{rarity_color};text-shadow:0 0 1px #FF5722;font-weight:bold;'>{rarity}</span> | 

                                    Score: {score:.2f}

                                    """, unsafe_allow_html=True)
                                elif rarity == "Urban Plague":
                                    st.markdown(f"""

                                    <span style='color:{color};font-weight:bold;'>{quality}</span> | 

                                    <span style='color:{rarity_color};text-shadow:0 0 1px #9C27B0;font-weight:bold;'>{rarity}</span> | 

                                    Score: {score:.2f}

                                    """, unsafe_allow_html=True)
                                else:
                                    st.markdown(f"""

                                    <span style='color:{color};font-weight:bold;'>{quality}</span> | 

                                    <span style='color:{rarity_color};font-weight:bold;'>{rarity}</span> | 

                                    Score: {score:.2f}

                                    """, unsafe_allow_html=True)
                                
                                # Add file info
                                if "discovered_on_disk" in info and info["discovered_on_disk"]:
                                    st.info("Found on disk (not in session state)")
                                
                                # Add button to open folder
                                if st.button(f"Open Folder", key=f"open_folder_{tag}_{quality}"):
                                    folder_path = os.path.dirname(info["path"])
                                    try:
                                        # Try different methods to open folder based on platform
                                        if os.name == 'nt':  # Windows
                                            os.startfile(folder_path)
                                        elif os.name == 'posix':  # macOS or Linux
                                            import subprocess
                                            if 'darwin' in os.sys.platform:  # macOS
                                                subprocess.call(['open', folder_path])
                                            else:  # Linux
                                                subprocess.call(['xdg-open', folder_path])
                                        st.success(f"Opened folder: {folder_path}")
                                    except Exception as e:
                                        st.error(f"Could not open folder: {str(e)}")
                                        # Provide the path for manual navigation
                                        st.code(folder_path)
                            else:
                                # Could not find image
                                st.warning(f"Image file not found: {info.get('path', 'No path available')}")
                                
                                # Show quality and tag name
                                st.markdown(f"""

                                <span style='color:{color};font-weight:bold;'>{quality}</span> | {tag}

                                """, unsafe_allow_html=True)
                                
                                # Only add reconnect button if we have some metadata
                                if "rarity" in info and "score" in info:
                                    if st.button(f"Reconnect File", key=f"reconnect_{tag}_{quality}"):
                                        # Update path in session state
                                        safe_tag = tag.replace('/', '_').replace('\\', '_').replace(' ', '_')
                                        score = info.get("score", 0)
                                        quality_dir = os.path.join(essence_dir, quality)
                                        
                                        # Create directory if it doesn't exist
                                        os.makedirs(quality_dir, exist_ok=True)
                                        
                                        # Set a path - user will need to manually add the image
                                        timestamp = time.strftime("%Y%m%d_%H%M%S")
                                        filename = f"{safe_tag}_{score:.2f}_{timestamp}.png"
                                        info["path"] = os.path.join(quality_dir, filename)
                                        
                                        st.info(f"Please save your image to this location: {info['path']}")
                                        st.session_state.generated_essences[tag] = info
                                        tag_storage.save_essence_state(session_state=st.session_state)
                                        st.rerun()
                        
                        except Exception as e:
                            st.write(f"Error loading {tag}: {str(e)}")
    
    # Add option to clean up missing files
    st.divider()
    if st.button("Clean Up Missing Files", help="Remove entries for essences where the file no longer exists"):
        # Find all entries with missing files
        to_remove = []
        for tag, info in st.session_state.generated_essences.items():
            if "path" in info and not os.path.exists(info["path"]):
                to_remove.append(tag)
        
        # Remove them
        for tag in to_remove:
            del st.session_state.generated_essences[tag]
        
        # Save state
        tag_storage.save_essence_state(session_state=st.session_state)
        
        if to_remove:
            st.success(f"Removed {len(to_remove)} entries with missing files")
        else:
            st.success("No missing files found")
        
        st.rerun()

def display_essence_generation_interface(model_available):
    """Display the interface for generating new essences"""
    # Initialize manual tags
    initialize_manual_tags()
    
    st.subheader("Generate Tag Essence")
    st.write("Select a tag to generate its essence. Higher quality essences can help unlock rare related tags when uploaded back into the tagger.")
    
    # Settings column
    col1, col2 = st.columns(2)
    
    with col1:
        # Simple settings
        st.write("Generation Settings:")
        
        # Basic settings
        scales = st.slider("Scales", 1, 5, DEFAULT_ESSENCE_SETTINGS["scales"], 
                         help="More scales produce more detailed essences")
        
        iterations = st.slider("Iterations", 64, 2048, DEFAULT_ESSENCE_SETTINGS["iterations"], 64, 
                             help="More iterations improve quality")
        
        # Layer emphasis selection - all options available including comparison
        layer_emphasis = st.selectbox(
            "Feature Targeting", 
            options=["auto", "balanced", "high", "mid", "low", "compare", "custom"],
            index=0,  # Default to auto
            format_func=lambda x: {
                "auto": "Auto-detect (best for each tag)",
                "balanced": "Balanced (mix of features)",
                "high": "High-level (characters, objects)",
                "mid": "Mid-level (parts, components)",
                "low": "Low-level (textures, patterns)",
                "compare": "Compare different approaches",
                "custom": "Custom layer selection"
            }.get(x, x),
            help="Controls which model features to emphasize in the essence"
        )
        
        # Custom layer selection if needed
        custom_layers = []
        if layer_emphasis == "custom" and model_available:
            st.write("Select Custom Layers:")
            
            # Get key layers (simplified approach)
            key_layers = get_key_layers(st.session_state.model, max_layers=15)
            
            # Show categories (early, middle, late, classifier)
            for category, layers in key_layers.items():
                if layers:
                    category_name = {
                        "early": "Early Layers (textures, colors)",
                        "middle": "Middle Layers (parts, components)",
                        "late": "Late Layers (objects, characters)",
                        "classifier": "Classifier (final recognition)"
                    }.get(category, category.capitalize())
                    
                    with st.expander(f"{category_name}", expanded=category in ["late", "classifier"]):
                        select_all = st.checkbox(f"Select all {category} layers", 
                                              key=f"select_all_{category}")
                        
                        for layer in layers:
                            # Create a shortened display name
                            parts = layer.split(".")
                            display_name = f"...{parts[-2]}.{parts[-1]}" if len(parts) > 3 else layer
                                
                            if select_all or st.checkbox(display_name, key=f"layer_{layer}"):
                                custom_layers.append(layer)
            
            # Show selected layers
            if custom_layers:
                st.success(f"Selected {len(custom_layers)} layers")
            else:
                st.warning("Please select at least one layer")
        
        # Save settings
        st.session_state.essence_custom_settings = {
            "scales": scales,
            "iterations": iterations,
            "image_size": 512,  # Fixed
            "lr": 0.03,  # Lower learning rate for better results
            "layer_emphasis": layer_emphasis,
            "custom_layers": custom_layers
        }
    
    with col2:
        # Show quality level descriptions
        st.write("Quality Levels:")
        for level, info in ESSENCE_QUALITY_LEVELS.items():
            st.markdown(f"""

            <div style="padding:5px;margin-bottom:5px;border-radius:4px;background-color:rgba({int(info['color'][1:3], 16)},{int(info['color'][3:5], 16)},{int(info['color'][5:7], 16)},0.1);border-left:3px solid {info['color']}">

                <span style="color:{info['color']};font-weight:bold;">{level}</span> ({info['threshold']:.0f} Score+): {info['description']}

            </div>

            """, unsafe_allow_html=True)
        
        # Feature targeting explanation
        st.write("Feature Targeting Explanation:")
        st.markdown("""

        ℹ️ **Feature targeting affects what the visualization emphasizes:**

        """)
    
    # Show current Enkephalin
    st.markdown(f"### Your {ENKEPHALIN_CURRENCY_NAME}: **{st.session_state.enkephalin}** {ENKEPHALIN_ICON}")
    st.divider()
    
    # Add CSS for animations matching tag collection display
    st.markdown("""

    <style>

    @keyframes rainbow-text {

        0% { color: red; }

        14% { color: orange; }

        28% { color: yellow; }

        42% { color: green; }

        57% { color: blue; }

        71% { color: indigo; }

        85% { color: violet; }

        100% { color: red; }

    }

    

    .impuritas-text {

        font-weight: bold;

        animation: rainbow-text 4s linear infinite;

    }

    

    @keyframes glow-text {

        0% { text-shadow: 0 0 2px gold; }

        50% { text-shadow: 0 0 6px gold; }

        100% { text-shadow: 0 0 2px gold; }

    }

    

    .star-text {

        color: #FFEB3B;

        text-shadow: 0 0 3px gold;

        animation: glow-text 2s infinite;

        font-weight: bold;

    }

    

    @keyframes pulse-text {

        0% { opacity: 0.8; }

        50% { opacity: 1; }

        100% { opacity: 0.8; }

    }

    

    .nightmare-text {

        color: #FF9800;

        text-shadow: 0 0 1px #FF5722;

        animation: pulse-text 3s infinite;

        font-weight: bold;

    }

    

    .plague-text {

        color: #9C27B0;

        text-shadow: 0 0 1px #9C27B0;

        font-weight: bold;

    }

    

    .category-section {

        margin-top: 20px;

        margin-bottom: 30px;

        padding: 10px;

        border-radius: 5px;

        border-left: 5px solid;

    }

    </style>

    """, unsafe_allow_html=True)
    
    # ----- NEW TAG COLLECTION DISPLAY -----
    
    # Gather all tags for essence generation
    all_tags = []
    
    # Process discovered tags
    if hasattr(st.session_state, 'discovered_tags'):
        for tag, info in st.session_state.discovered_tags.items():
            tag_info = {
                "tag": tag,
                "rarity": info.get("rarity", "Unknown"),
                "category": info.get("category", "unknown"),
                "source": "discovered",
                "library_floor": info.get("library_floor", ""),
                "discovery_time": info.get("discovery_time", "")
            }
            all_tags.append(tag_info)
    
    # Process manual tags
    if hasattr(st.session_state, 'manual_tags'):
        for tag, info in st.session_state.manual_tags.items():
            tag_info = {
                "tag": tag,
                "rarity": info.get("rarity", "Special"),
                "category": info.get("category", "special"),
                "source": "manual",
                "description": info.get("description", "")
            }
            all_tags.append(tag_info)
    
    # Count tags by rarity
    rarity_counts = {}
    for info in all_tags:
        rarity = info["rarity"]
        if rarity not in rarity_counts:
            rarity_counts[rarity] = 0
        rarity_counts[rarity] += 1
    
    # Display rarity counts at the top
    st.subheader("Available Tags for Essence Generation")
    st.write(f"You have {len(all_tags)} tags available for essence generation. Collect more from the library!")
    
    # Display rarity distribution
    rarity_cols = st.columns(len(rarity_counts))
    for i, (rarity, count) in enumerate(sorted(rarity_counts.items(), 
                                      key=lambda x: list(RARITY_LEVELS.keys()).index(x[0]) if x[0] in RARITY_LEVELS else 999)):
        with rarity_cols[i]:
            # Get color with fallback
            color = RARITY_LEVELS.get(rarity, {}).get("color", "#888888")
            
            # Apply special styling based on rarity
            style = f"color:{color};font-weight:bold;"
            class_name = ""
            
            if rarity == "Impuritas Civitas":
                class_name = "grid-impuritas"
            elif rarity == "Star of the City":
                class_name = "grid-star"
            elif rarity == "Urban Nightmare":
                class_name = "grid-nightmare"
            elif rarity == "Urban Plague":
                class_name = "grid-plague"
            
            if class_name:
                st.markdown(
                    f"<div style='text-align:center;'><span class='{class_name}' style='font-weight:bold;'>{rarity.capitalize()}</span><br>{count}</div>",
                    unsafe_allow_html=True
                )
            else:
                st.markdown(
                    f"<div style='text-align:center;'><span style='{style}'>{rarity.capitalize()}</span><br>{count}</div>",
                    unsafe_allow_html=True
                )
    
    # Search box for all tags
    search_term = st.text_input("Search tags", "", key="essence_search_tags")
    
    # Sort options
    sort_options = ["Category (rarest first)", "Rarity", "Discovery Time"]
    selected_sort = st.selectbox("Sort tags by:", sort_options, key="essence_tags_sort")
    
    # Filter tags by search term if provided
    if search_term:
        all_tags = [info for info in all_tags if search_term.lower() in info["tag"].lower()]
    
    selected_tag = None
    
    # Sort and group tags based on selection
    if selected_sort == "Category (rarest first)":
        # Group tags by category
        categories = {}
        for info in all_tags:
            category = info["category"]
            if category not in categories:
                categories[category] = []
            categories[category].append(info)
        
        # Display tags by category in expanders
        for category, tags in sorted(categories.items()):
            # Get rarity order for sorting
            rarity_order = list(reversed(RARITY_LEVELS.keys()))
            
            # Sort tags by rarity (rarest first)
            def get_rarity_index(info):
                rarity = info["rarity"]
                if rarity in rarity_order:
                    return len(rarity_order) - rarity_order.index(rarity)
                return 0
            
            sorted_tags = sorted(tags, key=get_rarity_index, reverse=True)
            
            # Check if category has any rare tags
            has_rare_tags = any(info["rarity"] in ["Impuritas Civitas", "Star of the City"] 
                               for info in sorted_tags)
            
            # Get category info if available
            category_display = category.capitalize()
            if category in TAG_CATEGORIES:
                category_info = TAG_CATEGORIES[category]
                icon = category_info.get("icon", "")
                color = category_info.get("color", "#888888")
                category_display = f"<span style='color:{color};'>{icon} {category.capitalize()}</span>"
            
            # Create header with information about rare tags if present
            header = f"{category_display} ({len(tags)} tags)"
            if has_rare_tags:
                header += " ✨ Contains rare tags!"
                
            # Display category header and expander
            st.markdown(header, unsafe_allow_html=True)
            with st.expander("Show/Hide", expanded=has_rare_tags):
                # Create grid layout for tags
                cols = st.columns(3)
                for i, info in enumerate(sorted_tags):
                    with cols[i % 3]:
                        tag = info["tag"]
                        rarity = info["rarity"]
                        source = info["source"]
                        
                        # Get rarity color
                        rarity_color = RARITY_LEVELS.get(rarity, {}).get("color", "#AAAAAA")
                        
                        # Check if this tag has an essence already
                        has_essence = hasattr(st.session_state, 'generated_essences') and tag in st.session_state.generated_essences
                        
                        # Get cost for this tag
                        cost = get_essence_cost(rarity)
                        can_afford = st.session_state.enkephalin >= cost
                        
                        # Format tag display with special styling
                        if rarity == "Impuritas Civitas":
                            tag_display = f'<span class="impuritas-text">{tag}</span>'
                        elif rarity == "Star of the City":
                            tag_display = f'<span class="star-text">{tag}</span>'
                        elif rarity == "Urban Nightmare":
                            tag_display = f'<span class="nightmare-text">{tag}</span>'
                        elif rarity == "Urban Plague":
                            tag_display = f'<span class="plague-text">{tag}</span>'
                        else:
                            tag_display = f'<span style="color:{rarity_color};font-weight:bold;">{tag}</span>'
                        
                        # Show tag with rarity badge and cost
                        st.markdown(
                            f'{tag_display} <span style="background-color:{rarity_color};color:white;padding:2px 6px;border-radius:10px;font-size:0.8em;">{rarity.capitalize()}</span> ({cost} {ENKEPHALIN_ICON})',
                            unsafe_allow_html=True
                        )
                        
                        # Show discovery details if available
                        if source == "discovered" and "library_floor" in info and info["library_floor"]:
                            st.markdown(f'<span style="font-size:0.85em;">Found in: {info["library_floor"]}</span>', 
                                      unsafe_allow_html=True)
                        elif source == "manual" and "description" in info and info["description"]:
                            st.markdown(f'<span style="font-size:0.85em;font-style:italic;">{info["description"]}</span>', 
                                      unsafe_allow_html=True)
                        
                        # Add generation button
                        button_label = "Generate" if not has_essence else "Regenerate ✓"
                        if st.button(button_label, key=f"gen_{tag}_{source}", disabled=not model_available or not can_afford):
                            selected_tag = tag
                            
    elif selected_sort == "Rarity":
        # Group tags by rarity
        rarity_groups = {}
        for info in all_tags:
            rarity = info["rarity"]
            if rarity not in rarity_groups:
                rarity_groups[rarity] = []
            rarity_groups[rarity].append(info)
        
        # Get ordered rarities (rarest first)
        ordered_rarities = list(RARITY_LEVELS.keys())
        ordered_rarities.reverse()  # Reverse to show rarest first
        
        # Add any rarities not in RARITY_LEVELS
        for rarity in rarity_groups.keys():
            if rarity not in ordered_rarities:
                ordered_rarities.append(rarity)
        
        # Display tags by rarity
        for rarity in ordered_rarities:
            if rarity in rarity_groups:
                tags = rarity_groups[rarity]
                color = RARITY_LEVELS.get(rarity, {}).get("color", "#AAAAAA")
                
                # Add special styling for rare rarities
                rarity_html = f"<span style='color:{color};font-weight:bold;'>{rarity.capitalize()}</span>"
                if rarity == "Impuritas Civitas":
                    rarity_html = f"<span style='animation:rainbow-text 4s linear infinite;font-weight:bold;'>{rarity.capitalize()}</span>"
                elif rarity == "Star of the City":
                    rarity_html = f"<span style='color:{color};text-shadow:0 0 3px gold;font-weight:bold;'>{rarity.capitalize()}</span>"
                elif rarity == "Urban Nightmare":
                    rarity_html = f"<span style='color:{color};text-shadow:0 0 1px #FF5722;font-weight:bold;'>{rarity.capitalize()}</span>"
                
                # First create the title with HTML, then use it in the expander
                st.markdown(f"### {rarity_html} ({len(tags)} tags)", unsafe_allow_html=True)
                with st.expander("Show/Hide", expanded=rarity in ["Impuritas Civitas", "Star of the City"]):
                    # Create grid layout for tags
                    cols = st.columns(3)
                    for i, info in enumerate(sorted(tags, key=lambda x: x["tag"])):
                        with cols[i % 3]:
                            tag = info["tag"]
                            source = info["source"]
                            
                            # Check if this tag has an essence already
                            has_essence = hasattr(st.session_state, 'generated_essences') and tag in st.session_state.generated_essences
                            
                            # Get cost for this tag
                            cost = get_essence_cost(rarity)
                            can_afford = st.session_state.enkephalin >= cost
                            
                            # Show tag with cost
                            st.markdown(f"**{tag}** ({cost} {ENKEPHALIN_ICON})")
                            
                            # Show discovery details if available
                            if source == "discovered" and "library_floor" in info and info["library_floor"]:
                                st.markdown(f'<span style="font-size:0.85em;">Found in: {info["library_floor"]}</span>', 
                                          unsafe_allow_html=True)
                            elif source == "manual" and "description" in info and info["description"]:
                                st.markdown(f'<span style="font-size:0.85em;font-style:italic;">{info["description"]}</span>', 
                                          unsafe_allow_html=True)
                            
                            # Add generation button
                            button_label = "Generate" if not has_essence else "Regenerate ✓"
                            if st.button(button_label, key=f"gen_{tag}_{source}", disabled=not model_available or not can_afford):
                                selected_tag = tag
                            
    elif selected_sort == "Discovery Time":
        # Filter to just discovered tags (manual tags don't have discovery time)
        discovered_tags = [info for info in all_tags if info["source"] == "discovered" and "discovery_time" in info]
        
        # Sort all tags by discovery time (newest first)
        sorted_tags = sorted(discovered_tags, key=lambda x: x["discovery_time"], reverse=True)
        
        # Group by date
        date_groups = {}
        for info in sorted_tags:
            time_str = info["discovery_time"]
            # Extract just the date part if timestamp has date and time
            date = time_str.split()[0] if " " in time_str else time_str
            
            if date not in date_groups:
                date_groups[date] = []
            date_groups[date].append(info)
        
        # Display tags grouped by discovery date
        for date, tags in date_groups.items():
            date_display = date if date else "Unknown date"
            st.markdown(f"### Discovered on {date_display} ({len(tags)} tags)")
            
            with st.expander("Show/Hide", expanded=date == list(date_groups.keys())[0]):  # Expand most recent by default
                # Create grid layout for tags
                cols = st.columns(3)
                for i, info in enumerate(tags):
                    with cols[i % 3]:
                        tag = info["tag"]
                        rarity = info["rarity"]
                        
                        # Get rarity color
                        rarity_color = RARITY_LEVELS.get(rarity, {}).get("color", "#AAAAAA")
                        
                        # Check if this tag has an essence already
                        has_essence = hasattr(st.session_state, 'generated_essences') and tag in st.session_state.generated_essences
                        
                        # Get cost for this tag
                        cost = get_essence_cost(rarity)
                        can_afford = st.session_state.enkephalin >= cost
                        
                        # Format tag display with special styling
                        if rarity == "Impuritas Civitas":
                            tag_display = f'<span class="impuritas-text">{tag}</span>'
                        elif rarity == "Star of the City":
                            tag_display = f'<span class="star-text">{tag}</span>'
                        elif rarity == "Urban Nightmare":
                            tag_display = f'<span class="nightmare-text">{tag}</span>'
                        elif rarity == "Urban Plague":
                            tag_display = f'<span class="plague-text">{tag}</span>'
                        else:
                            tag_display = f'<span style="color:{rarity_color};font-weight:bold;">{tag}</span>'
                        
                        # Show tag with rarity badge and cost
                        st.markdown(
                            f'{tag_display} <span style="background-color:{rarity_color};color:white;padding:2px 6px;border-radius:10px;font-size:0.8em;">{rarity.capitalize()}</span> ({cost} {ENKEPHALIN_ICON})',
                            unsafe_allow_html=True
                        )
                        
                        # Show discovery details
                        if "library_floor" in info and info["library_floor"]:
                            st.markdown(f'<span style="font-size:0.85em;">Found in: {info["library_floor"]}</span>', 
                                      unsafe_allow_html=True)
                        
                        # Add generation button
                        button_label = "Generate" if not has_essence else "Regenerate ✓"
                        if st.button(button_label, key=f"gen_{tag}_disc", disabled=not model_available or not can_afford):
                            selected_tag = tag
        
        # Show manual tags separately if we have any
        manual_tags = [info for info in all_tags if info["source"] == "manual"]
        if manual_tags:
            st.markdown("### Manual Tags")
            with st.expander("Show/Hide"):
                # Create grid layout for tags
                cols = st.columns(3)
                for i, info in enumerate(manual_tags):
                    with cols[i % 3]:
                        tag = info["tag"]
                        rarity = info["rarity"]
                        
                        # Get rarity color
                        rarity_color = RARITY_LEVELS.get(rarity, {}).get("color", "#AAAAAA")
                        
                        # Check if this tag has an essence already
                        has_essence = hasattr(st.session_state, 'generated_essences') and tag in st.session_state.generated_essences
                        
                        # Get cost for this tag
                        cost = get_essence_cost(rarity)
                        can_afford = st.session_state.enkephalin >= cost
                        
                        # Show tag with rarity badge and cost
                        st.markdown(f"**{tag}** ({cost} {ENKEPHALIN_ICON})")
                        
                        # Show description if available
                        if "description" in info and info["description"]:
                            st.markdown(f'<span style="font-size:0.85em;font-style:italic;">{info["description"]}</span>', 
                                      unsafe_allow_html=True)
                        
                        # Add generation button
                        button_label = "Generate" if not has_essence else "Regenerate ✓"
                        if st.button(button_label, key=f"gen_{tag}_manual", disabled=not model_available or not can_afford):
                            selected_tag = tag
    
    return selected_tag

def generate_essence_with_emphasis(model, tag_idx, tag_name=None, image_size=512, 

                            iterations=256, scales=3, progress_callback=None,

                            layer_emphasis="mid", color_boost=1.5, tv_weight=5e-4):
    """

    Generate an essence visualization with specific layer emphasis and enhancements.

    

    Args:

        model: Neural network model to visualize

        tag_idx: Index of the tag to visualize

        tag_name: Optional name of the tag (for logging)

        image_size: Size of output image (default: 512)

        iterations: Number of iterations per scale (default: 256)

        scales: Number of scales to use (default: 5)

        progress_callback: Optional callback for progress updates

        layer_emphasis: Type of layers to use ("auto", "balanced", "high", "mid", "low")

        color_boost: Factor for boosting color saturation (default: 1.5)

        tv_weight: Total variation weight (default: 5e-4)

        

    Returns:

        PIL Image of the generated essence and the activation score

    """
    # Create a tag-to-name mapping with the provided name
    tag_to_name = {tag_idx: tag_name} if tag_name else None
    
    # Determine layers to use if not auto
    layers_to_hook = None
    layer_weights = None
    
    if layer_emphasis != "auto":
        # Get layers based on the emphasis type
        layers_to_hook = get_suggested_layers(model, layer_emphasis)
        
        # Set layer weights based on position
        layer_weights = {}
        for i, layer in enumerate(layers_to_hook):
            # Base weight from position
            weight = 0.5 + 0.5 * (i / max(1, len(layers_to_hook) - 1))
            
            # Boost weight for classifier layers
            if any(x in layer.lower() for x in ["classifier", "fc", "linear", "output", "logits"]):
                weight *= 1.5
                
            layer_weights[layer] = weight
            
        print(f"Using {len(layers_to_hook)} {layer_emphasis}-level layers")
    else:
        print("Using auto layer detection")
    
    # Create instance of the improved generator
    generator = EssenceGenerator(
        model=model,
        tag_to_name=tag_to_name,
        iterations=iterations,
        scales=scales,
        learning_rate=0.05,  # Lower for better convergence
        decay_power=1.0,     # Stronger decay power for cleaner images
        tv_weight=tv_weight, # Customizable TV weight
        layers_to_hook=layers_to_hook,
        layer_weights=layer_weights,
        color_boost=color_boost  # Customizable color boost
    )
    
    # Generate the essence
    print(f"Generating essence for tag {tag_name or tag_idx} with {layer_emphasis} emphasis...")
    image, score = generator.generate_essence(
        tag_idx=tag_idx,
        image_size=image_size,
        return_score=True,
        progress_callback=progress_callback
    )
    
    print(f"Essence generation complete. Score: {score:.4f}")
    return image, score

def try_different_layer_emphasis(model, tag_idx, tag_name=None, image_size=512, 

                                iterations=256, scales=4, progress_callback=None):
    """

    Generate multiple essences with different layer emphasis types and return them all.

    

    Args:

        model: Neural network model to visualize

        tag_idx: Index of the tag to visualize

        tag_name: Optional name of the tag (for logging)

        image_size: Size of output image (default: 512)

        iterations: Number of iterations per scale (default: 256)

        scales: Number of scales to use (default: 4)

        progress_callback: Optional callback for progress updates

        

    Returns:

        Dictionary of PIL Images and scores for each layer emphasis type

    """
    emphasis_types = [
        {"name": "low", "color_boost": 1.3, "tv_weight": 2e-4},   # Low-level features (textures, colors)
        {"name": "mid", "color_boost": 1.5, "tv_weight": 5e-4},   # Mid-level features (parts, components)
        {"name": "high", "color_boost": 1.7, "tv_weight": 8e-4},  # High-level features (characters, objects)
    ]
    
    results = {}
    
    for emphasis in emphasis_types:
        print(f"\n=== Trying {emphasis['name']} layer emphasis ===")
        
        image, score = generate_essence_with_emphasis(
            model=model,
            tag_idx=tag_idx,
            tag_name=tag_name,
            image_size=image_size,
            iterations=iterations,
            scales=scales,
            progress_callback=progress_callback,
            layer_emphasis=emphasis["name"],
            color_boost=emphasis["color_boost"],
            tv_weight=emphasis["tv_weight"]
        )
        
        results[emphasis["name"]] = {
            "image": image,
            "score": score
        }
        
        print(f"=== Completed {emphasis['name']} layer emphasis with score {score:.4f} ===")
    
    return results