Spaces:
Paused
Paused
apolinario
commited on
Commit
•
a0bd9cc
1
Parent(s):
7f1aa40
Initial attempt Hypetron v2
Browse files- app.py +2346 -8
- requirements.txt +20 -1
app.py
CHANGED
@@ -1,11 +1,2349 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import gradio as gr
|
2 |
-
import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
else:
|
9 |
-
|
10 |
-
|
11 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
import argparse
|
3 |
+
import math
|
4 |
+
from pathlib import Path
|
5 |
+
import sys
|
6 |
+
import pandas as pd
|
7 |
+
from base64 import b64encode
|
8 |
+
from omegaconf import OmegaConf
|
9 |
+
from PIL import Image
|
10 |
+
from taming.models import cond_transformer, vqgan
|
11 |
+
import torch
|
12 |
+
from os.path import exists as path_exists
|
13 |
+
|
14 |
+
torch.cuda.empty_cache()
|
15 |
+
from torch import nn
|
16 |
+
import torch.optim as optim
|
17 |
+
from torch import optim
|
18 |
+
from torch.nn import functional as F
|
19 |
+
from torchvision import transforms
|
20 |
+
from torchvision.transforms import functional as TF
|
21 |
+
import torchvision.transforms as T
|
22 |
+
|
23 |
+
from CLIP import clip
|
24 |
import gradio as gr
|
25 |
+
import kornia.augmentation as K
|
26 |
+
import numpy as np
|
27 |
+
import subprocess
|
28 |
+
import imageio
|
29 |
+
from PIL import ImageFile, Image
|
30 |
+
import time
|
31 |
+
|
32 |
+
import hashlib
|
33 |
+
from PIL.PngImagePlugin import PngImageFile, PngInfo
|
34 |
+
import json
|
35 |
+
import IPython
|
36 |
+
from IPython.display import Markdown, display, Image, clear_output
|
37 |
+
import urllib.request
|
38 |
+
import random
|
39 |
+
from random import randint
|
40 |
+
from pathvalidate import sanitize_filename
|
41 |
+
from huggingface_hub import hf_hub_download
|
42 |
+
|
43 |
+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
44 |
+
print("Using device:", device)
|
45 |
+
|
46 |
+
vqgan_model = hf_hub_download(repo_id="boris/vqgan_f16_16384", filename="model.ckpt")
|
47 |
+
vqgan_config = hf_hub_download(repo_id="boris/vqgan_f16_16384", filename="config.yaml")
|
48 |
+
|
49 |
+
def load_vqgan_model(config_path, checkpoint_path):
|
50 |
+
config = OmegaConf.load(config_path)
|
51 |
+
if config.model.target == "taming.models.vqgan.VQModel":
|
52 |
+
model = vqgan.VQModel(**config.model.params)
|
53 |
+
model.eval().requires_grad_(False)
|
54 |
+
model.init_from_ckpt(checkpoint_path)
|
55 |
+
elif config.model.target == "taming.models.cond_transformer.Net2NetTransformer":
|
56 |
+
parent_model = cond_transformer.Net2NetTransformer(**config.model.params)
|
57 |
+
parent_model.eval().requires_grad_(False)
|
58 |
+
parent_model.init_from_ckpt(checkpoint_path)
|
59 |
+
model = parent_model.first_stage_model
|
60 |
+
elif config.model.target == "taming.models.vqgan.GumbelVQ":
|
61 |
+
model = vqgan.GumbelVQ(**config.model.params)
|
62 |
+
# print(config.model.params)
|
63 |
+
model.eval().requires_grad_(False)
|
64 |
+
model.init_from_ckpt(checkpoint_path)
|
65 |
+
else:
|
66 |
+
raise ValueError(f"unknown model type: {config.model.target}")
|
67 |
+
del model.loss
|
68 |
+
return model
|
69 |
+
model = load_vqgan_model(vqgan_config, vqgan_model).to(device)
|
70 |
+
perceptor = (
|
71 |
+
clip.load("ViT-B/32", jit=False)[0]
|
72 |
+
.eval()
|
73 |
+
.requires_grad_(False)
|
74 |
+
.to(device)
|
75 |
+
)
|
76 |
+
def run(user_input,num_steps, template, width,height):
|
77 |
+
#if uploaded_file is not None:
|
78 |
+
#uploaded_folder = f"{DefaultPaths.root_path}/uploaded"
|
79 |
+
#if not path_exists(uploaded_folder):
|
80 |
+
# os.makedirs(uploaded_folder)
|
81 |
+
#image_data = uploaded_file.read()
|
82 |
+
#f = open(f"{uploaded_folder}/{uploaded_file.name}", "wb")
|
83 |
+
#f.write(image_data)
|
84 |
+
#f.close()
|
85 |
+
#image_path = f"{uploaded_folder}/{uploaded_file.name}"
|
86 |
+
#pass
|
87 |
+
#else:
|
88 |
+
image_path = None
|
89 |
+
flavor = 'cumin'
|
90 |
+
|
91 |
+
args2 = argparse.Namespace(
|
92 |
+
prompt=user_input,
|
93 |
+
seed=int(seed),
|
94 |
+
sizex=width,
|
95 |
+
sizey=height,
|
96 |
+
flavor=flavor,
|
97 |
+
iterations=num_steps,
|
98 |
+
mse=True,
|
99 |
+
update=100,
|
100 |
+
template=template,
|
101 |
+
vqgan_model='ImageNet 16384',
|
102 |
+
seed_image=image_path,
|
103 |
+
image_file="progress.png",
|
104 |
+
#frame_dir=intermediary_folder,
|
105 |
+
)
|
106 |
+
if args2.seed is not None:
|
107 |
+
import torch
|
108 |
+
|
109 |
+
sys.stdout.write(f"Setting seed to {args2.seed} ...\n")
|
110 |
+
sys.stdout.flush()
|
111 |
+
import numpy as np
|
112 |
+
|
113 |
+
np.random.seed(args2.seed)
|
114 |
+
import random
|
115 |
+
|
116 |
+
random.seed(args2.seed)
|
117 |
+
# next line forces deterministic random values, but causes other issues with resampling (uncomment to see)
|
118 |
+
torch.manual_seed(args2.seed)
|
119 |
+
torch.cuda.manual_seed(args2.seed)
|
120 |
+
torch.cuda.manual_seed_all(args2.seed)
|
121 |
+
torch.backends.cudnn.deterministic = True
|
122 |
+
torch.backends.cudnn.benchmark = False
|
123 |
+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
124 |
+
print("Using device:", device)
|
125 |
+
|
126 |
+
def noise_gen(shape, octaves=5):
|
127 |
+
n, c, h, w = shape
|
128 |
+
noise = torch.zeros([n, c, 1, 1])
|
129 |
+
max_octaves = min(octaves, math.log(h) / math.log(2), math.log(w) / math.log(2))
|
130 |
+
for i in reversed(range(max_octaves)):
|
131 |
+
h_cur, w_cur = h // 2**i, w // 2**i
|
132 |
+
noise = F.interpolate(
|
133 |
+
noise, (h_cur, w_cur), mode="bicubic", align_corners=False
|
134 |
+
)
|
135 |
+
noise += torch.randn([n, c, h_cur, w_cur]) / 5
|
136 |
+
return noise
|
137 |
+
|
138 |
+
def sinc(x):
|
139 |
+
return torch.where(
|
140 |
+
x != 0, torch.sin(math.pi * x) / (math.pi * x), x.new_ones([])
|
141 |
+
)
|
142 |
+
|
143 |
+
def lanczos(x, a):
|
144 |
+
cond = torch.logical_and(-a < x, x < a)
|
145 |
+
out = torch.where(cond, sinc(x) * sinc(x / a), x.new_zeros([]))
|
146 |
+
return out / out.sum()
|
147 |
+
|
148 |
+
def ramp(ratio, width):
|
149 |
+
n = math.ceil(width / ratio + 1)
|
150 |
+
out = torch.empty([n])
|
151 |
+
cur = 0
|
152 |
+
for i in range(out.shape[0]):
|
153 |
+
out[i] = cur
|
154 |
+
cur += ratio
|
155 |
+
return torch.cat([-out[1:].flip([0]), out])[1:-1]
|
156 |
+
|
157 |
+
def resample(input, size, align_corners=True):
|
158 |
+
n, c, h, w = input.shape
|
159 |
+
dh, dw = size
|
160 |
+
|
161 |
+
input = input.view([n * c, 1, h, w])
|
162 |
+
|
163 |
+
if dh < h:
|
164 |
+
kernel_h = lanczos(ramp(dh / h, 2), 2).to(input.device, input.dtype)
|
165 |
+
pad_h = (kernel_h.shape[0] - 1) // 2
|
166 |
+
input = F.pad(input, (0, 0, pad_h, pad_h), "reflect")
|
167 |
+
input = F.conv2d(input, kernel_h[None, None, :, None])
|
168 |
+
|
169 |
+
if dw < w:
|
170 |
+
kernel_w = lanczos(ramp(dw / w, 2), 2).to(input.device, input.dtype)
|
171 |
+
pad_w = (kernel_w.shape[0] - 1) // 2
|
172 |
+
input = F.pad(input, (pad_w, pad_w, 0, 0), "reflect")
|
173 |
+
input = F.conv2d(input, kernel_w[None, None, None, :])
|
174 |
+
|
175 |
+
input = input.view([n, c, h, w])
|
176 |
+
return F.interpolate(input, size, mode="bicubic", align_corners=align_corners)
|
177 |
+
|
178 |
+
def lerp(a, b, f):
|
179 |
+
return (a * (1.0 - f)) + (b * f)
|
180 |
+
|
181 |
+
class ReplaceGrad(torch.autograd.Function):
|
182 |
+
@staticmethod
|
183 |
+
def forward(ctx, x_forward, x_backward):
|
184 |
+
ctx.shape = x_backward.shape
|
185 |
+
return x_forward
|
186 |
+
|
187 |
+
@staticmethod
|
188 |
+
def backward(ctx, grad_in):
|
189 |
+
return None, grad_in.sum_to_size(ctx.shape)
|
190 |
+
|
191 |
+
replace_grad = ReplaceGrad.apply
|
192 |
+
|
193 |
+
class ClampWithGrad(torch.autograd.Function):
|
194 |
+
@staticmethod
|
195 |
+
def forward(ctx, input, min, max):
|
196 |
+
ctx.min = min
|
197 |
+
ctx.max = max
|
198 |
+
ctx.save_for_backward(input)
|
199 |
+
return input.clamp(min, max)
|
200 |
+
|
201 |
+
@staticmethod
|
202 |
+
def backward(ctx, grad_in):
|
203 |
+
(input,) = ctx.saved_tensors
|
204 |
+
return (
|
205 |
+
grad_in * (grad_in * (input - input.clamp(ctx.min, ctx.max)) >= 0),
|
206 |
+
None,
|
207 |
+
None,
|
208 |
+
)
|
209 |
+
|
210 |
+
clamp_with_grad = ClampWithGrad.apply
|
211 |
+
|
212 |
+
def vector_quantize(x, codebook):
|
213 |
+
d = (
|
214 |
+
x.pow(2).sum(dim=-1, keepdim=True)
|
215 |
+
+ codebook.pow(2).sum(dim=1)
|
216 |
+
- 2 * x @ codebook.T
|
217 |
+
)
|
218 |
+
indices = d.argmin(-1)
|
219 |
+
x_q = F.one_hot(indices, codebook.shape[0]).to(d.dtype) @ codebook
|
220 |
+
return replace_grad(x_q, x)
|
221 |
+
|
222 |
+
class Prompt(nn.Module):
|
223 |
+
def __init__(self, embed, weight=1.0, stop=float("-inf")):
|
224 |
+
super().__init__()
|
225 |
+
self.register_buffer("embed", embed)
|
226 |
+
self.register_buffer("weight", torch.as_tensor(weight))
|
227 |
+
self.register_buffer("stop", torch.as_tensor(stop))
|
228 |
+
|
229 |
+
def forward(self, input):
|
230 |
+
input_normed = F.normalize(input.unsqueeze(1), dim=2)
|
231 |
+
embed_normed = F.normalize(self.embed.unsqueeze(0), dim=2)
|
232 |
+
dists = (
|
233 |
+
input_normed.sub(embed_normed).norm(dim=2).div(2).arcsin().pow(2).mul(2)
|
234 |
+
)
|
235 |
+
dists = dists * self.weight.sign()
|
236 |
+
return (
|
237 |
+
self.weight.abs()
|
238 |
+
* replace_grad(dists, torch.maximum(dists, self.stop)).mean()
|
239 |
+
)
|
240 |
+
|
241 |
+
def parse_prompt(prompt):
|
242 |
+
if prompt.startswith("http://") or prompt.startswith("https://"):
|
243 |
+
vals = prompt.rsplit(":", 1)
|
244 |
+
vals = [vals[0] + ":" + vals[1], *vals[2:]]
|
245 |
+
else:
|
246 |
+
vals = prompt.rsplit(":", 1)
|
247 |
+
vals = vals + ["", "1", "-inf"][len(vals) :]
|
248 |
+
return vals[0], float(vals[1]), float(vals[2])
|
249 |
+
|
250 |
+
def one_sided_clip_loss(input, target, labels=None, logit_scale=100):
|
251 |
+
input_normed = F.normalize(input, dim=-1)
|
252 |
+
target_normed = F.normalize(target, dim=-1)
|
253 |
+
logits = input_normed @ target_normed.T * logit_scale
|
254 |
+
if labels is None:
|
255 |
+
labels = torch.arange(len(input), device=logits.device)
|
256 |
+
return F.cross_entropy(logits, labels)
|
257 |
+
|
258 |
+
class EMATensor(nn.Module):
|
259 |
+
"""implmeneted by Katherine Crowson"""
|
260 |
+
|
261 |
+
def __init__(self, tensor, decay):
|
262 |
+
super().__init__()
|
263 |
+
self.tensor = nn.Parameter(tensor)
|
264 |
+
self.register_buffer("biased", torch.zeros_like(tensor))
|
265 |
+
self.register_buffer("average", torch.zeros_like(tensor))
|
266 |
+
self.decay = decay
|
267 |
+
self.register_buffer("accum", torch.tensor(1.0))
|
268 |
+
self.update()
|
269 |
+
|
270 |
+
@torch.no_grad()
|
271 |
+
def update(self):
|
272 |
+
if not self.training:
|
273 |
+
raise RuntimeError("update() should only be called during training")
|
274 |
+
|
275 |
+
self.accum *= self.decay
|
276 |
+
self.biased.mul_(self.decay)
|
277 |
+
self.biased.add_((1 - self.decay) * self.tensor)
|
278 |
+
self.average.copy_(self.biased)
|
279 |
+
self.average.div_(1 - self.accum)
|
280 |
+
|
281 |
+
def forward(self):
|
282 |
+
if self.training:
|
283 |
+
return self.tensor
|
284 |
+
return self.average
|
285 |
+
|
286 |
+
class MakeCutoutsCustom(nn.Module):
|
287 |
+
def __init__(self, cut_size, cutn, cut_pow, augs):
|
288 |
+
super().__init__()
|
289 |
+
self.cut_size = cut_size
|
290 |
+
# tqdm.write(f"cut size: {self.cut_size}")
|
291 |
+
self.cutn = cutn
|
292 |
+
self.cut_pow = cut_pow
|
293 |
+
self.noise_fac = 0.1
|
294 |
+
self.av_pool = nn.AdaptiveAvgPool2d((self.cut_size, self.cut_size))
|
295 |
+
self.max_pool = nn.AdaptiveMaxPool2d((self.cut_size, self.cut_size))
|
296 |
+
self.augs = nn.Sequential(
|
297 |
+
K.RandomHorizontalFlip(p=Random_Horizontal_Flip),
|
298 |
+
K.RandomSharpness(Random_Sharpness, p=Random_Sharpness_P),
|
299 |
+
K.RandomGaussianBlur(
|
300 |
+
(Random_Gaussian_Blur),
|
301 |
+
(Random_Gaussian_Blur_W, Random_Gaussian_Blur_W),
|
302 |
+
p=Random_Gaussian_Blur_P,
|
303 |
+
),
|
304 |
+
K.RandomGaussianNoise(p=Random_Gaussian_Noise_P),
|
305 |
+
K.RandomElasticTransform(
|
306 |
+
kernel_size=(
|
307 |
+
Random_Elastic_Transform_Kernel_Size_W,
|
308 |
+
Random_Elastic_Transform_Kernel_Size_H,
|
309 |
+
),
|
310 |
+
sigma=(Random_Elastic_Transform_Sigma),
|
311 |
+
p=Random_Elastic_Transform_P,
|
312 |
+
),
|
313 |
+
K.RandomAffine(
|
314 |
+
degrees=Random_Affine_Degrees,
|
315 |
+
translate=Random_Affine_Translate,
|
316 |
+
p=Random_Affine_P,
|
317 |
+
padding_mode="border",
|
318 |
+
),
|
319 |
+
K.RandomPerspective(Random_Perspective, p=Random_Perspective_P),
|
320 |
+
K.ColorJitter(
|
321 |
+
hue=Color_Jitter_Hue,
|
322 |
+
saturation=Color_Jitter_Saturation,
|
323 |
+
p=Color_Jitter_P,
|
324 |
+
),
|
325 |
+
)
|
326 |
+
# K.RandomErasing((0.1, 0.7), (0.3, 1/0.4), same_on_batch=True, p=0.2),)
|
327 |
+
|
328 |
+
def set_cut_pow(self, cut_pow):
|
329 |
+
self.cut_pow = cut_pow
|
330 |
+
|
331 |
+
def forward(self, input):
|
332 |
+
sideY, sideX = input.shape[2:4]
|
333 |
+
max_size = min(sideX, sideY)
|
334 |
+
min_size = min(sideX, sideY, self.cut_size)
|
335 |
+
cutouts = []
|
336 |
+
cutouts_full = []
|
337 |
+
noise_fac = 0.1
|
338 |
+
|
339 |
+
min_size_width = min(sideX, sideY)
|
340 |
+
lower_bound = float(self.cut_size / min_size_width)
|
341 |
+
|
342 |
+
for ii in range(self.cutn):
|
343 |
+
|
344 |
+
# size = int(torch.rand([])**self.cut_pow * (max_size - min_size) + min_size)
|
345 |
+
randsize = (
|
346 |
+
torch.zeros(
|
347 |
+
1,
|
348 |
+
)
|
349 |
+
.normal_(mean=0.8, std=0.3)
|
350 |
+
.clip(lower_bound, 1.0)
|
351 |
+
)
|
352 |
+
size_mult = randsize**self.cut_pow
|
353 |
+
size = int(
|
354 |
+
min_size_width * (size_mult.clip(lower_bound, 1.0))
|
355 |
+
) # replace .5 with a result for 224 the default large size is .95
|
356 |
+
# size = int(min_size_width*torch.zeros(1,).normal_(mean=.9, std=.3).clip(lower_bound, .95)) # replace .5 with a result for 224 the default large size is .95
|
357 |
+
|
358 |
+
offsetx = torch.randint(0, sideX - size + 1, ())
|
359 |
+
offsety = torch.randint(0, sideY - size + 1, ())
|
360 |
+
cutout = input[:, :, offsety : offsety + size, offsetx : offsetx + size]
|
361 |
+
cutouts.append(resample(cutout, (self.cut_size, self.cut_size)))
|
362 |
+
|
363 |
+
cutouts = torch.cat(cutouts, dim=0)
|
364 |
+
cutouts = clamp_with_grad(cutouts, 0, 1)
|
365 |
+
|
366 |
+
# if args.use_augs:
|
367 |
+
cutouts = self.augs(cutouts)
|
368 |
+
if self.noise_fac:
|
369 |
+
facs = cutouts.new_empty([cutouts.shape[0], 1, 1, 1]).uniform_(
|
370 |
+
0, self.noise_fac
|
371 |
+
)
|
372 |
+
cutouts = cutouts + facs * torch.randn_like(cutouts)
|
373 |
+
return cutouts
|
374 |
+
|
375 |
+
class MakeCutoutsJuu(nn.Module):
|
376 |
+
def __init__(self, cut_size, cutn, cut_pow, augs):
|
377 |
+
super().__init__()
|
378 |
+
self.cut_size = cut_size
|
379 |
+
self.cutn = cutn
|
380 |
+
self.cut_pow = cut_pow
|
381 |
+
self.augs = nn.Sequential(
|
382 |
+
# K.RandomGaussianNoise(mean=0.0, std=0.5, p=0.1),
|
383 |
+
K.RandomHorizontalFlip(p=0.5),
|
384 |
+
K.RandomSharpness(0.3, p=0.4),
|
385 |
+
K.RandomAffine(degrees=30, translate=0.1, p=0.8, padding_mode="border"),
|
386 |
+
K.RandomPerspective(0.2, p=0.4),
|
387 |
+
K.ColorJitter(hue=0.01, saturation=0.01, p=0.7),
|
388 |
+
K.RandomGrayscale(p=0.1),
|
389 |
+
)
|
390 |
+
self.noise_fac = 0.1
|
391 |
+
|
392 |
+
def forward(self, input):
|
393 |
+
sideY, sideX = input.shape[2:4]
|
394 |
+
max_size = min(sideX, sideY)
|
395 |
+
min_size = min(sideX, sideY, self.cut_size)
|
396 |
+
cutouts = []
|
397 |
+
for _ in range(self.cutn):
|
398 |
+
size = int(
|
399 |
+
torch.rand([]) ** self.cut_pow * (max_size - min_size) + min_size
|
400 |
+
)
|
401 |
+
offsetx = torch.randint(0, sideX - size + 1, ())
|
402 |
+
offsety = torch.randint(0, sideY - size + 1, ())
|
403 |
+
cutout = input[:, :, offsety : offsety + size, offsetx : offsetx + size]
|
404 |
+
cutouts.append(resample(cutout, (self.cut_size, self.cut_size)))
|
405 |
+
batch = self.augs(torch.cat(cutouts, dim=0))
|
406 |
+
if self.noise_fac:
|
407 |
+
facs = batch.new_empty([self.cutn, 1, 1, 1]).uniform_(0, self.noise_fac)
|
408 |
+
batch = batch + facs * torch.randn_like(batch)
|
409 |
+
return batch
|
410 |
+
|
411 |
+
class MakeCutoutsMoth(nn.Module):
|
412 |
+
def __init__(self, cut_size, cutn, cut_pow, augs, skip_augs=False):
|
413 |
+
super().__init__()
|
414 |
+
self.cut_size = cut_size
|
415 |
+
self.cutn = cutn
|
416 |
+
self.cut_pow = cut_pow
|
417 |
+
self.skip_augs = skip_augs
|
418 |
+
self.augs = T.Compose(
|
419 |
+
[
|
420 |
+
T.RandomHorizontalFlip(p=0.5),
|
421 |
+
T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),
|
422 |
+
T.RandomAffine(degrees=15, translate=(0.1, 0.1)),
|
423 |
+
T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),
|
424 |
+
T.RandomPerspective(distortion_scale=0.4, p=0.7),
|
425 |
+
T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),
|
426 |
+
T.RandomGrayscale(p=0.15),
|
427 |
+
T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),
|
428 |
+
# T.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1),
|
429 |
+
]
|
430 |
+
)
|
431 |
+
|
432 |
+
def forward(self, input):
|
433 |
+
input = T.Pad(input.shape[2] // 4, fill=0)(input)
|
434 |
+
sideY, sideX = input.shape[2:4]
|
435 |
+
max_size = min(sideX, sideY)
|
436 |
+
|
437 |
+
cutouts = []
|
438 |
+
for ch in range(cutn):
|
439 |
+
if ch > cutn - cutn // 4:
|
440 |
+
cutout = input.clone()
|
441 |
+
else:
|
442 |
+
size = int(
|
443 |
+
max_size
|
444 |
+
* torch.zeros(
|
445 |
+
1,
|
446 |
+
)
|
447 |
+
.normal_(mean=0.8, std=0.3)
|
448 |
+
.clip(float(self.cut_size / max_size), 1.0)
|
449 |
+
)
|
450 |
+
offsetx = torch.randint(0, abs(sideX - size + 1), ())
|
451 |
+
offsety = torch.randint(0, abs(sideY - size + 1), ())
|
452 |
+
cutout = input[
|
453 |
+
:, :, offsety : offsety + size, offsetx : offsetx + size
|
454 |
+
]
|
455 |
+
|
456 |
+
if not self.skip_augs:
|
457 |
+
cutout = self.augs(cutout)
|
458 |
+
cutouts.append(resample(cutout, (self.cut_size, self.cut_size)))
|
459 |
+
del cutout
|
460 |
+
|
461 |
+
cutouts = torch.cat(cutouts, dim=0)
|
462 |
+
return cutouts
|
463 |
+
|
464 |
+
class MakeCutoutsAaron(nn.Module):
|
465 |
+
def __init__(self, cut_size, cutn, cut_pow, augs):
|
466 |
+
super().__init__()
|
467 |
+
self.cut_size = cut_size
|
468 |
+
self.cutn = cutn
|
469 |
+
self.cut_pow = cut_pow
|
470 |
+
self.augs = augs
|
471 |
+
self.av_pool = nn.AdaptiveAvgPool2d((self.cut_size, self.cut_size))
|
472 |
+
self.max_pool = nn.AdaptiveMaxPool2d((self.cut_size, self.cut_size))
|
473 |
+
|
474 |
+
def set_cut_pow(self, cut_pow):
|
475 |
+
self.cut_pow = cut_pow
|
476 |
+
|
477 |
+
def forward(self, input):
|
478 |
+
sideY, sideX = input.shape[2:4]
|
479 |
+
max_size = min(sideX, sideY)
|
480 |
+
min_size = min(sideX, sideY, self.cut_size)
|
481 |
+
cutouts = []
|
482 |
+
cutouts_full = []
|
483 |
+
|
484 |
+
min_size_width = min(sideX, sideY)
|
485 |
+
lower_bound = float(self.cut_size / min_size_width)
|
486 |
+
|
487 |
+
for ii in range(self.cutn):
|
488 |
+
size = int(
|
489 |
+
min_size_width
|
490 |
+
* torch.zeros(
|
491 |
+
1,
|
492 |
+
)
|
493 |
+
.normal_(mean=0.8, std=0.3)
|
494 |
+
.clip(lower_bound, 1.0)
|
495 |
+
) # replace .5 with a result for 224 the default large size is .95
|
496 |
+
|
497 |
+
offsetx = torch.randint(0, sideX - size + 1, ())
|
498 |
+
offsety = torch.randint(0, sideY - size + 1, ())
|
499 |
+
cutout = input[:, :, offsety : offsety + size, offsetx : offsetx + size]
|
500 |
+
cutouts.append(resample(cutout, (self.cut_size, self.cut_size)))
|
501 |
+
|
502 |
+
cutouts = torch.cat(cutouts, dim=0)
|
503 |
+
|
504 |
+
return clamp_with_grad(cutouts, 0, 1)
|
505 |
+
|
506 |
+
class MakeCutoutsCumin(nn.Module):
|
507 |
+
# from https://colab.research.google.com/drive/1ZAus_gn2RhTZWzOWUpPERNC0Q8OhZRTZ
|
508 |
+
def __init__(self, cut_size, cutn, cut_pow, augs):
|
509 |
+
super().__init__()
|
510 |
+
self.cut_size = cut_size
|
511 |
+
# tqdm.write(f"cut size: {self.cut_size}")
|
512 |
+
self.cutn = cutn
|
513 |
+
self.cut_pow = cut_pow
|
514 |
+
self.noise_fac = 0.1
|
515 |
+
self.av_pool = nn.AdaptiveAvgPool2d((self.cut_size, self.cut_size))
|
516 |
+
self.max_pool = nn.AdaptiveMaxPool2d((self.cut_size, self.cut_size))
|
517 |
+
self.augs = nn.Sequential(
|
518 |
+
# K.RandomHorizontalFlip(p=0.5),
|
519 |
+
# K.RandomSharpness(0.3,p=0.4),
|
520 |
+
# K.RandomGaussianBlur((3,3),(10.5,10.5),p=0.2),
|
521 |
+
# K.RandomGaussianNoise(p=0.5),
|
522 |
+
# K.RandomElasticTransform(kernel_size=(33, 33), sigma=(7,7), p=0.2),
|
523 |
+
K.RandomAffine(degrees=15, translate=0.1, p=0.7, padding_mode="border"),
|
524 |
+
K.RandomPerspective(0.7, p=0.7),
|
525 |
+
K.ColorJitter(hue=0.1, saturation=0.1, p=0.7),
|
526 |
+
K.RandomErasing((0.1, 0.4), (0.3, 1 / 0.3), same_on_batch=True, p=0.7),
|
527 |
+
)
|
528 |
+
|
529 |
+
def set_cut_pow(self, cut_pow):
|
530 |
+
self.cut_pow = cut_pow
|
531 |
+
|
532 |
+
def forward(self, input):
|
533 |
+
sideY, sideX = input.shape[2:4]
|
534 |
+
max_size = min(sideX, sideY)
|
535 |
+
min_size = min(sideX, sideY, self.cut_size)
|
536 |
+
cutouts = []
|
537 |
+
cutouts_full = []
|
538 |
+
noise_fac = 0.1
|
539 |
+
|
540 |
+
min_size_width = min(sideX, sideY)
|
541 |
+
lower_bound = float(self.cut_size / min_size_width)
|
542 |
+
|
543 |
+
for ii in range(self.cutn):
|
544 |
+
|
545 |
+
# size = int(torch.rand([])**self.cut_pow * (max_size - min_size) + min_size)
|
546 |
+
randsize = (
|
547 |
+
torch.zeros(
|
548 |
+
1,
|
549 |
+
)
|
550 |
+
.normal_(mean=0.8, std=0.3)
|
551 |
+
.clip(lower_bound, 1.0)
|
552 |
+
)
|
553 |
+
size_mult = randsize**self.cut_pow
|
554 |
+
size = int(
|
555 |
+
min_size_width * (size_mult.clip(lower_bound, 1.0))
|
556 |
+
) # replace .5 with a result for 224 the default large size is .95
|
557 |
+
# size = int(min_size_width*torch.zeros(1,).normal_(mean=.9, std=.3).clip(lower_bound, .95)) # replace .5 with a result for 224 the default large size is .95
|
558 |
+
|
559 |
+
offsetx = torch.randint(0, sideX - size + 1, ())
|
560 |
+
offsety = torch.randint(0, sideY - size + 1, ())
|
561 |
+
cutout = input[:, :, offsety : offsety + size, offsetx : offsetx + size]
|
562 |
+
cutouts.append(resample(cutout, (self.cut_size, self.cut_size)))
|
563 |
+
|
564 |
+
cutouts = torch.cat(cutouts, dim=0)
|
565 |
+
cutouts = clamp_with_grad(cutouts, 0, 1)
|
566 |
+
|
567 |
+
# if args.use_augs:
|
568 |
+
cutouts = self.augs(cutouts)
|
569 |
+
if self.noise_fac:
|
570 |
+
facs = cutouts.new_empty([cutouts.shape[0], 1, 1, 1]).uniform_(
|
571 |
+
0, self.noise_fac
|
572 |
+
)
|
573 |
+
cutouts = cutouts + facs * torch.randn_like(cutouts)
|
574 |
+
return cutouts
|
575 |
+
|
576 |
+
class MakeCutoutsHolywater(nn.Module):
|
577 |
+
def __init__(self, cut_size, cutn, cut_pow, augs):
|
578 |
+
super().__init__()
|
579 |
+
self.cut_size = cut_size
|
580 |
+
# tqdm.write(f"cut size: {self.cut_size}")
|
581 |
+
self.cutn = cutn
|
582 |
+
self.cut_pow = cut_pow
|
583 |
+
self.noise_fac = 0.1
|
584 |
+
self.av_pool = nn.AdaptiveAvgPool2d((self.cut_size, self.cut_size))
|
585 |
+
self.max_pool = nn.AdaptiveMaxPool2d((self.cut_size, self.cut_size))
|
586 |
+
self.augs = nn.Sequential(
|
587 |
+
# K.RandomGaussianNoise(mean=0.0, std=0.5, p=0.1),
|
588 |
+
K.RandomHorizontalFlip(p=0.5),
|
589 |
+
K.RandomSharpness(0.3, p=0.4),
|
590 |
+
K.RandomAffine(degrees=30, translate=0.1, p=0.8, padding_mode="border"),
|
591 |
+
K.RandomPerspective(0.2, p=0.4),
|
592 |
+
K.ColorJitter(hue=0.01, saturation=0.01, p=0.7),
|
593 |
+
K.RandomGrayscale(p=0.1),
|
594 |
+
)
|
595 |
+
|
596 |
+
def set_cut_pow(self, cut_pow):
|
597 |
+
self.cut_pow = cut_pow
|
598 |
+
|
599 |
+
def forward(self, input):
|
600 |
+
sideY, sideX = input.shape[2:4]
|
601 |
+
max_size = min(sideX, sideY)
|
602 |
+
min_size = min(sideX, sideY, self.cut_size)
|
603 |
+
cutouts = []
|
604 |
+
cutouts_full = []
|
605 |
+
noise_fac = 0.1
|
606 |
+
min_size_width = min(sideX, sideY)
|
607 |
+
lower_bound = float(self.cut_size / min_size_width)
|
608 |
+
|
609 |
+
for ii in range(self.cutn):
|
610 |
+
size = int(
|
611 |
+
torch.rand([]) ** self.cut_pow * (max_size - min_size) + min_size
|
612 |
+
)
|
613 |
+
randsize = (
|
614 |
+
torch.zeros(
|
615 |
+
1,
|
616 |
+
)
|
617 |
+
.normal_(mean=0.8, std=0.3)
|
618 |
+
.clip(lower_bound, 1.0)
|
619 |
+
)
|
620 |
+
size_mult = randsize**self.cut_pow * ii + size
|
621 |
+
size1 = int(
|
622 |
+
(min_size_width) * (size_mult.clip(lower_bound, 1.0))
|
623 |
+
) # replace .5 with a result for 224 the default large size is .95
|
624 |
+
size2 = int(
|
625 |
+
(min_size_width)
|
626 |
+
* torch.zeros(
|
627 |
+
1,
|
628 |
+
)
|
629 |
+
.normal_(mean=0.9, std=0.3)
|
630 |
+
.clip(lower_bound, 0.95)
|
631 |
+
) # replace .5 with a result for 224 the default large size is .95
|
632 |
+
offsetx = torch.randint(0, sideX - size1 + 1, ())
|
633 |
+
offsety = torch.randint(0, sideY - size2 + 1, ())
|
634 |
+
cutout = input[
|
635 |
+
:, :, offsety : offsety + size2 + ii, offsetx : offsetx + size1 + ii
|
636 |
+
]
|
637 |
+
cutouts.append(resample(cutout, (self.cut_size, self.cut_size)))
|
638 |
+
|
639 |
+
cutouts = torch.cat(cutouts, dim=0)
|
640 |
+
cutouts = clamp_with_grad(cutouts, 0, 1)
|
641 |
+
cutouts = self.augs(cutouts)
|
642 |
+
facs = cutouts.new_empty([cutouts.shape[0], 1, 1, 1]).uniform_(
|
643 |
+
0, self.noise_fac
|
644 |
+
)
|
645 |
+
cutouts = cutouts + facs * torch.randn_like(cutouts)
|
646 |
+
return cutouts
|
647 |
+
|
648 |
+
class MakeCutoutsOldHolywater(nn.Module):
|
649 |
+
def __init__(self, cut_size, cutn, cut_pow, augs):
|
650 |
+
super().__init__()
|
651 |
+
self.cut_size = cut_size
|
652 |
+
# tqdm.write(f"cut size: {self.cut_size}")
|
653 |
+
self.cutn = cutn
|
654 |
+
self.cut_pow = cut_pow
|
655 |
+
self.noise_fac = 0.1
|
656 |
+
self.av_pool = nn.AdaptiveAvgPool2d((self.cut_size, self.cut_size))
|
657 |
+
self.max_pool = nn.AdaptiveMaxPool2d((self.cut_size, self.cut_size))
|
658 |
+
self.augs = nn.Sequential(
|
659 |
+
# K.RandomHorizontalFlip(p=0.5),
|
660 |
+
# K.RandomSharpness(0.3,p=0.4),
|
661 |
+
# K.RandomGaussianBlur((3,3),(10.5,10.5),p=0.2),
|
662 |
+
# K.RandomGaussianNoise(p=0.5),
|
663 |
+
# K.RandomElasticTransform(kernel_size=(33, 33), sigma=(7,7), p=0.2),
|
664 |
+
K.RandomAffine(
|
665 |
+
degrees=180, translate=0.5, p=0.2, padding_mode="border"
|
666 |
+
),
|
667 |
+
K.RandomPerspective(0.6, p=0.9),
|
668 |
+
K.ColorJitter(hue=0.03, saturation=0.01, p=0.1),
|
669 |
+
K.RandomErasing((0.1, 0.7), (0.3, 1 / 0.4), same_on_batch=True, p=0.2),
|
670 |
+
)
|
671 |
+
|
672 |
+
def set_cut_pow(self, cut_pow):
|
673 |
+
self.cut_pow = cut_pow
|
674 |
+
|
675 |
+
def forward(self, input):
|
676 |
+
sideY, sideX = input.shape[2:4]
|
677 |
+
max_size = min(sideX, sideY)
|
678 |
+
min_size = min(sideX, sideY, self.cut_size)
|
679 |
+
cutouts = []
|
680 |
+
cutouts_full = []
|
681 |
+
noise_fac = 0.1
|
682 |
+
|
683 |
+
min_size_width = min(sideX, sideY)
|
684 |
+
lower_bound = float(self.cut_size / min_size_width)
|
685 |
+
|
686 |
+
for ii in range(self.cutn):
|
687 |
+
|
688 |
+
# size = int(torch.rand([])**self.cut_pow * (max_size - min_size) + min_size)
|
689 |
+
randsize = (
|
690 |
+
torch.zeros(
|
691 |
+
1,
|
692 |
+
)
|
693 |
+
.normal_(mean=0.8, std=0.3)
|
694 |
+
.clip(lower_bound, 1.0)
|
695 |
+
)
|
696 |
+
size_mult = randsize**self.cut_pow
|
697 |
+
size = int(
|
698 |
+
min_size_width * (size_mult.clip(lower_bound, 1.0))
|
699 |
+
) # replace .5 with a result for 224 the default large size is .95
|
700 |
+
# size = int(min_size_width*torch.zeros(1,).normal_(mean=.9, std=.3).clip(lower_bound, .95)) # replace .5 with a result for 224 the default large size is .95
|
701 |
+
|
702 |
+
offsetx = torch.randint(0, sideX - size + 1, ())
|
703 |
+
offsety = torch.randint(0, sideY - size + 1, ())
|
704 |
+
cutout = input[:, :, offsety : offsety + size, offsetx : offsetx + size]
|
705 |
+
cutouts.append(resample(cutout, (self.cut_size, self.cut_size)))
|
706 |
+
|
707 |
+
cutouts = torch.cat(cutouts, dim=0)
|
708 |
+
cutouts = clamp_with_grad(cutouts, 0, 1)
|
709 |
+
|
710 |
+
# if args.use_augs:
|
711 |
+
cutouts = self.augs(cutouts)
|
712 |
+
if self.noise_fac:
|
713 |
+
facs = cutouts.new_empty([cutouts.shape[0], 1, 1, 1]).uniform_(
|
714 |
+
0, self.noise_fac
|
715 |
+
)
|
716 |
+
cutouts = cutouts + facs * torch.randn_like(cutouts)
|
717 |
+
return cutouts
|
718 |
+
|
719 |
+
class MakeCutoutsGinger(nn.Module):
|
720 |
+
def __init__(self, cut_size, cutn, cut_pow, augs):
|
721 |
+
super().__init__()
|
722 |
+
self.cut_size = cut_size
|
723 |
+
# tqdm.write(f"cut size: {self.cut_size}")
|
724 |
+
self.cutn = cutn
|
725 |
+
self.cut_pow = cut_pow
|
726 |
+
self.noise_fac = 0.1
|
727 |
+
self.av_pool = nn.AdaptiveAvgPool2d((self.cut_size, self.cut_size))
|
728 |
+
self.max_pool = nn.AdaptiveMaxPool2d((self.cut_size, self.cut_size))
|
729 |
+
self.augs = augs
|
730 |
+
"""
|
731 |
+
nn.Sequential(
|
732 |
+
K.RandomHorizontalFlip(p=0.5),
|
733 |
+
K.RandomSharpness(0.3,p=0.4),
|
734 |
+
K.RandomGaussianBlur((3,3),(10.5,10.5),p=0.2),
|
735 |
+
K.RandomGaussianNoise(p=0.5),
|
736 |
+
K.RandomElasticTransform(kernel_size=(33, 33), sigma=(7,7), p=0.2),
|
737 |
+
K.RandomAffine(degrees=30, translate=0.1, p=0.8, padding_mode='border'), # padding_mode=2
|
738 |
+
K.RandomPerspective(0.2,p=0.4, ),
|
739 |
+
K.ColorJitter(hue=0.01, saturation=0.01, p=0.7),)
|
740 |
+
"""
|
741 |
+
|
742 |
+
def set_cut_pow(self, cut_pow):
|
743 |
+
self.cut_pow = cut_pow
|
744 |
+
|
745 |
+
def forward(self, input):
|
746 |
+
sideY, sideX = input.shape[2:4]
|
747 |
+
max_size = min(sideX, sideY)
|
748 |
+
min_size = min(sideX, sideY, self.cut_size)
|
749 |
+
cutouts = []
|
750 |
+
cutouts_full = []
|
751 |
+
noise_fac = 0.1
|
752 |
+
|
753 |
+
min_size_width = min(sideX, sideY)
|
754 |
+
lower_bound = float(self.cut_size / min_size_width)
|
755 |
+
|
756 |
+
for ii in range(self.cutn):
|
757 |
+
|
758 |
+
# size = int(torch.rand([])**self.cut_pow * (max_size - min_size) + min_size)
|
759 |
+
randsize = (
|
760 |
+
torch.zeros(
|
761 |
+
1,
|
762 |
+
)
|
763 |
+
.normal_(mean=0.8, std=0.3)
|
764 |
+
.clip(lower_bound, 1.0)
|
765 |
+
)
|
766 |
+
size_mult = randsize**self.cut_pow
|
767 |
+
size = int(
|
768 |
+
min_size_width * (size_mult.clip(lower_bound, 1.0))
|
769 |
+
) # replace .5 with a result for 224 the default large size is .95
|
770 |
+
# size = int(min_size_width*torch.zeros(1,).normal_(mean=.9, std=.3).clip(lower_bound, .95)) # replace .5 with a result for 224 the default large size is .95
|
771 |
+
|
772 |
+
offsetx = torch.randint(0, sideX - size + 1, ())
|
773 |
+
offsety = torch.randint(0, sideY - size + 1, ())
|
774 |
+
cutout = input[:, :, offsety : offsety + size, offsetx : offsetx + size]
|
775 |
+
cutouts.append(resample(cutout, (self.cut_size, self.cut_size)))
|
776 |
+
|
777 |
+
cutouts = torch.cat(cutouts, dim=0)
|
778 |
+
cutouts = clamp_with_grad(cutouts, 0, 1)
|
779 |
+
|
780 |
+
# if args.use_augs:
|
781 |
+
cutouts = self.augs(cutouts)
|
782 |
+
if self.noise_fac:
|
783 |
+
facs = cutouts.new_empty([cutouts.shape[0], 1, 1, 1]).uniform_(
|
784 |
+
0, self.noise_fac
|
785 |
+
)
|
786 |
+
cutouts = cutouts + facs * torch.randn_like(cutouts)
|
787 |
+
return cutouts
|
788 |
+
|
789 |
+
class MakeCutoutsZynth(nn.Module):
|
790 |
+
def __init__(self, cut_size, cutn, cut_pow, augs):
|
791 |
+
super().__init__()
|
792 |
+
self.cut_size = cut_size
|
793 |
+
# tqdm.write(f"cut size: {self.cut_size}")
|
794 |
+
self.cutn = cutn
|
795 |
+
self.cut_pow = cut_pow
|
796 |
+
self.noise_fac = 0.1
|
797 |
+
self.av_pool = nn.AdaptiveAvgPool2d((self.cut_size, self.cut_size))
|
798 |
+
self.max_pool = nn.AdaptiveMaxPool2d((self.cut_size, self.cut_size))
|
799 |
+
self.augs = nn.Sequential(
|
800 |
+
K.RandomHorizontalFlip(p=0.5),
|
801 |
+
# K.RandomSolarize(0.01, 0.01, p=0.7),
|
802 |
+
K.RandomSharpness(0.3, p=0.4),
|
803 |
+
K.RandomAffine(degrees=30, translate=0.1, p=0.8, padding_mode="border"),
|
804 |
+
K.RandomPerspective(0.2, p=0.4),
|
805 |
+
K.ColorJitter(hue=0.01, saturation=0.01, p=0.7),
|
806 |
+
)
|
807 |
+
|
808 |
+
def set_cut_pow(self, cut_pow):
|
809 |
+
self.cut_pow = cut_pow
|
810 |
+
|
811 |
+
def forward(self, input):
|
812 |
+
sideY, sideX = input.shape[2:4]
|
813 |
+
max_size = min(sideX, sideY)
|
814 |
+
min_size = min(sideX, sideY, self.cut_size)
|
815 |
+
cutouts = []
|
816 |
+
cutouts_full = []
|
817 |
+
noise_fac = 0.1
|
818 |
+
|
819 |
+
min_size_width = min(sideX, sideY)
|
820 |
+
lower_bound = float(self.cut_size / min_size_width)
|
821 |
+
|
822 |
+
for ii in range(self.cutn):
|
823 |
+
|
824 |
+
# size = int(torch.rand([])**self.cut_pow * (max_size - min_size) + min_size)
|
825 |
+
randsize = (
|
826 |
+
torch.zeros(
|
827 |
+
1,
|
828 |
+
)
|
829 |
+
.normal_(mean=0.8, std=0.3)
|
830 |
+
.clip(lower_bound, 1.0)
|
831 |
+
)
|
832 |
+
size_mult = randsize**self.cut_pow
|
833 |
+
size = int(
|
834 |
+
min_size_width * (size_mult.clip(lower_bound, 1.0))
|
835 |
+
) # replace .5 with a result for 224 the default large size is .95
|
836 |
+
# size = int(min_size_width*torch.zeros(1,).normal_(mean=.9, std=.3).clip(lower_bound, .95)) # replace .5 with a result for 224 the default large size is .95
|
837 |
+
|
838 |
+
offsetx = torch.randint(0, sideX - size + 1, ())
|
839 |
+
offsety = torch.randint(0, sideY - size + 1, ())
|
840 |
+
cutout = input[:, :, offsety : offsety + size, offsetx : offsetx + size]
|
841 |
+
cutouts.append(resample(cutout, (self.cut_size, self.cut_size)))
|
842 |
+
|
843 |
+
cutouts = torch.cat(cutouts, dim=0)
|
844 |
+
cutouts = clamp_with_grad(cutouts, 0, 1)
|
845 |
+
|
846 |
+
# if args.use_augs:
|
847 |
+
cutouts = self.augs(cutouts)
|
848 |
+
if self.noise_fac:
|
849 |
+
facs = cutouts.new_empty([cutouts.shape[0], 1, 1, 1]).uniform_(
|
850 |
+
0, self.noise_fac
|
851 |
+
)
|
852 |
+
cutouts = cutouts + facs * torch.randn_like(cutouts)
|
853 |
+
return cutouts
|
854 |
+
|
855 |
+
class MakeCutoutsWyvern(nn.Module):
|
856 |
+
def __init__(self, cut_size, cutn, cut_pow, augs):
|
857 |
+
super().__init__()
|
858 |
+
self.cut_size = cut_size
|
859 |
+
# tqdm.write(f"cut size: {self.cut_size}")
|
860 |
+
self.cutn = cutn
|
861 |
+
self.cut_pow = cut_pow
|
862 |
+
self.noise_fac = 0.1
|
863 |
+
self.av_pool = nn.AdaptiveAvgPool2d((self.cut_size, self.cut_size))
|
864 |
+
self.max_pool = nn.AdaptiveMaxPool2d((self.cut_size, self.cut_size))
|
865 |
+
self.augs = augs
|
866 |
+
|
867 |
+
def forward(self, input):
|
868 |
+
sideY, sideX = input.shape[2:4]
|
869 |
+
max_size = min(sideX, sideY)
|
870 |
+
min_size = min(sideX, sideY, self.cut_size)
|
871 |
+
cutouts = []
|
872 |
+
for _ in range(self.cutn):
|
873 |
+
size = int(
|
874 |
+
torch.rand([]) ** self.cut_pow * (max_size - min_size) + min_size
|
875 |
+
)
|
876 |
+
offsetx = torch.randint(0, sideX - size + 1, ())
|
877 |
+
offsety = torch.randint(0, sideY - size + 1, ())
|
878 |
+
cutout = input[:, :, offsety : offsety + size, offsetx : offsetx + size]
|
879 |
+
cutouts.append(resample(cutout, (self.cut_size, self.cut_size)))
|
880 |
+
return clamp_with_grad(torch.cat(cutouts, dim=0), 0, 1)
|
881 |
+
|
882 |
+
|
883 |
+
import PIL
|
884 |
+
|
885 |
+
def resize_image(image, out_size):
|
886 |
+
ratio = image.size[0] / image.size[1]
|
887 |
+
area = min(image.size[0] * image.size[1], out_size[0] * out_size[1])
|
888 |
+
size = round((area * ratio) ** 0.5), round((area / ratio) ** 0.5)
|
889 |
+
return image.resize(size, PIL.Image.LANCZOS)
|
890 |
+
|
891 |
+
class GaussianBlur2d(nn.Module):
|
892 |
+
def __init__(self, sigma, window=0, mode="reflect", value=0):
|
893 |
+
super().__init__()
|
894 |
+
self.mode = mode
|
895 |
+
self.value = value
|
896 |
+
if not window:
|
897 |
+
window = max(math.ceil((sigma * 6 + 1) / 2) * 2 - 1, 3)
|
898 |
+
if sigma:
|
899 |
+
kernel = torch.exp(
|
900 |
+
-((torch.arange(window) - window // 2) ** 2) / 2 / sigma**2
|
901 |
+
)
|
902 |
+
kernel /= kernel.sum()
|
903 |
+
else:
|
904 |
+
kernel = torch.ones([1])
|
905 |
+
self.register_buffer("kernel", kernel)
|
906 |
+
|
907 |
+
def forward(self, input):
|
908 |
+
n, c, h, w = input.shape
|
909 |
+
input = input.view([n * c, 1, h, w])
|
910 |
+
start_pad = (self.kernel.shape[0] - 1) // 2
|
911 |
+
end_pad = self.kernel.shape[0] // 2
|
912 |
+
input = F.pad(
|
913 |
+
input, (start_pad, end_pad, start_pad, end_pad), self.mode, self.value
|
914 |
+
)
|
915 |
+
input = F.conv2d(input, self.kernel[None, None, None, :])
|
916 |
+
input = F.conv2d(input, self.kernel[None, None, :, None])
|
917 |
+
return input.view([n, c, h, w])
|
918 |
+
|
919 |
+
BUF_SIZE = 65536
|
920 |
+
|
921 |
+
def get_digest(path, alg=hashlib.sha256):
|
922 |
+
hash = alg()
|
923 |
+
# print(path)
|
924 |
+
with open(path, "rb") as fp:
|
925 |
+
while True:
|
926 |
+
data = fp.read(BUF_SIZE)
|
927 |
+
if not data:
|
928 |
+
break
|
929 |
+
hash.update(data)
|
930 |
+
return b64encode(hash.digest()).decode("utf-8")
|
931 |
+
|
932 |
+
flavordict = {
|
933 |
+
"cumin": MakeCutoutsCumin,
|
934 |
+
"holywater": MakeCutoutsHolywater,
|
935 |
+
"old_holywater": MakeCutoutsOldHolywater,
|
936 |
+
"ginger": MakeCutoutsGinger,
|
937 |
+
"zynth": MakeCutoutsZynth,
|
938 |
+
"wyvern": MakeCutoutsWyvern,
|
939 |
+
"aaron": MakeCutoutsAaron,
|
940 |
+
"moth": MakeCutoutsMoth,
|
941 |
+
"juu": MakeCutoutsJuu,
|
942 |
+
"custom": MakeCutoutsCustom,
|
943 |
+
}
|
944 |
+
|
945 |
+
@torch.jit.script
|
946 |
+
def gelu_impl(x):
|
947 |
+
"""OpenAI's gelu implementation."""
|
948 |
+
return (
|
949 |
+
0.5
|
950 |
+
* x
|
951 |
+
* (1.0 + torch.tanh(0.7978845608028654 * x * (1.0 + 0.044715 * x * x)))
|
952 |
+
)
|
953 |
+
|
954 |
+
def gelu(x):
|
955 |
+
return gelu_impl(x)
|
956 |
+
|
957 |
+
class MSEDecayLoss(nn.Module):
|
958 |
+
def __init__(self, init_weight, mse_decay_rate, mse_epoches, mse_quantize):
|
959 |
+
super().__init__()
|
960 |
+
|
961 |
+
self.init_weight = init_weight
|
962 |
+
self.has_init_image = False
|
963 |
+
self.mse_decay = init_weight / mse_epoches if init_weight else 0
|
964 |
+
self.mse_decay_rate = mse_decay_rate
|
965 |
+
self.mse_weight = init_weight
|
966 |
+
self.mse_epoches = mse_epoches
|
967 |
+
self.mse_quantize = mse_quantize
|
968 |
+
|
969 |
+
@torch.no_grad()
|
970 |
+
def set_target(self, z_tensor, model):
|
971 |
+
z_tensor = z_tensor.detach().clone()
|
972 |
+
if self.mse_quantize:
|
973 |
+
z_tensor = vector_quantize(
|
974 |
+
z_tensor.movedim(1, 3), model.quantize.embedding.weight
|
975 |
+
).movedim(
|
976 |
+
3, 1
|
977 |
+
) # z.average
|
978 |
+
self.z_orig = z_tensor
|
979 |
+
|
980 |
+
def forward(self, i, z):
|
981 |
+
if self.is_active(i):
|
982 |
+
return F.mse_loss(z, self.z_orig) * self.mse_weight / 2
|
983 |
+
return 0
|
984 |
+
|
985 |
+
def is_active(self, i):
|
986 |
+
if not self.init_weight:
|
987 |
+
return False
|
988 |
+
if i <= self.mse_decay_rate and not self.has_init_image:
|
989 |
+
return False
|
990 |
+
return True
|
991 |
+
|
992 |
+
@torch.no_grad()
|
993 |
+
def step(self, i):
|
994 |
+
|
995 |
+
if (
|
996 |
+
i % self.mse_decay_rate == 0
|
997 |
+
and i != 0
|
998 |
+
and i < self.mse_decay_rate * self.mse_epoches
|
999 |
+
):
|
1000 |
+
|
1001 |
+
if (
|
1002 |
+
self.mse_weight - self.mse_decay > 0
|
1003 |
+
and self.mse_weight - self.mse_decay >= self.mse_decay
|
1004 |
+
):
|
1005 |
+
self.mse_weight -= self.mse_decay
|
1006 |
+
else:
|
1007 |
+
self.mse_weight = 0
|
1008 |
+
# print(f"updated mse weight: {self.mse_weight}")
|
1009 |
+
|
1010 |
+
return True
|
1011 |
+
|
1012 |
+
return False
|
1013 |
+
|
1014 |
+
class TVLoss(nn.Module):
|
1015 |
+
def forward(self, input):
|
1016 |
+
input = F.pad(input, (0, 1, 0, 1), "replicate")
|
1017 |
+
x_diff = input[..., :-1, 1:] - input[..., :-1, :-1]
|
1018 |
+
y_diff = input[..., 1:, :-1] - input[..., :-1, :-1]
|
1019 |
+
diff = x_diff**2 + y_diff**2 + 1e-8
|
1020 |
+
return diff.mean(dim=1).sqrt().mean()
|
1021 |
+
|
1022 |
+
class MultiClipLoss(nn.Module):
|
1023 |
+
def __init__(
|
1024 |
+
self, clip_models, text_prompt, cutn, cut_pow=1.0, clip_weight=1.0
|
1025 |
+
):
|
1026 |
+
super().__init__()
|
1027 |
+
|
1028 |
+
# Load Clip
|
1029 |
+
self.perceptors = []
|
1030 |
+
for cm in clip_models:
|
1031 |
+
sys.stdout.write(f"Loading {cm[0]} ...\n")
|
1032 |
+
sys.stdout.flush()
|
1033 |
+
c = (
|
1034 |
+
clip.load(cm[0], jit=False)[0]
|
1035 |
+
.eval()
|
1036 |
+
.requires_grad_(False)
|
1037 |
+
.to(device)
|
1038 |
+
)
|
1039 |
+
self.perceptors.append(
|
1040 |
+
{
|
1041 |
+
"res": c.visual.input_resolution,
|
1042 |
+
"perceptor": c,
|
1043 |
+
"weight": cm[1],
|
1044 |
+
"prompts": [],
|
1045 |
+
}
|
1046 |
+
)
|
1047 |
+
self.perceptors.sort(key=lambda e: e["res"], reverse=True)
|
1048 |
+
|
1049 |
+
# Make Cutouts
|
1050 |
+
self.max_cut_size = self.perceptors[0]["res"]
|
1051 |
+
# self.make_cuts = flavordict[flavor](self.max_cut_size, cutn, cut_pow)
|
1052 |
+
# cutouts = flavordict[flavor](self.max_cut_size, cutn, cut_pow=cut_pow, augs=args.augs)
|
1053 |
+
|
1054 |
+
# Get Prompt Embedings
|
1055 |
+
# texts = [phrase.strip() for phrase in text_prompt.split("|")]
|
1056 |
+
# if text_prompt == ['']:
|
1057 |
+
# texts = []
|
1058 |
+
texts = text_prompt
|
1059 |
+
self.pMs = []
|
1060 |
+
for prompt in texts:
|
1061 |
+
txt, weight, stop = parse_prompt(prompt)
|
1062 |
+
clip_token = clip.tokenize(txt).to(device)
|
1063 |
+
for p in self.perceptors:
|
1064 |
+
embed = p["perceptor"].encode_text(clip_token).float()
|
1065 |
+
embed_normed = F.normalize(embed.unsqueeze(0), dim=2)
|
1066 |
+
p["prompts"].append(
|
1067 |
+
{
|
1068 |
+
"embed_normed": embed_normed,
|
1069 |
+
"weight": torch.as_tensor(weight, device=device),
|
1070 |
+
"stop": torch.as_tensor(stop, device=device),
|
1071 |
+
}
|
1072 |
+
)
|
1073 |
+
|
1074 |
+
# Prep Augments
|
1075 |
+
self.normalize = transforms.Normalize(
|
1076 |
+
mean=[0.48145466, 0.4578275, 0.40821073],
|
1077 |
+
std=[0.26862954, 0.26130258, 0.27577711],
|
1078 |
+
)
|
1079 |
+
|
1080 |
+
self.augs = nn.Sequential(
|
1081 |
+
K.RandomHorizontalFlip(p=0.5),
|
1082 |
+
K.RandomSharpness(0.3, p=0.1),
|
1083 |
+
K.RandomAffine(
|
1084 |
+
degrees=30, translate=0.1, p=0.8, padding_mode="border"
|
1085 |
+
), # padding_mode=2
|
1086 |
+
K.RandomPerspective(
|
1087 |
+
0.2,
|
1088 |
+
p=0.4,
|
1089 |
+
),
|
1090 |
+
K.ColorJitter(hue=0.01, saturation=0.01, p=0.7),
|
1091 |
+
K.RandomGrayscale(p=0.15),
|
1092 |
+
)
|
1093 |
+
self.noise_fac = 0.1
|
1094 |
+
|
1095 |
+
self.clip_weight = clip_weight
|
1096 |
+
|
1097 |
+
def prepare_cuts(self, img):
|
1098 |
+
cutouts = self.make_cuts(img)
|
1099 |
+
cutouts = self.augs(cutouts)
|
1100 |
+
if self.noise_fac:
|
1101 |
+
facs = cutouts.new_empty([cutouts.shape[0], 1, 1, 1]).uniform_(
|
1102 |
+
0, self.noise_fac
|
1103 |
+
)
|
1104 |
+
cutouts = cutouts + facs * torch.randn_like(cutouts)
|
1105 |
+
cutouts = self.normalize(cutouts)
|
1106 |
+
return cutouts
|
1107 |
+
|
1108 |
+
def forward(self, i, img):
|
1109 |
+
cutouts = checkpoint(self.prepare_cuts, img)
|
1110 |
+
loss = []
|
1111 |
+
|
1112 |
+
current_cuts = cutouts
|
1113 |
+
currentres = self.max_cut_size
|
1114 |
+
for p in self.perceptors:
|
1115 |
+
if currentres != p["res"]:
|
1116 |
+
current_cuts = resample(cutouts, (p["res"], p["res"]))
|
1117 |
+
currentres = p["res"]
|
1118 |
+
|
1119 |
+
iii = p["perceptor"].encode_image(current_cuts).float()
|
1120 |
+
input_normed = F.normalize(iii.unsqueeze(1), dim=2)
|
1121 |
+
for prompt in p["prompts"]:
|
1122 |
+
dists = (
|
1123 |
+
input_normed.sub(prompt["embed_normed"])
|
1124 |
+
.norm(dim=2)
|
1125 |
+
.div(2)
|
1126 |
+
.arcsin()
|
1127 |
+
.pow(2)
|
1128 |
+
.mul(2)
|
1129 |
+
)
|
1130 |
+
dists = dists * prompt["weight"].sign()
|
1131 |
+
l = (
|
1132 |
+
prompt["weight"].abs()
|
1133 |
+
* replace_grad(
|
1134 |
+
dists, torch.maximum(dists, prompt["stop"])
|
1135 |
+
).mean()
|
1136 |
+
)
|
1137 |
+
loss.append(l * p["weight"])
|
1138 |
+
|
1139 |
+
return loss
|
1140 |
+
|
1141 |
+
class ModelHost:
|
1142 |
+
def __init__(self, args):
|
1143 |
+
self.args = args
|
1144 |
+
self.model, self.perceptor = None, None
|
1145 |
+
self.make_cutouts = None
|
1146 |
+
self.alt_make_cutouts = None
|
1147 |
+
self.imageSize = None
|
1148 |
+
self.prompts = None
|
1149 |
+
self.opt = None
|
1150 |
+
self.normalize = None
|
1151 |
+
self.z, self.z_orig, self.z_min, self.z_max = None, None, None, None
|
1152 |
+
self.metadata = None
|
1153 |
+
self.mse_weight = 0
|
1154 |
+
self.normal_flip_optim = None
|
1155 |
+
self.usealtprompts = False
|
1156 |
+
|
1157 |
+
def setup_metadata(self, seed):
|
1158 |
+
metadata = {k: v for k, v in vars(self.args).items()}
|
1159 |
+
del metadata["max_iterations"]
|
1160 |
+
del metadata["display_freq"]
|
1161 |
+
metadata["seed"] = seed
|
1162 |
+
if metadata["init_image"]:
|
1163 |
+
path = metadata["init_image"]
|
1164 |
+
digest = get_digest(path)
|
1165 |
+
metadata["init_image"] = (path, digest)
|
1166 |
+
if metadata["image_prompts"]:
|
1167 |
+
prompts = []
|
1168 |
+
for prompt in metadata["image_prompts"]:
|
1169 |
+
path = prompt
|
1170 |
+
digest = get_digest(path)
|
1171 |
+
prompts.append((path, digest))
|
1172 |
+
metadata["image_prompts"] = prompts
|
1173 |
+
self.metadata = metadata
|
1174 |
+
|
1175 |
+
def setup_model(self, x):
|
1176 |
+
i = x
|
1177 |
+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
1178 |
+
|
1179 |
+
#perceptor = (
|
1180 |
+
# clip.load(args.clip_model, jit=False)[0]
|
1181 |
+
# .eval()
|
1182 |
+
# .requires_grad_(False)
|
1183 |
+
# .to(device)
|
1184 |
+
#)
|
1185 |
+
|
1186 |
+
cut_size = perceptor.visual.input_resolution
|
1187 |
+
|
1188 |
+
if self.args.is_gumbel:
|
1189 |
+
e_dim = model.quantize.embedding_dim
|
1190 |
+
else:
|
1191 |
+
e_dim = model.quantize.e_dim
|
1192 |
|
1193 |
+
f = 2 ** (model.decoder.num_resolutions - 1)
|
1194 |
+
|
1195 |
+
make_cutouts = flavordict[flavor](
|
1196 |
+
cut_size, args.mse_cutn, cut_pow=args.mse_cut_pow, augs=args.augs
|
1197 |
+
)
|
1198 |
+
|
1199 |
+
# make_cutouts = MakeCutouts(cut_size, args.mse_cutn, cut_pow=args.mse_cut_pow,augs=args.augs)
|
1200 |
+
if args.altprompts:
|
1201 |
+
self.usealtprompts = True
|
1202 |
+
self.alt_make_cutouts = flavordict[flavor](
|
1203 |
+
cut_size,
|
1204 |
+
args.mse_cutn,
|
1205 |
+
cut_pow=args.alt_mse_cut_pow,
|
1206 |
+
augs=args.altaugs,
|
1207 |
+
)
|
1208 |
+
# self.alt_make_cutouts = MakeCutouts(cut_size, args.mse_cutn, cut_pow=args.alt_mse_cut_pow,augs=args.altaugs)
|
1209 |
+
|
1210 |
+
if self.args.is_gumbel:
|
1211 |
+
n_toks = model.quantize.n_embed
|
1212 |
+
else:
|
1213 |
+
n_toks = model.quantize.n_e
|
1214 |
+
|
1215 |
+
toksX, toksY = args.size[0] // f, args.size[1] // f
|
1216 |
+
sideX, sideY = toksX * f, toksY * f
|
1217 |
+
|
1218 |
+
if self.args.is_gumbel:
|
1219 |
+
z_min = model.quantize.embed.weight.min(dim=0).values[
|
1220 |
+
None, :, None, None
|
1221 |
+
]
|
1222 |
+
z_max = model.quantize.embed.weight.max(dim=0).values[
|
1223 |
+
None, :, None, None
|
1224 |
+
]
|
1225 |
+
else:
|
1226 |
+
z_min = model.quantize.embedding.weight.min(dim=0).values[
|
1227 |
+
None, :, None, None
|
1228 |
+
]
|
1229 |
+
z_max = model.quantize.embedding.weight.max(dim=0).values[
|
1230 |
+
None, :, None, None
|
1231 |
+
]
|
1232 |
+
|
1233 |
+
from PIL import Image
|
1234 |
+
import cv2
|
1235 |
+
|
1236 |
+
# -------
|
1237 |
+
working_dir = self.args.folder_name
|
1238 |
+
|
1239 |
+
if self.args.init_image != "":
|
1240 |
+
img_0 = cv2.imread(init_image)
|
1241 |
+
z, *_ = model.encode(
|
1242 |
+
TF.to_tensor(img_0).to(device).unsqueeze(0) * 2 - 1
|
1243 |
+
)
|
1244 |
+
elif not os.path.isfile(f"{working_dir}/steps/{i:04d}.png"):
|
1245 |
+
one_hot = F.one_hot(
|
1246 |
+
torch.randint(n_toks, [toksY * toksX], device=device), n_toks
|
1247 |
+
).float()
|
1248 |
+
if self.args.is_gumbel:
|
1249 |
+
z = one_hot @ model.quantize.embed.weight
|
1250 |
+
else:
|
1251 |
+
z = one_hot @ model.quantize.embedding.weight
|
1252 |
+
z = z.view([-1, toksY, toksX, e_dim]).permute(0, 3, 1, 2)
|
1253 |
+
else:
|
1254 |
+
center = (1 * img_0.shape[1] // 2, 1 * img_0.shape[0] // 2)
|
1255 |
+
trans_mat = np.float32([[1, 0, 10], [0, 1, 10]])
|
1256 |
+
rot_mat = cv2.getRotationMatrix2D(center, 10, 20)
|
1257 |
+
|
1258 |
+
trans_mat = np.vstack([trans_mat, [0, 0, 1]])
|
1259 |
+
rot_mat = np.vstack([rot_mat, [0, 0, 1]])
|
1260 |
+
transformation_matrix = np.matmul(rot_mat, trans_mat)
|
1261 |
+
|
1262 |
+
img_0 = cv2.warpPerspective(
|
1263 |
+
img_0,
|
1264 |
+
transformation_matrix,
|
1265 |
+
(img_0.shape[1], img_0.shape[0]),
|
1266 |
+
borderMode=cv2.BORDER_WRAP,
|
1267 |
+
)
|
1268 |
+
z, *_ = model.encode(
|
1269 |
+
TF.to_tensor(img_0).to(device).unsqueeze(0) * 2 - 1
|
1270 |
+
)
|
1271 |
+
|
1272 |
+
def save_output(i, img, suffix="zoomed"):
|
1273 |
+
filename = f"{working_dir}/steps/{i:04}{'_' + suffix if suffix else ''}.png"
|
1274 |
+
imageio.imwrite(filename, np.array(img))
|
1275 |
+
|
1276 |
+
save_output(i, img_0)
|
1277 |
+
# -------
|
1278 |
+
if args.init_image:
|
1279 |
+
pil_image = Image.open(args.init_image).convert("RGB")
|
1280 |
+
pil_image = pil_image.resize((sideX, sideY), Image.LANCZOS)
|
1281 |
+
z, *_ = model.encode(
|
1282 |
+
TF.to_tensor(pil_image).to(device).unsqueeze(0) * 2 - 1
|
1283 |
+
)
|
1284 |
+
else:
|
1285 |
+
one_hot = F.one_hot(
|
1286 |
+
torch.randint(n_toks, [toksY * toksX], device=device), n_toks
|
1287 |
+
).float()
|
1288 |
+
if self.args.is_gumbel:
|
1289 |
+
z = one_hot @ model.quantize.embed.weight
|
1290 |
+
else:
|
1291 |
+
z = one_hot @ model.quantize.embedding.weight
|
1292 |
+
z = z.view([-1, toksY, toksX, e_dim]).permute(0, 3, 1, 2)
|
1293 |
+
z = EMATensor(z, args.ema_val)
|
1294 |
+
|
1295 |
+
if args.mse_with_zeros and not args.init_image:
|
1296 |
+
z_orig = torch.zeros_like(z.tensor)
|
1297 |
+
else:
|
1298 |
+
z_orig = z.tensor.clone()
|
1299 |
+
z.requires_grad_(True)
|
1300 |
+
# opt = optim.AdamW(z.parameters(), lr=args.mse_step_size, weight_decay=0.00000000)
|
1301 |
+
if self.normal_flip_optim == True:
|
1302 |
+
if randint(1, 2) == 1:
|
1303 |
+
opt = torch.optim.AdamW(
|
1304 |
+
z.parameters(), lr=args.step_size, weight_decay=0.00000000
|
1305 |
+
)
|
1306 |
+
# opt = Ranger21(z.parameters(), lr=args.step_size, weight_decay=0.00000000)
|
1307 |
+
else:
|
1308 |
+
opt = optim.DiffGrad(
|
1309 |
+
z.parameters(), lr=args.step_size, weight_decay=0.00000000
|
1310 |
+
)
|
1311 |
+
else:
|
1312 |
+
opt = torch.optim.AdamW(
|
1313 |
+
z.parameters(), lr=args.step_size, weight_decay=0.00000000
|
1314 |
+
)
|
1315 |
+
|
1316 |
+
self.cur_step_size = args.mse_step_size
|
1317 |
+
|
1318 |
+
normalize = transforms.Normalize(
|
1319 |
+
mean=[0.48145466, 0.4578275, 0.40821073],
|
1320 |
+
std=[0.26862954, 0.26130258, 0.27577711],
|
1321 |
+
)
|
1322 |
+
|
1323 |
+
pMs = []
|
1324 |
+
altpMs = []
|
1325 |
+
|
1326 |
+
for prompt in args.prompts:
|
1327 |
+
txt, weight, stop = parse_prompt(prompt)
|
1328 |
+
embed = perceptor.encode_text(clip.tokenize(txt).to(device)).float()
|
1329 |
+
pMs.append(Prompt(embed, weight, stop).to(device))
|
1330 |
+
|
1331 |
+
for prompt in args.altprompts:
|
1332 |
+
txt, weight, stop = parse_prompt(prompt)
|
1333 |
+
embed = perceptor.encode_text(clip.tokenize(txt).to(device)).float()
|
1334 |
+
altpMs.append(Prompt(embed, weight, stop).to(device))
|
1335 |
+
|
1336 |
+
from PIL import Image
|
1337 |
+
|
1338 |
+
for prompt in args.image_prompts:
|
1339 |
+
path, weight, stop = parse_prompt(prompt)
|
1340 |
+
img = resize_image(Image.open(path).convert("RGB"), (sideX, sideY))
|
1341 |
+
batch = make_cutouts(TF.to_tensor(img).unsqueeze(0).to(device))
|
1342 |
+
embed = perceptor.encode_image(normalize(batch)).float()
|
1343 |
+
pMs.append(Prompt(embed, weight, stop).to(device))
|
1344 |
+
|
1345 |
+
for seed, weight in zip(args.noise_prompt_seeds, args.noise_prompt_weights):
|
1346 |
+
gen = torch.Generator().manual_seed(seed)
|
1347 |
+
embed = torch.empty([1, perceptor.visual.output_dim]).normal_(
|
1348 |
+
generator=gen
|
1349 |
+
)
|
1350 |
+
pMs.append(Prompt(embed, weight).to(device))
|
1351 |
+
if self.usealtprompts:
|
1352 |
+
altpMs.append(Prompt(embed, weight).to(device))
|
1353 |
+
|
1354 |
+
self.model, self.perceptor = model, perceptor
|
1355 |
+
self.make_cutouts = make_cutouts
|
1356 |
+
self.imageSize = (sideX, sideY)
|
1357 |
+
self.prompts = pMs
|
1358 |
+
self.altprompts = altpMs
|
1359 |
+
self.opt = opt
|
1360 |
+
self.normalize = normalize
|
1361 |
+
self.z, self.z_orig, self.z_min, self.z_max = z, z_orig, z_min, z_max
|
1362 |
+
self.setup_metadata(args2.seed)
|
1363 |
+
self.mse_weight = self.args.init_weight
|
1364 |
+
|
1365 |
+
def synth(self, z):
|
1366 |
+
if self.args.is_gumbel:
|
1367 |
+
z_q = vector_quantize(
|
1368 |
+
z.movedim(1, 3), self.model.quantize.embed.weight
|
1369 |
+
).movedim(3, 1)
|
1370 |
+
else:
|
1371 |
+
z_q = vector_quantize(
|
1372 |
+
z.movedim(1, 3), self.model.quantize.embedding.weight
|
1373 |
+
).movedim(3, 1)
|
1374 |
+
return clamp_with_grad(self.model.decode(z_q).add(1).div(2), 0, 1)
|
1375 |
+
|
1376 |
+
def add_metadata(self, path, i):
|
1377 |
+
imfile = PngImageFile(path)
|
1378 |
+
meta = PngInfo()
|
1379 |
+
step_meta = {"iterations": i}
|
1380 |
+
step_meta.update(self.metadata)
|
1381 |
+
# meta.add_itxt('vqgan-params', json.dumps(step_meta), zip=True)
|
1382 |
+
imfile.save(path, pnginfo=meta)
|
1383 |
+
# Hey you. This one's for Glooperpogger#7353 on Discord (Gloop has a gun), they are a nice snek
|
1384 |
+
|
1385 |
+
@torch.no_grad()
|
1386 |
+
def checkin(self, i, losses, x):
|
1387 |
+
out = self.synth(self.z.average)
|
1388 |
+
|
1389 |
+
batchpath = "./"
|
1390 |
+
TF.to_pil_image(out[0].cpu()).save(args2.image_file)
|
1391 |
+
|
1392 |
+
def unique_index(self, batchpath):
|
1393 |
+
i = 0
|
1394 |
+
while i < 10000:
|
1395 |
+
if os.path.isfile(batchpath + "/" + str(i) + ".png"):
|
1396 |
+
i = i + 1
|
1397 |
+
else:
|
1398 |
+
return batchpath + "/" + str(i) + ".png"
|
1399 |
+
|
1400 |
+
def ascend_txt(self, i):
|
1401 |
+
out = self.synth(self.z.tensor)
|
1402 |
+
iii = self.perceptor.encode_image(
|
1403 |
+
self.normalize(self.make_cutouts(out))
|
1404 |
+
).float()
|
1405 |
+
|
1406 |
+
result = []
|
1407 |
+
if self.args.init_weight and self.mse_weight > 0:
|
1408 |
+
result.append(
|
1409 |
+
F.mse_loss(self.z.tensor, self.z_orig) * self.mse_weight / 2
|
1410 |
+
)
|
1411 |
+
|
1412 |
+
for prompt in self.prompts:
|
1413 |
+
result.append(prompt(iii))
|
1414 |
+
|
1415 |
+
if self.usealtprompts:
|
1416 |
+
iii = self.perceptor.encode_image(
|
1417 |
+
self.normalize(self.alt_make_cutouts(out))
|
1418 |
+
).float()
|
1419 |
+
for prompt in self.altprompts:
|
1420 |
+
result.append(prompt(iii))
|
1421 |
+
|
1422 |
+
return result
|
1423 |
+
|
1424 |
+
def train(self, i, x):
|
1425 |
+
self.opt.zero_grad()
|
1426 |
+
mse_decay = self.args.mse_decay
|
1427 |
+
mse_decay_rate = self.args.mse_decay_rate
|
1428 |
+
lossAll = self.ascend_txt(i)
|
1429 |
+
|
1430 |
+
sys.stdout.write("Iteration {}".format(i) + "\n")
|
1431 |
+
sys.stdout.flush()
|
1432 |
+
|
1433 |
+
if i % args2.update == 0:
|
1434 |
+
self.checkin(i, lossAll, x)
|
1435 |
+
|
1436 |
+
loss = sum(lossAll)
|
1437 |
+
loss.backward()
|
1438 |
+
self.opt.step()
|
1439 |
+
with torch.no_grad():
|
1440 |
+
if (
|
1441 |
+
self.mse_weight > 0
|
1442 |
+
and self.args.init_weight
|
1443 |
+
and i > 0
|
1444 |
+
and i % mse_decay_rate == 0
|
1445 |
+
):
|
1446 |
+
if self.args.is_gumbel:
|
1447 |
+
self.z_orig = vector_quantize(
|
1448 |
+
self.z.average.movedim(1, 3),
|
1449 |
+
self.model.quantize.embed.weight,
|
1450 |
+
).movedim(3, 1)
|
1451 |
+
else:
|
1452 |
+
self.z_orig = vector_quantize(
|
1453 |
+
self.z.average.movedim(1, 3),
|
1454 |
+
self.model.quantize.embedding.weight,
|
1455 |
+
).movedim(3, 1)
|
1456 |
+
if self.mse_weight - mse_decay > 0:
|
1457 |
+
self.mse_weight = self.mse_weight - mse_decay
|
1458 |
+
# print(f"updated mse weight: {self.mse_weight}")
|
1459 |
+
else:
|
1460 |
+
self.mse_weight = 0
|
1461 |
+
self.make_cutouts = flavordict[flavor](
|
1462 |
+
self.perceptor.visual.input_resolution,
|
1463 |
+
args.cutn,
|
1464 |
+
cut_pow=args.cut_pow,
|
1465 |
+
augs=args.augs,
|
1466 |
+
)
|
1467 |
+
if self.usealtprompts:
|
1468 |
+
self.alt_make_cutouts = flavordict[flavor](
|
1469 |
+
self.perceptor.visual.input_resolution,
|
1470 |
+
args.cutn,
|
1471 |
+
cut_pow=args.alt_cut_pow,
|
1472 |
+
augs=args.altaugs,
|
1473 |
+
)
|
1474 |
+
self.z = EMATensor(self.z.average, args.ema_val)
|
1475 |
+
self.new_step_size = args.step_size
|
1476 |
+
self.opt = torch.optim.AdamW(
|
1477 |
+
self.z.parameters(),
|
1478 |
+
lr=args.step_size,
|
1479 |
+
weight_decay=0.00000000,
|
1480 |
+
)
|
1481 |
+
# print(f"updated mse weight: {self.mse_weight}")
|
1482 |
+
if i > args.mse_end:
|
1483 |
+
if (
|
1484 |
+
args.step_size != args.final_step_size
|
1485 |
+
and args.max_iterations > 0
|
1486 |
+
):
|
1487 |
+
progress = (i - args.mse_end) / (args.max_iterations)
|
1488 |
+
self.cur_step_size = lerp(step_size, final_step_size, progress)
|
1489 |
+
for g in self.opt.param_groups:
|
1490 |
+
g["lr"] = self.cur_step_size
|
1491 |
+
|
1492 |
+
def run(self, x):
|
1493 |
+
j = 0
|
1494 |
+
try:
|
1495 |
+
before_start_time = time.perf_counter()
|
1496 |
+
total_steps = int(args.max_iterations + args.mse_end) - 1
|
1497 |
+
for _ in range(total_steps):
|
1498 |
+
self.train(j, x)
|
1499 |
+
if j > 0 and j % args.mse_decay_rate == 0 and self.mse_weight > 0:
|
1500 |
+
self.z = EMATensor(self.z.average, args.ema_val)
|
1501 |
+
self.opt = torch.optim.AdamW(
|
1502 |
+
self.z.parameters(),
|
1503 |
+
lr=args.mse_step_size,
|
1504 |
+
weight_decay=0.00000000,
|
1505 |
+
)
|
1506 |
+
if j >= total_steps:
|
1507 |
+
break
|
1508 |
+
self.z.update()
|
1509 |
+
j += 1
|
1510 |
+
time_past_seconds = time.perf_counter() - before_start_time
|
1511 |
+
iterations_per_second = j / time_past_seconds
|
1512 |
+
time_left = (total_steps - j) / iterations_per_second
|
1513 |
+
percentage = round((j / (total_steps + 1)) * 100)
|
1514 |
+
|
1515 |
+
import shutil
|
1516 |
+
import os
|
1517 |
+
|
1518 |
+
image_data = Image.open(args2.image_file)
|
1519 |
+
return(image_data)
|
1520 |
+
|
1521 |
+
except KeyboardInterrupt:
|
1522 |
+
pass
|
1523 |
+
except st.script_runner.StopException as e:
|
1524 |
+
torch.cuda.empty_cache()
|
1525 |
+
pass
|
1526 |
+
return j
|
1527 |
+
|
1528 |
+
def add_noise(img):
|
1529 |
+
|
1530 |
+
# Getting the dimensions of the image
|
1531 |
+
row, col = img.shape
|
1532 |
+
|
1533 |
+
# Randomly pick some pixels in the
|
1534 |
+
# image for coloring them white
|
1535 |
+
# Pick a random number between 300 and 10000
|
1536 |
+
number_of_pixels = random.randint(300, 10000)
|
1537 |
+
for i in range(number_of_pixels):
|
1538 |
+
|
1539 |
+
# Pick a random y coordinate
|
1540 |
+
y_coord = random.randint(0, row - 1)
|
1541 |
+
|
1542 |
+
# Pick a random x coordinate
|
1543 |
+
x_coord = random.randint(0, col - 1)
|
1544 |
+
|
1545 |
+
# Color that pixel to white
|
1546 |
+
img[y_coord][x_coord] = 255
|
1547 |
+
|
1548 |
+
# Randomly pick some pixels in
|
1549 |
+
# the image for coloring them black
|
1550 |
+
# Pick a random number between 300 and 10000
|
1551 |
+
number_of_pixels = random.randint(300, 10000)
|
1552 |
+
for i in range(number_of_pixels):
|
1553 |
+
|
1554 |
+
# Pick a random y coordinate
|
1555 |
+
y_coord = random.randint(0, row - 1)
|
1556 |
+
|
1557 |
+
# Pick a random x coordinate
|
1558 |
+
x_coord = random.randint(0, col - 1)
|
1559 |
+
|
1560 |
+
# Color that pixel to black
|
1561 |
+
img[y_coord][x_coord] = 0
|
1562 |
+
|
1563 |
+
return img
|
1564 |
+
|
1565 |
+
import io
|
1566 |
+
import base64
|
1567 |
+
|
1568 |
+
def image_to_data_url(img, ext):
|
1569 |
+
img_byte_arr = io.BytesIO()
|
1570 |
+
img.save(img_byte_arr, format=ext)
|
1571 |
+
img_byte_arr = img_byte_arr.getvalue()
|
1572 |
+
# ext = filename.split('.')[-1]
|
1573 |
+
prefix = f"data:image/{ext};base64,"
|
1574 |
+
return prefix + base64.b64encode(img_byte_arr).decode("utf-8")
|
1575 |
+
|
1576 |
+
import torch
|
1577 |
+
import math
|
1578 |
+
|
1579 |
+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
1580 |
+
|
1581 |
+
def rand_perlin_2d(
|
1582 |
+
shape, res, fade=lambda t: 6 * t**5 - 15 * t**4 + 10 * t**3
|
1583 |
+
):
|
1584 |
+
delta = (res[0] / shape[0], res[1] / shape[1])
|
1585 |
+
d = (shape[0] // res[0], shape[1] // res[1])
|
1586 |
+
|
1587 |
+
grid = (
|
1588 |
+
torch.stack(
|
1589 |
+
torch.meshgrid(
|
1590 |
+
torch.arange(0, res[0], delta[0]), torch.arange(0, res[1], delta[1])
|
1591 |
+
),
|
1592 |
+
dim=-1,
|
1593 |
+
)
|
1594 |
+
% 1
|
1595 |
+
)
|
1596 |
+
angles = 2 * math.pi * torch.rand(res[0] + 1, res[1] + 1)
|
1597 |
+
gradients = torch.stack((torch.cos(angles), torch.sin(angles)), dim=-1)
|
1598 |
+
|
1599 |
+
tile_grads = (
|
1600 |
+
lambda slice1, slice2: gradients[
|
1601 |
+
slice1[0] : slice1[1], slice2[0] : slice2[1]
|
1602 |
+
]
|
1603 |
+
.repeat_interleave(d[0], 0)
|
1604 |
+
.repeat_interleave(d[1], 1)
|
1605 |
+
)
|
1606 |
+
dot = lambda grad, shift: (
|
1607 |
+
torch.stack(
|
1608 |
+
(
|
1609 |
+
grid[: shape[0], : shape[1], 0] + shift[0],
|
1610 |
+
grid[: shape[0], : shape[1], 1] + shift[1],
|
1611 |
+
),
|
1612 |
+
dim=-1,
|
1613 |
+
)
|
1614 |
+
* grad[: shape[0], : shape[1]]
|
1615 |
+
).sum(dim=-1)
|
1616 |
+
|
1617 |
+
n00 = dot(tile_grads([0, -1], [0, -1]), [0, 0])
|
1618 |
+
n10 = dot(tile_grads([1, None], [0, -1]), [-1, 0])
|
1619 |
+
n01 = dot(tile_grads([0, -1], [1, None]), [0, -1])
|
1620 |
+
n11 = dot(tile_grads([1, None], [1, None]), [-1, -1])
|
1621 |
+
t = fade(grid[: shape[0], : shape[1]])
|
1622 |
+
return math.sqrt(2) * torch.lerp(
|
1623 |
+
torch.lerp(n00, n10, t[..., 0]), torch.lerp(n01, n11, t[..., 0]), t[..., 1]
|
1624 |
+
)
|
1625 |
+
|
1626 |
+
def rand_perlin_2d_octaves(desired_shape, octaves=1, persistence=0.5):
|
1627 |
+
shape = torch.tensor(desired_shape)
|
1628 |
+
shape = 2 ** torch.ceil(torch.log2(shape))
|
1629 |
+
shape = shape.type(torch.int)
|
1630 |
+
|
1631 |
+
max_octaves = int(
|
1632 |
+
min(
|
1633 |
+
octaves,
|
1634 |
+
math.log(shape[0]) / math.log(2),
|
1635 |
+
math.log(shape[1]) / math.log(2),
|
1636 |
+
)
|
1637 |
+
)
|
1638 |
+
res = torch.floor(shape / 2**max_octaves).type(torch.int)
|
1639 |
+
|
1640 |
+
noise = torch.zeros(list(shape))
|
1641 |
+
frequency = 1
|
1642 |
+
amplitude = 1
|
1643 |
+
for _ in range(max_octaves):
|
1644 |
+
noise += amplitude * rand_perlin_2d(
|
1645 |
+
shape, (frequency * res[0], frequency * res[1])
|
1646 |
+
)
|
1647 |
+
frequency *= 2
|
1648 |
+
amplitude *= persistence
|
1649 |
+
|
1650 |
+
return noise[: desired_shape[0], : desired_shape[1]]
|
1651 |
+
|
1652 |
+
def rand_perlin_rgb(desired_shape, amp=0.1, octaves=6):
|
1653 |
+
r = rand_perlin_2d_octaves(desired_shape, octaves)
|
1654 |
+
g = rand_perlin_2d_octaves(desired_shape, octaves)
|
1655 |
+
b = rand_perlin_2d_octaves(desired_shape, octaves)
|
1656 |
+
rgb = (torch.stack((r, g, b)) * amp + 1) * 0.5
|
1657 |
+
return rgb.unsqueeze(0).clip(0, 1).to(device)
|
1658 |
+
|
1659 |
+
def pyramid_noise_gen(shape, octaves=5, decay=1.0):
|
1660 |
+
n, c, h, w = shape
|
1661 |
+
noise = torch.zeros([n, c, 1, 1])
|
1662 |
+
max_octaves = int(min(math.log(h) / math.log(2), math.log(w) / math.log(2)))
|
1663 |
+
if octaves is not None and 0 < octaves:
|
1664 |
+
max_octaves = min(octaves, max_octaves)
|
1665 |
+
for i in reversed(range(max_octaves)):
|
1666 |
+
h_cur, w_cur = h // 2**i, w // 2**i
|
1667 |
+
noise = F.interpolate(
|
1668 |
+
noise, (h_cur, w_cur), mode="bicubic", align_corners=False
|
1669 |
+
)
|
1670 |
+
noise += (torch.randn([n, c, h_cur, w_cur]) / max_octaves) * decay ** (
|
1671 |
+
max_octaves - (i + 1)
|
1672 |
+
)
|
1673 |
+
return noise
|
1674 |
+
|
1675 |
+
def rand_z(model, toksX, toksY):
|
1676 |
+
e_dim = model.quantize.e_dim
|
1677 |
+
n_toks = model.quantize.n_e
|
1678 |
+
z_min = model.quantize.embedding.weight.min(dim=0).values[None, :, None, None]
|
1679 |
+
z_max = model.quantize.embedding.weight.max(dim=0).values[None, :, None, None]
|
1680 |
+
|
1681 |
+
one_hot = F.one_hot(
|
1682 |
+
torch.randint(n_toks, [toksY * toksX], device=device), n_toks
|
1683 |
+
).float()
|
1684 |
+
z = one_hot @ model.quantize.embedding.weight
|
1685 |
+
z = z.view([-1, toksY, toksX, e_dim]).permute(0, 3, 1, 2)
|
1686 |
+
|
1687 |
+
return z
|
1688 |
+
|
1689 |
+
def make_rand_init(
|
1690 |
+
mode,
|
1691 |
+
model,
|
1692 |
+
perlin_octaves,
|
1693 |
+
perlin_weight,
|
1694 |
+
pyramid_octaves,
|
1695 |
+
pyramid_decay,
|
1696 |
+
toksX,
|
1697 |
+
toksY,
|
1698 |
+
f,
|
1699 |
+
):
|
1700 |
+
|
1701 |
+
if mode == "VQGAN ZRand":
|
1702 |
+
return rand_z(model, toksX, toksY)
|
1703 |
+
elif mode == "Perlin Noise":
|
1704 |
+
rand_init = rand_perlin_rgb(
|
1705 |
+
(toksY * f, toksX * f), perlin_weight, perlin_octaves
|
1706 |
+
)
|
1707 |
+
z, *_ = model.encode(rand_init * 2 - 1)
|
1708 |
+
return z
|
1709 |
+
elif mode == "Pyramid Noise":
|
1710 |
+
rand_init = pyramid_noise_gen(
|
1711 |
+
(1, 3, toksY * f, toksX * f), pyramid_octaves, pyramid_decay
|
1712 |
+
).to(device)
|
1713 |
+
rand_init = (rand_init * 0.5 + 0.5).clip(0, 1)
|
1714 |
+
z, *_ = model.encode(rand_init * 2 - 1)
|
1715 |
+
return z
|
1716 |
+
|
1717 |
+
##################### JUICY MESS ###################################
|
1718 |
+
import os
|
1719 |
+
|
1720 |
+
imagenet_1024 = False # @param {type:"boolean"}
|
1721 |
+
imagenet_16384 = True # @param {type:"boolean"}
|
1722 |
+
gumbel_8192 = False # @param {type:"boolean"}
|
1723 |
+
sber_gumbel = False # @param {type:"boolean"}
|
1724 |
+
# imagenet_cin = False #@param {type:"boolean"}
|
1725 |
+
coco = False # @param {type:"boolean"}
|
1726 |
+
coco_1stage = False # @param {type:"boolean"}
|
1727 |
+
faceshq = False # @param {type:"boolean"}
|
1728 |
+
wikiart_1024 = False # @param {type:"boolean"}
|
1729 |
+
wikiart_16384 = False # @param {type:"boolean"}
|
1730 |
+
wikiart_7mil = False # @param {type:"boolean"}
|
1731 |
+
sflckr = False # @param {type:"boolean"}
|
1732 |
+
|
1733 |
+
##@markdown Experimental models (won't probably work, if you know how to make them work, go ahead :D):
|
1734 |
+
# celebahq = False #@param {type:"boolean"}
|
1735 |
+
# ade20k = False #@param {type:"boolean"}
|
1736 |
+
# drin = False #@param {type:"boolean"}
|
1737 |
+
# gumbel = False #@param {type:"boolean"}
|
1738 |
+
# gumbel_8192 = False #@param {type:"boolean"}
|
1739 |
+
|
1740 |
+
# Configure and run the model"""
|
1741 |
+
|
1742 |
+
# Commented out IPython magic to ensure Python compatibility.
|
1743 |
+
# @title <font color="lightgreen" size="+3">←</font> <font size="+2">🏃♂️</font> **Configure & Run** <font size="+2">🏃♂️</font>
|
1744 |
+
|
1745 |
+
import os
|
1746 |
+
import random
|
1747 |
+
import cv2
|
1748 |
+
|
1749 |
+
# from google.colab import drive
|
1750 |
+
from PIL import Image
|
1751 |
+
from importlib import reload
|
1752 |
+
|
1753 |
+
reload(PIL.TiffTags)
|
1754 |
+
# %cd /content/
|
1755 |
+
# @markdown >`prompts` is the list of prompts to give to the AI, separated by `|`. With more than one, it will attempt to mix them together. You can add weights to different parts of the prompt by adding a `p:x` at the end of a prompt (before a `|`) where `p` is the prompt and `x` is the weight.
|
1756 |
+
|
1757 |
+
# prompts = "A fantasy landscape, by Greg Rutkowski. A lush mountain.:1 | Trending on ArtStation, unreal engine. 4K HD, realism.:0.63" #@param {type:"string"}
|
1758 |
+
|
1759 |
+
prompts = args2.prompt
|
1760 |
+
|
1761 |
+
width = args2.sizex # @param {type:"number"}
|
1762 |
+
height = args2.sizey # @param {type:"number"}
|
1763 |
+
|
1764 |
+
# model = "ImageNet 16384" #@param ['ImageNet 16384', 'ImageNet 1024', "Gumbel 8192", "Sber Gumbel", 'WikiArt 1024', 'WikiArt 16384', 'WikiArt 7mil', 'COCO-Stuff', 'COCO 1 Stage', 'FacesHQ', 'S-FLCKR']
|
1765 |
+
#model = args2.vqgan_model
|
1766 |
+
|
1767 |
+
#if model == "Gumbel 8192" or model == "Sber Gumbel":
|
1768 |
+
# is_gumbel = True
|
1769 |
+
#else:
|
1770 |
+
# is_gumbel = False
|
1771 |
+
is_gumbel = False
|
1772 |
+
##@markdown The flavor effects the output greatly. Each has it's own characteristics and depending on what you choose, you'll get a widely different result with the same prompt and seed. Ginger is the default, nothing special. Cumin results more of a painting, while Holywater makes everythng super funky and/or colorful. Custom is a custom flavor, use the utilities above.
|
1773 |
+
# Type "old_holywater" to use the old holywater flavor from Hypertron V1
|
1774 |
+
flavor = (
|
1775 |
+
args2.flavor
|
1776 |
+
) #'ginger' #@param ["ginger", "cumin", "holywater", "zynth", "wyvern", "aaron", "moth", "juu", "custom"]
|
1777 |
+
template = (
|
1778 |
+
args2.template
|
1779 |
+
) # @param ["none", "----------Parameter Tweaking----------", "Balanced", "Detailed", "Consistent Creativity", "Realistic", "Smooth", "Subtle MSE", "Hyper Fast Results", "----------Complete Overhaul----------", "flag", "planet", "creature", "human", "----------Sizes----------", "Size: Square", "Size: Landscape", "Size: Poster", "----------Prompt Modifiers----------", "Better - Fast", "Better - Slow", "Movie Poster", "Negative Prompt", "Better Quality"]
|
1780 |
+
##@markdown To use initial or target images, upload it on the left in the file browser. You can also use previous outputs by putting its path below, e.g. `batch_01/0.png`. If your previous output is saved to drive, you can use the checkbox so you don't have to type the whole path.
|
1781 |
+
init = "default noise" # @param ["default noise", "image", "random image", "salt and pepper noise", "salt and pepper noise on init image"]
|
1782 |
+
|
1783 |
+
if args2.seed_image is None:
|
1784 |
+
init_image = "" # args2.seed_image #""#@param {type:"string"}
|
1785 |
else:
|
1786 |
+
init_image = args2.seed_image # ""#@param {type:"string"}
|
1787 |
+
|
1788 |
+
if init == "random image":
|
1789 |
+
url = (
|
1790 |
+
"https://picsum.photos/"
|
1791 |
+
+ str(width)
|
1792 |
+
+ "/"
|
1793 |
+
+ str(height)
|
1794 |
+
+ "?blur="
|
1795 |
+
+ str(random.randrange(5, 10))
|
1796 |
+
)
|
1797 |
+
urllib.request.urlretrieve(url, "Init_Img/Image.png")
|
1798 |
+
init_image = "Init_Img/Image.png"
|
1799 |
+
elif init == "random image clear":
|
1800 |
+
url = "https://source.unsplash.com/random/" + str(width) + "x" + str(height)
|
1801 |
+
urllib.request.urlretrieve(url, "Init_Img/Image.png")
|
1802 |
+
init_image = "Init_Img/Image.png"
|
1803 |
+
elif init == "random image clear 2":
|
1804 |
+
url = "https://loremflickr.com/" + str(width) + "/" + str(height)
|
1805 |
+
urllib.request.urlretrieve(url, "Init_Img/Image.png")
|
1806 |
+
init_image = "Init_Img/Image.png"
|
1807 |
+
elif init == "salt and pepper noise":
|
1808 |
+
urllib.request.urlretrieve(
|
1809 |
+
"https://i.stack.imgur.com/olrL8.png", "Init_Img/Image.png"
|
1810 |
+
)
|
1811 |
+
import cv2
|
1812 |
+
|
1813 |
+
img = cv2.imread("Init_Img/Image.png", 0)
|
1814 |
+
cv2.imwrite("Init_Img/Image.png", add_noise(img))
|
1815 |
+
init_image = "Init_Img/Image.png"
|
1816 |
+
elif init == "salt and pepper noise on init image":
|
1817 |
+
img = cv2.imread(init_image, 0)
|
1818 |
+
cv2.imwrite("Init_Img/Image.png", add_noise(img))
|
1819 |
+
init_image = "Init_Img/Image.png"
|
1820 |
+
elif init == "perlin noise":
|
1821 |
+
# For some reason Colab started crashing from this
|
1822 |
+
import noise
|
1823 |
+
import numpy as np
|
1824 |
+
from PIL import Image
|
1825 |
+
|
1826 |
+
shape = (width, height)
|
1827 |
+
scale = 100
|
1828 |
+
octaves = 6
|
1829 |
+
persistence = 0.5
|
1830 |
+
lacunarity = 2.0
|
1831 |
+
seed = np.random.randint(0, 100000)
|
1832 |
+
world = np.zeros(shape)
|
1833 |
+
for i in range(shape[0]):
|
1834 |
+
for j in range(shape[1]):
|
1835 |
+
world[i][j] = noise.pnoise2(
|
1836 |
+
i / scale,
|
1837 |
+
j / scale,
|
1838 |
+
octaves=octaves,
|
1839 |
+
persistence=persistence,
|
1840 |
+
lacunarity=lacunarity,
|
1841 |
+
repeatx=1024,
|
1842 |
+
repeaty=1024,
|
1843 |
+
base=seed,
|
1844 |
+
)
|
1845 |
+
Image.fromarray(prep_world(world)).convert("L").save("Init_Img/Image.png")
|
1846 |
+
init_image = "Init_Img/Image.png"
|
1847 |
+
elif init == "black and white":
|
1848 |
+
url = "https://www.random.org/bitmaps/?format=png&width=300&height=300&zoom=1"
|
1849 |
+
urllib.request.urlretrieve(url, "Init_Img/Image.png")
|
1850 |
+
init_image = "Init_Img/Image.png"
|
1851 |
+
|
1852 |
+
seed = args2.seed # @param {type:"number"}
|
1853 |
+
# @markdown >iterations excludes iterations spent during the mse phase, if it is being used. The total iterations will be more if `mse_decay_rate` is more than 0.
|
1854 |
+
iterations = args2.iterations # @param {type:"number"}
|
1855 |
+
transparent_png = False # @param {type:"boolean"}
|
1856 |
+
|
1857 |
+
# @markdown <font size="+3">⚠</font> **ADVANCED SETTINGS** <font size="+3">⚠</font>
|
1858 |
+
# @markdown ---
|
1859 |
+
# @markdown ---
|
1860 |
+
|
1861 |
+
# @markdown >If you want to make multiple images with different prompts, use this. Seperate different prompts for different images with a `~` (example: `prompt1~prompt1~prompt3`). Iter is the iterations you want each image to run for. If you use MSE, I'd type a pretty low number (about 10).
|
1862 |
+
multiple_prompt_batches = False # @param {type:"boolean"}
|
1863 |
+
multiple_prompt_batches_iter = 300 # @param {type:"number"}
|
1864 |
+
|
1865 |
+
# @markdown >`folder_name` is the name of the folder you want to output your result(s) to. Previous outputs will NOT be overwritten. By default, it will be saved to the colab's root folder, but the `save_to_drive` checkbox will save it to `MyDrive\VQGAN_Output` instead.
|
1866 |
+
folder_name = "" # @param {type:"string"}
|
1867 |
+
save_to_drive = False # @param {type:"boolean"}
|
1868 |
+
prompt_experiment = "None" # @param ['None', 'Fever Dream', 'Philipuss’s Basement', 'Vivid Turmoil', 'Mad Dad', 'Platinum', 'Negative Energy']
|
1869 |
+
if prompt_experiment == "Fever Dream":
|
1870 |
+
prompts = "<|startoftext|>" + prompts + "<|endoftext|>"
|
1871 |
+
elif prompt_experiment == "Vivid Turmoil":
|
1872 |
+
prompts = prompts.replace(" ", "¡")
|
1873 |
+
prompts = "¬" + prompts + "®"
|
1874 |
+
elif prompt_experiment == "Mad Dad":
|
1875 |
+
prompts = prompts.replace(" ", "\\s+")
|
1876 |
+
elif prompt_experiment == "Platinum":
|
1877 |
+
prompts = "~!" + prompts + "!~"
|
1878 |
+
prompts = prompts.replace(" ", "</w>")
|
1879 |
+
elif prompt_experiment == "Philipuss’s Basement":
|
1880 |
+
prompts = "<|startoftext|>" + prompts
|
1881 |
+
prompts = prompts.replace(" ", "<|endoftext|><|startoftext|>")
|
1882 |
+
elif prompt_experiment == "Lowercase":
|
1883 |
+
prompts = prompts.lower()
|
1884 |
+
|
1885 |
+
|
1886 |
+
# @markdown >Target images work like prompts, write the name of the image. You can add multiple target images by seperating them with a `|`.
|
1887 |
+
target_images = "" # @param {type:"string"}
|
1888 |
+
|
1889 |
+
# @markdown ><font size="+2">☢</font> Advanced values. Values of cut_pow below 1 prioritize structure over detail, and vice versa for above 1. Step_size affects how wild the change between iterations is, and if final_step_size is not 0, step_size will interpolate towards it over time.
|
1890 |
+
# @markdown >Cutn affects on 'Creativity': less cutout will lead to more random/creative results, sometimes barely readable, while higher values (90+) lead to very stable, photo-like outputs
|
1891 |
+
cutn = 130 # @param {type:"number"}
|
1892 |
+
cut_pow = 1 # @param {type:"number"}
|
1893 |
+
# @markdown >Step_size is like weirdness. Lower: more accurate/realistic, slower; Higher: less accurate/more funky, faster.
|
1894 |
+
step_size = 0.1 # @param {type:"number"}
|
1895 |
+
# @markdown >Start_step_size is a temporary step_size that will be active only in the first 10 iterations. It (sometimes) helps with speed. If it's set to 0, it won't be used.
|
1896 |
+
start_step_size = 0 # @param {type:"number"}
|
1897 |
+
# @markdown >Final_step_size is a goal step_size which the AI will try and reach. If set to 0, it won't be used.
|
1898 |
+
final_step_size = 0 # @param {type:"number"}
|
1899 |
+
if start_step_size <= 0:
|
1900 |
+
start_step_size = step_size
|
1901 |
+
if final_step_size <= 0:
|
1902 |
+
final_step_size = step_size
|
1903 |
+
|
1904 |
+
# @markdown ---
|
1905 |
+
|
1906 |
+
# @markdown >EMA maintains a moving average of trained parameters. The number below is the rate of decay (higher means slower).
|
1907 |
+
ema_val = 0.98 # @param {type:"number"}
|
1908 |
+
|
1909 |
+
# @markdown >If you want to keep starting from the same point, set `gen_seed` to a positive number. `-1` will make it random every time.
|
1910 |
+
gen_seed = -1 # @param {type:'number'}
|
1911 |
+
|
1912 |
+
init_image_in_drive = False # @param {type:"boolean"}
|
1913 |
+
if init_image_in_drive and init_image:
|
1914 |
+
init_image = "/content/drive/MyDrive/VQGAN_Output/" + init_image
|
1915 |
+
|
1916 |
+
images_interval = args2.update # @param {type:"number"}
|
1917 |
+
|
1918 |
+
# I think you should give "Free Thoughts on the Proceedings of the Continental Congress" a read, really funny and actually well-written, Hamilton presented it in a bad light IMO.
|
1919 |
+
|
1920 |
+
batch_size = 1 # @param {type:"number"}
|
1921 |
+
|
1922 |
+
# @markdown ---
|
1923 |
+
|
1924 |
+
# @markdown <font size="+1">🔮</font> **MSE Regulization** <font size="+1">🔮</font>
|
1925 |
+
# Based off of this notebook: https://colab.research.google.com/drive/1gFn9u3oPOgsNzJWEFmdK-N9h_y65b8fj?usp=sharing - already in credits
|
1926 |
+
use_mse = args2.mse # @param {type:"boolean"}
|
1927 |
+
mse_images_interval = images_interval
|
1928 |
+
mse_init_weight = 0.2 # @param {type:"number"}
|
1929 |
+
mse_decay_rate = 160 # @param {type:"number"}
|
1930 |
+
mse_epoches = 10 # @param {type:"number"}
|
1931 |
+
##@param {type:"number"}
|
1932 |
+
|
1933 |
+
# @markdown >Overwrites the usual values during the mse phase if included. If any value is 0, its normal counterpart is used instead.
|
1934 |
+
mse_with_zeros = True # @param {type:"boolean"}
|
1935 |
+
mse_step_size = 0.87 # @param {type:"number"}
|
1936 |
+
mse_cutn = 42 # @param {type:"number"}
|
1937 |
+
mse_cut_pow = 0.75 # @param {type:"number"}
|
1938 |
+
|
1939 |
+
# @markdown >normal_flip_optim flips between two optimizers during the normal (not MSE) phase. It can improve quality, but it's kind of experimental, use at your own risk.
|
1940 |
+
normal_flip_optim = True # @param {type:"boolean"}
|
1941 |
+
##@markdown >Adding some TV may make the image blurrier but also helps to get rid of noise. A good value to try might be 0.1.
|
1942 |
+
# tv_weight = 0.1 #@param {type:'number'}
|
1943 |
+
# @markdown ---
|
1944 |
+
|
1945 |
+
# @markdown >`altprompts` is a set of prompts that take in a different augmentation pipeline, and can have their own cut_pow. At the moment, the default "alt augment" settings flip the picture cutouts upside down before evaluating. This can be good for optical illusion images. If either cut_pow value is 0, it will use the same value as the normal prompts.
|
1946 |
+
altprompts = "" # @param {type:"string"}
|
1947 |
+
altprompt_mode = "flipped"
|
1948 |
+
##@param ["normal" , "flipped", "sideways"]
|
1949 |
+
alt_cut_pow = 0 # @param {type:"number"}
|
1950 |
+
alt_mse_cut_pow = 0 # @param {type:"number"}
|
1951 |
+
# altprompt_type = "upside-down" #@param ['upside-down', 'as']
|
1952 |
+
|
1953 |
+
##@markdown ---
|
1954 |
+
##@markdown <font size="+1">💫</font> **Zooming and Moving** <font size="+1">💫</font>
|
1955 |
+
zoom = False
|
1956 |
+
##@param {type:"boolean"}
|
1957 |
+
zoom_speed = 100
|
1958 |
+
##@param {type:"number"}
|
1959 |
+
zoom_frequency = 20
|
1960 |
+
##@param {type:"number"}
|
1961 |
+
|
1962 |
+
# @markdown ---
|
1963 |
+
# @markdown On an unrelated note, if you get any errors while running this, restart the runtime and run the first cell again. If that doesn't work either, message me on Discord (Philipuss#4066).
|
1964 |
+
|
1965 |
+
model_names = {
|
1966 |
+
"vqgan_imagenet_f16_16384": "vqgan_imagenet_f16_16384",
|
1967 |
+
"ImageNet 1024": "vqgan_imagenet_f16_1024",
|
1968 |
+
"Gumbel 8192": "gumbel_8192",
|
1969 |
+
"Sber Gumbel": "sber_gumbel",
|
1970 |
+
"imagenet_cin": "imagenet_cin",
|
1971 |
+
"WikiArt 1024": "wikiart_1024",
|
1972 |
+
"WikiArt 16384": "wikiart_16384",
|
1973 |
+
"COCO-Stuff": "coco",
|
1974 |
+
"FacesHQ": "faceshq",
|
1975 |
+
"S-FLCKR": "sflckr",
|
1976 |
+
"WikiArt 7mil": "wikiart_7mil",
|
1977 |
+
"COCO 1 Stage": "coco_1stage",
|
1978 |
+
}
|
1979 |
+
|
1980 |
+
if template == "Better - Fast":
|
1981 |
+
prompts = prompts + ". Detailed artwork. ArtStationHQ. unreal engine. 4K HD."
|
1982 |
+
elif template == "Better - Slow":
|
1983 |
+
prompts = (
|
1984 |
+
prompts
|
1985 |
+
+ ". Detailed artwork. Trending on ArtStation. unreal engine. | Rendered in Maya. "
|
1986 |
+
+ prompts
|
1987 |
+
+ ". 4K HD."
|
1988 |
+
)
|
1989 |
+
elif template == "Movie Poster":
|
1990 |
+
prompts = prompts + ". Movie poster. Rendered in unreal engine. ArtStationHQ."
|
1991 |
+
width = 400
|
1992 |
+
height = 592
|
1993 |
+
elif template == "flag":
|
1994 |
+
prompts = (
|
1995 |
+
"A photo of a flag of the country "
|
1996 |
+
+ prompts
|
1997 |
+
+ " | Flag of "
|
1998 |
+
+ prompts
|
1999 |
+
+ ". White background."
|
2000 |
+
)
|
2001 |
+
# import cv2
|
2002 |
+
# img = cv2.imread('templates/flag.png', 0)
|
2003 |
+
# cv2.imwrite('templates/final_flag.png', add_noise(img))
|
2004 |
+
init_image = "templates/flag.png"
|
2005 |
+
transparent_png = True
|
2006 |
+
elif template == "planet":
|
2007 |
+
import cv2
|
2008 |
+
|
2009 |
+
img = cv2.imread("templates/planet.png", 0)
|
2010 |
+
cv2.imwrite("templates/final_planet.png", add_noise(img))
|
2011 |
+
prompts = (
|
2012 |
+
"A photo of the planet "
|
2013 |
+
+ prompts
|
2014 |
+
+ ". Planet in the middle with black background. | The planet of "
|
2015 |
+
+ prompts
|
2016 |
+
+ ". Photo of a planet. Black background. Trending on ArtStation. | Colorful."
|
2017 |
+
)
|
2018 |
+
init_image = "templates/final_planet.png"
|
2019 |
+
elif template == "creature":
|
2020 |
+
# import cv2
|
2021 |
+
# img = cv2.imread('templates/planet.png', 0)
|
2022 |
+
# cv2.imwrite('templates/final_planet.png', add_noise(img))
|
2023 |
+
prompts = (
|
2024 |
+
"A photo of a creature with "
|
2025 |
+
+ prompts
|
2026 |
+
+ ". Animal in the middle with white background. | The creature has "
|
2027 |
+
+ prompts
|
2028 |
+
+ ". Photo of a creature/animal. White background. Detailed image of a creature. | White background."
|
2029 |
+
)
|
2030 |
+
init_image = "templates/creature.png"
|
2031 |
+
# transparent_png = True
|
2032 |
+
elif template == "Detailed":
|
2033 |
+
prompts = (
|
2034 |
+
prompts
|
2035 |
+
+ ", by Puer Udger. Detailed artwork, trending on artstation. 4K HD, realism."
|
2036 |
+
)
|
2037 |
+
flavor = "cumin"
|
2038 |
+
elif template == "human":
|
2039 |
+
init_image = "/content/templates/human.png"
|
2040 |
+
elif template == "Realistic":
|
2041 |
+
cutn = 200
|
2042 |
+
step_size = 0.03
|
2043 |
+
cut_pow = 0.2
|
2044 |
+
flavor = "holywater"
|
2045 |
+
elif template == "Consistent Creativity":
|
2046 |
+
flavor = "cumin"
|
2047 |
+
cut_pow = 0.01
|
2048 |
+
cutn = 136
|
2049 |
+
step_size = 0.08
|
2050 |
+
mse_step_size = 0.41
|
2051 |
+
mse_cut_pow = 0.3
|
2052 |
+
ema_val = 0.99
|
2053 |
+
normal_flip_optim = False
|
2054 |
+
elif template == "Smooth":
|
2055 |
+
flavor = "wyvern"
|
2056 |
+
step_size = 0.10
|
2057 |
+
cutn = 120
|
2058 |
+
normal_flip_optim = False
|
2059 |
+
tv_weight = 10
|
2060 |
+
elif template == "Subtle MSE":
|
2061 |
+
mse_init_weight = 0.07
|
2062 |
+
mse_decay_rate = 130
|
2063 |
+
mse_step_size = 0.2
|
2064 |
+
mse_cutn = 100
|
2065 |
+
mse_cut_pow = 0.6
|
2066 |
+
elif template == "Balanced":
|
2067 |
+
cutn = 130
|
2068 |
+
cut_pow = 1
|
2069 |
+
step_size = 0.16
|
2070 |
+
final_step_size = 0
|
2071 |
+
ema_val = 0.98
|
2072 |
+
mse_init_weight = 0.2
|
2073 |
+
mse_decay_rate = 130
|
2074 |
+
mse_with_zeros = True
|
2075 |
+
mse_step_size = 0.9
|
2076 |
+
mse_cutn = 50
|
2077 |
+
mse_cut_pow = 0.8
|
2078 |
+
normal_flip_optim = True
|
2079 |
+
elif template == "Size: Square":
|
2080 |
+
width = 450
|
2081 |
+
height = 450
|
2082 |
+
elif template == "Size: Landscape":
|
2083 |
+
width = 480
|
2084 |
+
height = 336
|
2085 |
+
elif template == "Size: Poster":
|
2086 |
+
width = 336
|
2087 |
+
height = 480
|
2088 |
+
elif template == "Negative Prompt":
|
2089 |
+
prompts = prompts.replace(":", ":-")
|
2090 |
+
prompts = prompts.replace(":--", ":")
|
2091 |
+
elif template == "Hyper Fast Results":
|
2092 |
+
step_size = 1
|
2093 |
+
ema_val = 0.3
|
2094 |
+
cutn = 30
|
2095 |
+
elif template == "Better Quality":
|
2096 |
+
prompts = (
|
2097 |
+
prompts + ":1 | Watermark, blurry, cropped, confusing, cut, incoherent:-1"
|
2098 |
+
)
|
2099 |
+
|
2100 |
+
mse_decay = 0
|
2101 |
+
|
2102 |
+
if use_mse == False:
|
2103 |
+
mse_init_weight = 0.0
|
2104 |
+
else:
|
2105 |
+
mse_decay = mse_init_weight / mse_epoches
|
2106 |
+
|
2107 |
+
|
2108 |
+
if seed == -1:
|
2109 |
+
seed = None
|
2110 |
+
if init_image == "None":
|
2111 |
+
init_image = None
|
2112 |
+
if target_images == "None" or not target_images:
|
2113 |
+
target_images = []
|
2114 |
+
else:
|
2115 |
+
target_images = target_images.split("|")
|
2116 |
+
target_images = [image.strip() for image in target_images]
|
2117 |
+
|
2118 |
+
prompts = [phrase.strip() for phrase in prompts.split("|")]
|
2119 |
+
if prompts == [""]:
|
2120 |
+
prompts = []
|
2121 |
+
|
2122 |
+
altprompts = [phrase.strip() for phrase in altprompts.split("|")]
|
2123 |
+
if altprompts == [""]:
|
2124 |
+
altprompts = []
|
2125 |
+
|
2126 |
+
if mse_images_interval == 0:
|
2127 |
+
mse_images_interval = images_interval
|
2128 |
+
if mse_step_size == 0:
|
2129 |
+
mse_step_size = step_size
|
2130 |
+
if mse_cutn == 0:
|
2131 |
+
mse_cutn = cutn
|
2132 |
+
if mse_cut_pow == 0:
|
2133 |
+
mse_cut_pow = cut_pow
|
2134 |
+
if alt_cut_pow == 0:
|
2135 |
+
alt_cut_pow = cut_pow
|
2136 |
+
if alt_mse_cut_pow == 0:
|
2137 |
+
alt_mse_cut_pow = mse_cut_pow
|
2138 |
+
|
2139 |
+
augs = nn.Sequential(
|
2140 |
+
K.RandomHorizontalFlip(p=0.5),
|
2141 |
+
K.RandomSharpness(0.3, p=0.4),
|
2142 |
+
K.RandomGaussianBlur((3, 3), (4.5, 4.5), p=0.3),
|
2143 |
+
# K.RandomGaussianNoise(p=0.5),
|
2144 |
+
# K.RandomElasticTransform(kernel_size=(33, 33), sigma=(7,7), p=0.2),
|
2145 |
+
K.RandomAffine(
|
2146 |
+
degrees=30, translate=0.1, p=0.8, padding_mode="border"
|
2147 |
+
), # padding_mode=2
|
2148 |
+
K.RandomPerspective(
|
2149 |
+
0.2,
|
2150 |
+
p=0.4,
|
2151 |
+
),
|
2152 |
+
K.ColorJitter(hue=0.01, saturation=0.01, p=0.7),
|
2153 |
+
K.RandomGrayscale(p=0.1),
|
2154 |
+
)
|
2155 |
+
|
2156 |
+
if altprompt_mode == "normal":
|
2157 |
+
altaugs = nn.Sequential(
|
2158 |
+
K.RandomRotation(degrees=90.0, return_transform=True),
|
2159 |
+
K.RandomHorizontalFlip(p=0.5),
|
2160 |
+
K.RandomSharpness(0.3, p=0.4),
|
2161 |
+
K.RandomGaussianBlur((3, 3), (4.5, 4.5), p=0.3),
|
2162 |
+
# K.RandomGaussianNoise(p=0.5),
|
2163 |
+
# K.RandomElasticTransform(kernel_size=(33, 33), sigma=(7,7), p=0.2),
|
2164 |
+
K.RandomAffine(
|
2165 |
+
degrees=30, translate=0.1, p=0.8, padding_mode="border"
|
2166 |
+
), # padding_mode=2
|
2167 |
+
K.RandomPerspective(
|
2168 |
+
0.2,
|
2169 |
+
p=0.4,
|
2170 |
+
),
|
2171 |
+
K.ColorJitter(hue=0.01, saturation=0.01, p=0.7),
|
2172 |
+
K.RandomGrayscale(p=0.1),
|
2173 |
+
)
|
2174 |
+
elif altprompt_mode == "flipped":
|
2175 |
+
altaugs = nn.Sequential(
|
2176 |
+
K.RandomHorizontalFlip(p=0.5),
|
2177 |
+
# K.RandomRotation(degrees=90.0),
|
2178 |
+
K.RandomVerticalFlip(p=1),
|
2179 |
+
K.RandomSharpness(0.3, p=0.4),
|
2180 |
+
K.RandomGaussianBlur((3, 3), (4.5, 4.5), p=0.3),
|
2181 |
+
# K.RandomGaussianNoise(p=0.5),
|
2182 |
+
# K.RandomElasticTransform(kernel_size=(33, 33), sigma=(7,7), p=0.2),
|
2183 |
+
K.RandomAffine(
|
2184 |
+
degrees=30, translate=0.1, p=0.8, padding_mode="border"
|
2185 |
+
), # padding_mode=2
|
2186 |
+
K.RandomPerspective(
|
2187 |
+
0.2,
|
2188 |
+
p=0.4,
|
2189 |
+
),
|
2190 |
+
K.ColorJitter(hue=0.01, saturation=0.01, p=0.7),
|
2191 |
+
K.RandomGrayscale(p=0.1),
|
2192 |
+
)
|
2193 |
+
elif altprompt_mode == "sideways":
|
2194 |
+
altaugs = nn.Sequential(
|
2195 |
+
K.RandomHorizontalFlip(p=0.5),
|
2196 |
+
# K.RandomRotation(degrees=90.0),
|
2197 |
+
K.RandomVerticalFlip(p=1),
|
2198 |
+
K.RandomSharpness(0.3, p=0.4),
|
2199 |
+
K.RandomGaussianBlur((3, 3), (4.5, 4.5), p=0.3),
|
2200 |
+
# K.RandomGaussianNoise(p=0.5),
|
2201 |
+
# K.RandomElasticTransform(kernel_size=(33, 33), sigma=(7,7), p=0.2),
|
2202 |
+
K.RandomAffine(
|
2203 |
+
degrees=30, translate=0.1, p=0.8, padding_mode="border"
|
2204 |
+
), # padding_mode=2
|
2205 |
+
K.RandomPerspective(
|
2206 |
+
0.2,
|
2207 |
+
p=0.4,
|
2208 |
+
),
|
2209 |
+
K.ColorJitter(hue=0.01, saturation=0.01, p=0.7),
|
2210 |
+
K.RandomGrayscale(p=0.1),
|
2211 |
+
)
|
2212 |
+
|
2213 |
+
if multiple_prompt_batches:
|
2214 |
+
prompts_all = str(prompts).split("~")
|
2215 |
+
else:
|
2216 |
+
prompts_all = prompts
|
2217 |
+
multiple_prompt_batches_iter = iterations
|
2218 |
+
|
2219 |
+
if multiple_prompt_batches:
|
2220 |
+
mtpl_prmpts_btchs = len(prompts_all)
|
2221 |
+
else:
|
2222 |
+
mtpl_prmpts_btchs = 1
|
2223 |
+
|
2224 |
+
# print(mtpl_prmpts_btchs)
|
2225 |
+
|
2226 |
+
steps_path = "./"
|
2227 |
+
zoom_path = "./"
|
2228 |
+
|
2229 |
+
path = "./"
|
2230 |
+
|
2231 |
+
iterations = multiple_prompt_batches_iter
|
2232 |
+
|
2233 |
+
for pr in range(0, mtpl_prmpts_btchs):
|
2234 |
+
# print(prompts_all[pr].replace('[\'', '').replace('\']', ''))
|
2235 |
+
if multiple_prompt_batches:
|
2236 |
+
prompts = prompts_all[pr].replace("['", "").replace("']", "")
|
2237 |
+
|
2238 |
+
if zoom:
|
2239 |
+
mdf_iter = round(iterations / zoom_frequency)
|
2240 |
+
else:
|
2241 |
+
mdf_iter = 2
|
2242 |
+
zoom_frequency = iterations
|
2243 |
+
|
2244 |
+
for iter in range(1, mdf_iter):
|
2245 |
+
if zoom:
|
2246 |
+
if iter != 0:
|
2247 |
+
image = Image.open("progress.png")
|
2248 |
+
area = (0, 0, width - zoom_speed, height - zoom_speed)
|
2249 |
+
cropped_img = image.crop(area)
|
2250 |
+
cropped_img.show()
|
2251 |
+
|
2252 |
+
new_image = cropped_img.resize((width, height))
|
2253 |
+
new_image.save("zoom.png")
|
2254 |
+
init_image = "zoom.png"
|
2255 |
+
|
2256 |
+
args = argparse.Namespace(
|
2257 |
+
prompts=prompts,
|
2258 |
+
altprompts=altprompts,
|
2259 |
+
image_prompts=target_images,
|
2260 |
+
noise_prompt_seeds=[],
|
2261 |
+
noise_prompt_weights=[],
|
2262 |
+
size=[width, height],
|
2263 |
+
init_image=init_image,
|
2264 |
+
png=transparent_png,
|
2265 |
+
init_weight=mse_init_weight,
|
2266 |
+
vqgan_model=model_names[model],
|
2267 |
+
step_size=step_size,
|
2268 |
+
start_step_size=start_step_size,
|
2269 |
+
final_step_size=final_step_size,
|
2270 |
+
cutn=cutn,
|
2271 |
+
cut_pow=cut_pow,
|
2272 |
+
mse_cutn=mse_cutn,
|
2273 |
+
mse_cut_pow=mse_cut_pow,
|
2274 |
+
mse_step_size=mse_step_size,
|
2275 |
+
display_freq=images_interval,
|
2276 |
+
mse_display_freq=mse_images_interval,
|
2277 |
+
max_iterations=zoom_frequency,
|
2278 |
+
mse_end=0,
|
2279 |
+
seed=seed,
|
2280 |
+
folder_name=folder_name,
|
2281 |
+
save_to_drive=save_to_drive,
|
2282 |
+
mse_decay_rate=mse_decay_rate,
|
2283 |
+
mse_decay=mse_decay,
|
2284 |
+
mse_with_zeros=mse_with_zeros,
|
2285 |
+
normal_flip_optim=normal_flip_optim,
|
2286 |
+
ema_val=ema_val,
|
2287 |
+
augs=augs,
|
2288 |
+
altaugs=altaugs,
|
2289 |
+
alt_cut_pow=alt_cut_pow,
|
2290 |
+
alt_mse_cut_pow=alt_mse_cut_pow,
|
2291 |
+
is_gumbel=is_gumbel,
|
2292 |
+
gen_seed=gen_seed,
|
2293 |
+
)
|
2294 |
+
|
2295 |
+
mh = ModelHost(args)
|
2296 |
+
x = 0
|
2297 |
+
|
2298 |
+
for x in range(batch_size):
|
2299 |
+
mh.setup_model(x)
|
2300 |
+
last_iter = mh.run(x)
|
2301 |
+
x = x + 1
|
2302 |
+
|
2303 |
+
#if batch_size != 1:
|
2304 |
+
# clear_output()
|
2305 |
+
# print("===============================================================================")
|
2306 |
+
#q = 0
|
2307 |
+
#while q < batch_size:
|
2308 |
+
#display(Image("/content/" + folder_name + "/" + str(q) + ".png"))
|
2309 |
+
# print("Image" + str(q) + '.png')
|
2310 |
+
#q += 1
|
2311 |
+
|
2312 |
+
if zoom:
|
2313 |
+
files = os.listdir(steps_path)
|
2314 |
+
for index, file in enumerate(files):
|
2315 |
+
os.rename(
|
2316 |
+
os.path.join(steps_path, file),
|
2317 |
+
os.path.join(
|
2318 |
+
steps_path,
|
2319 |
+
"".join([str(index + 1 + zoom_frequency * iter), ".png"]),
|
2320 |
+
),
|
2321 |
+
)
|
2322 |
+
index = index + 1
|
2323 |
+
|
2324 |
+
from pathlib import Path
|
2325 |
+
import shutil
|
2326 |
+
|
2327 |
+
src_path = steps_path
|
2328 |
+
trg_path = zoom_path
|
2329 |
+
|
2330 |
+
for src_file in range(1, mdf_iter):
|
2331 |
+
shutil.move(os.path.join(src_path, src_file), trg_path)
|
2332 |
+
|
2333 |
+
##################### START GRADIO HERE ############################
|
2334 |
+
image = gr.outputs.Image(type="pil", label="Your result")
|
2335 |
+
iface = gr.Interface(
|
2336 |
+
fn=run,
|
2337 |
+
inputs=[
|
2338 |
+
gr.inputs.Textbox(label="Prompt - try adding increments to your prompt such as 'oil on canvas', 'a painting', 'a book cover'",default="chalk pastel drawing of a dog wearing a funny hat"),
|
2339 |
+
gr.inputs.Slider(label="Steps - more steps can increase quality but will take longer to generate",default=45,maximum=50,minimum=1,step=1),
|
2340 |
+
gr.inputs.Dropdown(label="Style",choices=["none","Balanced","Detailed","Consistent Creativity","Realistic","Smooth","Subtle MSE","Hyper Fast Results"]),
|
2341 |
+
gr.inputs.Radio(label="Width", choices=[32,64,128,256,512],default=256),
|
2342 |
+
gr.inputs.Radio(label="Height", choices=[32,64,128,256,512],default=256),
|
2343 |
+
],
|
2344 |
+
outputs=[image],
|
2345 |
+
title="Generate images from text with VQGAN+CLIP",
|
2346 |
+
#description="<div>By typing a prompt and pressing submit you can generate images based on this prompt. <a href='https://github.com/CompVis/latent-diffusion' target='_blank'>Latent Diffusion</a> is a text-to-image model created by <a href='https://github.com/CompVis' target='_blank'>CompVis</a>, trained on the <a href='https://laion.ai/laion-400-open-dataset/'>LAION-400M dataset.</a><br>This UI to the model was assembled by <a style='color: rgb(245, 158, 11);font-weight:bold' href='https://twitter.com/multimodalart' target='_blank'>@multimodalart</a></div>",
|
2347 |
+
#article="<h4 style='font-size: 110%;margin-top:.5em'>Biases acknowledgment</h4><div>Despite how impressive being able to turn text into image is, beware to the fact that this model may output content that reinforces or exarcbates societal biases. According to the <a href='https://arxiv.org/abs/2112.10752' target='_blank'>Latent Diffusion paper</a>:<i> \"Deep learning modules tend to reproduce or exacerbate biases that are already present in the data\"</i>. The model was trained on an unfiltered version the LAION-400M dataset, which scrapped non-curated image-text-pairs from the internet (the exception being the the removal of illegal content) and is meant to be used for research purposes, such as this one. <a href='https://laion.ai/laion-400-open-dataset/' target='_blank'>You can read more on LAION's website</a></div><h4 style='font-size: 110%;margin-top:1em'>Who owns the images produced by this demo?</h4><div>Definetly not me! Probably you do. I say probably because the Copyright discussion about AI generated art is ongoing. So <a href='https://www.theverge.com/2022/2/21/22944335/us-copyright-office-reject-ai-generated-art-recent-entrance-to-paradise' target='_blank'>it may be the case that everything produced here falls automatically into the public domain</a>. But in any case it is either yours or is in the public domain.</div>"
|
2348 |
+
)
|
2349 |
+
iface.launch(enable_queue=True)
|
requirements.txt
CHANGED
@@ -1 +1,20 @@
|
|
1 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
-e git+https://github.com/CompVis/taming-transformers.git#egg=taming-transformers
|
2 |
+
ftfy
|
3 |
+
regex
|
4 |
+
pandas
|
5 |
+
omegaconf
|
6 |
+
pytorch-lightning
|
7 |
+
torch-fidelity
|
8 |
+
transformers
|
9 |
+
einops
|
10 |
+
gradio
|
11 |
+
torch
|
12 |
+
open_clip_torch
|
13 |
+
numpy
|
14 |
+
tqdm
|
15 |
+
torchvision
|
16 |
+
Pillow
|
17 |
+
autokeras
|
18 |
+
huggingface_hub
|
19 |
+
kornia
|
20 |
+
clip
|