7bd1acf8301721733c956a62ee480dc281b7213caae348c6d8690ec8224ac24f
Browse files- repositories/stable-diffusion-stability-ai/ldm/modules/karlo/kakao/template.py +141 -0
- repositories/stable-diffusion-stability-ai/ldm/modules/midas/__init__.py +0 -0
- repositories/stable-diffusion-stability-ai/ldm/modules/midas/__pycache__/__init__.cpython-310.pyc +0 -0
- repositories/stable-diffusion-stability-ai/ldm/modules/midas/__pycache__/api.cpython-310.pyc +0 -0
- repositories/stable-diffusion-stability-ai/ldm/modules/midas/api.py +170 -0
- repositories/stable-diffusion-stability-ai/ldm/modules/midas/midas/__init__.py +0 -0
- repositories/stable-diffusion-stability-ai/ldm/modules/midas/midas/__pycache__/__init__.cpython-310.pyc +0 -0
- repositories/stable-diffusion-stability-ai/ldm/modules/midas/midas/__pycache__/base_model.cpython-310.pyc +0 -0
- repositories/stable-diffusion-stability-ai/ldm/modules/midas/midas/__pycache__/blocks.cpython-310.pyc +0 -0
- repositories/stable-diffusion-stability-ai/ldm/modules/midas/midas/__pycache__/dpt_depth.cpython-310.pyc +0 -0
- repositories/stable-diffusion-stability-ai/ldm/modules/midas/midas/__pycache__/midas_net.cpython-310.pyc +0 -0
- repositories/stable-diffusion-stability-ai/ldm/modules/midas/midas/__pycache__/midas_net_custom.cpython-310.pyc +0 -0
- repositories/stable-diffusion-stability-ai/ldm/modules/midas/midas/__pycache__/transforms.cpython-310.pyc +0 -0
- repositories/stable-diffusion-stability-ai/ldm/modules/midas/midas/__pycache__/vit.cpython-310.pyc +0 -0
- repositories/stable-diffusion-stability-ai/ldm/modules/midas/midas/base_model.py +16 -0
- repositories/stable-diffusion-stability-ai/ldm/modules/midas/midas/blocks.py +342 -0
- repositories/stable-diffusion-stability-ai/ldm/modules/midas/midas/dpt_depth.py +109 -0
- repositories/stable-diffusion-stability-ai/ldm/modules/midas/midas/midas_net.py +76 -0
- repositories/stable-diffusion-stability-ai/ldm/modules/midas/midas/midas_net_custom.py +128 -0
- repositories/stable-diffusion-stability-ai/ldm/modules/midas/midas/transforms.py +234 -0
- repositories/stable-diffusion-stability-ai/ldm/modules/midas/midas/vit.py +491 -0
- repositories/stable-diffusion-stability-ai/ldm/modules/midas/utils.py +189 -0
- repositories/stable-diffusion-stability-ai/ldm/util.py +207 -0
- repositories/stable-diffusion-stability-ai/modelcard.md +153 -0
- repositories/stable-diffusion-stability-ai/requirements.txt +19 -0
- repositories/stable-diffusion-stability-ai/scripts/gradio/depth2img.py +184 -0
- repositories/stable-diffusion-stability-ai/scripts/gradio/inpainting.py +195 -0
- repositories/stable-diffusion-stability-ai/scripts/gradio/superresolution.py +197 -0
- repositories/stable-diffusion-stability-ai/scripts/img2img.py +279 -0
- repositories/stable-diffusion-stability-ai/scripts/streamlit/depth2img.py +157 -0
- repositories/stable-diffusion-stability-ai/scripts/streamlit/inpainting.py +195 -0
- repositories/stable-diffusion-stability-ai/scripts/streamlit/stableunclip.py +416 -0
- repositories/stable-diffusion-stability-ai/scripts/streamlit/superresolution.py +170 -0
- repositories/stable-diffusion-stability-ai/scripts/tests/test_watermark.py +18 -0
- repositories/stable-diffusion-stability-ai/scripts/txt2img.py +388 -0
- repositories/stable-diffusion-stability-ai/setup.py +13 -0
- requirements-test.txt +3 -0
- requirements.txt +33 -0
- requirements_versions.txt +31 -0
- screenshot.png +0 -0
- script.js +163 -0
- scripts/__pycache__/custom_code.cpython-310.pyc +0 -0
- scripts/__pycache__/img2imgalt.cpython-310.pyc +0 -0
- scripts/__pycache__/loopback.cpython-310.pyc +0 -0
- scripts/__pycache__/outpainting_mk_2.cpython-310.pyc +0 -0
- scripts/__pycache__/poor_mans_outpainting.cpython-310.pyc +0 -0
- scripts/__pycache__/postprocessing_codeformer.cpython-310.pyc +0 -0
- scripts/__pycache__/postprocessing_gfpgan.cpython-310.pyc +0 -0
- scripts/__pycache__/postprocessing_upscale.cpython-310.pyc +0 -0
- scripts/__pycache__/prompt_matrix.cpython-310.pyc +0 -0
repositories/stable-diffusion-stability-ai/ldm/modules/karlo/kakao/template.py
ADDED
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ------------------------------------------------------------------------------------
|
2 |
+
# Karlo-v1.0.alpha
|
3 |
+
# Copyright (c) 2022 KakaoBrain. All Rights Reserved.
|
4 |
+
# ------------------------------------------------------------------------------------
|
5 |
+
|
6 |
+
import os
|
7 |
+
import logging
|
8 |
+
import torch
|
9 |
+
|
10 |
+
from omegaconf import OmegaConf
|
11 |
+
|
12 |
+
from ldm.modules.karlo.kakao.models.clip import CustomizedCLIP, CustomizedTokenizer
|
13 |
+
from ldm.modules.karlo.kakao.models.prior_model import PriorDiffusionModel
|
14 |
+
from ldm.modules.karlo.kakao.models.decoder_model import Text2ImProgressiveModel
|
15 |
+
from ldm.modules.karlo.kakao.models.sr_64_256 import ImprovedSupRes64to256ProgressiveModel
|
16 |
+
|
17 |
+
|
18 |
+
SAMPLING_CONF = {
|
19 |
+
"default": {
|
20 |
+
"prior_sm": "25",
|
21 |
+
"prior_n_samples": 1,
|
22 |
+
"prior_cf_scale": 4.0,
|
23 |
+
"decoder_sm": "50",
|
24 |
+
"decoder_cf_scale": 8.0,
|
25 |
+
"sr_sm": "7",
|
26 |
+
},
|
27 |
+
"fast": {
|
28 |
+
"prior_sm": "25",
|
29 |
+
"prior_n_samples": 1,
|
30 |
+
"prior_cf_scale": 4.0,
|
31 |
+
"decoder_sm": "25",
|
32 |
+
"decoder_cf_scale": 8.0,
|
33 |
+
"sr_sm": "7",
|
34 |
+
},
|
35 |
+
}
|
36 |
+
|
37 |
+
CKPT_PATH = {
|
38 |
+
"prior": "prior-ckpt-step=01000000-of-01000000.ckpt",
|
39 |
+
"decoder": "decoder-ckpt-step=01000000-of-01000000.ckpt",
|
40 |
+
"sr_256": "improved-sr-ckpt-step=1.2M.ckpt",
|
41 |
+
}
|
42 |
+
|
43 |
+
|
44 |
+
class BaseSampler:
|
45 |
+
_PRIOR_CLASS = PriorDiffusionModel
|
46 |
+
_DECODER_CLASS = Text2ImProgressiveModel
|
47 |
+
_SR256_CLASS = ImprovedSupRes64to256ProgressiveModel
|
48 |
+
|
49 |
+
def __init__(
|
50 |
+
self,
|
51 |
+
root_dir: str,
|
52 |
+
sampling_type: str = "fast",
|
53 |
+
):
|
54 |
+
self._root_dir = root_dir
|
55 |
+
|
56 |
+
sampling_type = SAMPLING_CONF[sampling_type]
|
57 |
+
self._prior_sm = sampling_type["prior_sm"]
|
58 |
+
self._prior_n_samples = sampling_type["prior_n_samples"]
|
59 |
+
self._prior_cf_scale = sampling_type["prior_cf_scale"]
|
60 |
+
|
61 |
+
assert self._prior_n_samples == 1
|
62 |
+
|
63 |
+
self._decoder_sm = sampling_type["decoder_sm"]
|
64 |
+
self._decoder_cf_scale = sampling_type["decoder_cf_scale"]
|
65 |
+
|
66 |
+
self._sr_sm = sampling_type["sr_sm"]
|
67 |
+
|
68 |
+
def __repr__(self):
|
69 |
+
line = ""
|
70 |
+
line += f"Prior, sampling method: {self._prior_sm}, cf_scale: {self._prior_cf_scale}\n"
|
71 |
+
line += f"Decoder, sampling method: {self._decoder_sm}, cf_scale: {self._decoder_cf_scale}\n"
|
72 |
+
line += f"SR(64->256), sampling method: {self._sr_sm}"
|
73 |
+
|
74 |
+
return line
|
75 |
+
|
76 |
+
def load_clip(self, clip_path: str):
|
77 |
+
clip = CustomizedCLIP.load_from_checkpoint(
|
78 |
+
os.path.join(self._root_dir, clip_path)
|
79 |
+
)
|
80 |
+
clip = torch.jit.script(clip)
|
81 |
+
clip.cuda()
|
82 |
+
clip.eval()
|
83 |
+
|
84 |
+
self._clip = clip
|
85 |
+
self._tokenizer = CustomizedTokenizer()
|
86 |
+
|
87 |
+
def load_prior(
|
88 |
+
self,
|
89 |
+
ckpt_path: str,
|
90 |
+
clip_stat_path: str,
|
91 |
+
prior_config: str = "configs/prior_1B_vit_l.yaml"
|
92 |
+
):
|
93 |
+
logging.info(f"Loading prior: {ckpt_path}")
|
94 |
+
|
95 |
+
config = OmegaConf.load(prior_config)
|
96 |
+
clip_mean, clip_std = torch.load(
|
97 |
+
os.path.join(self._root_dir, clip_stat_path), map_location="cpu"
|
98 |
+
)
|
99 |
+
|
100 |
+
prior = self._PRIOR_CLASS.load_from_checkpoint(
|
101 |
+
config,
|
102 |
+
self._tokenizer,
|
103 |
+
clip_mean,
|
104 |
+
clip_std,
|
105 |
+
os.path.join(self._root_dir, ckpt_path),
|
106 |
+
strict=True,
|
107 |
+
)
|
108 |
+
prior.cuda()
|
109 |
+
prior.eval()
|
110 |
+
logging.info("done.")
|
111 |
+
|
112 |
+
self._prior = prior
|
113 |
+
|
114 |
+
def load_decoder(self, ckpt_path: str, decoder_config: str = "configs/decoder_900M_vit_l.yaml"):
|
115 |
+
logging.info(f"Loading decoder: {ckpt_path}")
|
116 |
+
|
117 |
+
config = OmegaConf.load(decoder_config)
|
118 |
+
decoder = self._DECODER_CLASS.load_from_checkpoint(
|
119 |
+
config,
|
120 |
+
self._tokenizer,
|
121 |
+
os.path.join(self._root_dir, ckpt_path),
|
122 |
+
strict=True,
|
123 |
+
)
|
124 |
+
decoder.cuda()
|
125 |
+
decoder.eval()
|
126 |
+
logging.info("done.")
|
127 |
+
|
128 |
+
self._decoder = decoder
|
129 |
+
|
130 |
+
def load_sr_64_256(self, ckpt_path: str, sr_config: str = "configs/improved_sr_64_256_1.4B.yaml"):
|
131 |
+
logging.info(f"Loading SR(64->256): {ckpt_path}")
|
132 |
+
|
133 |
+
config = OmegaConf.load(sr_config)
|
134 |
+
sr = self._SR256_CLASS.load_from_checkpoint(
|
135 |
+
config, os.path.join(self._root_dir, ckpt_path), strict=True
|
136 |
+
)
|
137 |
+
sr.cuda()
|
138 |
+
sr.eval()
|
139 |
+
logging.info("done.")
|
140 |
+
|
141 |
+
self._sr_64_256 = sr
|
repositories/stable-diffusion-stability-ai/ldm/modules/midas/__init__.py
ADDED
File without changes
|
repositories/stable-diffusion-stability-ai/ldm/modules/midas/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (189 Bytes). View file
|
|
repositories/stable-diffusion-stability-ai/ldm/modules/midas/__pycache__/api.cpython-310.pyc
ADDED
Binary file (3.63 kB). View file
|
|
repositories/stable-diffusion-stability-ai/ldm/modules/midas/api.py
ADDED
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# based on https://github.com/isl-org/MiDaS
|
2 |
+
|
3 |
+
import cv2
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
from torchvision.transforms import Compose
|
7 |
+
|
8 |
+
from ldm.modules.midas.midas.dpt_depth import DPTDepthModel
|
9 |
+
from ldm.modules.midas.midas.midas_net import MidasNet
|
10 |
+
from ldm.modules.midas.midas.midas_net_custom import MidasNet_small
|
11 |
+
from ldm.modules.midas.midas.transforms import Resize, NormalizeImage, PrepareForNet
|
12 |
+
|
13 |
+
|
14 |
+
ISL_PATHS = {
|
15 |
+
"dpt_large": "midas_models/dpt_large-midas-2f21e586.pt",
|
16 |
+
"dpt_hybrid": "midas_models/dpt_hybrid-midas-501f0c75.pt",
|
17 |
+
"midas_v21": "",
|
18 |
+
"midas_v21_small": "",
|
19 |
+
}
|
20 |
+
|
21 |
+
|
22 |
+
def disabled_train(self, mode=True):
|
23 |
+
"""Overwrite model.train with this function to make sure train/eval mode
|
24 |
+
does not change anymore."""
|
25 |
+
return self
|
26 |
+
|
27 |
+
|
28 |
+
def load_midas_transform(model_type):
|
29 |
+
# https://github.com/isl-org/MiDaS/blob/master/run.py
|
30 |
+
# load transform only
|
31 |
+
if model_type == "dpt_large": # DPT-Large
|
32 |
+
net_w, net_h = 384, 384
|
33 |
+
resize_mode = "minimal"
|
34 |
+
normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
|
35 |
+
|
36 |
+
elif model_type == "dpt_hybrid": # DPT-Hybrid
|
37 |
+
net_w, net_h = 384, 384
|
38 |
+
resize_mode = "minimal"
|
39 |
+
normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
|
40 |
+
|
41 |
+
elif model_type == "midas_v21":
|
42 |
+
net_w, net_h = 384, 384
|
43 |
+
resize_mode = "upper_bound"
|
44 |
+
normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
45 |
+
|
46 |
+
elif model_type == "midas_v21_small":
|
47 |
+
net_w, net_h = 256, 256
|
48 |
+
resize_mode = "upper_bound"
|
49 |
+
normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
50 |
+
|
51 |
+
else:
|
52 |
+
assert False, f"model_type '{model_type}' not implemented, use: --model_type large"
|
53 |
+
|
54 |
+
transform = Compose(
|
55 |
+
[
|
56 |
+
Resize(
|
57 |
+
net_w,
|
58 |
+
net_h,
|
59 |
+
resize_target=None,
|
60 |
+
keep_aspect_ratio=True,
|
61 |
+
ensure_multiple_of=32,
|
62 |
+
resize_method=resize_mode,
|
63 |
+
image_interpolation_method=cv2.INTER_CUBIC,
|
64 |
+
),
|
65 |
+
normalization,
|
66 |
+
PrepareForNet(),
|
67 |
+
]
|
68 |
+
)
|
69 |
+
|
70 |
+
return transform
|
71 |
+
|
72 |
+
|
73 |
+
def load_model(model_type):
|
74 |
+
# https://github.com/isl-org/MiDaS/blob/master/run.py
|
75 |
+
# load network
|
76 |
+
model_path = ISL_PATHS[model_type]
|
77 |
+
if model_type == "dpt_large": # DPT-Large
|
78 |
+
model = DPTDepthModel(
|
79 |
+
path=model_path,
|
80 |
+
backbone="vitl16_384",
|
81 |
+
non_negative=True,
|
82 |
+
)
|
83 |
+
net_w, net_h = 384, 384
|
84 |
+
resize_mode = "minimal"
|
85 |
+
normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
|
86 |
+
|
87 |
+
elif model_type == "dpt_hybrid": # DPT-Hybrid
|
88 |
+
model = DPTDepthModel(
|
89 |
+
path=model_path,
|
90 |
+
backbone="vitb_rn50_384",
|
91 |
+
non_negative=True,
|
92 |
+
)
|
93 |
+
net_w, net_h = 384, 384
|
94 |
+
resize_mode = "minimal"
|
95 |
+
normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
|
96 |
+
|
97 |
+
elif model_type == "midas_v21":
|
98 |
+
model = MidasNet(model_path, non_negative=True)
|
99 |
+
net_w, net_h = 384, 384
|
100 |
+
resize_mode = "upper_bound"
|
101 |
+
normalization = NormalizeImage(
|
102 |
+
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
|
103 |
+
)
|
104 |
+
|
105 |
+
elif model_type == "midas_v21_small":
|
106 |
+
model = MidasNet_small(model_path, features=64, backbone="efficientnet_lite3", exportable=True,
|
107 |
+
non_negative=True, blocks={'expand': True})
|
108 |
+
net_w, net_h = 256, 256
|
109 |
+
resize_mode = "upper_bound"
|
110 |
+
normalization = NormalizeImage(
|
111 |
+
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
|
112 |
+
)
|
113 |
+
|
114 |
+
else:
|
115 |
+
print(f"model_type '{model_type}' not implemented, use: --model_type large")
|
116 |
+
assert False
|
117 |
+
|
118 |
+
transform = Compose(
|
119 |
+
[
|
120 |
+
Resize(
|
121 |
+
net_w,
|
122 |
+
net_h,
|
123 |
+
resize_target=None,
|
124 |
+
keep_aspect_ratio=True,
|
125 |
+
ensure_multiple_of=32,
|
126 |
+
resize_method=resize_mode,
|
127 |
+
image_interpolation_method=cv2.INTER_CUBIC,
|
128 |
+
),
|
129 |
+
normalization,
|
130 |
+
PrepareForNet(),
|
131 |
+
]
|
132 |
+
)
|
133 |
+
|
134 |
+
return model.eval(), transform
|
135 |
+
|
136 |
+
|
137 |
+
class MiDaSInference(nn.Module):
|
138 |
+
MODEL_TYPES_TORCH_HUB = [
|
139 |
+
"DPT_Large",
|
140 |
+
"DPT_Hybrid",
|
141 |
+
"MiDaS_small"
|
142 |
+
]
|
143 |
+
MODEL_TYPES_ISL = [
|
144 |
+
"dpt_large",
|
145 |
+
"dpt_hybrid",
|
146 |
+
"midas_v21",
|
147 |
+
"midas_v21_small",
|
148 |
+
]
|
149 |
+
|
150 |
+
def __init__(self, model_type):
|
151 |
+
super().__init__()
|
152 |
+
assert (model_type in self.MODEL_TYPES_ISL)
|
153 |
+
model, _ = load_model(model_type)
|
154 |
+
self.model = model
|
155 |
+
self.model.train = disabled_train
|
156 |
+
|
157 |
+
def forward(self, x):
|
158 |
+
# x in 0..1 as produced by calling self.transform on a 0..1 float64 numpy array
|
159 |
+
# NOTE: we expect that the correct transform has been called during dataloading.
|
160 |
+
with torch.no_grad():
|
161 |
+
prediction = self.model(x)
|
162 |
+
prediction = torch.nn.functional.interpolate(
|
163 |
+
prediction.unsqueeze(1),
|
164 |
+
size=x.shape[2:],
|
165 |
+
mode="bicubic",
|
166 |
+
align_corners=False,
|
167 |
+
)
|
168 |
+
assert prediction.shape == (x.shape[0], 1, x.shape[2], x.shape[3])
|
169 |
+
return prediction
|
170 |
+
|
repositories/stable-diffusion-stability-ai/ldm/modules/midas/midas/__init__.py
ADDED
File without changes
|
repositories/stable-diffusion-stability-ai/ldm/modules/midas/midas/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (195 Bytes). View file
|
|
repositories/stable-diffusion-stability-ai/ldm/modules/midas/midas/__pycache__/base_model.cpython-310.pyc
ADDED
Binary file (723 Bytes). View file
|
|
repositories/stable-diffusion-stability-ai/ldm/modules/midas/midas/__pycache__/blocks.cpython-310.pyc
ADDED
Binary file (7.24 kB). View file
|
|
repositories/stable-diffusion-stability-ai/ldm/modules/midas/midas/__pycache__/dpt_depth.cpython-310.pyc
ADDED
Binary file (2.95 kB). View file
|
|
repositories/stable-diffusion-stability-ai/ldm/modules/midas/midas/__pycache__/midas_net.cpython-310.pyc
ADDED
Binary file (2.63 kB). View file
|
|
repositories/stable-diffusion-stability-ai/ldm/modules/midas/midas/__pycache__/midas_net_custom.cpython-310.pyc
ADDED
Binary file (3.75 kB). View file
|
|
repositories/stable-diffusion-stability-ai/ldm/modules/midas/midas/__pycache__/transforms.cpython-310.pyc
ADDED
Binary file (5.71 kB). View file
|
|
repositories/stable-diffusion-stability-ai/ldm/modules/midas/midas/__pycache__/vit.cpython-310.pyc
ADDED
Binary file (9.4 kB). View file
|
|
repositories/stable-diffusion-stability-ai/ldm/modules/midas/midas/base_model.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
|
4 |
+
class BaseModel(torch.nn.Module):
|
5 |
+
def load(self, path):
|
6 |
+
"""Load model from file.
|
7 |
+
|
8 |
+
Args:
|
9 |
+
path (str): file path
|
10 |
+
"""
|
11 |
+
parameters = torch.load(path, map_location=torch.device('cpu'))
|
12 |
+
|
13 |
+
if "optimizer" in parameters:
|
14 |
+
parameters = parameters["model"]
|
15 |
+
|
16 |
+
self.load_state_dict(parameters)
|
repositories/stable-diffusion-stability-ai/ldm/modules/midas/midas/blocks.py
ADDED
@@ -0,0 +1,342 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
from .vit import (
|
5 |
+
_make_pretrained_vitb_rn50_384,
|
6 |
+
_make_pretrained_vitl16_384,
|
7 |
+
_make_pretrained_vitb16_384,
|
8 |
+
forward_vit,
|
9 |
+
)
|
10 |
+
|
11 |
+
def _make_encoder(backbone, features, use_pretrained, groups=1, expand=False, exportable=True, hooks=None, use_vit_only=False, use_readout="ignore",):
|
12 |
+
if backbone == "vitl16_384":
|
13 |
+
pretrained = _make_pretrained_vitl16_384(
|
14 |
+
use_pretrained, hooks=hooks, use_readout=use_readout
|
15 |
+
)
|
16 |
+
scratch = _make_scratch(
|
17 |
+
[256, 512, 1024, 1024], features, groups=groups, expand=expand
|
18 |
+
) # ViT-L/16 - 85.0% Top1 (backbone)
|
19 |
+
elif backbone == "vitb_rn50_384":
|
20 |
+
pretrained = _make_pretrained_vitb_rn50_384(
|
21 |
+
use_pretrained,
|
22 |
+
hooks=hooks,
|
23 |
+
use_vit_only=use_vit_only,
|
24 |
+
use_readout=use_readout,
|
25 |
+
)
|
26 |
+
scratch = _make_scratch(
|
27 |
+
[256, 512, 768, 768], features, groups=groups, expand=expand
|
28 |
+
) # ViT-H/16 - 85.0% Top1 (backbone)
|
29 |
+
elif backbone == "vitb16_384":
|
30 |
+
pretrained = _make_pretrained_vitb16_384(
|
31 |
+
use_pretrained, hooks=hooks, use_readout=use_readout
|
32 |
+
)
|
33 |
+
scratch = _make_scratch(
|
34 |
+
[96, 192, 384, 768], features, groups=groups, expand=expand
|
35 |
+
) # ViT-B/16 - 84.6% Top1 (backbone)
|
36 |
+
elif backbone == "resnext101_wsl":
|
37 |
+
pretrained = _make_pretrained_resnext101_wsl(use_pretrained)
|
38 |
+
scratch = _make_scratch([256, 512, 1024, 2048], features, groups=groups, expand=expand) # efficientnet_lite3
|
39 |
+
elif backbone == "efficientnet_lite3":
|
40 |
+
pretrained = _make_pretrained_efficientnet_lite3(use_pretrained, exportable=exportable)
|
41 |
+
scratch = _make_scratch([32, 48, 136, 384], features, groups=groups, expand=expand) # efficientnet_lite3
|
42 |
+
else:
|
43 |
+
print(f"Backbone '{backbone}' not implemented")
|
44 |
+
assert False
|
45 |
+
|
46 |
+
return pretrained, scratch
|
47 |
+
|
48 |
+
|
49 |
+
def _make_scratch(in_shape, out_shape, groups=1, expand=False):
|
50 |
+
scratch = nn.Module()
|
51 |
+
|
52 |
+
out_shape1 = out_shape
|
53 |
+
out_shape2 = out_shape
|
54 |
+
out_shape3 = out_shape
|
55 |
+
out_shape4 = out_shape
|
56 |
+
if expand==True:
|
57 |
+
out_shape1 = out_shape
|
58 |
+
out_shape2 = out_shape*2
|
59 |
+
out_shape3 = out_shape*4
|
60 |
+
out_shape4 = out_shape*8
|
61 |
+
|
62 |
+
scratch.layer1_rn = nn.Conv2d(
|
63 |
+
in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
|
64 |
+
)
|
65 |
+
scratch.layer2_rn = nn.Conv2d(
|
66 |
+
in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
|
67 |
+
)
|
68 |
+
scratch.layer3_rn = nn.Conv2d(
|
69 |
+
in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
|
70 |
+
)
|
71 |
+
scratch.layer4_rn = nn.Conv2d(
|
72 |
+
in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
|
73 |
+
)
|
74 |
+
|
75 |
+
return scratch
|
76 |
+
|
77 |
+
|
78 |
+
def _make_pretrained_efficientnet_lite3(use_pretrained, exportable=False):
|
79 |
+
efficientnet = torch.hub.load(
|
80 |
+
"rwightman/gen-efficientnet-pytorch",
|
81 |
+
"tf_efficientnet_lite3",
|
82 |
+
pretrained=use_pretrained,
|
83 |
+
exportable=exportable
|
84 |
+
)
|
85 |
+
return _make_efficientnet_backbone(efficientnet)
|
86 |
+
|
87 |
+
|
88 |
+
def _make_efficientnet_backbone(effnet):
|
89 |
+
pretrained = nn.Module()
|
90 |
+
|
91 |
+
pretrained.layer1 = nn.Sequential(
|
92 |
+
effnet.conv_stem, effnet.bn1, effnet.act1, *effnet.blocks[0:2]
|
93 |
+
)
|
94 |
+
pretrained.layer2 = nn.Sequential(*effnet.blocks[2:3])
|
95 |
+
pretrained.layer3 = nn.Sequential(*effnet.blocks[3:5])
|
96 |
+
pretrained.layer4 = nn.Sequential(*effnet.blocks[5:9])
|
97 |
+
|
98 |
+
return pretrained
|
99 |
+
|
100 |
+
|
101 |
+
def _make_resnet_backbone(resnet):
|
102 |
+
pretrained = nn.Module()
|
103 |
+
pretrained.layer1 = nn.Sequential(
|
104 |
+
resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool, resnet.layer1
|
105 |
+
)
|
106 |
+
|
107 |
+
pretrained.layer2 = resnet.layer2
|
108 |
+
pretrained.layer3 = resnet.layer3
|
109 |
+
pretrained.layer4 = resnet.layer4
|
110 |
+
|
111 |
+
return pretrained
|
112 |
+
|
113 |
+
|
114 |
+
def _make_pretrained_resnext101_wsl(use_pretrained):
|
115 |
+
resnet = torch.hub.load("facebookresearch/WSL-Images", "resnext101_32x8d_wsl")
|
116 |
+
return _make_resnet_backbone(resnet)
|
117 |
+
|
118 |
+
|
119 |
+
|
120 |
+
class Interpolate(nn.Module):
|
121 |
+
"""Interpolation module.
|
122 |
+
"""
|
123 |
+
|
124 |
+
def __init__(self, scale_factor, mode, align_corners=False):
|
125 |
+
"""Init.
|
126 |
+
|
127 |
+
Args:
|
128 |
+
scale_factor (float): scaling
|
129 |
+
mode (str): interpolation mode
|
130 |
+
"""
|
131 |
+
super(Interpolate, self).__init__()
|
132 |
+
|
133 |
+
self.interp = nn.functional.interpolate
|
134 |
+
self.scale_factor = scale_factor
|
135 |
+
self.mode = mode
|
136 |
+
self.align_corners = align_corners
|
137 |
+
|
138 |
+
def forward(self, x):
|
139 |
+
"""Forward pass.
|
140 |
+
|
141 |
+
Args:
|
142 |
+
x (tensor): input
|
143 |
+
|
144 |
+
Returns:
|
145 |
+
tensor: interpolated data
|
146 |
+
"""
|
147 |
+
|
148 |
+
x = self.interp(
|
149 |
+
x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners
|
150 |
+
)
|
151 |
+
|
152 |
+
return x
|
153 |
+
|
154 |
+
|
155 |
+
class ResidualConvUnit(nn.Module):
|
156 |
+
"""Residual convolution module.
|
157 |
+
"""
|
158 |
+
|
159 |
+
def __init__(self, features):
|
160 |
+
"""Init.
|
161 |
+
|
162 |
+
Args:
|
163 |
+
features (int): number of features
|
164 |
+
"""
|
165 |
+
super().__init__()
|
166 |
+
|
167 |
+
self.conv1 = nn.Conv2d(
|
168 |
+
features, features, kernel_size=3, stride=1, padding=1, bias=True
|
169 |
+
)
|
170 |
+
|
171 |
+
self.conv2 = nn.Conv2d(
|
172 |
+
features, features, kernel_size=3, stride=1, padding=1, bias=True
|
173 |
+
)
|
174 |
+
|
175 |
+
self.relu = nn.ReLU(inplace=True)
|
176 |
+
|
177 |
+
def forward(self, x):
|
178 |
+
"""Forward pass.
|
179 |
+
|
180 |
+
Args:
|
181 |
+
x (tensor): input
|
182 |
+
|
183 |
+
Returns:
|
184 |
+
tensor: output
|
185 |
+
"""
|
186 |
+
out = self.relu(x)
|
187 |
+
out = self.conv1(out)
|
188 |
+
out = self.relu(out)
|
189 |
+
out = self.conv2(out)
|
190 |
+
|
191 |
+
return out + x
|
192 |
+
|
193 |
+
|
194 |
+
class FeatureFusionBlock(nn.Module):
|
195 |
+
"""Feature fusion block.
|
196 |
+
"""
|
197 |
+
|
198 |
+
def __init__(self, features):
|
199 |
+
"""Init.
|
200 |
+
|
201 |
+
Args:
|
202 |
+
features (int): number of features
|
203 |
+
"""
|
204 |
+
super(FeatureFusionBlock, self).__init__()
|
205 |
+
|
206 |
+
self.resConfUnit1 = ResidualConvUnit(features)
|
207 |
+
self.resConfUnit2 = ResidualConvUnit(features)
|
208 |
+
|
209 |
+
def forward(self, *xs):
|
210 |
+
"""Forward pass.
|
211 |
+
|
212 |
+
Returns:
|
213 |
+
tensor: output
|
214 |
+
"""
|
215 |
+
output = xs[0]
|
216 |
+
|
217 |
+
if len(xs) == 2:
|
218 |
+
output += self.resConfUnit1(xs[1])
|
219 |
+
|
220 |
+
output = self.resConfUnit2(output)
|
221 |
+
|
222 |
+
output = nn.functional.interpolate(
|
223 |
+
output, scale_factor=2, mode="bilinear", align_corners=True
|
224 |
+
)
|
225 |
+
|
226 |
+
return output
|
227 |
+
|
228 |
+
|
229 |
+
|
230 |
+
|
231 |
+
class ResidualConvUnit_custom(nn.Module):
|
232 |
+
"""Residual convolution module.
|
233 |
+
"""
|
234 |
+
|
235 |
+
def __init__(self, features, activation, bn):
|
236 |
+
"""Init.
|
237 |
+
|
238 |
+
Args:
|
239 |
+
features (int): number of features
|
240 |
+
"""
|
241 |
+
super().__init__()
|
242 |
+
|
243 |
+
self.bn = bn
|
244 |
+
|
245 |
+
self.groups=1
|
246 |
+
|
247 |
+
self.conv1 = nn.Conv2d(
|
248 |
+
features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups
|
249 |
+
)
|
250 |
+
|
251 |
+
self.conv2 = nn.Conv2d(
|
252 |
+
features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups
|
253 |
+
)
|
254 |
+
|
255 |
+
if self.bn==True:
|
256 |
+
self.bn1 = nn.BatchNorm2d(features)
|
257 |
+
self.bn2 = nn.BatchNorm2d(features)
|
258 |
+
|
259 |
+
self.activation = activation
|
260 |
+
|
261 |
+
self.skip_add = nn.quantized.FloatFunctional()
|
262 |
+
|
263 |
+
def forward(self, x):
|
264 |
+
"""Forward pass.
|
265 |
+
|
266 |
+
Args:
|
267 |
+
x (tensor): input
|
268 |
+
|
269 |
+
Returns:
|
270 |
+
tensor: output
|
271 |
+
"""
|
272 |
+
|
273 |
+
out = self.activation(x)
|
274 |
+
out = self.conv1(out)
|
275 |
+
if self.bn==True:
|
276 |
+
out = self.bn1(out)
|
277 |
+
|
278 |
+
out = self.activation(out)
|
279 |
+
out = self.conv2(out)
|
280 |
+
if self.bn==True:
|
281 |
+
out = self.bn2(out)
|
282 |
+
|
283 |
+
if self.groups > 1:
|
284 |
+
out = self.conv_merge(out)
|
285 |
+
|
286 |
+
return self.skip_add.add(out, x)
|
287 |
+
|
288 |
+
# return out + x
|
289 |
+
|
290 |
+
|
291 |
+
class FeatureFusionBlock_custom(nn.Module):
|
292 |
+
"""Feature fusion block.
|
293 |
+
"""
|
294 |
+
|
295 |
+
def __init__(self, features, activation, deconv=False, bn=False, expand=False, align_corners=True):
|
296 |
+
"""Init.
|
297 |
+
|
298 |
+
Args:
|
299 |
+
features (int): number of features
|
300 |
+
"""
|
301 |
+
super(FeatureFusionBlock_custom, self).__init__()
|
302 |
+
|
303 |
+
self.deconv = deconv
|
304 |
+
self.align_corners = align_corners
|
305 |
+
|
306 |
+
self.groups=1
|
307 |
+
|
308 |
+
self.expand = expand
|
309 |
+
out_features = features
|
310 |
+
if self.expand==True:
|
311 |
+
out_features = features//2
|
312 |
+
|
313 |
+
self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1)
|
314 |
+
|
315 |
+
self.resConfUnit1 = ResidualConvUnit_custom(features, activation, bn)
|
316 |
+
self.resConfUnit2 = ResidualConvUnit_custom(features, activation, bn)
|
317 |
+
|
318 |
+
self.skip_add = nn.quantized.FloatFunctional()
|
319 |
+
|
320 |
+
def forward(self, *xs):
|
321 |
+
"""Forward pass.
|
322 |
+
|
323 |
+
Returns:
|
324 |
+
tensor: output
|
325 |
+
"""
|
326 |
+
output = xs[0]
|
327 |
+
|
328 |
+
if len(xs) == 2:
|
329 |
+
res = self.resConfUnit1(xs[1])
|
330 |
+
output = self.skip_add.add(output, res)
|
331 |
+
# output += res
|
332 |
+
|
333 |
+
output = self.resConfUnit2(output)
|
334 |
+
|
335 |
+
output = nn.functional.interpolate(
|
336 |
+
output, scale_factor=2, mode="bilinear", align_corners=self.align_corners
|
337 |
+
)
|
338 |
+
|
339 |
+
output = self.out_conv(output)
|
340 |
+
|
341 |
+
return output
|
342 |
+
|
repositories/stable-diffusion-stability-ai/ldm/modules/midas/midas/dpt_depth.py
ADDED
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
from .base_model import BaseModel
|
6 |
+
from .blocks import (
|
7 |
+
FeatureFusionBlock,
|
8 |
+
FeatureFusionBlock_custom,
|
9 |
+
Interpolate,
|
10 |
+
_make_encoder,
|
11 |
+
forward_vit,
|
12 |
+
)
|
13 |
+
|
14 |
+
|
15 |
+
def _make_fusion_block(features, use_bn):
|
16 |
+
return FeatureFusionBlock_custom(
|
17 |
+
features,
|
18 |
+
nn.ReLU(False),
|
19 |
+
deconv=False,
|
20 |
+
bn=use_bn,
|
21 |
+
expand=False,
|
22 |
+
align_corners=True,
|
23 |
+
)
|
24 |
+
|
25 |
+
|
26 |
+
class DPT(BaseModel):
|
27 |
+
def __init__(
|
28 |
+
self,
|
29 |
+
head,
|
30 |
+
features=256,
|
31 |
+
backbone="vitb_rn50_384",
|
32 |
+
readout="project",
|
33 |
+
channels_last=False,
|
34 |
+
use_bn=False,
|
35 |
+
):
|
36 |
+
|
37 |
+
super(DPT, self).__init__()
|
38 |
+
|
39 |
+
self.channels_last = channels_last
|
40 |
+
|
41 |
+
hooks = {
|
42 |
+
"vitb_rn50_384": [0, 1, 8, 11],
|
43 |
+
"vitb16_384": [2, 5, 8, 11],
|
44 |
+
"vitl16_384": [5, 11, 17, 23],
|
45 |
+
}
|
46 |
+
|
47 |
+
# Instantiate backbone and reassemble blocks
|
48 |
+
self.pretrained, self.scratch = _make_encoder(
|
49 |
+
backbone,
|
50 |
+
features,
|
51 |
+
False, # Set to true of you want to train from scratch, uses ImageNet weights
|
52 |
+
groups=1,
|
53 |
+
expand=False,
|
54 |
+
exportable=False,
|
55 |
+
hooks=hooks[backbone],
|
56 |
+
use_readout=readout,
|
57 |
+
)
|
58 |
+
|
59 |
+
self.scratch.refinenet1 = _make_fusion_block(features, use_bn)
|
60 |
+
self.scratch.refinenet2 = _make_fusion_block(features, use_bn)
|
61 |
+
self.scratch.refinenet3 = _make_fusion_block(features, use_bn)
|
62 |
+
self.scratch.refinenet4 = _make_fusion_block(features, use_bn)
|
63 |
+
|
64 |
+
self.scratch.output_conv = head
|
65 |
+
|
66 |
+
|
67 |
+
def forward(self, x):
|
68 |
+
if self.channels_last == True:
|
69 |
+
x.contiguous(memory_format=torch.channels_last)
|
70 |
+
|
71 |
+
layer_1, layer_2, layer_3, layer_4 = forward_vit(self.pretrained, x)
|
72 |
+
|
73 |
+
layer_1_rn = self.scratch.layer1_rn(layer_1)
|
74 |
+
layer_2_rn = self.scratch.layer2_rn(layer_2)
|
75 |
+
layer_3_rn = self.scratch.layer3_rn(layer_3)
|
76 |
+
layer_4_rn = self.scratch.layer4_rn(layer_4)
|
77 |
+
|
78 |
+
path_4 = self.scratch.refinenet4(layer_4_rn)
|
79 |
+
path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
|
80 |
+
path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
|
81 |
+
path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
|
82 |
+
|
83 |
+
out = self.scratch.output_conv(path_1)
|
84 |
+
|
85 |
+
return out
|
86 |
+
|
87 |
+
|
88 |
+
class DPTDepthModel(DPT):
|
89 |
+
def __init__(self, path=None, non_negative=True, **kwargs):
|
90 |
+
features = kwargs["features"] if "features" in kwargs else 256
|
91 |
+
|
92 |
+
head = nn.Sequential(
|
93 |
+
nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1),
|
94 |
+
Interpolate(scale_factor=2, mode="bilinear", align_corners=True),
|
95 |
+
nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1),
|
96 |
+
nn.ReLU(True),
|
97 |
+
nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
|
98 |
+
nn.ReLU(True) if non_negative else nn.Identity(),
|
99 |
+
nn.Identity(),
|
100 |
+
)
|
101 |
+
|
102 |
+
super().__init__(head, **kwargs)
|
103 |
+
|
104 |
+
if path is not None:
|
105 |
+
self.load(path)
|
106 |
+
|
107 |
+
def forward(self, x):
|
108 |
+
return super().forward(x).squeeze(dim=1)
|
109 |
+
|
repositories/stable-diffusion-stability-ai/ldm/modules/midas/midas/midas_net.py
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""MidashNet: Network for monocular depth estimation trained by mixing several datasets.
|
2 |
+
This file contains code that is adapted from
|
3 |
+
https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py
|
4 |
+
"""
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
|
8 |
+
from .base_model import BaseModel
|
9 |
+
from .blocks import FeatureFusionBlock, Interpolate, _make_encoder
|
10 |
+
|
11 |
+
|
12 |
+
class MidasNet(BaseModel):
|
13 |
+
"""Network for monocular depth estimation.
|
14 |
+
"""
|
15 |
+
|
16 |
+
def __init__(self, path=None, features=256, non_negative=True):
|
17 |
+
"""Init.
|
18 |
+
|
19 |
+
Args:
|
20 |
+
path (str, optional): Path to saved model. Defaults to None.
|
21 |
+
features (int, optional): Number of features. Defaults to 256.
|
22 |
+
backbone (str, optional): Backbone network for encoder. Defaults to resnet50
|
23 |
+
"""
|
24 |
+
print("Loading weights: ", path)
|
25 |
+
|
26 |
+
super(MidasNet, self).__init__()
|
27 |
+
|
28 |
+
use_pretrained = False if path is None else True
|
29 |
+
|
30 |
+
self.pretrained, self.scratch = _make_encoder(backbone="resnext101_wsl", features=features, use_pretrained=use_pretrained)
|
31 |
+
|
32 |
+
self.scratch.refinenet4 = FeatureFusionBlock(features)
|
33 |
+
self.scratch.refinenet3 = FeatureFusionBlock(features)
|
34 |
+
self.scratch.refinenet2 = FeatureFusionBlock(features)
|
35 |
+
self.scratch.refinenet1 = FeatureFusionBlock(features)
|
36 |
+
|
37 |
+
self.scratch.output_conv = nn.Sequential(
|
38 |
+
nn.Conv2d(features, 128, kernel_size=3, stride=1, padding=1),
|
39 |
+
Interpolate(scale_factor=2, mode="bilinear"),
|
40 |
+
nn.Conv2d(128, 32, kernel_size=3, stride=1, padding=1),
|
41 |
+
nn.ReLU(True),
|
42 |
+
nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
|
43 |
+
nn.ReLU(True) if non_negative else nn.Identity(),
|
44 |
+
)
|
45 |
+
|
46 |
+
if path:
|
47 |
+
self.load(path)
|
48 |
+
|
49 |
+
def forward(self, x):
|
50 |
+
"""Forward pass.
|
51 |
+
|
52 |
+
Args:
|
53 |
+
x (tensor): input data (image)
|
54 |
+
|
55 |
+
Returns:
|
56 |
+
tensor: depth
|
57 |
+
"""
|
58 |
+
|
59 |
+
layer_1 = self.pretrained.layer1(x)
|
60 |
+
layer_2 = self.pretrained.layer2(layer_1)
|
61 |
+
layer_3 = self.pretrained.layer3(layer_2)
|
62 |
+
layer_4 = self.pretrained.layer4(layer_3)
|
63 |
+
|
64 |
+
layer_1_rn = self.scratch.layer1_rn(layer_1)
|
65 |
+
layer_2_rn = self.scratch.layer2_rn(layer_2)
|
66 |
+
layer_3_rn = self.scratch.layer3_rn(layer_3)
|
67 |
+
layer_4_rn = self.scratch.layer4_rn(layer_4)
|
68 |
+
|
69 |
+
path_4 = self.scratch.refinenet4(layer_4_rn)
|
70 |
+
path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
|
71 |
+
path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
|
72 |
+
path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
|
73 |
+
|
74 |
+
out = self.scratch.output_conv(path_1)
|
75 |
+
|
76 |
+
return torch.squeeze(out, dim=1)
|
repositories/stable-diffusion-stability-ai/ldm/modules/midas/midas/midas_net_custom.py
ADDED
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""MidashNet: Network for monocular depth estimation trained by mixing several datasets.
|
2 |
+
This file contains code that is adapted from
|
3 |
+
https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py
|
4 |
+
"""
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
|
8 |
+
from .base_model import BaseModel
|
9 |
+
from .blocks import FeatureFusionBlock, FeatureFusionBlock_custom, Interpolate, _make_encoder
|
10 |
+
|
11 |
+
|
12 |
+
class MidasNet_small(BaseModel):
|
13 |
+
"""Network for monocular depth estimation.
|
14 |
+
"""
|
15 |
+
|
16 |
+
def __init__(self, path=None, features=64, backbone="efficientnet_lite3", non_negative=True, exportable=True, channels_last=False, align_corners=True,
|
17 |
+
blocks={'expand': True}):
|
18 |
+
"""Init.
|
19 |
+
|
20 |
+
Args:
|
21 |
+
path (str, optional): Path to saved model. Defaults to None.
|
22 |
+
features (int, optional): Number of features. Defaults to 256.
|
23 |
+
backbone (str, optional): Backbone network for encoder. Defaults to resnet50
|
24 |
+
"""
|
25 |
+
print("Loading weights: ", path)
|
26 |
+
|
27 |
+
super(MidasNet_small, self).__init__()
|
28 |
+
|
29 |
+
use_pretrained = False if path else True
|
30 |
+
|
31 |
+
self.channels_last = channels_last
|
32 |
+
self.blocks = blocks
|
33 |
+
self.backbone = backbone
|
34 |
+
|
35 |
+
self.groups = 1
|
36 |
+
|
37 |
+
features1=features
|
38 |
+
features2=features
|
39 |
+
features3=features
|
40 |
+
features4=features
|
41 |
+
self.expand = False
|
42 |
+
if "expand" in self.blocks and self.blocks['expand'] == True:
|
43 |
+
self.expand = True
|
44 |
+
features1=features
|
45 |
+
features2=features*2
|
46 |
+
features3=features*4
|
47 |
+
features4=features*8
|
48 |
+
|
49 |
+
self.pretrained, self.scratch = _make_encoder(self.backbone, features, use_pretrained, groups=self.groups, expand=self.expand, exportable=exportable)
|
50 |
+
|
51 |
+
self.scratch.activation = nn.ReLU(False)
|
52 |
+
|
53 |
+
self.scratch.refinenet4 = FeatureFusionBlock_custom(features4, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners)
|
54 |
+
self.scratch.refinenet3 = FeatureFusionBlock_custom(features3, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners)
|
55 |
+
self.scratch.refinenet2 = FeatureFusionBlock_custom(features2, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners)
|
56 |
+
self.scratch.refinenet1 = FeatureFusionBlock_custom(features1, self.scratch.activation, deconv=False, bn=False, align_corners=align_corners)
|
57 |
+
|
58 |
+
|
59 |
+
self.scratch.output_conv = nn.Sequential(
|
60 |
+
nn.Conv2d(features, features//2, kernel_size=3, stride=1, padding=1, groups=self.groups),
|
61 |
+
Interpolate(scale_factor=2, mode="bilinear"),
|
62 |
+
nn.Conv2d(features//2, 32, kernel_size=3, stride=1, padding=1),
|
63 |
+
self.scratch.activation,
|
64 |
+
nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
|
65 |
+
nn.ReLU(True) if non_negative else nn.Identity(),
|
66 |
+
nn.Identity(),
|
67 |
+
)
|
68 |
+
|
69 |
+
if path:
|
70 |
+
self.load(path)
|
71 |
+
|
72 |
+
|
73 |
+
def forward(self, x):
|
74 |
+
"""Forward pass.
|
75 |
+
|
76 |
+
Args:
|
77 |
+
x (tensor): input data (image)
|
78 |
+
|
79 |
+
Returns:
|
80 |
+
tensor: depth
|
81 |
+
"""
|
82 |
+
if self.channels_last==True:
|
83 |
+
print("self.channels_last = ", self.channels_last)
|
84 |
+
x.contiguous(memory_format=torch.channels_last)
|
85 |
+
|
86 |
+
|
87 |
+
layer_1 = self.pretrained.layer1(x)
|
88 |
+
layer_2 = self.pretrained.layer2(layer_1)
|
89 |
+
layer_3 = self.pretrained.layer3(layer_2)
|
90 |
+
layer_4 = self.pretrained.layer4(layer_3)
|
91 |
+
|
92 |
+
layer_1_rn = self.scratch.layer1_rn(layer_1)
|
93 |
+
layer_2_rn = self.scratch.layer2_rn(layer_2)
|
94 |
+
layer_3_rn = self.scratch.layer3_rn(layer_3)
|
95 |
+
layer_4_rn = self.scratch.layer4_rn(layer_4)
|
96 |
+
|
97 |
+
|
98 |
+
path_4 = self.scratch.refinenet4(layer_4_rn)
|
99 |
+
path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
|
100 |
+
path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
|
101 |
+
path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
|
102 |
+
|
103 |
+
out = self.scratch.output_conv(path_1)
|
104 |
+
|
105 |
+
return torch.squeeze(out, dim=1)
|
106 |
+
|
107 |
+
|
108 |
+
|
109 |
+
def fuse_model(m):
|
110 |
+
prev_previous_type = nn.Identity()
|
111 |
+
prev_previous_name = ''
|
112 |
+
previous_type = nn.Identity()
|
113 |
+
previous_name = ''
|
114 |
+
for name, module in m.named_modules():
|
115 |
+
if prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d and type(module) == nn.ReLU:
|
116 |
+
# print("FUSED ", prev_previous_name, previous_name, name)
|
117 |
+
torch.quantization.fuse_modules(m, [prev_previous_name, previous_name, name], inplace=True)
|
118 |
+
elif prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d:
|
119 |
+
# print("FUSED ", prev_previous_name, previous_name)
|
120 |
+
torch.quantization.fuse_modules(m, [prev_previous_name, previous_name], inplace=True)
|
121 |
+
# elif previous_type == nn.Conv2d and type(module) == nn.ReLU:
|
122 |
+
# print("FUSED ", previous_name, name)
|
123 |
+
# torch.quantization.fuse_modules(m, [previous_name, name], inplace=True)
|
124 |
+
|
125 |
+
prev_previous_type = previous_type
|
126 |
+
prev_previous_name = previous_name
|
127 |
+
previous_type = type(module)
|
128 |
+
previous_name = name
|
repositories/stable-diffusion-stability-ai/ldm/modules/midas/midas/transforms.py
ADDED
@@ -0,0 +1,234 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import cv2
|
3 |
+
import math
|
4 |
+
|
5 |
+
|
6 |
+
def apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AREA):
|
7 |
+
"""Rezise the sample to ensure the given size. Keeps aspect ratio.
|
8 |
+
|
9 |
+
Args:
|
10 |
+
sample (dict): sample
|
11 |
+
size (tuple): image size
|
12 |
+
|
13 |
+
Returns:
|
14 |
+
tuple: new size
|
15 |
+
"""
|
16 |
+
shape = list(sample["disparity"].shape)
|
17 |
+
|
18 |
+
if shape[0] >= size[0] and shape[1] >= size[1]:
|
19 |
+
return sample
|
20 |
+
|
21 |
+
scale = [0, 0]
|
22 |
+
scale[0] = size[0] / shape[0]
|
23 |
+
scale[1] = size[1] / shape[1]
|
24 |
+
|
25 |
+
scale = max(scale)
|
26 |
+
|
27 |
+
shape[0] = math.ceil(scale * shape[0])
|
28 |
+
shape[1] = math.ceil(scale * shape[1])
|
29 |
+
|
30 |
+
# resize
|
31 |
+
sample["image"] = cv2.resize(
|
32 |
+
sample["image"], tuple(shape[::-1]), interpolation=image_interpolation_method
|
33 |
+
)
|
34 |
+
|
35 |
+
sample["disparity"] = cv2.resize(
|
36 |
+
sample["disparity"], tuple(shape[::-1]), interpolation=cv2.INTER_NEAREST
|
37 |
+
)
|
38 |
+
sample["mask"] = cv2.resize(
|
39 |
+
sample["mask"].astype(np.float32),
|
40 |
+
tuple(shape[::-1]),
|
41 |
+
interpolation=cv2.INTER_NEAREST,
|
42 |
+
)
|
43 |
+
sample["mask"] = sample["mask"].astype(bool)
|
44 |
+
|
45 |
+
return tuple(shape)
|
46 |
+
|
47 |
+
|
48 |
+
class Resize(object):
|
49 |
+
"""Resize sample to given size (width, height).
|
50 |
+
"""
|
51 |
+
|
52 |
+
def __init__(
|
53 |
+
self,
|
54 |
+
width,
|
55 |
+
height,
|
56 |
+
resize_target=True,
|
57 |
+
keep_aspect_ratio=False,
|
58 |
+
ensure_multiple_of=1,
|
59 |
+
resize_method="lower_bound",
|
60 |
+
image_interpolation_method=cv2.INTER_AREA,
|
61 |
+
):
|
62 |
+
"""Init.
|
63 |
+
|
64 |
+
Args:
|
65 |
+
width (int): desired output width
|
66 |
+
height (int): desired output height
|
67 |
+
resize_target (bool, optional):
|
68 |
+
True: Resize the full sample (image, mask, target).
|
69 |
+
False: Resize image only.
|
70 |
+
Defaults to True.
|
71 |
+
keep_aspect_ratio (bool, optional):
|
72 |
+
True: Keep the aspect ratio of the input sample.
|
73 |
+
Output sample might not have the given width and height, and
|
74 |
+
resize behaviour depends on the parameter 'resize_method'.
|
75 |
+
Defaults to False.
|
76 |
+
ensure_multiple_of (int, optional):
|
77 |
+
Output width and height is constrained to be multiple of this parameter.
|
78 |
+
Defaults to 1.
|
79 |
+
resize_method (str, optional):
|
80 |
+
"lower_bound": Output will be at least as large as the given size.
|
81 |
+
"upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.)
|
82 |
+
"minimal": Scale as least as possible. (Output size might be smaller than given size.)
|
83 |
+
Defaults to "lower_bound".
|
84 |
+
"""
|
85 |
+
self.__width = width
|
86 |
+
self.__height = height
|
87 |
+
|
88 |
+
self.__resize_target = resize_target
|
89 |
+
self.__keep_aspect_ratio = keep_aspect_ratio
|
90 |
+
self.__multiple_of = ensure_multiple_of
|
91 |
+
self.__resize_method = resize_method
|
92 |
+
self.__image_interpolation_method = image_interpolation_method
|
93 |
+
|
94 |
+
def constrain_to_multiple_of(self, x, min_val=0, max_val=None):
|
95 |
+
y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int)
|
96 |
+
|
97 |
+
if max_val is not None and y > max_val:
|
98 |
+
y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int)
|
99 |
+
|
100 |
+
if y < min_val:
|
101 |
+
y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int)
|
102 |
+
|
103 |
+
return y
|
104 |
+
|
105 |
+
def get_size(self, width, height):
|
106 |
+
# determine new height and width
|
107 |
+
scale_height = self.__height / height
|
108 |
+
scale_width = self.__width / width
|
109 |
+
|
110 |
+
if self.__keep_aspect_ratio:
|
111 |
+
if self.__resize_method == "lower_bound":
|
112 |
+
# scale such that output size is lower bound
|
113 |
+
if scale_width > scale_height:
|
114 |
+
# fit width
|
115 |
+
scale_height = scale_width
|
116 |
+
else:
|
117 |
+
# fit height
|
118 |
+
scale_width = scale_height
|
119 |
+
elif self.__resize_method == "upper_bound":
|
120 |
+
# scale such that output size is upper bound
|
121 |
+
if scale_width < scale_height:
|
122 |
+
# fit width
|
123 |
+
scale_height = scale_width
|
124 |
+
else:
|
125 |
+
# fit height
|
126 |
+
scale_width = scale_height
|
127 |
+
elif self.__resize_method == "minimal":
|
128 |
+
# scale as least as possbile
|
129 |
+
if abs(1 - scale_width) < abs(1 - scale_height):
|
130 |
+
# fit width
|
131 |
+
scale_height = scale_width
|
132 |
+
else:
|
133 |
+
# fit height
|
134 |
+
scale_width = scale_height
|
135 |
+
else:
|
136 |
+
raise ValueError(
|
137 |
+
f"resize_method {self.__resize_method} not implemented"
|
138 |
+
)
|
139 |
+
|
140 |
+
if self.__resize_method == "lower_bound":
|
141 |
+
new_height = self.constrain_to_multiple_of(
|
142 |
+
scale_height * height, min_val=self.__height
|
143 |
+
)
|
144 |
+
new_width = self.constrain_to_multiple_of(
|
145 |
+
scale_width * width, min_val=self.__width
|
146 |
+
)
|
147 |
+
elif self.__resize_method == "upper_bound":
|
148 |
+
new_height = self.constrain_to_multiple_of(
|
149 |
+
scale_height * height, max_val=self.__height
|
150 |
+
)
|
151 |
+
new_width = self.constrain_to_multiple_of(
|
152 |
+
scale_width * width, max_val=self.__width
|
153 |
+
)
|
154 |
+
elif self.__resize_method == "minimal":
|
155 |
+
new_height = self.constrain_to_multiple_of(scale_height * height)
|
156 |
+
new_width = self.constrain_to_multiple_of(scale_width * width)
|
157 |
+
else:
|
158 |
+
raise ValueError(f"resize_method {self.__resize_method} not implemented")
|
159 |
+
|
160 |
+
return (new_width, new_height)
|
161 |
+
|
162 |
+
def __call__(self, sample):
|
163 |
+
width, height = self.get_size(
|
164 |
+
sample["image"].shape[1], sample["image"].shape[0]
|
165 |
+
)
|
166 |
+
|
167 |
+
# resize sample
|
168 |
+
sample["image"] = cv2.resize(
|
169 |
+
sample["image"],
|
170 |
+
(width, height),
|
171 |
+
interpolation=self.__image_interpolation_method,
|
172 |
+
)
|
173 |
+
|
174 |
+
if self.__resize_target:
|
175 |
+
if "disparity" in sample:
|
176 |
+
sample["disparity"] = cv2.resize(
|
177 |
+
sample["disparity"],
|
178 |
+
(width, height),
|
179 |
+
interpolation=cv2.INTER_NEAREST,
|
180 |
+
)
|
181 |
+
|
182 |
+
if "depth" in sample:
|
183 |
+
sample["depth"] = cv2.resize(
|
184 |
+
sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST
|
185 |
+
)
|
186 |
+
|
187 |
+
sample["mask"] = cv2.resize(
|
188 |
+
sample["mask"].astype(np.float32),
|
189 |
+
(width, height),
|
190 |
+
interpolation=cv2.INTER_NEAREST,
|
191 |
+
)
|
192 |
+
sample["mask"] = sample["mask"].astype(bool)
|
193 |
+
|
194 |
+
return sample
|
195 |
+
|
196 |
+
|
197 |
+
class NormalizeImage(object):
|
198 |
+
"""Normlize image by given mean and std.
|
199 |
+
"""
|
200 |
+
|
201 |
+
def __init__(self, mean, std):
|
202 |
+
self.__mean = mean
|
203 |
+
self.__std = std
|
204 |
+
|
205 |
+
def __call__(self, sample):
|
206 |
+
sample["image"] = (sample["image"] - self.__mean) / self.__std
|
207 |
+
|
208 |
+
return sample
|
209 |
+
|
210 |
+
|
211 |
+
class PrepareForNet(object):
|
212 |
+
"""Prepare sample for usage as network input.
|
213 |
+
"""
|
214 |
+
|
215 |
+
def __init__(self):
|
216 |
+
pass
|
217 |
+
|
218 |
+
def __call__(self, sample):
|
219 |
+
image = np.transpose(sample["image"], (2, 0, 1))
|
220 |
+
sample["image"] = np.ascontiguousarray(image).astype(np.float32)
|
221 |
+
|
222 |
+
if "mask" in sample:
|
223 |
+
sample["mask"] = sample["mask"].astype(np.float32)
|
224 |
+
sample["mask"] = np.ascontiguousarray(sample["mask"])
|
225 |
+
|
226 |
+
if "disparity" in sample:
|
227 |
+
disparity = sample["disparity"].astype(np.float32)
|
228 |
+
sample["disparity"] = np.ascontiguousarray(disparity)
|
229 |
+
|
230 |
+
if "depth" in sample:
|
231 |
+
depth = sample["depth"].astype(np.float32)
|
232 |
+
sample["depth"] = np.ascontiguousarray(depth)
|
233 |
+
|
234 |
+
return sample
|
repositories/stable-diffusion-stability-ai/ldm/modules/midas/midas/vit.py
ADDED
@@ -0,0 +1,491 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import timm
|
4 |
+
import types
|
5 |
+
import math
|
6 |
+
import torch.nn.functional as F
|
7 |
+
|
8 |
+
|
9 |
+
class Slice(nn.Module):
|
10 |
+
def __init__(self, start_index=1):
|
11 |
+
super(Slice, self).__init__()
|
12 |
+
self.start_index = start_index
|
13 |
+
|
14 |
+
def forward(self, x):
|
15 |
+
return x[:, self.start_index :]
|
16 |
+
|
17 |
+
|
18 |
+
class AddReadout(nn.Module):
|
19 |
+
def __init__(self, start_index=1):
|
20 |
+
super(AddReadout, self).__init__()
|
21 |
+
self.start_index = start_index
|
22 |
+
|
23 |
+
def forward(self, x):
|
24 |
+
if self.start_index == 2:
|
25 |
+
readout = (x[:, 0] + x[:, 1]) / 2
|
26 |
+
else:
|
27 |
+
readout = x[:, 0]
|
28 |
+
return x[:, self.start_index :] + readout.unsqueeze(1)
|
29 |
+
|
30 |
+
|
31 |
+
class ProjectReadout(nn.Module):
|
32 |
+
def __init__(self, in_features, start_index=1):
|
33 |
+
super(ProjectReadout, self).__init__()
|
34 |
+
self.start_index = start_index
|
35 |
+
|
36 |
+
self.project = nn.Sequential(nn.Linear(2 * in_features, in_features), nn.GELU())
|
37 |
+
|
38 |
+
def forward(self, x):
|
39 |
+
readout = x[:, 0].unsqueeze(1).expand_as(x[:, self.start_index :])
|
40 |
+
features = torch.cat((x[:, self.start_index :], readout), -1)
|
41 |
+
|
42 |
+
return self.project(features)
|
43 |
+
|
44 |
+
|
45 |
+
class Transpose(nn.Module):
|
46 |
+
def __init__(self, dim0, dim1):
|
47 |
+
super(Transpose, self).__init__()
|
48 |
+
self.dim0 = dim0
|
49 |
+
self.dim1 = dim1
|
50 |
+
|
51 |
+
def forward(self, x):
|
52 |
+
x = x.transpose(self.dim0, self.dim1)
|
53 |
+
return x
|
54 |
+
|
55 |
+
|
56 |
+
def forward_vit(pretrained, x):
|
57 |
+
b, c, h, w = x.shape
|
58 |
+
|
59 |
+
glob = pretrained.model.forward_flex(x)
|
60 |
+
|
61 |
+
layer_1 = pretrained.activations["1"]
|
62 |
+
layer_2 = pretrained.activations["2"]
|
63 |
+
layer_3 = pretrained.activations["3"]
|
64 |
+
layer_4 = pretrained.activations["4"]
|
65 |
+
|
66 |
+
layer_1 = pretrained.act_postprocess1[0:2](layer_1)
|
67 |
+
layer_2 = pretrained.act_postprocess2[0:2](layer_2)
|
68 |
+
layer_3 = pretrained.act_postprocess3[0:2](layer_3)
|
69 |
+
layer_4 = pretrained.act_postprocess4[0:2](layer_4)
|
70 |
+
|
71 |
+
unflatten = nn.Sequential(
|
72 |
+
nn.Unflatten(
|
73 |
+
2,
|
74 |
+
torch.Size(
|
75 |
+
[
|
76 |
+
h // pretrained.model.patch_size[1],
|
77 |
+
w // pretrained.model.patch_size[0],
|
78 |
+
]
|
79 |
+
),
|
80 |
+
)
|
81 |
+
)
|
82 |
+
|
83 |
+
if layer_1.ndim == 3:
|
84 |
+
layer_1 = unflatten(layer_1)
|
85 |
+
if layer_2.ndim == 3:
|
86 |
+
layer_2 = unflatten(layer_2)
|
87 |
+
if layer_3.ndim == 3:
|
88 |
+
layer_3 = unflatten(layer_3)
|
89 |
+
if layer_4.ndim == 3:
|
90 |
+
layer_4 = unflatten(layer_4)
|
91 |
+
|
92 |
+
layer_1 = pretrained.act_postprocess1[3 : len(pretrained.act_postprocess1)](layer_1)
|
93 |
+
layer_2 = pretrained.act_postprocess2[3 : len(pretrained.act_postprocess2)](layer_2)
|
94 |
+
layer_3 = pretrained.act_postprocess3[3 : len(pretrained.act_postprocess3)](layer_3)
|
95 |
+
layer_4 = pretrained.act_postprocess4[3 : len(pretrained.act_postprocess4)](layer_4)
|
96 |
+
|
97 |
+
return layer_1, layer_2, layer_3, layer_4
|
98 |
+
|
99 |
+
|
100 |
+
def _resize_pos_embed(self, posemb, gs_h, gs_w):
|
101 |
+
posemb_tok, posemb_grid = (
|
102 |
+
posemb[:, : self.start_index],
|
103 |
+
posemb[0, self.start_index :],
|
104 |
+
)
|
105 |
+
|
106 |
+
gs_old = int(math.sqrt(len(posemb_grid)))
|
107 |
+
|
108 |
+
posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2)
|
109 |
+
posemb_grid = F.interpolate(posemb_grid, size=(gs_h, gs_w), mode="bilinear")
|
110 |
+
posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_h * gs_w, -1)
|
111 |
+
|
112 |
+
posemb = torch.cat([posemb_tok, posemb_grid], dim=1)
|
113 |
+
|
114 |
+
return posemb
|
115 |
+
|
116 |
+
|
117 |
+
def forward_flex(self, x):
|
118 |
+
b, c, h, w = x.shape
|
119 |
+
|
120 |
+
pos_embed = self._resize_pos_embed(
|
121 |
+
self.pos_embed, h // self.patch_size[1], w // self.patch_size[0]
|
122 |
+
)
|
123 |
+
|
124 |
+
B = x.shape[0]
|
125 |
+
|
126 |
+
if hasattr(self.patch_embed, "backbone"):
|
127 |
+
x = self.patch_embed.backbone(x)
|
128 |
+
if isinstance(x, (list, tuple)):
|
129 |
+
x = x[-1] # last feature if backbone outputs list/tuple of features
|
130 |
+
|
131 |
+
x = self.patch_embed.proj(x).flatten(2).transpose(1, 2)
|
132 |
+
|
133 |
+
if getattr(self, "dist_token", None) is not None:
|
134 |
+
cls_tokens = self.cls_token.expand(
|
135 |
+
B, -1, -1
|
136 |
+
) # stole cls_tokens impl from Phil Wang, thanks
|
137 |
+
dist_token = self.dist_token.expand(B, -1, -1)
|
138 |
+
x = torch.cat((cls_tokens, dist_token, x), dim=1)
|
139 |
+
else:
|
140 |
+
cls_tokens = self.cls_token.expand(
|
141 |
+
B, -1, -1
|
142 |
+
) # stole cls_tokens impl from Phil Wang, thanks
|
143 |
+
x = torch.cat((cls_tokens, x), dim=1)
|
144 |
+
|
145 |
+
x = x + pos_embed
|
146 |
+
x = self.pos_drop(x)
|
147 |
+
|
148 |
+
for blk in self.blocks:
|
149 |
+
x = blk(x)
|
150 |
+
|
151 |
+
x = self.norm(x)
|
152 |
+
|
153 |
+
return x
|
154 |
+
|
155 |
+
|
156 |
+
activations = {}
|
157 |
+
|
158 |
+
|
159 |
+
def get_activation(name):
|
160 |
+
def hook(model, input, output):
|
161 |
+
activations[name] = output
|
162 |
+
|
163 |
+
return hook
|
164 |
+
|
165 |
+
|
166 |
+
def get_readout_oper(vit_features, features, use_readout, start_index=1):
|
167 |
+
if use_readout == "ignore":
|
168 |
+
readout_oper = [Slice(start_index)] * len(features)
|
169 |
+
elif use_readout == "add":
|
170 |
+
readout_oper = [AddReadout(start_index)] * len(features)
|
171 |
+
elif use_readout == "project":
|
172 |
+
readout_oper = [
|
173 |
+
ProjectReadout(vit_features, start_index) for out_feat in features
|
174 |
+
]
|
175 |
+
else:
|
176 |
+
assert (
|
177 |
+
False
|
178 |
+
), "wrong operation for readout token, use_readout can be 'ignore', 'add', or 'project'"
|
179 |
+
|
180 |
+
return readout_oper
|
181 |
+
|
182 |
+
|
183 |
+
def _make_vit_b16_backbone(
|
184 |
+
model,
|
185 |
+
features=[96, 192, 384, 768],
|
186 |
+
size=[384, 384],
|
187 |
+
hooks=[2, 5, 8, 11],
|
188 |
+
vit_features=768,
|
189 |
+
use_readout="ignore",
|
190 |
+
start_index=1,
|
191 |
+
):
|
192 |
+
pretrained = nn.Module()
|
193 |
+
|
194 |
+
pretrained.model = model
|
195 |
+
pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1"))
|
196 |
+
pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2"))
|
197 |
+
pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3"))
|
198 |
+
pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4"))
|
199 |
+
|
200 |
+
pretrained.activations = activations
|
201 |
+
|
202 |
+
readout_oper = get_readout_oper(vit_features, features, use_readout, start_index)
|
203 |
+
|
204 |
+
# 32, 48, 136, 384
|
205 |
+
pretrained.act_postprocess1 = nn.Sequential(
|
206 |
+
readout_oper[0],
|
207 |
+
Transpose(1, 2),
|
208 |
+
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
|
209 |
+
nn.Conv2d(
|
210 |
+
in_channels=vit_features,
|
211 |
+
out_channels=features[0],
|
212 |
+
kernel_size=1,
|
213 |
+
stride=1,
|
214 |
+
padding=0,
|
215 |
+
),
|
216 |
+
nn.ConvTranspose2d(
|
217 |
+
in_channels=features[0],
|
218 |
+
out_channels=features[0],
|
219 |
+
kernel_size=4,
|
220 |
+
stride=4,
|
221 |
+
padding=0,
|
222 |
+
bias=True,
|
223 |
+
dilation=1,
|
224 |
+
groups=1,
|
225 |
+
),
|
226 |
+
)
|
227 |
+
|
228 |
+
pretrained.act_postprocess2 = nn.Sequential(
|
229 |
+
readout_oper[1],
|
230 |
+
Transpose(1, 2),
|
231 |
+
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
|
232 |
+
nn.Conv2d(
|
233 |
+
in_channels=vit_features,
|
234 |
+
out_channels=features[1],
|
235 |
+
kernel_size=1,
|
236 |
+
stride=1,
|
237 |
+
padding=0,
|
238 |
+
),
|
239 |
+
nn.ConvTranspose2d(
|
240 |
+
in_channels=features[1],
|
241 |
+
out_channels=features[1],
|
242 |
+
kernel_size=2,
|
243 |
+
stride=2,
|
244 |
+
padding=0,
|
245 |
+
bias=True,
|
246 |
+
dilation=1,
|
247 |
+
groups=1,
|
248 |
+
),
|
249 |
+
)
|
250 |
+
|
251 |
+
pretrained.act_postprocess3 = nn.Sequential(
|
252 |
+
readout_oper[2],
|
253 |
+
Transpose(1, 2),
|
254 |
+
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
|
255 |
+
nn.Conv2d(
|
256 |
+
in_channels=vit_features,
|
257 |
+
out_channels=features[2],
|
258 |
+
kernel_size=1,
|
259 |
+
stride=1,
|
260 |
+
padding=0,
|
261 |
+
),
|
262 |
+
)
|
263 |
+
|
264 |
+
pretrained.act_postprocess4 = nn.Sequential(
|
265 |
+
readout_oper[3],
|
266 |
+
Transpose(1, 2),
|
267 |
+
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
|
268 |
+
nn.Conv2d(
|
269 |
+
in_channels=vit_features,
|
270 |
+
out_channels=features[3],
|
271 |
+
kernel_size=1,
|
272 |
+
stride=1,
|
273 |
+
padding=0,
|
274 |
+
),
|
275 |
+
nn.Conv2d(
|
276 |
+
in_channels=features[3],
|
277 |
+
out_channels=features[3],
|
278 |
+
kernel_size=3,
|
279 |
+
stride=2,
|
280 |
+
padding=1,
|
281 |
+
),
|
282 |
+
)
|
283 |
+
|
284 |
+
pretrained.model.start_index = start_index
|
285 |
+
pretrained.model.patch_size = [16, 16]
|
286 |
+
|
287 |
+
# We inject this function into the VisionTransformer instances so that
|
288 |
+
# we can use it with interpolated position embeddings without modifying the library source.
|
289 |
+
pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model)
|
290 |
+
pretrained.model._resize_pos_embed = types.MethodType(
|
291 |
+
_resize_pos_embed, pretrained.model
|
292 |
+
)
|
293 |
+
|
294 |
+
return pretrained
|
295 |
+
|
296 |
+
|
297 |
+
def _make_pretrained_vitl16_384(pretrained, use_readout="ignore", hooks=None):
|
298 |
+
model = timm.create_model("vit_large_patch16_384", pretrained=pretrained)
|
299 |
+
|
300 |
+
hooks = [5, 11, 17, 23] if hooks == None else hooks
|
301 |
+
return _make_vit_b16_backbone(
|
302 |
+
model,
|
303 |
+
features=[256, 512, 1024, 1024],
|
304 |
+
hooks=hooks,
|
305 |
+
vit_features=1024,
|
306 |
+
use_readout=use_readout,
|
307 |
+
)
|
308 |
+
|
309 |
+
|
310 |
+
def _make_pretrained_vitb16_384(pretrained, use_readout="ignore", hooks=None):
|
311 |
+
model = timm.create_model("vit_base_patch16_384", pretrained=pretrained)
|
312 |
+
|
313 |
+
hooks = [2, 5, 8, 11] if hooks == None else hooks
|
314 |
+
return _make_vit_b16_backbone(
|
315 |
+
model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout
|
316 |
+
)
|
317 |
+
|
318 |
+
|
319 |
+
def _make_pretrained_deitb16_384(pretrained, use_readout="ignore", hooks=None):
|
320 |
+
model = timm.create_model("vit_deit_base_patch16_384", pretrained=pretrained)
|
321 |
+
|
322 |
+
hooks = [2, 5, 8, 11] if hooks == None else hooks
|
323 |
+
return _make_vit_b16_backbone(
|
324 |
+
model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout
|
325 |
+
)
|
326 |
+
|
327 |
+
|
328 |
+
def _make_pretrained_deitb16_distil_384(pretrained, use_readout="ignore", hooks=None):
|
329 |
+
model = timm.create_model(
|
330 |
+
"vit_deit_base_distilled_patch16_384", pretrained=pretrained
|
331 |
+
)
|
332 |
+
|
333 |
+
hooks = [2, 5, 8, 11] if hooks == None else hooks
|
334 |
+
return _make_vit_b16_backbone(
|
335 |
+
model,
|
336 |
+
features=[96, 192, 384, 768],
|
337 |
+
hooks=hooks,
|
338 |
+
use_readout=use_readout,
|
339 |
+
start_index=2,
|
340 |
+
)
|
341 |
+
|
342 |
+
|
343 |
+
def _make_vit_b_rn50_backbone(
|
344 |
+
model,
|
345 |
+
features=[256, 512, 768, 768],
|
346 |
+
size=[384, 384],
|
347 |
+
hooks=[0, 1, 8, 11],
|
348 |
+
vit_features=768,
|
349 |
+
use_vit_only=False,
|
350 |
+
use_readout="ignore",
|
351 |
+
start_index=1,
|
352 |
+
):
|
353 |
+
pretrained = nn.Module()
|
354 |
+
|
355 |
+
pretrained.model = model
|
356 |
+
|
357 |
+
if use_vit_only == True:
|
358 |
+
pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1"))
|
359 |
+
pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2"))
|
360 |
+
else:
|
361 |
+
pretrained.model.patch_embed.backbone.stages[0].register_forward_hook(
|
362 |
+
get_activation("1")
|
363 |
+
)
|
364 |
+
pretrained.model.patch_embed.backbone.stages[1].register_forward_hook(
|
365 |
+
get_activation("2")
|
366 |
+
)
|
367 |
+
|
368 |
+
pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3"))
|
369 |
+
pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4"))
|
370 |
+
|
371 |
+
pretrained.activations = activations
|
372 |
+
|
373 |
+
readout_oper = get_readout_oper(vit_features, features, use_readout, start_index)
|
374 |
+
|
375 |
+
if use_vit_only == True:
|
376 |
+
pretrained.act_postprocess1 = nn.Sequential(
|
377 |
+
readout_oper[0],
|
378 |
+
Transpose(1, 2),
|
379 |
+
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
|
380 |
+
nn.Conv2d(
|
381 |
+
in_channels=vit_features,
|
382 |
+
out_channels=features[0],
|
383 |
+
kernel_size=1,
|
384 |
+
stride=1,
|
385 |
+
padding=0,
|
386 |
+
),
|
387 |
+
nn.ConvTranspose2d(
|
388 |
+
in_channels=features[0],
|
389 |
+
out_channels=features[0],
|
390 |
+
kernel_size=4,
|
391 |
+
stride=4,
|
392 |
+
padding=0,
|
393 |
+
bias=True,
|
394 |
+
dilation=1,
|
395 |
+
groups=1,
|
396 |
+
),
|
397 |
+
)
|
398 |
+
|
399 |
+
pretrained.act_postprocess2 = nn.Sequential(
|
400 |
+
readout_oper[1],
|
401 |
+
Transpose(1, 2),
|
402 |
+
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
|
403 |
+
nn.Conv2d(
|
404 |
+
in_channels=vit_features,
|
405 |
+
out_channels=features[1],
|
406 |
+
kernel_size=1,
|
407 |
+
stride=1,
|
408 |
+
padding=0,
|
409 |
+
),
|
410 |
+
nn.ConvTranspose2d(
|
411 |
+
in_channels=features[1],
|
412 |
+
out_channels=features[1],
|
413 |
+
kernel_size=2,
|
414 |
+
stride=2,
|
415 |
+
padding=0,
|
416 |
+
bias=True,
|
417 |
+
dilation=1,
|
418 |
+
groups=1,
|
419 |
+
),
|
420 |
+
)
|
421 |
+
else:
|
422 |
+
pretrained.act_postprocess1 = nn.Sequential(
|
423 |
+
nn.Identity(), nn.Identity(), nn.Identity()
|
424 |
+
)
|
425 |
+
pretrained.act_postprocess2 = nn.Sequential(
|
426 |
+
nn.Identity(), nn.Identity(), nn.Identity()
|
427 |
+
)
|
428 |
+
|
429 |
+
pretrained.act_postprocess3 = nn.Sequential(
|
430 |
+
readout_oper[2],
|
431 |
+
Transpose(1, 2),
|
432 |
+
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
|
433 |
+
nn.Conv2d(
|
434 |
+
in_channels=vit_features,
|
435 |
+
out_channels=features[2],
|
436 |
+
kernel_size=1,
|
437 |
+
stride=1,
|
438 |
+
padding=0,
|
439 |
+
),
|
440 |
+
)
|
441 |
+
|
442 |
+
pretrained.act_postprocess4 = nn.Sequential(
|
443 |
+
readout_oper[3],
|
444 |
+
Transpose(1, 2),
|
445 |
+
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
|
446 |
+
nn.Conv2d(
|
447 |
+
in_channels=vit_features,
|
448 |
+
out_channels=features[3],
|
449 |
+
kernel_size=1,
|
450 |
+
stride=1,
|
451 |
+
padding=0,
|
452 |
+
),
|
453 |
+
nn.Conv2d(
|
454 |
+
in_channels=features[3],
|
455 |
+
out_channels=features[3],
|
456 |
+
kernel_size=3,
|
457 |
+
stride=2,
|
458 |
+
padding=1,
|
459 |
+
),
|
460 |
+
)
|
461 |
+
|
462 |
+
pretrained.model.start_index = start_index
|
463 |
+
pretrained.model.patch_size = [16, 16]
|
464 |
+
|
465 |
+
# We inject this function into the VisionTransformer instances so that
|
466 |
+
# we can use it with interpolated position embeddings without modifying the library source.
|
467 |
+
pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model)
|
468 |
+
|
469 |
+
# We inject this function into the VisionTransformer instances so that
|
470 |
+
# we can use it with interpolated position embeddings without modifying the library source.
|
471 |
+
pretrained.model._resize_pos_embed = types.MethodType(
|
472 |
+
_resize_pos_embed, pretrained.model
|
473 |
+
)
|
474 |
+
|
475 |
+
return pretrained
|
476 |
+
|
477 |
+
|
478 |
+
def _make_pretrained_vitb_rn50_384(
|
479 |
+
pretrained, use_readout="ignore", hooks=None, use_vit_only=False
|
480 |
+
):
|
481 |
+
model = timm.create_model("vit_base_resnet50_384", pretrained=pretrained)
|
482 |
+
|
483 |
+
hooks = [0, 1, 8, 11] if hooks == None else hooks
|
484 |
+
return _make_vit_b_rn50_backbone(
|
485 |
+
model,
|
486 |
+
features=[256, 512, 768, 768],
|
487 |
+
size=[384, 384],
|
488 |
+
hooks=hooks,
|
489 |
+
use_vit_only=use_vit_only,
|
490 |
+
use_readout=use_readout,
|
491 |
+
)
|
repositories/stable-diffusion-stability-ai/ldm/modules/midas/utils.py
ADDED
@@ -0,0 +1,189 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Utils for monoDepth."""
|
2 |
+
import sys
|
3 |
+
import re
|
4 |
+
import numpy as np
|
5 |
+
import cv2
|
6 |
+
import torch
|
7 |
+
|
8 |
+
|
9 |
+
def read_pfm(path):
|
10 |
+
"""Read pfm file.
|
11 |
+
|
12 |
+
Args:
|
13 |
+
path (str): path to file
|
14 |
+
|
15 |
+
Returns:
|
16 |
+
tuple: (data, scale)
|
17 |
+
"""
|
18 |
+
with open(path, "rb") as file:
|
19 |
+
|
20 |
+
color = None
|
21 |
+
width = None
|
22 |
+
height = None
|
23 |
+
scale = None
|
24 |
+
endian = None
|
25 |
+
|
26 |
+
header = file.readline().rstrip()
|
27 |
+
if header.decode("ascii") == "PF":
|
28 |
+
color = True
|
29 |
+
elif header.decode("ascii") == "Pf":
|
30 |
+
color = False
|
31 |
+
else:
|
32 |
+
raise Exception("Not a PFM file: " + path)
|
33 |
+
|
34 |
+
dim_match = re.match(r"^(\d+)\s(\d+)\s$", file.readline().decode("ascii"))
|
35 |
+
if dim_match:
|
36 |
+
width, height = list(map(int, dim_match.groups()))
|
37 |
+
else:
|
38 |
+
raise Exception("Malformed PFM header.")
|
39 |
+
|
40 |
+
scale = float(file.readline().decode("ascii").rstrip())
|
41 |
+
if scale < 0:
|
42 |
+
# little-endian
|
43 |
+
endian = "<"
|
44 |
+
scale = -scale
|
45 |
+
else:
|
46 |
+
# big-endian
|
47 |
+
endian = ">"
|
48 |
+
|
49 |
+
data = np.fromfile(file, endian + "f")
|
50 |
+
shape = (height, width, 3) if color else (height, width)
|
51 |
+
|
52 |
+
data = np.reshape(data, shape)
|
53 |
+
data = np.flipud(data)
|
54 |
+
|
55 |
+
return data, scale
|
56 |
+
|
57 |
+
|
58 |
+
def write_pfm(path, image, scale=1):
|
59 |
+
"""Write pfm file.
|
60 |
+
|
61 |
+
Args:
|
62 |
+
path (str): pathto file
|
63 |
+
image (array): data
|
64 |
+
scale (int, optional): Scale. Defaults to 1.
|
65 |
+
"""
|
66 |
+
|
67 |
+
with open(path, "wb") as file:
|
68 |
+
color = None
|
69 |
+
|
70 |
+
if image.dtype.name != "float32":
|
71 |
+
raise Exception("Image dtype must be float32.")
|
72 |
+
|
73 |
+
image = np.flipud(image)
|
74 |
+
|
75 |
+
if len(image.shape) == 3 and image.shape[2] == 3: # color image
|
76 |
+
color = True
|
77 |
+
elif (
|
78 |
+
len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1
|
79 |
+
): # greyscale
|
80 |
+
color = False
|
81 |
+
else:
|
82 |
+
raise Exception("Image must have H x W x 3, H x W x 1 or H x W dimensions.")
|
83 |
+
|
84 |
+
file.write("PF\n" if color else "Pf\n".encode())
|
85 |
+
file.write("%d %d\n".encode() % (image.shape[1], image.shape[0]))
|
86 |
+
|
87 |
+
endian = image.dtype.byteorder
|
88 |
+
|
89 |
+
if endian == "<" or endian == "=" and sys.byteorder == "little":
|
90 |
+
scale = -scale
|
91 |
+
|
92 |
+
file.write("%f\n".encode() % scale)
|
93 |
+
|
94 |
+
image.tofile(file)
|
95 |
+
|
96 |
+
|
97 |
+
def read_image(path):
|
98 |
+
"""Read image and output RGB image (0-1).
|
99 |
+
|
100 |
+
Args:
|
101 |
+
path (str): path to file
|
102 |
+
|
103 |
+
Returns:
|
104 |
+
array: RGB image (0-1)
|
105 |
+
"""
|
106 |
+
img = cv2.imread(path)
|
107 |
+
|
108 |
+
if img.ndim == 2:
|
109 |
+
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
|
110 |
+
|
111 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) / 255.0
|
112 |
+
|
113 |
+
return img
|
114 |
+
|
115 |
+
|
116 |
+
def resize_image(img):
|
117 |
+
"""Resize image and make it fit for network.
|
118 |
+
|
119 |
+
Args:
|
120 |
+
img (array): image
|
121 |
+
|
122 |
+
Returns:
|
123 |
+
tensor: data ready for network
|
124 |
+
"""
|
125 |
+
height_orig = img.shape[0]
|
126 |
+
width_orig = img.shape[1]
|
127 |
+
|
128 |
+
if width_orig > height_orig:
|
129 |
+
scale = width_orig / 384
|
130 |
+
else:
|
131 |
+
scale = height_orig / 384
|
132 |
+
|
133 |
+
height = (np.ceil(height_orig / scale / 32) * 32).astype(int)
|
134 |
+
width = (np.ceil(width_orig / scale / 32) * 32).astype(int)
|
135 |
+
|
136 |
+
img_resized = cv2.resize(img, (width, height), interpolation=cv2.INTER_AREA)
|
137 |
+
|
138 |
+
img_resized = (
|
139 |
+
torch.from_numpy(np.transpose(img_resized, (2, 0, 1))).contiguous().float()
|
140 |
+
)
|
141 |
+
img_resized = img_resized.unsqueeze(0)
|
142 |
+
|
143 |
+
return img_resized
|
144 |
+
|
145 |
+
|
146 |
+
def resize_depth(depth, width, height):
|
147 |
+
"""Resize depth map and bring to CPU (numpy).
|
148 |
+
|
149 |
+
Args:
|
150 |
+
depth (tensor): depth
|
151 |
+
width (int): image width
|
152 |
+
height (int): image height
|
153 |
+
|
154 |
+
Returns:
|
155 |
+
array: processed depth
|
156 |
+
"""
|
157 |
+
depth = torch.squeeze(depth[0, :, :, :]).to("cpu")
|
158 |
+
|
159 |
+
depth_resized = cv2.resize(
|
160 |
+
depth.numpy(), (width, height), interpolation=cv2.INTER_CUBIC
|
161 |
+
)
|
162 |
+
|
163 |
+
return depth_resized
|
164 |
+
|
165 |
+
def write_depth(path, depth, bits=1):
|
166 |
+
"""Write depth map to pfm and png file.
|
167 |
+
|
168 |
+
Args:
|
169 |
+
path (str): filepath without extension
|
170 |
+
depth (array): depth
|
171 |
+
"""
|
172 |
+
write_pfm(path + ".pfm", depth.astype(np.float32))
|
173 |
+
|
174 |
+
depth_min = depth.min()
|
175 |
+
depth_max = depth.max()
|
176 |
+
|
177 |
+
max_val = (2**(8*bits))-1
|
178 |
+
|
179 |
+
if depth_max - depth_min > np.finfo("float").eps:
|
180 |
+
out = max_val * (depth - depth_min) / (depth_max - depth_min)
|
181 |
+
else:
|
182 |
+
out = np.zeros(depth.shape, dtype=depth.type)
|
183 |
+
|
184 |
+
if bits == 1:
|
185 |
+
cv2.imwrite(path + ".png", out.astype("uint8"))
|
186 |
+
elif bits == 2:
|
187 |
+
cv2.imwrite(path + ".png", out.astype("uint16"))
|
188 |
+
|
189 |
+
return
|
repositories/stable-diffusion-stability-ai/ldm/util.py
ADDED
@@ -0,0 +1,207 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import importlib
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from torch import optim
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
from inspect import isfunction
|
8 |
+
from PIL import Image, ImageDraw, ImageFont
|
9 |
+
|
10 |
+
|
11 |
+
def autocast(f):
|
12 |
+
def do_autocast(*args, **kwargs):
|
13 |
+
with torch.cuda.amp.autocast(enabled=True,
|
14 |
+
dtype=torch.get_autocast_gpu_dtype(),
|
15 |
+
cache_enabled=torch.is_autocast_cache_enabled()):
|
16 |
+
return f(*args, **kwargs)
|
17 |
+
|
18 |
+
return do_autocast
|
19 |
+
|
20 |
+
|
21 |
+
def log_txt_as_img(wh, xc, size=10):
|
22 |
+
# wh a tuple of (width, height)
|
23 |
+
# xc a list of captions to plot
|
24 |
+
b = len(xc)
|
25 |
+
txts = list()
|
26 |
+
for bi in range(b):
|
27 |
+
txt = Image.new("RGB", wh, color="white")
|
28 |
+
draw = ImageDraw.Draw(txt)
|
29 |
+
font = ImageFont.truetype('data/DejaVuSans.ttf', size=size)
|
30 |
+
nc = int(40 * (wh[0] / 256))
|
31 |
+
lines = "\n".join(xc[bi][start:start + nc] for start in range(0, len(xc[bi]), nc))
|
32 |
+
|
33 |
+
try:
|
34 |
+
draw.text((0, 0), lines, fill="black", font=font)
|
35 |
+
except UnicodeEncodeError:
|
36 |
+
print("Cant encode string for logging. Skipping.")
|
37 |
+
|
38 |
+
txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0
|
39 |
+
txts.append(txt)
|
40 |
+
txts = np.stack(txts)
|
41 |
+
txts = torch.tensor(txts)
|
42 |
+
return txts
|
43 |
+
|
44 |
+
|
45 |
+
def ismap(x):
|
46 |
+
if not isinstance(x, torch.Tensor):
|
47 |
+
return False
|
48 |
+
return (len(x.shape) == 4) and (x.shape[1] > 3)
|
49 |
+
|
50 |
+
|
51 |
+
def isimage(x):
|
52 |
+
if not isinstance(x,torch.Tensor):
|
53 |
+
return False
|
54 |
+
return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1)
|
55 |
+
|
56 |
+
|
57 |
+
def exists(x):
|
58 |
+
return x is not None
|
59 |
+
|
60 |
+
|
61 |
+
def default(val, d):
|
62 |
+
if exists(val):
|
63 |
+
return val
|
64 |
+
return d() if isfunction(d) else d
|
65 |
+
|
66 |
+
|
67 |
+
def mean_flat(tensor):
|
68 |
+
"""
|
69 |
+
https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86
|
70 |
+
Take the mean over all non-batch dimensions.
|
71 |
+
"""
|
72 |
+
return tensor.mean(dim=list(range(1, len(tensor.shape))))
|
73 |
+
|
74 |
+
|
75 |
+
def count_params(model, verbose=False):
|
76 |
+
total_params = sum(p.numel() for p in model.parameters())
|
77 |
+
if verbose:
|
78 |
+
print(f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.")
|
79 |
+
return total_params
|
80 |
+
|
81 |
+
|
82 |
+
def instantiate_from_config(config):
|
83 |
+
if not "target" in config:
|
84 |
+
if config == '__is_first_stage__':
|
85 |
+
return None
|
86 |
+
elif config == "__is_unconditional__":
|
87 |
+
return None
|
88 |
+
raise KeyError("Expected key `target` to instantiate.")
|
89 |
+
return get_obj_from_str(config["target"])(**config.get("params", dict()))
|
90 |
+
|
91 |
+
|
92 |
+
def get_obj_from_str(string, reload=False):
|
93 |
+
module, cls = string.rsplit(".", 1)
|
94 |
+
if reload:
|
95 |
+
module_imp = importlib.import_module(module)
|
96 |
+
importlib.reload(module_imp)
|
97 |
+
return getattr(importlib.import_module(module, package=None), cls)
|
98 |
+
|
99 |
+
|
100 |
+
class AdamWwithEMAandWings(optim.Optimizer):
|
101 |
+
# credit to https://gist.github.com/crowsonkb/65f7265353f403714fce3b2595e0b298
|
102 |
+
def __init__(self, params, lr=1.e-3, betas=(0.9, 0.999), eps=1.e-8, # TODO: check hyperparameters before using
|
103 |
+
weight_decay=1.e-2, amsgrad=False, ema_decay=0.9999, # ema decay to match previous code
|
104 |
+
ema_power=1., param_names=()):
|
105 |
+
"""AdamW that saves EMA versions of the parameters."""
|
106 |
+
if not 0.0 <= lr:
|
107 |
+
raise ValueError("Invalid learning rate: {}".format(lr))
|
108 |
+
if not 0.0 <= eps:
|
109 |
+
raise ValueError("Invalid epsilon value: {}".format(eps))
|
110 |
+
if not 0.0 <= betas[0] < 1.0:
|
111 |
+
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
|
112 |
+
if not 0.0 <= betas[1] < 1.0:
|
113 |
+
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
|
114 |
+
if not 0.0 <= weight_decay:
|
115 |
+
raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
|
116 |
+
if not 0.0 <= ema_decay <= 1.0:
|
117 |
+
raise ValueError("Invalid ema_decay value: {}".format(ema_decay))
|
118 |
+
defaults = dict(lr=lr, betas=betas, eps=eps,
|
119 |
+
weight_decay=weight_decay, amsgrad=amsgrad, ema_decay=ema_decay,
|
120 |
+
ema_power=ema_power, param_names=param_names)
|
121 |
+
super().__init__(params, defaults)
|
122 |
+
|
123 |
+
def __setstate__(self, state):
|
124 |
+
super().__setstate__(state)
|
125 |
+
for group in self.param_groups:
|
126 |
+
group.setdefault('amsgrad', False)
|
127 |
+
|
128 |
+
@torch.no_grad()
|
129 |
+
def step(self, closure=None):
|
130 |
+
"""Performs a single optimization step.
|
131 |
+
Args:
|
132 |
+
closure (callable, optional): A closure that reevaluates the model
|
133 |
+
and returns the loss.
|
134 |
+
"""
|
135 |
+
loss = None
|
136 |
+
if closure is not None:
|
137 |
+
with torch.enable_grad():
|
138 |
+
loss = closure()
|
139 |
+
|
140 |
+
for group in self.param_groups:
|
141 |
+
params_with_grad = []
|
142 |
+
grads = []
|
143 |
+
exp_avgs = []
|
144 |
+
exp_avg_sqs = []
|
145 |
+
ema_params_with_grad = []
|
146 |
+
state_sums = []
|
147 |
+
max_exp_avg_sqs = []
|
148 |
+
state_steps = []
|
149 |
+
amsgrad = group['amsgrad']
|
150 |
+
beta1, beta2 = group['betas']
|
151 |
+
ema_decay = group['ema_decay']
|
152 |
+
ema_power = group['ema_power']
|
153 |
+
|
154 |
+
for p in group['params']:
|
155 |
+
if p.grad is None:
|
156 |
+
continue
|
157 |
+
params_with_grad.append(p)
|
158 |
+
if p.grad.is_sparse:
|
159 |
+
raise RuntimeError('AdamW does not support sparse gradients')
|
160 |
+
grads.append(p.grad)
|
161 |
+
|
162 |
+
state = self.state[p]
|
163 |
+
|
164 |
+
# State initialization
|
165 |
+
if len(state) == 0:
|
166 |
+
state['step'] = 0
|
167 |
+
# Exponential moving average of gradient values
|
168 |
+
state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format)
|
169 |
+
# Exponential moving average of squared gradient values
|
170 |
+
state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
|
171 |
+
if amsgrad:
|
172 |
+
# Maintains max of all exp. moving avg. of sq. grad. values
|
173 |
+
state['max_exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
|
174 |
+
# Exponential moving average of parameter values
|
175 |
+
state['param_exp_avg'] = p.detach().float().clone()
|
176 |
+
|
177 |
+
exp_avgs.append(state['exp_avg'])
|
178 |
+
exp_avg_sqs.append(state['exp_avg_sq'])
|
179 |
+
ema_params_with_grad.append(state['param_exp_avg'])
|
180 |
+
|
181 |
+
if amsgrad:
|
182 |
+
max_exp_avg_sqs.append(state['max_exp_avg_sq'])
|
183 |
+
|
184 |
+
# update the steps for each param group update
|
185 |
+
state['step'] += 1
|
186 |
+
# record the step after step update
|
187 |
+
state_steps.append(state['step'])
|
188 |
+
|
189 |
+
optim._functional.adamw(params_with_grad,
|
190 |
+
grads,
|
191 |
+
exp_avgs,
|
192 |
+
exp_avg_sqs,
|
193 |
+
max_exp_avg_sqs,
|
194 |
+
state_steps,
|
195 |
+
amsgrad=amsgrad,
|
196 |
+
beta1=beta1,
|
197 |
+
beta2=beta2,
|
198 |
+
lr=group['lr'],
|
199 |
+
weight_decay=group['weight_decay'],
|
200 |
+
eps=group['eps'],
|
201 |
+
maximize=False)
|
202 |
+
|
203 |
+
cur_ema_decay = min(ema_decay, 1 - state['step'] ** -ema_power)
|
204 |
+
for param, ema_param in zip(params_with_grad, ema_params_with_grad):
|
205 |
+
ema_param.mul_(cur_ema_decay).add_(param.float(), alpha=1 - cur_ema_decay)
|
206 |
+
|
207 |
+
return loss
|
repositories/stable-diffusion-stability-ai/modelcard.md
ADDED
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Stable Diffusion v2 Model Card
|
2 |
+
This model card focuses on the models associated with the Stable Diffusion v2, available [here](https://github.com/Stability-AI/stablediffusion/).
|
3 |
+
|
4 |
+
## Model Details
|
5 |
+
- **Developed by:** Robin Rombach, Patrick Esser
|
6 |
+
- **Model type:** Diffusion-based text-to-image generation model
|
7 |
+
- **Language(s):** English
|
8 |
+
- **License:** CreativeML Open RAIL++-M License
|
9 |
+
- **Model Description:** This is a model that can be used to generate and modify images based on text prompts. It is a [Latent Diffusion Model](https://arxiv.org/abs/2112.10752) that uses a fixed, pretrained text encoder ([OpenCLIP-ViT/H](https://github.com/mlfoundations/open_clip)).
|
10 |
+
- **Resources for more information:** [GitHub Repository](https://github.com/Stability-AI/).
|
11 |
+
- **Cite as:**
|
12 |
+
|
13 |
+
@InProceedings{Rombach_2022_CVPR,
|
14 |
+
author = {Rombach, Robin and Blattmann, Andreas and Lorenz, Dominik and Esser, Patrick and Ommer, Bj\"orn},
|
15 |
+
title = {High-Resolution Image Synthesis With Latent Diffusion Models},
|
16 |
+
booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
|
17 |
+
month = {June},
|
18 |
+
year = {2022},
|
19 |
+
pages = {10684-10695}
|
20 |
+
}
|
21 |
+
|
22 |
+
# Uses
|
23 |
+
|
24 |
+
## Direct Use
|
25 |
+
The model is intended for research purposes only. Possible research areas and tasks include
|
26 |
+
|
27 |
+
- Safe deployment of models which have the potential to generate harmful content.
|
28 |
+
- Probing and understanding the limitations and biases of generative models.
|
29 |
+
- Generation of artworks and use in design and other artistic processes.
|
30 |
+
- Applications in educational or creative tools.
|
31 |
+
- Research on generative models.
|
32 |
+
|
33 |
+
Excluded uses are described below.
|
34 |
+
|
35 |
+
### Misuse, Malicious Use, and Out-of-Scope Use
|
36 |
+
_Note: This section is originally taken from the [DALLE-MINI model card](https://huggingface.co/dalle-mini/dalle-mini), was used for Stable Diffusion v1, but applies in the same way to Stable Diffusion v2_.
|
37 |
+
|
38 |
+
The model should not be used to intentionally create or disseminate images that create hostile or alienating environments for people. This includes generating images that people would foreseeably find disturbing, distressing, or offensive; or content that propagates historical or current stereotypes.
|
39 |
+
|
40 |
+
#### Out-of-Scope Use
|
41 |
+
The model was not trained to be factual or true representations of people or events, and therefore using the model to generate such content is out-of-scope for the abilities of this model.
|
42 |
+
|
43 |
+
#### Misuse and Malicious Use
|
44 |
+
Using the model to generate content that is cruel to individuals is a misuse of this model. This includes, but is not limited to:
|
45 |
+
|
46 |
+
- Generating demeaning, dehumanizing, or otherwise harmful representations of people or their environments, cultures, religions, etc.
|
47 |
+
- Intentionally promoting or propagating discriminatory content or harmful stereotypes.
|
48 |
+
- Impersonating individuals without their consent.
|
49 |
+
- Sexual content without consent of the people who might see it.
|
50 |
+
- Mis- and disinformation
|
51 |
+
- Representations of egregious violence and gore
|
52 |
+
- Sharing of copyrighted or licensed material in violation of its terms of use.
|
53 |
+
- Sharing content that is an alteration of copyrighted or licensed material in violation of its terms of use.
|
54 |
+
|
55 |
+
## Limitations and Bias
|
56 |
+
|
57 |
+
### Limitations
|
58 |
+
|
59 |
+
- The model does not achieve perfect photorealism
|
60 |
+
- The model cannot render legible text
|
61 |
+
- The model does not perform well on more difficult tasks which involve compositionality, such as rendering an image corresponding to “A red cube on top of a blue sphere”
|
62 |
+
- Faces and people in general may not be generated properly.
|
63 |
+
- The model was trained mainly with English captions and will not work as well in other languages.
|
64 |
+
- The autoencoding part of the model is lossy
|
65 |
+
- The model was trained on a subset of the large-scale dataset
|
66 |
+
[LAION-5B](https://laion.ai/blog/laion-5b/), which contains adult, violent and sexual content. To partially mitigate this, we have filtered the dataset using LAION's NFSW detector (see Training section).
|
67 |
+
|
68 |
+
### Bias
|
69 |
+
While the capabilities of image generation models are impressive, they can also reinforce or exacerbate social biases.
|
70 |
+
Stable Diffusion vw was primarily trained on subsets of [LAION-2B(en)](https://laion.ai/blog/laion-5b/),
|
71 |
+
which consists of images that are limited to English descriptions.
|
72 |
+
Texts and images from communities and cultures that use other languages are likely to be insufficiently accounted for.
|
73 |
+
This affects the overall output of the model, as white and western cultures are often set as the default. Further, the
|
74 |
+
ability of the model to generate content with non-English prompts is significantly worse than with English-language prompts.
|
75 |
+
Stable Diffusion v2 mirrors and exacerbates biases to such a degree that viewer discretion must be advised irrespective of the input or its intent.
|
76 |
+
|
77 |
+
|
78 |
+
## Training
|
79 |
+
|
80 |
+
**Training Data**
|
81 |
+
The model developers used the following dataset for training the model:
|
82 |
+
|
83 |
+
- LAION-5B and subsets (details below). The training data is further filtered using LAION's NSFW detector. For more details, please refer to LAION-5B's [NeurIPS 2022](https://openreview.net/forum?id=M3Y74vmsMcY) paper and reviewer discussions on the topic.
|
84 |
+
|
85 |
+
**Training Procedure**
|
86 |
+
Stable Diffusion v2 is a latent diffusion model which combines an autoencoder with a diffusion model that is trained in the latent space of the autoencoder. During training,
|
87 |
+
|
88 |
+
- Images are encoded through an encoder, which turns images into latent representations. The autoencoder uses a relative downsampling factor of 8 and maps images of shape H x W x 3 to latents of shape H/f x W/f x 4
|
89 |
+
- Text prompts are encoded through the OpenCLIP-ViT/H text-encoder.
|
90 |
+
- The output of the text encoder is fed into the UNet backbone of the latent diffusion model via cross-attention.
|
91 |
+
- The loss is a reconstruction objective between the noise that was added to the latent and the prediction made by the UNet. We also use the so-called _v-objective_, see https://arxiv.org/abs/2202.00512.
|
92 |
+
|
93 |
+
We currently provide the following checkpoints, for various versions:
|
94 |
+
|
95 |
+
### Version 2.1
|
96 |
+
|
97 |
+
- `512-base-ema.ckpt`: Fine-tuned on `512-base-ema.ckpt` 2.0 with 220k extra steps taken, with `punsafe=0.98` on the same dataset.
|
98 |
+
- `768-v-ema.ckpt`: Resumed from `768-v-ema.ckpt` 2.0 with an additional 55k steps on the same dataset (`punsafe=0.1`), and then fine-tuned for another 155k extra steps with `punsafe=0.98`.
|
99 |
+
|
100 |
+
**SD-unCLIP 2.1** is a finetuned version of Stable Diffusion 2.1, modified to accept (noisy) CLIP image embedding in addition to the text prompt, and can be used to create image variations ([Examples](https://github.com/Stability-AI/stablediffusion/blob/main/doc/UNCLIP.MD)) or can be chained with text-to-image CLIP priors. The amount of noise added to the image embedding can be specified via the `noise_level` (0 means no noise, 1000 full noise).
|
101 |
+
|
102 |
+
If you plan on building applications on top of the model that the general public may use, you are responsible for adding the guardrails to minimize or prevent misuse of the application, especially for use-cases highlighted in the earlier section, Misuse, Malicious Use, and Out-of-Scope Use.
|
103 |
+
|
104 |
+
A public demo of SD-unCLIP is already available at [clipdrop.co/stable-diffusion-reimagine](https://clipdrop.co/stable-diffusion-reimagine)
|
105 |
+
|
106 |
+
### Version 2.0
|
107 |
+
|
108 |
+
- `512-base-ema.ckpt`: 550k steps at resolution `256x256` on a subset of [LAION-5B](https://laion.ai/blog/laion-5b/) filtered for explicit pornographic material, using the [LAION-NSFW classifier](https://github.com/LAION-AI/CLIP-based-NSFW-Detector) with `punsafe=0.1` and an [aesthetic score](https://github.com/christophschuhmann/improved-aesthetic-predictor) >= `4.5`.
|
109 |
+
850k steps at resolution `512x512` on the same dataset with resolution `>= 512x512`.
|
110 |
+
- `768-v-ema.ckpt`: Resumed from `512-base-ema.ckpt` and trained for 150k steps using a [v-objective](https://arxiv.org/abs/2202.00512) on the same dataset. Resumed for another 140k steps on a `768x768` subset of our dataset.
|
111 |
+
- `512-depth-ema.ckpt`: Resumed from `512-base-ema.ckpt` and finetuned for 200k steps. Added an extra input channel to process the (relative) depth prediction produced by [MiDaS](https://github.com/isl-org/MiDaS) (`dpt_hybrid`) which is used as an additional conditioning.
|
112 |
+
The additional input channels of the U-Net which process this extra information were zero-initialized.
|
113 |
+
- `512-inpainting-ema.ckpt`: Resumed from `512-base-ema.ckpt` and trained for another 200k steps. Follows the mask-generation strategy presented in [LAMA](https://github.com/saic-mdal/lama) which, in combination with the latent VAE representations of the masked image, are used as an additional conditioning.
|
114 |
+
The additional input channels of the U-Net which process this extra information were zero-initialized. The same strategy was used to train the [1.5-inpainting checkpoint](https://github.com/saic-mdal/lama).
|
115 |
+
- `x4-upscaling-ema.ckpt`: Trained for 1.25M steps on a 10M subset of LAION containing images `>2048x2048`. The model was trained on crops of size `512x512` and is a text-guided [latent upscaling diffusion model](https://arxiv.org/abs/2112.10752).
|
116 |
+
In addition to the textual input, it receives a `noise_level` as an input parameter, which can be used to add noise to the low-resolution input according to a [predefined diffusion schedule](configs/stable-diffusion/x4-upscaling.yaml).
|
117 |
+
|
118 |
+
- **Hardware:** 32 x 8 x A100 GPUs
|
119 |
+
- **Optimizer:** AdamW
|
120 |
+
- **Gradient Accumulations**: 1
|
121 |
+
- **Batch:** 32 x 8 x 2 x 4 = 2048
|
122 |
+
- **Learning rate:** warmup to 0.0001 for 10,000 steps and then kept constant
|
123 |
+
|
124 |
+
## Evaluation Results
|
125 |
+
Evaluations with different classifier-free guidance scales (1.5, 2.0, 3.0, 4.0,
|
126 |
+
5.0, 6.0, 7.0, 8.0) and 50 steps DDIM sampling steps show the relative improvements of the checkpoints:
|
127 |
+
|
128 |
+

|
129 |
+
|
130 |
+
Evaluated using 50 DDIM steps and 10000 random prompts from the COCO2017 validation set, evaluated at 512x512 resolution. Not optimized for FID scores.
|
131 |
+
|
132 |
+
## Environmental Impact
|
133 |
+
|
134 |
+
**Stable Diffusion v1** **Estimated Emissions**
|
135 |
+
Based on that information, we estimate the following CO2 emissions using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700). The hardware, runtime, cloud provider, and compute region were utilized to estimate the carbon impact.
|
136 |
+
|
137 |
+
- **Hardware Type:** A100 PCIe 40GB
|
138 |
+
- **Hours used:** 200000
|
139 |
+
- **Cloud Provider:** AWS
|
140 |
+
- **Compute Region:** US-east
|
141 |
+
- **Carbon Emitted (Power consumption x Time x Carbon produced based on location of power grid):** 15000 kg CO2 eq.
|
142 |
+
|
143 |
+
## Citation
|
144 |
+
@InProceedings{Rombach_2022_CVPR,
|
145 |
+
author = {Rombach, Robin and Blattmann, Andreas and Lorenz, Dominik and Esser, Patrick and Ommer, Bj\"orn},
|
146 |
+
title = {High-Resolution Image Synthesis With Latent Diffusion Models},
|
147 |
+
booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
|
148 |
+
month = {June},
|
149 |
+
year = {2022},
|
150 |
+
pages = {10684-10695}
|
151 |
+
}
|
152 |
+
|
153 |
+
*This model card was written by: Robin Rombach, Patrick Esser and David Ha and is based on the [Stable Diffusion v1](https://github.com/CompVis/stable-diffusion/blob/main/Stable_Diffusion_v1_Model_Card.md) and [DALL-E Mini model card](https://huggingface.co/dalle-mini/dalle-mini).*
|
repositories/stable-diffusion-stability-ai/requirements.txt
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
albumentations==0.4.3
|
2 |
+
opencv-python
|
3 |
+
pudb==2019.2
|
4 |
+
imageio==2.9.0
|
5 |
+
imageio-ffmpeg==0.4.2
|
6 |
+
pytorch-lightning==1.4.2
|
7 |
+
torchmetrics==0.6
|
8 |
+
omegaconf==2.1.1
|
9 |
+
test-tube>=0.7.5
|
10 |
+
streamlit>=0.73.1
|
11 |
+
einops==0.3.0
|
12 |
+
transformers==4.19.2
|
13 |
+
webdataset==0.2.5
|
14 |
+
open-clip-torch==2.7.0
|
15 |
+
gradio==3.13.2
|
16 |
+
kornia==0.6
|
17 |
+
invisible-watermark>=0.1.5
|
18 |
+
streamlit-drawable-canvas==0.8.0
|
19 |
+
-e .
|
repositories/stable-diffusion-stability-ai/scripts/gradio/depth2img.py
ADDED
@@ -0,0 +1,184 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
import torch
|
3 |
+
import numpy as np
|
4 |
+
import gradio as gr
|
5 |
+
from PIL import Image
|
6 |
+
from omegaconf import OmegaConf
|
7 |
+
from einops import repeat, rearrange
|
8 |
+
from pytorch_lightning import seed_everything
|
9 |
+
from imwatermark import WatermarkEncoder
|
10 |
+
|
11 |
+
from scripts.txt2img import put_watermark
|
12 |
+
from ldm.util import instantiate_from_config
|
13 |
+
from ldm.models.diffusion.ddim import DDIMSampler
|
14 |
+
from ldm.data.util import AddMiDaS
|
15 |
+
|
16 |
+
torch.set_grad_enabled(False)
|
17 |
+
|
18 |
+
|
19 |
+
def initialize_model(config, ckpt):
|
20 |
+
config = OmegaConf.load(config)
|
21 |
+
model = instantiate_from_config(config.model)
|
22 |
+
model.load_state_dict(torch.load(ckpt)["state_dict"], strict=False)
|
23 |
+
|
24 |
+
device = torch.device(
|
25 |
+
"cuda") if torch.cuda.is_available() else torch.device("cpu")
|
26 |
+
model = model.to(device)
|
27 |
+
sampler = DDIMSampler(model)
|
28 |
+
return sampler
|
29 |
+
|
30 |
+
|
31 |
+
def make_batch_sd(
|
32 |
+
image,
|
33 |
+
txt,
|
34 |
+
device,
|
35 |
+
num_samples=1,
|
36 |
+
model_type="dpt_hybrid"
|
37 |
+
):
|
38 |
+
image = np.array(image.convert("RGB"))
|
39 |
+
image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
|
40 |
+
# sample['jpg'] is tensor hwc in [-1, 1] at this point
|
41 |
+
midas_trafo = AddMiDaS(model_type=model_type)
|
42 |
+
batch = {
|
43 |
+
"jpg": image,
|
44 |
+
"txt": num_samples * [txt],
|
45 |
+
}
|
46 |
+
batch = midas_trafo(batch)
|
47 |
+
batch["jpg"] = rearrange(batch["jpg"], 'h w c -> 1 c h w')
|
48 |
+
batch["jpg"] = repeat(batch["jpg"].to(device=device),
|
49 |
+
"1 ... -> n ...", n=num_samples)
|
50 |
+
batch["midas_in"] = repeat(torch.from_numpy(batch["midas_in"][None, ...]).to(
|
51 |
+
device=device), "1 ... -> n ...", n=num_samples)
|
52 |
+
return batch
|
53 |
+
|
54 |
+
|
55 |
+
def paint(sampler, image, prompt, t_enc, seed, scale, num_samples=1, callback=None,
|
56 |
+
do_full_sample=False):
|
57 |
+
device = torch.device(
|
58 |
+
"cuda") if torch.cuda.is_available() else torch.device("cpu")
|
59 |
+
model = sampler.model
|
60 |
+
seed_everything(seed)
|
61 |
+
|
62 |
+
print("Creating invisible watermark encoder (see https://github.com/ShieldMnt/invisible-watermark)...")
|
63 |
+
wm = "SDV2"
|
64 |
+
wm_encoder = WatermarkEncoder()
|
65 |
+
wm_encoder.set_watermark('bytes', wm.encode('utf-8'))
|
66 |
+
|
67 |
+
with torch.no_grad(),\
|
68 |
+
torch.autocast("cuda"):
|
69 |
+
batch = make_batch_sd(
|
70 |
+
image, txt=prompt, device=device, num_samples=num_samples)
|
71 |
+
z = model.get_first_stage_encoding(model.encode_first_stage(
|
72 |
+
batch[model.first_stage_key])) # move to latent space
|
73 |
+
c = model.cond_stage_model.encode(batch["txt"])
|
74 |
+
c_cat = list()
|
75 |
+
for ck in model.concat_keys:
|
76 |
+
cc = batch[ck]
|
77 |
+
cc = model.depth_model(cc)
|
78 |
+
depth_min, depth_max = torch.amin(cc, dim=[1, 2, 3], keepdim=True), torch.amax(cc, dim=[1, 2, 3],
|
79 |
+
keepdim=True)
|
80 |
+
display_depth = (cc - depth_min) / (depth_max - depth_min)
|
81 |
+
depth_image = Image.fromarray(
|
82 |
+
(display_depth[0, 0, ...].cpu().numpy() * 255.).astype(np.uint8))
|
83 |
+
cc = torch.nn.functional.interpolate(
|
84 |
+
cc,
|
85 |
+
size=z.shape[2:],
|
86 |
+
mode="bicubic",
|
87 |
+
align_corners=False,
|
88 |
+
)
|
89 |
+
depth_min, depth_max = torch.amin(cc, dim=[1, 2, 3], keepdim=True), torch.amax(cc, dim=[1, 2, 3],
|
90 |
+
keepdim=True)
|
91 |
+
cc = 2. * (cc - depth_min) / (depth_max - depth_min) - 1.
|
92 |
+
c_cat.append(cc)
|
93 |
+
c_cat = torch.cat(c_cat, dim=1)
|
94 |
+
# cond
|
95 |
+
cond = {"c_concat": [c_cat], "c_crossattn": [c]}
|
96 |
+
|
97 |
+
# uncond cond
|
98 |
+
uc_cross = model.get_unconditional_conditioning(num_samples, "")
|
99 |
+
uc_full = {"c_concat": [c_cat], "c_crossattn": [uc_cross]}
|
100 |
+
if not do_full_sample:
|
101 |
+
# encode (scaled latent)
|
102 |
+
z_enc = sampler.stochastic_encode(
|
103 |
+
z, torch.tensor([t_enc] * num_samples).to(model.device))
|
104 |
+
else:
|
105 |
+
z_enc = torch.randn_like(z)
|
106 |
+
# decode it
|
107 |
+
samples = sampler.decode(z_enc, cond, t_enc, unconditional_guidance_scale=scale,
|
108 |
+
unconditional_conditioning=uc_full, callback=callback)
|
109 |
+
x_samples_ddim = model.decode_first_stage(samples)
|
110 |
+
result = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
|
111 |
+
result = result.cpu().numpy().transpose(0, 2, 3, 1) * 255
|
112 |
+
return [depth_image] + [put_watermark(Image.fromarray(img.astype(np.uint8)), wm_encoder) for img in result]
|
113 |
+
|
114 |
+
|
115 |
+
def pad_image(input_image):
|
116 |
+
pad_w, pad_h = np.max(((2, 2), np.ceil(
|
117 |
+
np.array(input_image.size) / 64).astype(int)), axis=0) * 64 - input_image.size
|
118 |
+
im_padded = Image.fromarray(
|
119 |
+
np.pad(np.array(input_image), ((0, pad_h), (0, pad_w), (0, 0)), mode='edge'))
|
120 |
+
return im_padded
|
121 |
+
|
122 |
+
|
123 |
+
def predict(input_image, prompt, steps, num_samples, scale, seed, eta, strength):
|
124 |
+
init_image = input_image.convert("RGB")
|
125 |
+
image = pad_image(init_image) # resize to integer multiple of 32
|
126 |
+
|
127 |
+
sampler.make_schedule(steps, ddim_eta=eta, verbose=True)
|
128 |
+
assert 0. <= strength <= 1., 'can only work with strength in [0.0, 1.0]'
|
129 |
+
do_full_sample = strength == 1.
|
130 |
+
t_enc = min(int(strength * steps), steps-1)
|
131 |
+
result = paint(
|
132 |
+
sampler=sampler,
|
133 |
+
image=image,
|
134 |
+
prompt=prompt,
|
135 |
+
t_enc=t_enc,
|
136 |
+
seed=seed,
|
137 |
+
scale=scale,
|
138 |
+
num_samples=num_samples,
|
139 |
+
callback=None,
|
140 |
+
do_full_sample=do_full_sample
|
141 |
+
)
|
142 |
+
return result
|
143 |
+
|
144 |
+
|
145 |
+
sampler = initialize_model(sys.argv[1], sys.argv[2])
|
146 |
+
|
147 |
+
block = gr.Blocks().queue()
|
148 |
+
with block:
|
149 |
+
with gr.Row():
|
150 |
+
gr.Markdown("## Stable Diffusion Depth2Img")
|
151 |
+
|
152 |
+
with gr.Row():
|
153 |
+
with gr.Column():
|
154 |
+
input_image = gr.Image(source='upload', type="pil")
|
155 |
+
prompt = gr.Textbox(label="Prompt")
|
156 |
+
run_button = gr.Button(label="Run")
|
157 |
+
with gr.Accordion("Advanced options", open=False):
|
158 |
+
num_samples = gr.Slider(
|
159 |
+
label="Images", minimum=1, maximum=4, value=1, step=1)
|
160 |
+
ddim_steps = gr.Slider(label="Steps", minimum=1,
|
161 |
+
maximum=50, value=50, step=1)
|
162 |
+
scale = gr.Slider(
|
163 |
+
label="Guidance Scale", minimum=0.1, maximum=30.0, value=9.0, step=0.1
|
164 |
+
)
|
165 |
+
strength = gr.Slider(
|
166 |
+
label="Strength", minimum=0.0, maximum=1.0, value=0.9, step=0.01
|
167 |
+
)
|
168 |
+
seed = gr.Slider(
|
169 |
+
label="Seed",
|
170 |
+
minimum=0,
|
171 |
+
maximum=2147483647,
|
172 |
+
step=1,
|
173 |
+
randomize=True,
|
174 |
+
)
|
175 |
+
eta = gr.Number(label="eta (DDIM)", value=0.0)
|
176 |
+
with gr.Column():
|
177 |
+
gallery = gr.Gallery(label="Generated images", show_label=False).style(
|
178 |
+
grid=[2], height="auto")
|
179 |
+
|
180 |
+
run_button.click(fn=predict, inputs=[
|
181 |
+
input_image, prompt, ddim_steps, num_samples, scale, seed, eta, strength], outputs=[gallery])
|
182 |
+
|
183 |
+
|
184 |
+
block.launch()
|
repositories/stable-diffusion-stability-ai/scripts/gradio/inpainting.py
ADDED
@@ -0,0 +1,195 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
import cv2
|
3 |
+
import torch
|
4 |
+
import numpy as np
|
5 |
+
import gradio as gr
|
6 |
+
from PIL import Image
|
7 |
+
from omegaconf import OmegaConf
|
8 |
+
from einops import repeat
|
9 |
+
from imwatermark import WatermarkEncoder
|
10 |
+
from pathlib import Path
|
11 |
+
|
12 |
+
from ldm.models.diffusion.ddim import DDIMSampler
|
13 |
+
from ldm.util import instantiate_from_config
|
14 |
+
|
15 |
+
|
16 |
+
torch.set_grad_enabled(False)
|
17 |
+
|
18 |
+
|
19 |
+
def put_watermark(img, wm_encoder=None):
|
20 |
+
if wm_encoder is not None:
|
21 |
+
img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
|
22 |
+
img = wm_encoder.encode(img, 'dwtDct')
|
23 |
+
img = Image.fromarray(img[:, :, ::-1])
|
24 |
+
return img
|
25 |
+
|
26 |
+
|
27 |
+
def initialize_model(config, ckpt):
|
28 |
+
config = OmegaConf.load(config)
|
29 |
+
model = instantiate_from_config(config.model)
|
30 |
+
|
31 |
+
model.load_state_dict(torch.load(ckpt)["state_dict"], strict=False)
|
32 |
+
|
33 |
+
device = torch.device(
|
34 |
+
"cuda") if torch.cuda.is_available() else torch.device("cpu")
|
35 |
+
model = model.to(device)
|
36 |
+
sampler = DDIMSampler(model)
|
37 |
+
|
38 |
+
return sampler
|
39 |
+
|
40 |
+
|
41 |
+
def make_batch_sd(
|
42 |
+
image,
|
43 |
+
mask,
|
44 |
+
txt,
|
45 |
+
device,
|
46 |
+
num_samples=1):
|
47 |
+
image = np.array(image.convert("RGB"))
|
48 |
+
image = image[None].transpose(0, 3, 1, 2)
|
49 |
+
image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
|
50 |
+
|
51 |
+
mask = np.array(mask.convert("L"))
|
52 |
+
mask = mask.astype(np.float32) / 255.0
|
53 |
+
mask = mask[None, None]
|
54 |
+
mask[mask < 0.5] = 0
|
55 |
+
mask[mask >= 0.5] = 1
|
56 |
+
mask = torch.from_numpy(mask)
|
57 |
+
|
58 |
+
masked_image = image * (mask < 0.5)
|
59 |
+
|
60 |
+
batch = {
|
61 |
+
"image": repeat(image.to(device=device), "1 ... -> n ...", n=num_samples),
|
62 |
+
"txt": num_samples * [txt],
|
63 |
+
"mask": repeat(mask.to(device=device), "1 ... -> n ...", n=num_samples),
|
64 |
+
"masked_image": repeat(masked_image.to(device=device), "1 ... -> n ...", n=num_samples),
|
65 |
+
}
|
66 |
+
return batch
|
67 |
+
|
68 |
+
|
69 |
+
def inpaint(sampler, image, mask, prompt, seed, scale, ddim_steps, num_samples=1, w=512, h=512):
|
70 |
+
device = torch.device(
|
71 |
+
"cuda") if torch.cuda.is_available() else torch.device("cpu")
|
72 |
+
model = sampler.model
|
73 |
+
|
74 |
+
print("Creating invisible watermark encoder (see https://github.com/ShieldMnt/invisible-watermark)...")
|
75 |
+
wm = "SDV2"
|
76 |
+
wm_encoder = WatermarkEncoder()
|
77 |
+
wm_encoder.set_watermark('bytes', wm.encode('utf-8'))
|
78 |
+
|
79 |
+
prng = np.random.RandomState(seed)
|
80 |
+
start_code = prng.randn(num_samples, 4, h // 8, w // 8)
|
81 |
+
start_code = torch.from_numpy(start_code).to(
|
82 |
+
device=device, dtype=torch.float32)
|
83 |
+
|
84 |
+
with torch.no_grad(), \
|
85 |
+
torch.autocast("cuda"):
|
86 |
+
batch = make_batch_sd(image, mask, txt=prompt,
|
87 |
+
device=device, num_samples=num_samples)
|
88 |
+
|
89 |
+
c = model.cond_stage_model.encode(batch["txt"])
|
90 |
+
|
91 |
+
c_cat = list()
|
92 |
+
for ck in model.concat_keys:
|
93 |
+
cc = batch[ck].float()
|
94 |
+
if ck != model.masked_image_key:
|
95 |
+
bchw = [num_samples, 4, h // 8, w // 8]
|
96 |
+
cc = torch.nn.functional.interpolate(cc, size=bchw[-2:])
|
97 |
+
else:
|
98 |
+
cc = model.get_first_stage_encoding(
|
99 |
+
model.encode_first_stage(cc))
|
100 |
+
c_cat.append(cc)
|
101 |
+
c_cat = torch.cat(c_cat, dim=1)
|
102 |
+
|
103 |
+
# cond
|
104 |
+
cond = {"c_concat": [c_cat], "c_crossattn": [c]}
|
105 |
+
|
106 |
+
# uncond cond
|
107 |
+
uc_cross = model.get_unconditional_conditioning(num_samples, "")
|
108 |
+
uc_full = {"c_concat": [c_cat], "c_crossattn": [uc_cross]}
|
109 |
+
|
110 |
+
shape = [model.channels, h // 8, w // 8]
|
111 |
+
samples_cfg, intermediates = sampler.sample(
|
112 |
+
ddim_steps,
|
113 |
+
num_samples,
|
114 |
+
shape,
|
115 |
+
cond,
|
116 |
+
verbose=False,
|
117 |
+
eta=1.0,
|
118 |
+
unconditional_guidance_scale=scale,
|
119 |
+
unconditional_conditioning=uc_full,
|
120 |
+
x_T=start_code,
|
121 |
+
)
|
122 |
+
x_samples_ddim = model.decode_first_stage(samples_cfg)
|
123 |
+
|
124 |
+
result = torch.clamp((x_samples_ddim + 1.0) / 2.0,
|
125 |
+
min=0.0, max=1.0)
|
126 |
+
|
127 |
+
result = result.cpu().numpy().transpose(0, 2, 3, 1) * 255
|
128 |
+
return [put_watermark(Image.fromarray(img.astype(np.uint8)), wm_encoder) for img in result]
|
129 |
+
|
130 |
+
def pad_image(input_image):
|
131 |
+
pad_w, pad_h = np.max(((2, 2), np.ceil(
|
132 |
+
np.array(input_image.size) / 64).astype(int)), axis=0) * 64 - input_image.size
|
133 |
+
im_padded = Image.fromarray(
|
134 |
+
np.pad(np.array(input_image), ((0, pad_h), (0, pad_w), (0, 0)), mode='edge'))
|
135 |
+
return im_padded
|
136 |
+
|
137 |
+
def predict(input_image, prompt, ddim_steps, num_samples, scale, seed):
|
138 |
+
init_image = input_image["image"].convert("RGB")
|
139 |
+
init_mask = input_image["mask"].convert("RGB")
|
140 |
+
image = pad_image(init_image) # resize to integer multiple of 32
|
141 |
+
mask = pad_image(init_mask) # resize to integer multiple of 32
|
142 |
+
width, height = image.size
|
143 |
+
print("Inpainting...", width, height)
|
144 |
+
|
145 |
+
result = inpaint(
|
146 |
+
sampler=sampler,
|
147 |
+
image=image,
|
148 |
+
mask=mask,
|
149 |
+
prompt=prompt,
|
150 |
+
seed=seed,
|
151 |
+
scale=scale,
|
152 |
+
ddim_steps=ddim_steps,
|
153 |
+
num_samples=num_samples,
|
154 |
+
h=height, w=width
|
155 |
+
)
|
156 |
+
|
157 |
+
return result
|
158 |
+
|
159 |
+
|
160 |
+
sampler = initialize_model(sys.argv[1], sys.argv[2])
|
161 |
+
|
162 |
+
block = gr.Blocks().queue()
|
163 |
+
with block:
|
164 |
+
with gr.Row():
|
165 |
+
gr.Markdown("## Stable Diffusion Inpainting")
|
166 |
+
|
167 |
+
with gr.Row():
|
168 |
+
with gr.Column():
|
169 |
+
input_image = gr.Image(source='upload', tool='sketch', type="pil")
|
170 |
+
prompt = gr.Textbox(label="Prompt")
|
171 |
+
run_button = gr.Button(label="Run")
|
172 |
+
with gr.Accordion("Advanced options", open=False):
|
173 |
+
num_samples = gr.Slider(
|
174 |
+
label="Images", minimum=1, maximum=4, value=4, step=1)
|
175 |
+
ddim_steps = gr.Slider(label="Steps", minimum=1,
|
176 |
+
maximum=50, value=45, step=1)
|
177 |
+
scale = gr.Slider(
|
178 |
+
label="Guidance Scale", minimum=0.1, maximum=30.0, value=10, step=0.1
|
179 |
+
)
|
180 |
+
seed = gr.Slider(
|
181 |
+
label="Seed",
|
182 |
+
minimum=0,
|
183 |
+
maximum=2147483647,
|
184 |
+
step=1,
|
185 |
+
randomize=True,
|
186 |
+
)
|
187 |
+
with gr.Column():
|
188 |
+
gallery = gr.Gallery(label="Generated images", show_label=False).style(
|
189 |
+
grid=[2], height="auto")
|
190 |
+
|
191 |
+
run_button.click(fn=predict, inputs=[
|
192 |
+
input_image, prompt, ddim_steps, num_samples, scale, seed], outputs=[gallery])
|
193 |
+
|
194 |
+
|
195 |
+
block.launch()
|
repositories/stable-diffusion-stability-ai/scripts/gradio/superresolution.py
ADDED
@@ -0,0 +1,197 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
import torch
|
3 |
+
import numpy as np
|
4 |
+
import gradio as gr
|
5 |
+
from PIL import Image
|
6 |
+
from omegaconf import OmegaConf
|
7 |
+
from einops import repeat, rearrange
|
8 |
+
from pytorch_lightning import seed_everything
|
9 |
+
from imwatermark import WatermarkEncoder
|
10 |
+
|
11 |
+
from scripts.txt2img import put_watermark
|
12 |
+
from ldm.models.diffusion.ddim import DDIMSampler
|
13 |
+
from ldm.models.diffusion.ddpm import LatentUpscaleDiffusion, LatentUpscaleFinetuneDiffusion
|
14 |
+
from ldm.util import exists, instantiate_from_config
|
15 |
+
|
16 |
+
|
17 |
+
torch.set_grad_enabled(False)
|
18 |
+
|
19 |
+
|
20 |
+
def initialize_model(config, ckpt):
|
21 |
+
config = OmegaConf.load(config)
|
22 |
+
model = instantiate_from_config(config.model)
|
23 |
+
model.load_state_dict(torch.load(ckpt)["state_dict"], strict=False)
|
24 |
+
|
25 |
+
device = torch.device(
|
26 |
+
"cuda") if torch.cuda.is_available() else torch.device("cpu")
|
27 |
+
model = model.to(device)
|
28 |
+
sampler = DDIMSampler(model)
|
29 |
+
return sampler
|
30 |
+
|
31 |
+
|
32 |
+
def make_batch_sd(
|
33 |
+
image,
|
34 |
+
txt,
|
35 |
+
device,
|
36 |
+
num_samples=1,
|
37 |
+
):
|
38 |
+
image = np.array(image.convert("RGB"))
|
39 |
+
image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
|
40 |
+
batch = {
|
41 |
+
"lr": rearrange(image, 'h w c -> 1 c h w'),
|
42 |
+
"txt": num_samples * [txt],
|
43 |
+
}
|
44 |
+
batch["lr"] = repeat(batch["lr"].to(device=device),
|
45 |
+
"1 ... -> n ...", n=num_samples)
|
46 |
+
return batch
|
47 |
+
|
48 |
+
|
49 |
+
def make_noise_augmentation(model, batch, noise_level=None):
|
50 |
+
x_low = batch[model.low_scale_key]
|
51 |
+
x_low = x_low.to(memory_format=torch.contiguous_format).float()
|
52 |
+
x_aug, noise_level = model.low_scale_model(x_low, noise_level)
|
53 |
+
return x_aug, noise_level
|
54 |
+
|
55 |
+
|
56 |
+
def paint(sampler, image, prompt, seed, scale, h, w, steps, num_samples=1, callback=None, eta=0., noise_level=None):
|
57 |
+
device = torch.device(
|
58 |
+
"cuda") if torch.cuda.is_available() else torch.device("cpu")
|
59 |
+
model = sampler.model
|
60 |
+
seed_everything(seed)
|
61 |
+
prng = np.random.RandomState(seed)
|
62 |
+
start_code = prng.randn(num_samples, model.channels, h, w)
|
63 |
+
start_code = torch.from_numpy(start_code).to(
|
64 |
+
device=device, dtype=torch.float32)
|
65 |
+
|
66 |
+
print("Creating invisible watermark encoder (see https://github.com/ShieldMnt/invisible-watermark)...")
|
67 |
+
wm = "SDV2"
|
68 |
+
wm_encoder = WatermarkEncoder()
|
69 |
+
wm_encoder.set_watermark('bytes', wm.encode('utf-8'))
|
70 |
+
with torch.no_grad(),\
|
71 |
+
torch.autocast("cuda"):
|
72 |
+
batch = make_batch_sd(
|
73 |
+
image, txt=prompt, device=device, num_samples=num_samples)
|
74 |
+
c = model.cond_stage_model.encode(batch["txt"])
|
75 |
+
c_cat = list()
|
76 |
+
if isinstance(model, LatentUpscaleFinetuneDiffusion):
|
77 |
+
for ck in model.concat_keys:
|
78 |
+
cc = batch[ck]
|
79 |
+
if exists(model.reshuffle_patch_size):
|
80 |
+
assert isinstance(model.reshuffle_patch_size, int)
|
81 |
+
cc = rearrange(cc, 'b c (p1 h) (p2 w) -> b (p1 p2 c) h w',
|
82 |
+
p1=model.reshuffle_patch_size, p2=model.reshuffle_patch_size)
|
83 |
+
c_cat.append(cc)
|
84 |
+
c_cat = torch.cat(c_cat, dim=1)
|
85 |
+
# cond
|
86 |
+
cond = {"c_concat": [c_cat], "c_crossattn": [c]}
|
87 |
+
# uncond cond
|
88 |
+
uc_cross = model.get_unconditional_conditioning(num_samples, "")
|
89 |
+
uc_full = {"c_concat": [c_cat], "c_crossattn": [uc_cross]}
|
90 |
+
elif isinstance(model, LatentUpscaleDiffusion):
|
91 |
+
x_augment, noise_level = make_noise_augmentation(
|
92 |
+
model, batch, noise_level)
|
93 |
+
cond = {"c_concat": [x_augment],
|
94 |
+
"c_crossattn": [c], "c_adm": noise_level}
|
95 |
+
# uncond cond
|
96 |
+
uc_cross = model.get_unconditional_conditioning(num_samples, "")
|
97 |
+
uc_full = {"c_concat": [x_augment], "c_crossattn": [
|
98 |
+
uc_cross], "c_adm": noise_level}
|
99 |
+
else:
|
100 |
+
raise NotImplementedError()
|
101 |
+
|
102 |
+
shape = [model.channels, h, w]
|
103 |
+
samples, intermediates = sampler.sample(
|
104 |
+
steps,
|
105 |
+
num_samples,
|
106 |
+
shape,
|
107 |
+
cond,
|
108 |
+
verbose=False,
|
109 |
+
eta=eta,
|
110 |
+
unconditional_guidance_scale=scale,
|
111 |
+
unconditional_conditioning=uc_full,
|
112 |
+
x_T=start_code,
|
113 |
+
callback=callback
|
114 |
+
)
|
115 |
+
with torch.no_grad():
|
116 |
+
x_samples_ddim = model.decode_first_stage(samples)
|
117 |
+
result = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
|
118 |
+
result = result.cpu().numpy().transpose(0, 2, 3, 1) * 255
|
119 |
+
return [put_watermark(Image.fromarray(img.astype(np.uint8)), wm_encoder) for img in result]
|
120 |
+
|
121 |
+
|
122 |
+
def pad_image(input_image):
|
123 |
+
pad_w, pad_h = np.max(((2, 2), np.ceil(
|
124 |
+
np.array(input_image.size) / 64).astype(int)), axis=0) * 64 - input_image.size
|
125 |
+
im_padded = Image.fromarray(
|
126 |
+
np.pad(np.array(input_image), ((0, pad_h), (0, pad_w), (0, 0)), mode='edge'))
|
127 |
+
return im_padded
|
128 |
+
|
129 |
+
|
130 |
+
def predict(input_image, prompt, steps, num_samples, scale, seed, eta, noise_level):
|
131 |
+
init_image = input_image.convert("RGB")
|
132 |
+
image = pad_image(init_image) # resize to integer multiple of 32
|
133 |
+
width, height = image.size
|
134 |
+
|
135 |
+
noise_level = torch.Tensor(
|
136 |
+
num_samples * [noise_level]).to(sampler.model.device).long()
|
137 |
+
sampler.make_schedule(steps, ddim_eta=eta, verbose=True)
|
138 |
+
result = paint(
|
139 |
+
sampler=sampler,
|
140 |
+
image=image,
|
141 |
+
prompt=prompt,
|
142 |
+
seed=seed,
|
143 |
+
scale=scale,
|
144 |
+
h=height, w=width, steps=steps,
|
145 |
+
num_samples=num_samples,
|
146 |
+
callback=None,
|
147 |
+
noise_level=noise_level
|
148 |
+
)
|
149 |
+
return result
|
150 |
+
|
151 |
+
|
152 |
+
sampler = initialize_model(sys.argv[1], sys.argv[2])
|
153 |
+
|
154 |
+
block = gr.Blocks().queue()
|
155 |
+
with block:
|
156 |
+
with gr.Row():
|
157 |
+
gr.Markdown("## Stable Diffusion Upscaling")
|
158 |
+
|
159 |
+
with gr.Row():
|
160 |
+
with gr.Column():
|
161 |
+
input_image = gr.Image(source='upload', type="pil")
|
162 |
+
gr.Markdown(
|
163 |
+
"Tip: Add a description of the object that should be upscaled, e.g.: 'a professional photograph of a cat")
|
164 |
+
prompt = gr.Textbox(label="Prompt")
|
165 |
+
run_button = gr.Button(label="Run")
|
166 |
+
with gr.Accordion("Advanced options", open=False):
|
167 |
+
num_samples = gr.Slider(
|
168 |
+
label="Number of Samples", minimum=1, maximum=4, value=1, step=1)
|
169 |
+
steps = gr.Slider(label="DDIM Steps", minimum=2,
|
170 |
+
maximum=200, value=75, step=1)
|
171 |
+
scale = gr.Slider(
|
172 |
+
label="Scale", minimum=0.1, maximum=30.0, value=10, step=0.1
|
173 |
+
)
|
174 |
+
seed = gr.Slider(
|
175 |
+
label="Seed",
|
176 |
+
minimum=0,
|
177 |
+
maximum=2147483647,
|
178 |
+
step=1,
|
179 |
+
randomize=True,
|
180 |
+
)
|
181 |
+
eta = gr.Number(label="eta (DDIM)",
|
182 |
+
value=0.0, min=0.0, max=1.0)
|
183 |
+
noise_level = None
|
184 |
+
if isinstance(sampler.model, LatentUpscaleDiffusion):
|
185 |
+
# TODO: make this work for all models
|
186 |
+
noise_level = gr.Number(
|
187 |
+
label="Noise Augmentation", min=0, max=350, value=20, step=1)
|
188 |
+
|
189 |
+
with gr.Column():
|
190 |
+
gallery = gr.Gallery(label="Generated images", show_label=False).style(
|
191 |
+
grid=[2], height="auto")
|
192 |
+
|
193 |
+
run_button.click(fn=predict, inputs=[
|
194 |
+
input_image, prompt, steps, num_samples, scale, seed, eta, noise_level], outputs=[gallery])
|
195 |
+
|
196 |
+
|
197 |
+
block.launch()
|
repositories/stable-diffusion-stability-ai/scripts/img2img.py
ADDED
@@ -0,0 +1,279 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""make variations of input image"""
|
2 |
+
|
3 |
+
import argparse, os
|
4 |
+
import PIL
|
5 |
+
import torch
|
6 |
+
import numpy as np
|
7 |
+
from omegaconf import OmegaConf
|
8 |
+
from PIL import Image
|
9 |
+
from tqdm import tqdm, trange
|
10 |
+
from itertools import islice
|
11 |
+
from einops import rearrange, repeat
|
12 |
+
from torchvision.utils import make_grid
|
13 |
+
from torch import autocast
|
14 |
+
from contextlib import nullcontext
|
15 |
+
from pytorch_lightning import seed_everything
|
16 |
+
from imwatermark import WatermarkEncoder
|
17 |
+
|
18 |
+
|
19 |
+
from scripts.txt2img import put_watermark
|
20 |
+
from ldm.util import instantiate_from_config
|
21 |
+
from ldm.models.diffusion.ddim import DDIMSampler
|
22 |
+
|
23 |
+
|
24 |
+
def chunk(it, size):
|
25 |
+
it = iter(it)
|
26 |
+
return iter(lambda: tuple(islice(it, size)), ())
|
27 |
+
|
28 |
+
|
29 |
+
def load_model_from_config(config, ckpt, verbose=False):
|
30 |
+
print(f"Loading model from {ckpt}")
|
31 |
+
pl_sd = torch.load(ckpt, map_location="cpu")
|
32 |
+
if "global_step" in pl_sd:
|
33 |
+
print(f"Global Step: {pl_sd['global_step']}")
|
34 |
+
sd = pl_sd["state_dict"]
|
35 |
+
model = instantiate_from_config(config.model)
|
36 |
+
m, u = model.load_state_dict(sd, strict=False)
|
37 |
+
if len(m) > 0 and verbose:
|
38 |
+
print("missing keys:")
|
39 |
+
print(m)
|
40 |
+
if len(u) > 0 and verbose:
|
41 |
+
print("unexpected keys:")
|
42 |
+
print(u)
|
43 |
+
|
44 |
+
model.cuda()
|
45 |
+
model.eval()
|
46 |
+
return model
|
47 |
+
|
48 |
+
|
49 |
+
def load_img(path):
|
50 |
+
image = Image.open(path).convert("RGB")
|
51 |
+
w, h = image.size
|
52 |
+
print(f"loaded input image of size ({w}, {h}) from {path}")
|
53 |
+
w, h = map(lambda x: x - x % 64, (w, h)) # resize to integer multiple of 64
|
54 |
+
image = image.resize((w, h), resample=PIL.Image.LANCZOS)
|
55 |
+
image = np.array(image).astype(np.float32) / 255.0
|
56 |
+
image = image[None].transpose(0, 3, 1, 2)
|
57 |
+
image = torch.from_numpy(image)
|
58 |
+
return 2. * image - 1.
|
59 |
+
|
60 |
+
|
61 |
+
def main():
|
62 |
+
parser = argparse.ArgumentParser()
|
63 |
+
|
64 |
+
parser.add_argument(
|
65 |
+
"--prompt",
|
66 |
+
type=str,
|
67 |
+
nargs="?",
|
68 |
+
default="a painting of a virus monster playing guitar",
|
69 |
+
help="the prompt to render"
|
70 |
+
)
|
71 |
+
|
72 |
+
parser.add_argument(
|
73 |
+
"--init-img",
|
74 |
+
type=str,
|
75 |
+
nargs="?",
|
76 |
+
help="path to the input image"
|
77 |
+
)
|
78 |
+
|
79 |
+
parser.add_argument(
|
80 |
+
"--outdir",
|
81 |
+
type=str,
|
82 |
+
nargs="?",
|
83 |
+
help="dir to write results to",
|
84 |
+
default="outputs/img2img-samples"
|
85 |
+
)
|
86 |
+
|
87 |
+
parser.add_argument(
|
88 |
+
"--ddim_steps",
|
89 |
+
type=int,
|
90 |
+
default=50,
|
91 |
+
help="number of ddim sampling steps",
|
92 |
+
)
|
93 |
+
|
94 |
+
parser.add_argument(
|
95 |
+
"--fixed_code",
|
96 |
+
action='store_true',
|
97 |
+
help="if enabled, uses the same starting code across all samples ",
|
98 |
+
)
|
99 |
+
|
100 |
+
parser.add_argument(
|
101 |
+
"--ddim_eta",
|
102 |
+
type=float,
|
103 |
+
default=0.0,
|
104 |
+
help="ddim eta (eta=0.0 corresponds to deterministic sampling",
|
105 |
+
)
|
106 |
+
parser.add_argument(
|
107 |
+
"--n_iter",
|
108 |
+
type=int,
|
109 |
+
default=1,
|
110 |
+
help="sample this often",
|
111 |
+
)
|
112 |
+
|
113 |
+
parser.add_argument(
|
114 |
+
"--C",
|
115 |
+
type=int,
|
116 |
+
default=4,
|
117 |
+
help="latent channels",
|
118 |
+
)
|
119 |
+
parser.add_argument(
|
120 |
+
"--f",
|
121 |
+
type=int,
|
122 |
+
default=8,
|
123 |
+
help="downsampling factor, most often 8 or 16",
|
124 |
+
)
|
125 |
+
|
126 |
+
parser.add_argument(
|
127 |
+
"--n_samples",
|
128 |
+
type=int,
|
129 |
+
default=2,
|
130 |
+
help="how many samples to produce for each given prompt. A.k.a batch size",
|
131 |
+
)
|
132 |
+
|
133 |
+
parser.add_argument(
|
134 |
+
"--n_rows",
|
135 |
+
type=int,
|
136 |
+
default=0,
|
137 |
+
help="rows in the grid (default: n_samples)",
|
138 |
+
)
|
139 |
+
|
140 |
+
parser.add_argument(
|
141 |
+
"--scale",
|
142 |
+
type=float,
|
143 |
+
default=9.0,
|
144 |
+
help="unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))",
|
145 |
+
)
|
146 |
+
|
147 |
+
parser.add_argument(
|
148 |
+
"--strength",
|
149 |
+
type=float,
|
150 |
+
default=0.8,
|
151 |
+
help="strength for noising/unnoising. 1.0 corresponds to full destruction of information in init image",
|
152 |
+
)
|
153 |
+
|
154 |
+
parser.add_argument(
|
155 |
+
"--from-file",
|
156 |
+
type=str,
|
157 |
+
help="if specified, load prompts from this file",
|
158 |
+
)
|
159 |
+
parser.add_argument(
|
160 |
+
"--config",
|
161 |
+
type=str,
|
162 |
+
default="configs/stable-diffusion/v2-inference.yaml",
|
163 |
+
help="path to config which constructs model",
|
164 |
+
)
|
165 |
+
parser.add_argument(
|
166 |
+
"--ckpt",
|
167 |
+
type=str,
|
168 |
+
help="path to checkpoint of model",
|
169 |
+
)
|
170 |
+
parser.add_argument(
|
171 |
+
"--seed",
|
172 |
+
type=int,
|
173 |
+
default=42,
|
174 |
+
help="the seed (for reproducible sampling)",
|
175 |
+
)
|
176 |
+
parser.add_argument(
|
177 |
+
"--precision",
|
178 |
+
type=str,
|
179 |
+
help="evaluate at this precision",
|
180 |
+
choices=["full", "autocast"],
|
181 |
+
default="autocast"
|
182 |
+
)
|
183 |
+
|
184 |
+
opt = parser.parse_args()
|
185 |
+
seed_everything(opt.seed)
|
186 |
+
|
187 |
+
config = OmegaConf.load(f"{opt.config}")
|
188 |
+
model = load_model_from_config(config, f"{opt.ckpt}")
|
189 |
+
|
190 |
+
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
191 |
+
model = model.to(device)
|
192 |
+
|
193 |
+
sampler = DDIMSampler(model)
|
194 |
+
|
195 |
+
os.makedirs(opt.outdir, exist_ok=True)
|
196 |
+
outpath = opt.outdir
|
197 |
+
|
198 |
+
print("Creating invisible watermark encoder (see https://github.com/ShieldMnt/invisible-watermark)...")
|
199 |
+
wm = "SDV2"
|
200 |
+
wm_encoder = WatermarkEncoder()
|
201 |
+
wm_encoder.set_watermark('bytes', wm.encode('utf-8'))
|
202 |
+
|
203 |
+
batch_size = opt.n_samples
|
204 |
+
n_rows = opt.n_rows if opt.n_rows > 0 else batch_size
|
205 |
+
if not opt.from_file:
|
206 |
+
prompt = opt.prompt
|
207 |
+
assert prompt is not None
|
208 |
+
data = [batch_size * [prompt]]
|
209 |
+
|
210 |
+
else:
|
211 |
+
print(f"reading prompts from {opt.from_file}")
|
212 |
+
with open(opt.from_file, "r") as f:
|
213 |
+
data = f.read().splitlines()
|
214 |
+
data = list(chunk(data, batch_size))
|
215 |
+
|
216 |
+
sample_path = os.path.join(outpath, "samples")
|
217 |
+
os.makedirs(sample_path, exist_ok=True)
|
218 |
+
base_count = len(os.listdir(sample_path))
|
219 |
+
grid_count = len(os.listdir(outpath)) - 1
|
220 |
+
|
221 |
+
assert os.path.isfile(opt.init_img)
|
222 |
+
init_image = load_img(opt.init_img).to(device)
|
223 |
+
init_image = repeat(init_image, '1 ... -> b ...', b=batch_size)
|
224 |
+
init_latent = model.get_first_stage_encoding(model.encode_first_stage(init_image)) # move to latent space
|
225 |
+
|
226 |
+
sampler.make_schedule(ddim_num_steps=opt.ddim_steps, ddim_eta=opt.ddim_eta, verbose=False)
|
227 |
+
|
228 |
+
assert 0. <= opt.strength <= 1., 'can only work with strength in [0.0, 1.0]'
|
229 |
+
t_enc = int(opt.strength * opt.ddim_steps)
|
230 |
+
print(f"target t_enc is {t_enc} steps")
|
231 |
+
|
232 |
+
precision_scope = autocast if opt.precision == "autocast" else nullcontext
|
233 |
+
with torch.no_grad():
|
234 |
+
with precision_scope("cuda"):
|
235 |
+
with model.ema_scope():
|
236 |
+
all_samples = list()
|
237 |
+
for n in trange(opt.n_iter, desc="Sampling"):
|
238 |
+
for prompts in tqdm(data, desc="data"):
|
239 |
+
uc = None
|
240 |
+
if opt.scale != 1.0:
|
241 |
+
uc = model.get_learned_conditioning(batch_size * [""])
|
242 |
+
if isinstance(prompts, tuple):
|
243 |
+
prompts = list(prompts)
|
244 |
+
c = model.get_learned_conditioning(prompts)
|
245 |
+
|
246 |
+
# encode (scaled latent)
|
247 |
+
z_enc = sampler.stochastic_encode(init_latent, torch.tensor([t_enc] * batch_size).to(device))
|
248 |
+
# decode it
|
249 |
+
samples = sampler.decode(z_enc, c, t_enc, unconditional_guidance_scale=opt.scale,
|
250 |
+
unconditional_conditioning=uc, )
|
251 |
+
|
252 |
+
x_samples = model.decode_first_stage(samples)
|
253 |
+
x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)
|
254 |
+
|
255 |
+
for x_sample in x_samples:
|
256 |
+
x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
|
257 |
+
img = Image.fromarray(x_sample.astype(np.uint8))
|
258 |
+
img = put_watermark(img, wm_encoder)
|
259 |
+
img.save(os.path.join(sample_path, f"{base_count:05}.png"))
|
260 |
+
base_count += 1
|
261 |
+
all_samples.append(x_samples)
|
262 |
+
|
263 |
+
# additionally, save as grid
|
264 |
+
grid = torch.stack(all_samples, 0)
|
265 |
+
grid = rearrange(grid, 'n b c h w -> (n b) c h w')
|
266 |
+
grid = make_grid(grid, nrow=n_rows)
|
267 |
+
|
268 |
+
# to image
|
269 |
+
grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy()
|
270 |
+
grid = Image.fromarray(grid.astype(np.uint8))
|
271 |
+
grid = put_watermark(grid, wm_encoder)
|
272 |
+
grid.save(os.path.join(outpath, f'grid-{grid_count:04}.png'))
|
273 |
+
grid_count += 1
|
274 |
+
|
275 |
+
print(f"Your samples are ready and waiting for you here: \n{outpath} \nEnjoy.")
|
276 |
+
|
277 |
+
|
278 |
+
if __name__ == "__main__":
|
279 |
+
main()
|
repositories/stable-diffusion-stability-ai/scripts/streamlit/depth2img.py
ADDED
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
import torch
|
3 |
+
import numpy as np
|
4 |
+
import streamlit as st
|
5 |
+
from PIL import Image
|
6 |
+
from omegaconf import OmegaConf
|
7 |
+
from einops import repeat, rearrange
|
8 |
+
from pytorch_lightning import seed_everything
|
9 |
+
from imwatermark import WatermarkEncoder
|
10 |
+
|
11 |
+
from scripts.txt2img import put_watermark
|
12 |
+
from ldm.util import instantiate_from_config
|
13 |
+
from ldm.models.diffusion.ddim import DDIMSampler
|
14 |
+
from ldm.data.util import AddMiDaS
|
15 |
+
|
16 |
+
torch.set_grad_enabled(False)
|
17 |
+
|
18 |
+
|
19 |
+
@st.cache(allow_output_mutation=True)
|
20 |
+
def initialize_model(config, ckpt):
|
21 |
+
config = OmegaConf.load(config)
|
22 |
+
model = instantiate_from_config(config.model)
|
23 |
+
model.load_state_dict(torch.load(ckpt)["state_dict"], strict=False)
|
24 |
+
|
25 |
+
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
26 |
+
model = model.to(device)
|
27 |
+
sampler = DDIMSampler(model)
|
28 |
+
return sampler
|
29 |
+
|
30 |
+
|
31 |
+
def make_batch_sd(
|
32 |
+
image,
|
33 |
+
txt,
|
34 |
+
device,
|
35 |
+
num_samples=1,
|
36 |
+
model_type="dpt_hybrid"
|
37 |
+
):
|
38 |
+
image = np.array(image.convert("RGB"))
|
39 |
+
image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
|
40 |
+
# sample['jpg'] is tensor hwc in [-1, 1] at this point
|
41 |
+
midas_trafo = AddMiDaS(model_type=model_type)
|
42 |
+
batch = {
|
43 |
+
"jpg": image,
|
44 |
+
"txt": num_samples * [txt],
|
45 |
+
}
|
46 |
+
batch = midas_trafo(batch)
|
47 |
+
batch["jpg"] = rearrange(batch["jpg"], 'h w c -> 1 c h w')
|
48 |
+
batch["jpg"] = repeat(batch["jpg"].to(device=device), "1 ... -> n ...", n=num_samples)
|
49 |
+
batch["midas_in"] = repeat(torch.from_numpy(batch["midas_in"][None, ...]).to(device=device), "1 ... -> n ...", n=num_samples)
|
50 |
+
return batch
|
51 |
+
|
52 |
+
|
53 |
+
def paint(sampler, image, prompt, t_enc, seed, scale, num_samples=1, callback=None,
|
54 |
+
do_full_sample=False):
|
55 |
+
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
56 |
+
model = sampler.model
|
57 |
+
seed_everything(seed)
|
58 |
+
|
59 |
+
print("Creating invisible watermark encoder (see https://github.com/ShieldMnt/invisible-watermark)...")
|
60 |
+
wm = "SDV2"
|
61 |
+
wm_encoder = WatermarkEncoder()
|
62 |
+
wm_encoder.set_watermark('bytes', wm.encode('utf-8'))
|
63 |
+
|
64 |
+
with torch.no_grad(),\
|
65 |
+
torch.autocast("cuda"):
|
66 |
+
batch = make_batch_sd(image, txt=prompt, device=device, num_samples=num_samples)
|
67 |
+
z = model.get_first_stage_encoding(model.encode_first_stage(batch[model.first_stage_key])) # move to latent space
|
68 |
+
c = model.cond_stage_model.encode(batch["txt"])
|
69 |
+
c_cat = list()
|
70 |
+
for ck in model.concat_keys:
|
71 |
+
cc = batch[ck]
|
72 |
+
cc = model.depth_model(cc)
|
73 |
+
depth_min, depth_max = torch.amin(cc, dim=[1, 2, 3], keepdim=True), torch.amax(cc, dim=[1, 2, 3],
|
74 |
+
keepdim=True)
|
75 |
+
display_depth = (cc - depth_min) / (depth_max - depth_min)
|
76 |
+
st.image(Image.fromarray((display_depth[0, 0, ...].cpu().numpy() * 255.).astype(np.uint8)))
|
77 |
+
cc = torch.nn.functional.interpolate(
|
78 |
+
cc,
|
79 |
+
size=z.shape[2:],
|
80 |
+
mode="bicubic",
|
81 |
+
align_corners=False,
|
82 |
+
)
|
83 |
+
depth_min, depth_max = torch.amin(cc, dim=[1, 2, 3], keepdim=True), torch.amax(cc, dim=[1, 2, 3],
|
84 |
+
keepdim=True)
|
85 |
+
cc = 2. * (cc - depth_min) / (depth_max - depth_min) - 1.
|
86 |
+
c_cat.append(cc)
|
87 |
+
c_cat = torch.cat(c_cat, dim=1)
|
88 |
+
# cond
|
89 |
+
cond = {"c_concat": [c_cat], "c_crossattn": [c]}
|
90 |
+
|
91 |
+
# uncond cond
|
92 |
+
uc_cross = model.get_unconditional_conditioning(num_samples, "")
|
93 |
+
uc_full = {"c_concat": [c_cat], "c_crossattn": [uc_cross]}
|
94 |
+
if not do_full_sample:
|
95 |
+
# encode (scaled latent)
|
96 |
+
z_enc = sampler.stochastic_encode(z, torch.tensor([t_enc] * num_samples).to(model.device))
|
97 |
+
else:
|
98 |
+
z_enc = torch.randn_like(z)
|
99 |
+
# decode it
|
100 |
+
samples = sampler.decode(z_enc, cond, t_enc, unconditional_guidance_scale=scale,
|
101 |
+
unconditional_conditioning=uc_full, callback=callback)
|
102 |
+
x_samples_ddim = model.decode_first_stage(samples)
|
103 |
+
result = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
|
104 |
+
result = result.cpu().numpy().transpose(0, 2, 3, 1) * 255
|
105 |
+
return [put_watermark(Image.fromarray(img.astype(np.uint8)), wm_encoder) for img in result]
|
106 |
+
|
107 |
+
|
108 |
+
def run():
|
109 |
+
st.title("Stable Diffusion Depth2Img")
|
110 |
+
# run via streamlit run scripts/demo/depth2img.py <path-tp-config> <path-to-ckpt>
|
111 |
+
sampler = initialize_model(sys.argv[1], sys.argv[2])
|
112 |
+
|
113 |
+
image = st.file_uploader("Image", ["jpg", "png"])
|
114 |
+
if image:
|
115 |
+
image = Image.open(image)
|
116 |
+
w, h = image.size
|
117 |
+
st.text(f"loaded input image of size ({w}, {h})")
|
118 |
+
width, height = map(lambda x: x - x % 64, (w, h)) # resize to integer multiple of 64
|
119 |
+
image = image.resize((width, height))
|
120 |
+
st.text(f"resized input image to size ({width}, {height} (w, h))")
|
121 |
+
st.image(image)
|
122 |
+
|
123 |
+
prompt = st.text_input("Prompt")
|
124 |
+
|
125 |
+
seed = st.number_input("Seed", min_value=0, max_value=1000000, value=0)
|
126 |
+
num_samples = st.number_input("Number of Samples", min_value=1, max_value=64, value=1)
|
127 |
+
scale = st.slider("Scale", min_value=0.1, max_value=30.0, value=9.0, step=0.1)
|
128 |
+
steps = st.slider("DDIM Steps", min_value=0, max_value=50, value=50, step=1)
|
129 |
+
strength = st.slider("Strength", min_value=0., max_value=1., value=0.9)
|
130 |
+
|
131 |
+
t_progress = st.progress(0)
|
132 |
+
def t_callback(t):
|
133 |
+
t_progress.progress(min((t + 1) / t_enc, 1.))
|
134 |
+
|
135 |
+
assert 0. <= strength <= 1., 'can only work with strength in [0.0, 1.0]'
|
136 |
+
do_full_sample = strength == 1.
|
137 |
+
t_enc = min(int(strength * steps), steps-1)
|
138 |
+
sampler.make_schedule(steps, ddim_eta=0., verbose=True)
|
139 |
+
if st.button("Sample"):
|
140 |
+
result = paint(
|
141 |
+
sampler=sampler,
|
142 |
+
image=image,
|
143 |
+
prompt=prompt,
|
144 |
+
t_enc=t_enc,
|
145 |
+
seed=seed,
|
146 |
+
scale=scale,
|
147 |
+
num_samples=num_samples,
|
148 |
+
callback=t_callback,
|
149 |
+
do_full_sample=do_full_sample,
|
150 |
+
)
|
151 |
+
st.write("Result")
|
152 |
+
for image in result:
|
153 |
+
st.image(image, output_format='PNG')
|
154 |
+
|
155 |
+
|
156 |
+
if __name__ == "__main__":
|
157 |
+
run()
|
repositories/stable-diffusion-stability-ai/scripts/streamlit/inpainting.py
ADDED
@@ -0,0 +1,195 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
import cv2
|
3 |
+
import torch
|
4 |
+
import numpy as np
|
5 |
+
import streamlit as st
|
6 |
+
from PIL import Image
|
7 |
+
from omegaconf import OmegaConf
|
8 |
+
from einops import repeat
|
9 |
+
from streamlit_drawable_canvas import st_canvas
|
10 |
+
from imwatermark import WatermarkEncoder
|
11 |
+
|
12 |
+
from ldm.models.diffusion.ddim import DDIMSampler
|
13 |
+
from ldm.util import instantiate_from_config
|
14 |
+
|
15 |
+
|
16 |
+
torch.set_grad_enabled(False)
|
17 |
+
|
18 |
+
|
19 |
+
def put_watermark(img, wm_encoder=None):
|
20 |
+
if wm_encoder is not None:
|
21 |
+
img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
|
22 |
+
img = wm_encoder.encode(img, 'dwtDct')
|
23 |
+
img = Image.fromarray(img[:, :, ::-1])
|
24 |
+
return img
|
25 |
+
|
26 |
+
|
27 |
+
@st.cache(allow_output_mutation=True)
|
28 |
+
def initialize_model(config, ckpt):
|
29 |
+
config = OmegaConf.load(config)
|
30 |
+
model = instantiate_from_config(config.model)
|
31 |
+
|
32 |
+
model.load_state_dict(torch.load(ckpt)["state_dict"], strict=False)
|
33 |
+
|
34 |
+
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
35 |
+
model = model.to(device)
|
36 |
+
sampler = DDIMSampler(model)
|
37 |
+
|
38 |
+
return sampler
|
39 |
+
|
40 |
+
|
41 |
+
def make_batch_sd(
|
42 |
+
image,
|
43 |
+
mask,
|
44 |
+
txt,
|
45 |
+
device,
|
46 |
+
num_samples=1):
|
47 |
+
image = np.array(image.convert("RGB"))
|
48 |
+
image = image[None].transpose(0, 3, 1, 2)
|
49 |
+
image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
|
50 |
+
|
51 |
+
mask = np.array(mask.convert("L"))
|
52 |
+
mask = mask.astype(np.float32) / 255.0
|
53 |
+
mask = mask[None, None]
|
54 |
+
mask[mask < 0.5] = 0
|
55 |
+
mask[mask >= 0.5] = 1
|
56 |
+
mask = torch.from_numpy(mask)
|
57 |
+
|
58 |
+
masked_image = image * (mask < 0.5)
|
59 |
+
|
60 |
+
batch = {
|
61 |
+
"image": repeat(image.to(device=device), "1 ... -> n ...", n=num_samples),
|
62 |
+
"txt": num_samples * [txt],
|
63 |
+
"mask": repeat(mask.to(device=device), "1 ... -> n ...", n=num_samples),
|
64 |
+
"masked_image": repeat(masked_image.to(device=device), "1 ... -> n ...", n=num_samples),
|
65 |
+
}
|
66 |
+
return batch
|
67 |
+
|
68 |
+
|
69 |
+
def inpaint(sampler, image, mask, prompt, seed, scale, ddim_steps, num_samples=1, w=512, h=512, eta=1.):
|
70 |
+
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
71 |
+
model = sampler.model
|
72 |
+
|
73 |
+
print("Creating invisible watermark encoder (see https://github.com/ShieldMnt/invisible-watermark)...")
|
74 |
+
wm = "SDV2"
|
75 |
+
wm_encoder = WatermarkEncoder()
|
76 |
+
wm_encoder.set_watermark('bytes', wm.encode('utf-8'))
|
77 |
+
|
78 |
+
prng = np.random.RandomState(seed)
|
79 |
+
start_code = prng.randn(num_samples, 4, h // 8, w // 8)
|
80 |
+
start_code = torch.from_numpy(start_code).to(device=device, dtype=torch.float32)
|
81 |
+
|
82 |
+
with torch.no_grad(), \
|
83 |
+
torch.autocast("cuda"):
|
84 |
+
batch = make_batch_sd(image, mask, txt=prompt, device=device, num_samples=num_samples)
|
85 |
+
|
86 |
+
c = model.cond_stage_model.encode(batch["txt"])
|
87 |
+
|
88 |
+
c_cat = list()
|
89 |
+
for ck in model.concat_keys:
|
90 |
+
cc = batch[ck].float()
|
91 |
+
if ck != model.masked_image_key:
|
92 |
+
bchw = [num_samples, 4, h // 8, w // 8]
|
93 |
+
cc = torch.nn.functional.interpolate(cc, size=bchw[-2:])
|
94 |
+
else:
|
95 |
+
cc = model.get_first_stage_encoding(model.encode_first_stage(cc))
|
96 |
+
c_cat.append(cc)
|
97 |
+
c_cat = torch.cat(c_cat, dim=1)
|
98 |
+
|
99 |
+
# cond
|
100 |
+
cond = {"c_concat": [c_cat], "c_crossattn": [c]}
|
101 |
+
|
102 |
+
# uncond cond
|
103 |
+
uc_cross = model.get_unconditional_conditioning(num_samples, "")
|
104 |
+
uc_full = {"c_concat": [c_cat], "c_crossattn": [uc_cross]}
|
105 |
+
|
106 |
+
shape = [model.channels, h // 8, w // 8]
|
107 |
+
samples_cfg, intermediates = sampler.sample(
|
108 |
+
ddim_steps,
|
109 |
+
num_samples,
|
110 |
+
shape,
|
111 |
+
cond,
|
112 |
+
verbose=False,
|
113 |
+
eta=eta,
|
114 |
+
unconditional_guidance_scale=scale,
|
115 |
+
unconditional_conditioning=uc_full,
|
116 |
+
x_T=start_code,
|
117 |
+
)
|
118 |
+
x_samples_ddim = model.decode_first_stage(samples_cfg)
|
119 |
+
|
120 |
+
result = torch.clamp((x_samples_ddim + 1.0) / 2.0,
|
121 |
+
min=0.0, max=1.0)
|
122 |
+
|
123 |
+
result = result.cpu().numpy().transpose(0, 2, 3, 1) * 255
|
124 |
+
return [put_watermark(Image.fromarray(img.astype(np.uint8)), wm_encoder) for img in result]
|
125 |
+
|
126 |
+
|
127 |
+
def run():
|
128 |
+
st.title("Stable Diffusion Inpainting")
|
129 |
+
|
130 |
+
sampler = initialize_model(sys.argv[1], sys.argv[2])
|
131 |
+
|
132 |
+
image = st.file_uploader("Image", ["jpg", "png"])
|
133 |
+
if image:
|
134 |
+
image = Image.open(image)
|
135 |
+
w, h = image.size
|
136 |
+
print(f"loaded input image of size ({w}, {h})")
|
137 |
+
width, height = map(lambda x: x - x % 64, (w, h)) # resize to integer multiple of 32
|
138 |
+
image = image.resize((width, height))
|
139 |
+
|
140 |
+
prompt = st.text_input("Prompt")
|
141 |
+
|
142 |
+
seed = st.number_input("Seed", min_value=0, max_value=1000000, value=0)
|
143 |
+
num_samples = st.number_input("Number of Samples", min_value=1, max_value=64, value=1)
|
144 |
+
scale = st.slider("Scale", min_value=0.1, max_value=30.0, value=10., step=0.1)
|
145 |
+
ddim_steps = st.slider("DDIM Steps", min_value=0, max_value=50, value=50, step=1)
|
146 |
+
eta = st.sidebar.number_input("eta (DDIM)", value=0., min_value=0., max_value=1.)
|
147 |
+
|
148 |
+
fill_color = "rgba(255, 255, 255, 0.0)"
|
149 |
+
stroke_width = st.number_input("Brush Size",
|
150 |
+
value=64,
|
151 |
+
min_value=1,
|
152 |
+
max_value=100)
|
153 |
+
stroke_color = "rgba(255, 255, 255, 1.0)"
|
154 |
+
bg_color = "rgba(0, 0, 0, 1.0)"
|
155 |
+
drawing_mode = "freedraw"
|
156 |
+
|
157 |
+
st.write("Canvas")
|
158 |
+
st.caption(
|
159 |
+
"Draw a mask to inpaint, then click the 'Send to Streamlit' button (bottom left, with an arrow on it).")
|
160 |
+
canvas_result = st_canvas(
|
161 |
+
fill_color=fill_color,
|
162 |
+
stroke_width=stroke_width,
|
163 |
+
stroke_color=stroke_color,
|
164 |
+
background_color=bg_color,
|
165 |
+
background_image=image,
|
166 |
+
update_streamlit=False,
|
167 |
+
height=height,
|
168 |
+
width=width,
|
169 |
+
drawing_mode=drawing_mode,
|
170 |
+
key="canvas",
|
171 |
+
)
|
172 |
+
if canvas_result:
|
173 |
+
mask = canvas_result.image_data
|
174 |
+
mask = mask[:, :, -1] > 0
|
175 |
+
if mask.sum() > 0:
|
176 |
+
mask = Image.fromarray(mask)
|
177 |
+
|
178 |
+
result = inpaint(
|
179 |
+
sampler=sampler,
|
180 |
+
image=image,
|
181 |
+
mask=mask,
|
182 |
+
prompt=prompt,
|
183 |
+
seed=seed,
|
184 |
+
scale=scale,
|
185 |
+
ddim_steps=ddim_steps,
|
186 |
+
num_samples=num_samples,
|
187 |
+
h=height, w=width, eta=eta
|
188 |
+
)
|
189 |
+
st.write("Inpainted")
|
190 |
+
for image in result:
|
191 |
+
st.image(image, output_format='PNG')
|
192 |
+
|
193 |
+
|
194 |
+
if __name__ == "__main__":
|
195 |
+
run()
|
repositories/stable-diffusion-stability-ai/scripts/streamlit/stableunclip.py
ADDED
@@ -0,0 +1,416 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import importlib
|
2 |
+
import streamlit as st
|
3 |
+
import torch
|
4 |
+
import cv2
|
5 |
+
import numpy as np
|
6 |
+
import PIL
|
7 |
+
from omegaconf import OmegaConf
|
8 |
+
from PIL import Image
|
9 |
+
from tqdm import trange
|
10 |
+
import io, os
|
11 |
+
from torch import autocast
|
12 |
+
from einops import rearrange, repeat
|
13 |
+
from torchvision.utils import make_grid
|
14 |
+
from pytorch_lightning import seed_everything
|
15 |
+
from contextlib import nullcontext
|
16 |
+
|
17 |
+
from ldm.models.diffusion.ddim import DDIMSampler
|
18 |
+
from ldm.models.diffusion.plms import PLMSSampler
|
19 |
+
from ldm.models.diffusion.dpm_solver import DPMSolverSampler
|
20 |
+
|
21 |
+
torch.set_grad_enabled(False)
|
22 |
+
|
23 |
+
PROMPTS_ROOT = "scripts/prompts/"
|
24 |
+
SAVE_PATH = "outputs/demo/stable-unclip/"
|
25 |
+
|
26 |
+
VERSION2SPECS = {
|
27 |
+
"Stable unCLIP-L": {"H": 768, "W": 768, "C": 4, "f": 8},
|
28 |
+
"Stable unOpenCLIP-H": {"H": 768, "W": 768, "C": 4, "f": 8},
|
29 |
+
"Full Karlo": {}
|
30 |
+
}
|
31 |
+
|
32 |
+
|
33 |
+
def get_obj_from_str(string, reload=False):
|
34 |
+
module, cls = string.rsplit(".", 1)
|
35 |
+
importlib.invalidate_caches()
|
36 |
+
if reload:
|
37 |
+
module_imp = importlib.import_module(module)
|
38 |
+
importlib.reload(module_imp)
|
39 |
+
return getattr(importlib.import_module(module, package=None), cls)
|
40 |
+
|
41 |
+
|
42 |
+
def instantiate_from_config(config):
|
43 |
+
if not "target" in config:
|
44 |
+
raise KeyError("Expected key `target` to instantiate.")
|
45 |
+
return get_obj_from_str(config["target"])(**config.get("params", dict()))
|
46 |
+
|
47 |
+
|
48 |
+
def get_interactive_image(key=None):
|
49 |
+
image = st.file_uploader("Input", type=["jpg", "JPEG", "png"], key=key)
|
50 |
+
if image is not None:
|
51 |
+
image = Image.open(image)
|
52 |
+
if not image.mode == "RGB":
|
53 |
+
image = image.convert("RGB")
|
54 |
+
return image
|
55 |
+
|
56 |
+
|
57 |
+
def load_img(display=True, key=None):
|
58 |
+
image = get_interactive_image(key=key)
|
59 |
+
if display:
|
60 |
+
st.image(image)
|
61 |
+
w, h = image.size
|
62 |
+
print(f"loaded input image of size ({w}, {h})")
|
63 |
+
w, h = map(lambda x: x - x % 64, (w, h))
|
64 |
+
image = image.resize((w, h), resample=PIL.Image.LANCZOS)
|
65 |
+
image = np.array(image).astype(np.float32) / 255.0
|
66 |
+
image = image[None].transpose(0, 3, 1, 2)
|
67 |
+
image = torch.from_numpy(image)
|
68 |
+
return 2. * image - 1.
|
69 |
+
|
70 |
+
|
71 |
+
def get_init_img(batch_size=1, key=None):
|
72 |
+
init_image = load_img(key=key).cuda()
|
73 |
+
init_image = repeat(init_image, '1 ... -> b ...', b=batch_size)
|
74 |
+
return init_image
|
75 |
+
|
76 |
+
|
77 |
+
def sample(
|
78 |
+
model,
|
79 |
+
prompt,
|
80 |
+
n_runs=3,
|
81 |
+
n_samples=2,
|
82 |
+
H=512,
|
83 |
+
W=512,
|
84 |
+
C=4,
|
85 |
+
f=8,
|
86 |
+
scale=10.0,
|
87 |
+
ddim_steps=50,
|
88 |
+
ddim_eta=0.0,
|
89 |
+
callback=None,
|
90 |
+
skip_single_save=False,
|
91 |
+
save_grid=True,
|
92 |
+
ucg_schedule=None,
|
93 |
+
negative_prompt="",
|
94 |
+
adm_cond=None,
|
95 |
+
adm_uc=None,
|
96 |
+
use_full_precision=False,
|
97 |
+
only_adm_cond=False
|
98 |
+
):
|
99 |
+
batch_size = n_samples
|
100 |
+
precision_scope = autocast if not use_full_precision else nullcontext
|
101 |
+
# decoderscope = autocast if not use_full_precision else nullcontext
|
102 |
+
if use_full_precision: st.warning(f"Running {model.__class__.__name__} at full precision.")
|
103 |
+
if isinstance(prompt, str):
|
104 |
+
prompt = [prompt]
|
105 |
+
prompts = batch_size * prompt
|
106 |
+
|
107 |
+
outputs = st.empty()
|
108 |
+
|
109 |
+
with precision_scope("cuda"):
|
110 |
+
with model.ema_scope():
|
111 |
+
all_samples = list()
|
112 |
+
for n in trange(n_runs, desc="Sampling"):
|
113 |
+
shape = [C, H // f, W // f]
|
114 |
+
if not only_adm_cond:
|
115 |
+
uc = None
|
116 |
+
if scale != 1.0:
|
117 |
+
uc = model.get_learned_conditioning(batch_size * [negative_prompt])
|
118 |
+
if isinstance(prompts, tuple):
|
119 |
+
prompts = list(prompts)
|
120 |
+
c = model.get_learned_conditioning(prompts)
|
121 |
+
|
122 |
+
if adm_cond is not None:
|
123 |
+
if adm_cond.shape[0] == 1:
|
124 |
+
adm_cond = repeat(adm_cond, '1 ... -> b ...', b=batch_size)
|
125 |
+
if adm_uc is None:
|
126 |
+
st.warning("Not guiding via c_adm")
|
127 |
+
adm_uc = adm_cond
|
128 |
+
else:
|
129 |
+
if adm_uc.shape[0] == 1:
|
130 |
+
adm_uc = repeat(adm_uc, '1 ... -> b ...', b=batch_size)
|
131 |
+
if not only_adm_cond:
|
132 |
+
c = {"c_crossattn": [c], "c_adm": adm_cond}
|
133 |
+
uc = {"c_crossattn": [uc], "c_adm": adm_uc}
|
134 |
+
else:
|
135 |
+
c = adm_cond
|
136 |
+
uc = adm_uc
|
137 |
+
samples_ddim, _ = sampler.sample(S=ddim_steps,
|
138 |
+
conditioning=c,
|
139 |
+
batch_size=batch_size,
|
140 |
+
shape=shape,
|
141 |
+
verbose=False,
|
142 |
+
unconditional_guidance_scale=scale,
|
143 |
+
unconditional_conditioning=uc,
|
144 |
+
eta=ddim_eta,
|
145 |
+
x_T=None,
|
146 |
+
callback=callback,
|
147 |
+
ucg_schedule=ucg_schedule
|
148 |
+
)
|
149 |
+
x_samples = model.decode_first_stage(samples_ddim)
|
150 |
+
x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)
|
151 |
+
|
152 |
+
if not skip_single_save:
|
153 |
+
base_count = len(os.listdir(os.path.join(SAVE_PATH, "samples")))
|
154 |
+
for x_sample in x_samples:
|
155 |
+
x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
|
156 |
+
Image.fromarray(x_sample.astype(np.uint8)).save(
|
157 |
+
os.path.join(SAVE_PATH, "samples", f"{base_count:09}.png"))
|
158 |
+
base_count += 1
|
159 |
+
|
160 |
+
all_samples.append(x_samples)
|
161 |
+
|
162 |
+
# get grid of all samples
|
163 |
+
grid = torch.stack(all_samples, 0)
|
164 |
+
grid = rearrange(grid, 'n b c h w -> (n h) (b w) c')
|
165 |
+
outputs.image(grid.cpu().numpy())
|
166 |
+
|
167 |
+
# additionally, save grid
|
168 |
+
grid = Image.fromarray((255. * grid.cpu().numpy()).astype(np.uint8))
|
169 |
+
if save_grid:
|
170 |
+
grid_count = len(os.listdir(SAVE_PATH)) - 1
|
171 |
+
grid.save(os.path.join(SAVE_PATH, f'grid-{grid_count:06}.png'))
|
172 |
+
|
173 |
+
return x_samples
|
174 |
+
|
175 |
+
|
176 |
+
def make_oscillating_guidance_schedule(num_steps, max_weight=15., min_weight=1.):
|
177 |
+
schedule = list()
|
178 |
+
for i in range(num_steps):
|
179 |
+
if float(i / num_steps) < 0.1:
|
180 |
+
schedule.append(max_weight)
|
181 |
+
elif i % 2 == 0:
|
182 |
+
schedule.append(min_weight)
|
183 |
+
else:
|
184 |
+
schedule.append(max_weight)
|
185 |
+
print(f"OSCILLATING GUIDANCE SCHEDULE: \n {schedule}")
|
186 |
+
return schedule
|
187 |
+
|
188 |
+
|
189 |
+
def torch2np(x):
|
190 |
+
x = ((x + 1.0) * 127.5).clamp(0, 255).to(dtype=torch.uint8)
|
191 |
+
x = x.permute(0, 2, 3, 1).detach().cpu().numpy()
|
192 |
+
return x
|
193 |
+
|
194 |
+
|
195 |
+
@st.cache(allow_output_mutation=True, suppress_st_warning=True)
|
196 |
+
def init(version="Stable unCLIP-L", load_karlo_prior=False):
|
197 |
+
state = dict()
|
198 |
+
if not "model" in state:
|
199 |
+
if version == "Stable unCLIP-L":
|
200 |
+
config = "configs/stable-diffusion/v2-1-stable-unclip-l-inference.yaml"
|
201 |
+
ckpt = "checkpoints/sd21-unclip-l.ckpt"
|
202 |
+
|
203 |
+
elif version == "Stable unOpenCLIP-H":
|
204 |
+
config = "configs/stable-diffusion/v2-1-stable-unclip-h-inference.yaml"
|
205 |
+
ckpt = "checkpoints/sd21-unclip-h.ckpt"
|
206 |
+
|
207 |
+
elif version == "Full Karlo":
|
208 |
+
from ldm.modules.karlo.kakao.sampler import T2ISampler
|
209 |
+
st.info("Loading full KARLO..")
|
210 |
+
karlo = T2ISampler.from_pretrained(
|
211 |
+
root_dir="checkpoints/karlo_models",
|
212 |
+
clip_model_path="ViT-L-14.pt",
|
213 |
+
clip_stat_path="ViT-L-14_stats.th",
|
214 |
+
sampling_type="default",
|
215 |
+
)
|
216 |
+
state["karlo_prior"] = karlo
|
217 |
+
state["msg"] = "loaded full Karlo"
|
218 |
+
return state
|
219 |
+
else:
|
220 |
+
raise ValueError(f"version {version} unknown!")
|
221 |
+
|
222 |
+
config = OmegaConf.load(config)
|
223 |
+
model, msg = load_model_from_config(config, ckpt, vae_sd=None)
|
224 |
+
state["msg"] = msg
|
225 |
+
|
226 |
+
if load_karlo_prior:
|
227 |
+
from ldm.modules.karlo.kakao.sampler import PriorSampler
|
228 |
+
st.info("Loading KARLO CLIP prior...")
|
229 |
+
karlo_prior = PriorSampler.from_pretrained(
|
230 |
+
root_dir="checkpoints/karlo_models",
|
231 |
+
clip_model_path="ViT-L-14.pt",
|
232 |
+
clip_stat_path="ViT-L-14_stats.th",
|
233 |
+
sampling_type="default",
|
234 |
+
)
|
235 |
+
state["karlo_prior"] = karlo_prior
|
236 |
+
state["model"] = model
|
237 |
+
state["ckpt"] = ckpt
|
238 |
+
state["config"] = config
|
239 |
+
return state
|
240 |
+
|
241 |
+
|
242 |
+
def load_model_from_config(config, ckpt, verbose=False, vae_sd=None):
|
243 |
+
print(f"Loading model from {ckpt}")
|
244 |
+
pl_sd = torch.load(ckpt, map_location="cpu")
|
245 |
+
msg = None
|
246 |
+
if "global_step" in pl_sd:
|
247 |
+
msg = f"This is global step {pl_sd['global_step']}. "
|
248 |
+
if "model_ema.num_updates" in pl_sd["state_dict"]:
|
249 |
+
msg += f"And we got {pl_sd['state_dict']['model_ema.num_updates']} EMA updates."
|
250 |
+
global_step = pl_sd.get("global_step", "?")
|
251 |
+
sd = pl_sd["state_dict"]
|
252 |
+
if vae_sd is not None:
|
253 |
+
for k in sd.keys():
|
254 |
+
if "first_stage" in k:
|
255 |
+
sd[k] = vae_sd[k[len("first_stage_model."):]]
|
256 |
+
|
257 |
+
model = instantiate_from_config(config.model)
|
258 |
+
m, u = model.load_state_dict(sd, strict=False)
|
259 |
+
if len(m) > 0 and verbose:
|
260 |
+
print("missing keys:")
|
261 |
+
print(m)
|
262 |
+
if len(u) > 0 and verbose:
|
263 |
+
print("unexpected keys:")
|
264 |
+
print(u)
|
265 |
+
|
266 |
+
model.cuda()
|
267 |
+
model.eval()
|
268 |
+
print(f"Loaded global step {global_step}")
|
269 |
+
return model, msg
|
270 |
+
|
271 |
+
|
272 |
+
if __name__ == "__main__":
|
273 |
+
st.title("Stable unCLIP")
|
274 |
+
mode = "txt2img"
|
275 |
+
version = st.selectbox("Model Version", list(VERSION2SPECS.keys()), 0)
|
276 |
+
use_karlo_prior = version in ["Stable unCLIP-L"] and st.checkbox("Use KARLO prior", False)
|
277 |
+
state = init(version=version, load_karlo_prior=use_karlo_prior)
|
278 |
+
prompt = st.text_input("Prompt", "a professional photograph")
|
279 |
+
negative_prompt = st.text_input("Negative Prompt", "")
|
280 |
+
scale = st.number_input("cfg-scale", value=10., min_value=-100., max_value=100.)
|
281 |
+
number_rows = st.number_input("num rows", value=2, min_value=1, max_value=10)
|
282 |
+
number_cols = st.number_input("num cols", value=2, min_value=1, max_value=10)
|
283 |
+
steps = st.sidebar.number_input("steps", value=20, min_value=1, max_value=1000)
|
284 |
+
eta = st.sidebar.number_input("eta (DDIM)", value=0., min_value=0., max_value=1.)
|
285 |
+
force_full_precision = st.sidebar.checkbox("Force FP32", False) # TODO: check if/where things break.
|
286 |
+
if version != "Full Karlo":
|
287 |
+
H = st.sidebar.number_input("H", value=VERSION2SPECS[version]["H"], min_value=64, max_value=2048)
|
288 |
+
W = st.sidebar.number_input("W", value=VERSION2SPECS[version]["W"], min_value=64, max_value=2048)
|
289 |
+
C = VERSION2SPECS[version]["C"]
|
290 |
+
f = VERSION2SPECS[version]["f"]
|
291 |
+
|
292 |
+
SAVE_PATH = os.path.join(SAVE_PATH, version)
|
293 |
+
os.makedirs(os.path.join(SAVE_PATH, "samples"), exist_ok=True)
|
294 |
+
|
295 |
+
seed = st.sidebar.number_input("seed", value=42, min_value=0, max_value=int(1e9))
|
296 |
+
seed_everything(seed)
|
297 |
+
|
298 |
+
ucg_schedule = None
|
299 |
+
sampler = st.sidebar.selectbox("Sampler", ["DDIM", "DPM"], 0)
|
300 |
+
if version == "Full Karlo":
|
301 |
+
pass
|
302 |
+
else:
|
303 |
+
if sampler == "DPM":
|
304 |
+
sampler = DPMSolverSampler(state["model"])
|
305 |
+
elif sampler == "DDIM":
|
306 |
+
sampler = DDIMSampler(state["model"])
|
307 |
+
else:
|
308 |
+
raise ValueError(f"unknown sampler {sampler}!")
|
309 |
+
|
310 |
+
adm_cond, adm_uc = None, None
|
311 |
+
if use_karlo_prior:
|
312 |
+
# uses the prior
|
313 |
+
karlo_sampler = state["karlo_prior"]
|
314 |
+
noise_level = None
|
315 |
+
if state["model"].noise_augmentor is not None:
|
316 |
+
noise_level = st.number_input("Noise Augmentation for CLIP embeddings", min_value=0,
|
317 |
+
max_value=state["model"].noise_augmentor.max_noise_level - 1, value=0)
|
318 |
+
with torch.no_grad():
|
319 |
+
karlo_prediction = iter(
|
320 |
+
karlo_sampler(
|
321 |
+
prompt=prompt,
|
322 |
+
bsz=number_cols,
|
323 |
+
progressive_mode="final",
|
324 |
+
)
|
325 |
+
).__next__()
|
326 |
+
adm_cond = karlo_prediction
|
327 |
+
if noise_level is not None:
|
328 |
+
c_adm, noise_level_emb = state["model"].noise_augmentor(adm_cond, noise_level=repeat(
|
329 |
+
torch.tensor([noise_level]).to(state["model"].device), '1 -> b', b=number_cols))
|
330 |
+
adm_cond = torch.cat((c_adm, noise_level_emb), 1)
|
331 |
+
adm_uc = torch.zeros_like(adm_cond)
|
332 |
+
elif version == "Full Karlo":
|
333 |
+
pass
|
334 |
+
else:
|
335 |
+
num_inputs = st.number_input("Number of Input Images", 1)
|
336 |
+
|
337 |
+
|
338 |
+
def make_conditionings_from_input(num=1, key=None):
|
339 |
+
init_img = get_init_img(batch_size=number_cols, key=key)
|
340 |
+
with torch.no_grad():
|
341 |
+
adm_cond = state["model"].embedder(init_img)
|
342 |
+
weight = st.slider(f"Weight for Input {num}", min_value=-10., max_value=10., value=1.)
|
343 |
+
if state["model"].noise_augmentor is not None:
|
344 |
+
noise_level = st.number_input(f"Noise Augmentation for CLIP embedding of input #{num}", min_value=0,
|
345 |
+
max_value=state["model"].noise_augmentor.max_noise_level - 1,
|
346 |
+
value=0, )
|
347 |
+
c_adm, noise_level_emb = state["model"].noise_augmentor(adm_cond, noise_level=repeat(
|
348 |
+
torch.tensor([noise_level]).to(state["model"].device), '1 -> b', b=number_cols))
|
349 |
+
adm_cond = torch.cat((c_adm, noise_level_emb), 1) * weight
|
350 |
+
adm_uc = torch.zeros_like(adm_cond)
|
351 |
+
return adm_cond, adm_uc, weight
|
352 |
+
|
353 |
+
|
354 |
+
adm_inputs = list()
|
355 |
+
weights = list()
|
356 |
+
for n in range(num_inputs):
|
357 |
+
adm_cond, adm_uc, w = make_conditionings_from_input(num=n + 1, key=n)
|
358 |
+
weights.append(w)
|
359 |
+
adm_inputs.append(adm_cond)
|
360 |
+
adm_cond = torch.stack(adm_inputs).sum(0) / sum(weights)
|
361 |
+
if num_inputs > 1:
|
362 |
+
if st.checkbox("Apply Noise to Embedding Mix", True):
|
363 |
+
noise_level = st.number_input(f"Noise Augmentation for averaged CLIP embeddings", min_value=0,
|
364 |
+
max_value=state["model"].noise_augmentor.max_noise_level - 1, value=50, )
|
365 |
+
c_adm, noise_level_emb = state["model"].noise_augmentor(
|
366 |
+
adm_cond[:, :state["model"].noise_augmentor.time_embed.dim],
|
367 |
+
noise_level=repeat(
|
368 |
+
torch.tensor([noise_level]).to(state["model"].device), '1 -> b', b=number_cols))
|
369 |
+
adm_cond = torch.cat((c_adm, noise_level_emb), 1)
|
370 |
+
|
371 |
+
if st.button("Sample"):
|
372 |
+
print("running prompt:", prompt)
|
373 |
+
st.text("Sampling")
|
374 |
+
t_progress = st.progress(0)
|
375 |
+
result = st.empty()
|
376 |
+
|
377 |
+
|
378 |
+
def t_callback(t):
|
379 |
+
t_progress.progress(min((t + 1) / steps, 1.))
|
380 |
+
|
381 |
+
|
382 |
+
if version == "Full Karlo":
|
383 |
+
outputs = st.empty()
|
384 |
+
karlo_sampler = state["karlo_prior"]
|
385 |
+
all_samples = list()
|
386 |
+
with torch.no_grad():
|
387 |
+
for _ in range(number_rows):
|
388 |
+
karlo_prediction = iter(
|
389 |
+
karlo_sampler(
|
390 |
+
prompt=prompt,
|
391 |
+
bsz=number_cols,
|
392 |
+
progressive_mode="final",
|
393 |
+
)
|
394 |
+
).__next__()
|
395 |
+
all_samples.append(karlo_prediction)
|
396 |
+
grid = torch.stack(all_samples, 0)
|
397 |
+
grid = rearrange(grid, 'n b c h w -> (n h) (b w) c')
|
398 |
+
outputs.image(grid.cpu().numpy())
|
399 |
+
|
400 |
+
else:
|
401 |
+
samples = sample(
|
402 |
+
state["model"],
|
403 |
+
prompt,
|
404 |
+
n_runs=number_rows,
|
405 |
+
n_samples=number_cols,
|
406 |
+
H=H, W=W, C=C, f=f,
|
407 |
+
scale=scale,
|
408 |
+
ddim_steps=steps,
|
409 |
+
ddim_eta=eta,
|
410 |
+
callback=t_callback,
|
411 |
+
ucg_schedule=ucg_schedule,
|
412 |
+
negative_prompt=negative_prompt,
|
413 |
+
adm_cond=adm_cond, adm_uc=adm_uc,
|
414 |
+
use_full_precision=force_full_precision,
|
415 |
+
only_adm_cond=False
|
416 |
+
)
|
repositories/stable-diffusion-stability-ai/scripts/streamlit/superresolution.py
ADDED
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
import torch
|
3 |
+
import numpy as np
|
4 |
+
import streamlit as st
|
5 |
+
from PIL import Image
|
6 |
+
from omegaconf import OmegaConf
|
7 |
+
from einops import repeat, rearrange
|
8 |
+
from pytorch_lightning import seed_everything
|
9 |
+
from imwatermark import WatermarkEncoder
|
10 |
+
|
11 |
+
from scripts.txt2img import put_watermark
|
12 |
+
from ldm.models.diffusion.ddim import DDIMSampler
|
13 |
+
from ldm.models.diffusion.ddpm import LatentUpscaleDiffusion, LatentUpscaleFinetuneDiffusion
|
14 |
+
from ldm.util import exists, instantiate_from_config
|
15 |
+
|
16 |
+
|
17 |
+
torch.set_grad_enabled(False)
|
18 |
+
|
19 |
+
|
20 |
+
@st.cache(allow_output_mutation=True)
|
21 |
+
def initialize_model(config, ckpt):
|
22 |
+
config = OmegaConf.load(config)
|
23 |
+
model = instantiate_from_config(config.model)
|
24 |
+
model.load_state_dict(torch.load(ckpt)["state_dict"], strict=False)
|
25 |
+
|
26 |
+
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
27 |
+
model = model.to(device)
|
28 |
+
sampler = DDIMSampler(model)
|
29 |
+
return sampler
|
30 |
+
|
31 |
+
|
32 |
+
def make_batch_sd(
|
33 |
+
image,
|
34 |
+
txt,
|
35 |
+
device,
|
36 |
+
num_samples=1,
|
37 |
+
):
|
38 |
+
image = np.array(image.convert("RGB"))
|
39 |
+
image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
|
40 |
+
batch = {
|
41 |
+
"lr": rearrange(image, 'h w c -> 1 c h w'),
|
42 |
+
"txt": num_samples * [txt],
|
43 |
+
}
|
44 |
+
batch["lr"] = repeat(batch["lr"].to(device=device), "1 ... -> n ...", n=num_samples)
|
45 |
+
return batch
|
46 |
+
|
47 |
+
|
48 |
+
def make_noise_augmentation(model, batch, noise_level=None):
|
49 |
+
x_low = batch[model.low_scale_key]
|
50 |
+
x_low = x_low.to(memory_format=torch.contiguous_format).float()
|
51 |
+
x_aug, noise_level = model.low_scale_model(x_low, noise_level)
|
52 |
+
return x_aug, noise_level
|
53 |
+
|
54 |
+
|
55 |
+
def paint(sampler, image, prompt, seed, scale, h, w, steps, num_samples=1, callback=None, eta=0., noise_level=None):
|
56 |
+
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
57 |
+
model = sampler.model
|
58 |
+
seed_everything(seed)
|
59 |
+
prng = np.random.RandomState(seed)
|
60 |
+
start_code = prng.randn(num_samples, model.channels, h , w)
|
61 |
+
start_code = torch.from_numpy(start_code).to(device=device, dtype=torch.float32)
|
62 |
+
|
63 |
+
print("Creating invisible watermark encoder (see https://github.com/ShieldMnt/invisible-watermark)...")
|
64 |
+
wm = "SDV2"
|
65 |
+
wm_encoder = WatermarkEncoder()
|
66 |
+
wm_encoder.set_watermark('bytes', wm.encode('utf-8'))
|
67 |
+
with torch.no_grad(),\
|
68 |
+
torch.autocast("cuda"):
|
69 |
+
batch = make_batch_sd(image, txt=prompt, device=device, num_samples=num_samples)
|
70 |
+
c = model.cond_stage_model.encode(batch["txt"])
|
71 |
+
c_cat = list()
|
72 |
+
if isinstance(model, LatentUpscaleFinetuneDiffusion):
|
73 |
+
for ck in model.concat_keys:
|
74 |
+
cc = batch[ck]
|
75 |
+
if exists(model.reshuffle_patch_size):
|
76 |
+
assert isinstance(model.reshuffle_patch_size, int)
|
77 |
+
cc = rearrange(cc, 'b c (p1 h) (p2 w) -> b (p1 p2 c) h w',
|
78 |
+
p1=model.reshuffle_patch_size, p2=model.reshuffle_patch_size)
|
79 |
+
c_cat.append(cc)
|
80 |
+
c_cat = torch.cat(c_cat, dim=1)
|
81 |
+
# cond
|
82 |
+
cond = {"c_concat": [c_cat], "c_crossattn": [c]}
|
83 |
+
# uncond cond
|
84 |
+
uc_cross = model.get_unconditional_conditioning(num_samples, "")
|
85 |
+
uc_full = {"c_concat": [c_cat], "c_crossattn": [uc_cross]}
|
86 |
+
elif isinstance(model, LatentUpscaleDiffusion):
|
87 |
+
x_augment, noise_level = make_noise_augmentation(model, batch, noise_level)
|
88 |
+
cond = {"c_concat": [x_augment], "c_crossattn": [c], "c_adm": noise_level}
|
89 |
+
# uncond cond
|
90 |
+
uc_cross = model.get_unconditional_conditioning(num_samples, "")
|
91 |
+
uc_full = {"c_concat": [x_augment], "c_crossattn": [uc_cross], "c_adm": noise_level}
|
92 |
+
else:
|
93 |
+
raise NotImplementedError()
|
94 |
+
|
95 |
+
shape = [model.channels, h, w]
|
96 |
+
samples, intermediates = sampler.sample(
|
97 |
+
steps,
|
98 |
+
num_samples,
|
99 |
+
shape,
|
100 |
+
cond,
|
101 |
+
verbose=False,
|
102 |
+
eta=eta,
|
103 |
+
unconditional_guidance_scale=scale,
|
104 |
+
unconditional_conditioning=uc_full,
|
105 |
+
x_T=start_code,
|
106 |
+
callback=callback
|
107 |
+
)
|
108 |
+
with torch.no_grad():
|
109 |
+
x_samples_ddim = model.decode_first_stage(samples)
|
110 |
+
result = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
|
111 |
+
result = result.cpu().numpy().transpose(0, 2, 3, 1) * 255
|
112 |
+
st.text(f"upscaled image shape: {result.shape}")
|
113 |
+
return [put_watermark(Image.fromarray(img.astype(np.uint8)), wm_encoder) for img in result]
|
114 |
+
|
115 |
+
|
116 |
+
def run():
|
117 |
+
st.title("Stable Diffusion Upscaling")
|
118 |
+
# run via streamlit run scripts/demo/depth2img.py <path-tp-config> <path-to-ckpt>
|
119 |
+
sampler = initialize_model(sys.argv[1], sys.argv[2])
|
120 |
+
|
121 |
+
image = st.file_uploader("Image", ["jpg", "png"])
|
122 |
+
if image:
|
123 |
+
image = Image.open(image)
|
124 |
+
w, h = image.size
|
125 |
+
st.text(f"loaded input image of size ({w}, {h})")
|
126 |
+
width, height = map(lambda x: x - x % 64, (w, h)) # resize to integer multiple of 64
|
127 |
+
image = image.resize((width, height))
|
128 |
+
st.text(f"resized input image to size ({width}, {height} (w, h))")
|
129 |
+
st.image(image)
|
130 |
+
|
131 |
+
st.write(f"\n Tip: Add a description of the object that should be upscaled, e.g.: 'a professional photograph of a cat'")
|
132 |
+
prompt = st.text_input("Prompt", "a high quality professional photograph")
|
133 |
+
|
134 |
+
seed = st.number_input("Seed", min_value=0, max_value=1000000, value=0)
|
135 |
+
num_samples = st.number_input("Number of Samples", min_value=1, max_value=64, value=1)
|
136 |
+
scale = st.slider("Scale", min_value=0.1, max_value=30.0, value=9.0, step=0.1)
|
137 |
+
steps = st.slider("DDIM Steps", min_value=2, max_value=250, value=50, step=1)
|
138 |
+
eta = st.sidebar.number_input("eta (DDIM)", value=0., min_value=0., max_value=1.)
|
139 |
+
|
140 |
+
noise_level = None
|
141 |
+
if isinstance(sampler.model, LatentUpscaleDiffusion):
|
142 |
+
# TODO: make this work for all models
|
143 |
+
noise_level = st.sidebar.number_input("Noise Augmentation", min_value=0, max_value=350, value=20)
|
144 |
+
noise_level = torch.Tensor(num_samples * [noise_level]).to(sampler.model.device).long()
|
145 |
+
|
146 |
+
t_progress = st.progress(0)
|
147 |
+
def t_callback(t):
|
148 |
+
t_progress.progress(min((t + 1) / steps, 1.))
|
149 |
+
|
150 |
+
sampler.make_schedule(steps, ddim_eta=eta, verbose=True)
|
151 |
+
if st.button("Sample"):
|
152 |
+
result = paint(
|
153 |
+
sampler=sampler,
|
154 |
+
image=image,
|
155 |
+
prompt=prompt,
|
156 |
+
seed=seed,
|
157 |
+
scale=scale,
|
158 |
+
h=height, w=width, steps=steps,
|
159 |
+
num_samples=num_samples,
|
160 |
+
callback=t_callback,
|
161 |
+
noise_level=noise_level,
|
162 |
+
eta=eta
|
163 |
+
)
|
164 |
+
st.write("Result")
|
165 |
+
for image in result:
|
166 |
+
st.image(image, output_format='PNG')
|
167 |
+
|
168 |
+
|
169 |
+
if __name__ == "__main__":
|
170 |
+
run()
|
repositories/stable-diffusion-stability-ai/scripts/tests/test_watermark.py
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import fire
|
3 |
+
from imwatermark import WatermarkDecoder
|
4 |
+
|
5 |
+
|
6 |
+
def testit(img_path):
|
7 |
+
bgr = cv2.imread(img_path)
|
8 |
+
decoder = WatermarkDecoder('bytes', 136)
|
9 |
+
watermark = decoder.decode(bgr, 'dwtDct')
|
10 |
+
try:
|
11 |
+
dec = watermark.decode('utf-8')
|
12 |
+
except:
|
13 |
+
dec = "null"
|
14 |
+
print(dec)
|
15 |
+
|
16 |
+
|
17 |
+
if __name__ == "__main__":
|
18 |
+
fire.Fire(testit)
|
repositories/stable-diffusion-stability-ai/scripts/txt2img.py
ADDED
@@ -0,0 +1,388 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse, os
|
2 |
+
import cv2
|
3 |
+
import torch
|
4 |
+
import numpy as np
|
5 |
+
from omegaconf import OmegaConf
|
6 |
+
from PIL import Image
|
7 |
+
from tqdm import tqdm, trange
|
8 |
+
from itertools import islice
|
9 |
+
from einops import rearrange
|
10 |
+
from torchvision.utils import make_grid
|
11 |
+
from pytorch_lightning import seed_everything
|
12 |
+
from torch import autocast
|
13 |
+
from contextlib import nullcontext
|
14 |
+
from imwatermark import WatermarkEncoder
|
15 |
+
|
16 |
+
from ldm.util import instantiate_from_config
|
17 |
+
from ldm.models.diffusion.ddim import DDIMSampler
|
18 |
+
from ldm.models.diffusion.plms import PLMSSampler
|
19 |
+
from ldm.models.diffusion.dpm_solver import DPMSolverSampler
|
20 |
+
|
21 |
+
torch.set_grad_enabled(False)
|
22 |
+
|
23 |
+
def chunk(it, size):
|
24 |
+
it = iter(it)
|
25 |
+
return iter(lambda: tuple(islice(it, size)), ())
|
26 |
+
|
27 |
+
|
28 |
+
def load_model_from_config(config, ckpt, device=torch.device("cuda"), verbose=False):
|
29 |
+
print(f"Loading model from {ckpt}")
|
30 |
+
pl_sd = torch.load(ckpt, map_location="cpu")
|
31 |
+
if "global_step" in pl_sd:
|
32 |
+
print(f"Global Step: {pl_sd['global_step']}")
|
33 |
+
sd = pl_sd["state_dict"]
|
34 |
+
model = instantiate_from_config(config.model)
|
35 |
+
m, u = model.load_state_dict(sd, strict=False)
|
36 |
+
if len(m) > 0 and verbose:
|
37 |
+
print("missing keys:")
|
38 |
+
print(m)
|
39 |
+
if len(u) > 0 and verbose:
|
40 |
+
print("unexpected keys:")
|
41 |
+
print(u)
|
42 |
+
|
43 |
+
if device == torch.device("cuda"):
|
44 |
+
model.cuda()
|
45 |
+
elif device == torch.device("cpu"):
|
46 |
+
model.cpu()
|
47 |
+
model.cond_stage_model.device = "cpu"
|
48 |
+
else:
|
49 |
+
raise ValueError(f"Incorrect device name. Received: {device}")
|
50 |
+
model.eval()
|
51 |
+
return model
|
52 |
+
|
53 |
+
|
54 |
+
def parse_args():
|
55 |
+
parser = argparse.ArgumentParser()
|
56 |
+
parser.add_argument(
|
57 |
+
"--prompt",
|
58 |
+
type=str,
|
59 |
+
nargs="?",
|
60 |
+
default="a professional photograph of an astronaut riding a triceratops",
|
61 |
+
help="the prompt to render"
|
62 |
+
)
|
63 |
+
parser.add_argument(
|
64 |
+
"--outdir",
|
65 |
+
type=str,
|
66 |
+
nargs="?",
|
67 |
+
help="dir to write results to",
|
68 |
+
default="outputs/txt2img-samples"
|
69 |
+
)
|
70 |
+
parser.add_argument(
|
71 |
+
"--steps",
|
72 |
+
type=int,
|
73 |
+
default=50,
|
74 |
+
help="number of ddim sampling steps",
|
75 |
+
)
|
76 |
+
parser.add_argument(
|
77 |
+
"--plms",
|
78 |
+
action='store_true',
|
79 |
+
help="use plms sampling",
|
80 |
+
)
|
81 |
+
parser.add_argument(
|
82 |
+
"--dpm",
|
83 |
+
action='store_true',
|
84 |
+
help="use DPM (2) sampler",
|
85 |
+
)
|
86 |
+
parser.add_argument(
|
87 |
+
"--fixed_code",
|
88 |
+
action='store_true',
|
89 |
+
help="if enabled, uses the same starting code across all samples ",
|
90 |
+
)
|
91 |
+
parser.add_argument(
|
92 |
+
"--ddim_eta",
|
93 |
+
type=float,
|
94 |
+
default=0.0,
|
95 |
+
help="ddim eta (eta=0.0 corresponds to deterministic sampling",
|
96 |
+
)
|
97 |
+
parser.add_argument(
|
98 |
+
"--n_iter",
|
99 |
+
type=int,
|
100 |
+
default=3,
|
101 |
+
help="sample this often",
|
102 |
+
)
|
103 |
+
parser.add_argument(
|
104 |
+
"--H",
|
105 |
+
type=int,
|
106 |
+
default=512,
|
107 |
+
help="image height, in pixel space",
|
108 |
+
)
|
109 |
+
parser.add_argument(
|
110 |
+
"--W",
|
111 |
+
type=int,
|
112 |
+
default=512,
|
113 |
+
help="image width, in pixel space",
|
114 |
+
)
|
115 |
+
parser.add_argument(
|
116 |
+
"--C",
|
117 |
+
type=int,
|
118 |
+
default=4,
|
119 |
+
help="latent channels",
|
120 |
+
)
|
121 |
+
parser.add_argument(
|
122 |
+
"--f",
|
123 |
+
type=int,
|
124 |
+
default=8,
|
125 |
+
help="downsampling factor, most often 8 or 16",
|
126 |
+
)
|
127 |
+
parser.add_argument(
|
128 |
+
"--n_samples",
|
129 |
+
type=int,
|
130 |
+
default=3,
|
131 |
+
help="how many samples to produce for each given prompt. A.k.a batch size",
|
132 |
+
)
|
133 |
+
parser.add_argument(
|
134 |
+
"--n_rows",
|
135 |
+
type=int,
|
136 |
+
default=0,
|
137 |
+
help="rows in the grid (default: n_samples)",
|
138 |
+
)
|
139 |
+
parser.add_argument(
|
140 |
+
"--scale",
|
141 |
+
type=float,
|
142 |
+
default=9.0,
|
143 |
+
help="unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))",
|
144 |
+
)
|
145 |
+
parser.add_argument(
|
146 |
+
"--from-file",
|
147 |
+
type=str,
|
148 |
+
help="if specified, load prompts from this file, separated by newlines",
|
149 |
+
)
|
150 |
+
parser.add_argument(
|
151 |
+
"--config",
|
152 |
+
type=str,
|
153 |
+
default="configs/stable-diffusion/v2-inference.yaml",
|
154 |
+
help="path to config which constructs model",
|
155 |
+
)
|
156 |
+
parser.add_argument(
|
157 |
+
"--ckpt",
|
158 |
+
type=str,
|
159 |
+
help="path to checkpoint of model",
|
160 |
+
)
|
161 |
+
parser.add_argument(
|
162 |
+
"--seed",
|
163 |
+
type=int,
|
164 |
+
default=42,
|
165 |
+
help="the seed (for reproducible sampling)",
|
166 |
+
)
|
167 |
+
parser.add_argument(
|
168 |
+
"--precision",
|
169 |
+
type=str,
|
170 |
+
help="evaluate at this precision",
|
171 |
+
choices=["full", "autocast"],
|
172 |
+
default="autocast"
|
173 |
+
)
|
174 |
+
parser.add_argument(
|
175 |
+
"--repeat",
|
176 |
+
type=int,
|
177 |
+
default=1,
|
178 |
+
help="repeat each prompt in file this often",
|
179 |
+
)
|
180 |
+
parser.add_argument(
|
181 |
+
"--device",
|
182 |
+
type=str,
|
183 |
+
help="Device on which Stable Diffusion will be run",
|
184 |
+
choices=["cpu", "cuda"],
|
185 |
+
default="cpu"
|
186 |
+
)
|
187 |
+
parser.add_argument(
|
188 |
+
"--torchscript",
|
189 |
+
action='store_true',
|
190 |
+
help="Use TorchScript",
|
191 |
+
)
|
192 |
+
parser.add_argument(
|
193 |
+
"--ipex",
|
194 |
+
action='store_true',
|
195 |
+
help="Use Intel® Extension for PyTorch*",
|
196 |
+
)
|
197 |
+
parser.add_argument(
|
198 |
+
"--bf16",
|
199 |
+
action='store_true',
|
200 |
+
help="Use bfloat16",
|
201 |
+
)
|
202 |
+
opt = parser.parse_args()
|
203 |
+
return opt
|
204 |
+
|
205 |
+
|
206 |
+
def put_watermark(img, wm_encoder=None):
|
207 |
+
if wm_encoder is not None:
|
208 |
+
img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
|
209 |
+
img = wm_encoder.encode(img, 'dwtDct')
|
210 |
+
img = Image.fromarray(img[:, :, ::-1])
|
211 |
+
return img
|
212 |
+
|
213 |
+
|
214 |
+
def main(opt):
|
215 |
+
seed_everything(opt.seed)
|
216 |
+
|
217 |
+
config = OmegaConf.load(f"{opt.config}")
|
218 |
+
device = torch.device("cuda") if opt.device == "cuda" else torch.device("cpu")
|
219 |
+
model = load_model_from_config(config, f"{opt.ckpt}", device)
|
220 |
+
|
221 |
+
if opt.plms:
|
222 |
+
sampler = PLMSSampler(model, device=device)
|
223 |
+
elif opt.dpm:
|
224 |
+
sampler = DPMSolverSampler(model, device=device)
|
225 |
+
else:
|
226 |
+
sampler = DDIMSampler(model, device=device)
|
227 |
+
|
228 |
+
os.makedirs(opt.outdir, exist_ok=True)
|
229 |
+
outpath = opt.outdir
|
230 |
+
|
231 |
+
print("Creating invisible watermark encoder (see https://github.com/ShieldMnt/invisible-watermark)...")
|
232 |
+
wm = "SDV2"
|
233 |
+
wm_encoder = WatermarkEncoder()
|
234 |
+
wm_encoder.set_watermark('bytes', wm.encode('utf-8'))
|
235 |
+
|
236 |
+
batch_size = opt.n_samples
|
237 |
+
n_rows = opt.n_rows if opt.n_rows > 0 else batch_size
|
238 |
+
if not opt.from_file:
|
239 |
+
prompt = opt.prompt
|
240 |
+
assert prompt is not None
|
241 |
+
data = [batch_size * [prompt]]
|
242 |
+
|
243 |
+
else:
|
244 |
+
print(f"reading prompts from {opt.from_file}")
|
245 |
+
with open(opt.from_file, "r") as f:
|
246 |
+
data = f.read().splitlines()
|
247 |
+
data = [p for p in data for i in range(opt.repeat)]
|
248 |
+
data = list(chunk(data, batch_size))
|
249 |
+
|
250 |
+
sample_path = os.path.join(outpath, "samples")
|
251 |
+
os.makedirs(sample_path, exist_ok=True)
|
252 |
+
sample_count = 0
|
253 |
+
base_count = len(os.listdir(sample_path))
|
254 |
+
grid_count = len(os.listdir(outpath)) - 1
|
255 |
+
|
256 |
+
start_code = None
|
257 |
+
if opt.fixed_code:
|
258 |
+
start_code = torch.randn([opt.n_samples, opt.C, opt.H // opt.f, opt.W // opt.f], device=device)
|
259 |
+
|
260 |
+
if opt.torchscript or opt.ipex:
|
261 |
+
transformer = model.cond_stage_model.model
|
262 |
+
unet = model.model.diffusion_model
|
263 |
+
decoder = model.first_stage_model.decoder
|
264 |
+
additional_context = torch.cpu.amp.autocast() if opt.bf16 else nullcontext()
|
265 |
+
shape = [opt.C, opt.H // opt.f, opt.W // opt.f]
|
266 |
+
|
267 |
+
if opt.bf16 and not opt.torchscript and not opt.ipex:
|
268 |
+
raise ValueError('Bfloat16 is supported only for torchscript+ipex')
|
269 |
+
if opt.bf16 and unet.dtype != torch.bfloat16:
|
270 |
+
raise ValueError("Use configs/stable-diffusion/intel/ configs with bf16 enabled if " +
|
271 |
+
"you'd like to use bfloat16 with CPU.")
|
272 |
+
if unet.dtype == torch.float16 and device == torch.device("cpu"):
|
273 |
+
raise ValueError("Use configs/stable-diffusion/intel/ configs for your model if you'd like to run it on CPU.")
|
274 |
+
|
275 |
+
if opt.ipex:
|
276 |
+
import intel_extension_for_pytorch as ipex
|
277 |
+
bf16_dtype = torch.bfloat16 if opt.bf16 else None
|
278 |
+
transformer = transformer.to(memory_format=torch.channels_last)
|
279 |
+
transformer = ipex.optimize(transformer, level="O1", inplace=True)
|
280 |
+
|
281 |
+
unet = unet.to(memory_format=torch.channels_last)
|
282 |
+
unet = ipex.optimize(unet, level="O1", auto_kernel_selection=True, inplace=True, dtype=bf16_dtype)
|
283 |
+
|
284 |
+
decoder = decoder.to(memory_format=torch.channels_last)
|
285 |
+
decoder = ipex.optimize(decoder, level="O1", auto_kernel_selection=True, inplace=True, dtype=bf16_dtype)
|
286 |
+
|
287 |
+
if opt.torchscript:
|
288 |
+
with torch.no_grad(), additional_context:
|
289 |
+
# get UNET scripted
|
290 |
+
if unet.use_checkpoint:
|
291 |
+
raise ValueError("Gradient checkpoint won't work with tracing. " +
|
292 |
+
"Use configs/stable-diffusion/intel/ configs for your model or disable checkpoint in your config.")
|
293 |
+
|
294 |
+
img_in = torch.ones(2, 4, 96, 96, dtype=torch.float32)
|
295 |
+
t_in = torch.ones(2, dtype=torch.int64)
|
296 |
+
context = torch.ones(2, 77, 1024, dtype=torch.float32)
|
297 |
+
scripted_unet = torch.jit.trace(unet, (img_in, t_in, context))
|
298 |
+
scripted_unet = torch.jit.optimize_for_inference(scripted_unet)
|
299 |
+
print(type(scripted_unet))
|
300 |
+
model.model.scripted_diffusion_model = scripted_unet
|
301 |
+
|
302 |
+
# get Decoder for first stage model scripted
|
303 |
+
samples_ddim = torch.ones(1, 4, 96, 96, dtype=torch.float32)
|
304 |
+
scripted_decoder = torch.jit.trace(decoder, (samples_ddim))
|
305 |
+
scripted_decoder = torch.jit.optimize_for_inference(scripted_decoder)
|
306 |
+
print(type(scripted_decoder))
|
307 |
+
model.first_stage_model.decoder = scripted_decoder
|
308 |
+
|
309 |
+
prompts = data[0]
|
310 |
+
print("Running a forward pass to initialize optimizations")
|
311 |
+
uc = None
|
312 |
+
if opt.scale != 1.0:
|
313 |
+
uc = model.get_learned_conditioning(batch_size * [""])
|
314 |
+
if isinstance(prompts, tuple):
|
315 |
+
prompts = list(prompts)
|
316 |
+
|
317 |
+
with torch.no_grad(), additional_context:
|
318 |
+
for _ in range(3):
|
319 |
+
c = model.get_learned_conditioning(prompts)
|
320 |
+
samples_ddim, _ = sampler.sample(S=5,
|
321 |
+
conditioning=c,
|
322 |
+
batch_size=batch_size,
|
323 |
+
shape=shape,
|
324 |
+
verbose=False,
|
325 |
+
unconditional_guidance_scale=opt.scale,
|
326 |
+
unconditional_conditioning=uc,
|
327 |
+
eta=opt.ddim_eta,
|
328 |
+
x_T=start_code)
|
329 |
+
print("Running a forward pass for decoder")
|
330 |
+
for _ in range(3):
|
331 |
+
x_samples_ddim = model.decode_first_stage(samples_ddim)
|
332 |
+
|
333 |
+
precision_scope = autocast if opt.precision=="autocast" or opt.bf16 else nullcontext
|
334 |
+
with torch.no_grad(), \
|
335 |
+
precision_scope(opt.device), \
|
336 |
+
model.ema_scope():
|
337 |
+
all_samples = list()
|
338 |
+
for n in trange(opt.n_iter, desc="Sampling"):
|
339 |
+
for prompts in tqdm(data, desc="data"):
|
340 |
+
uc = None
|
341 |
+
if opt.scale != 1.0:
|
342 |
+
uc = model.get_learned_conditioning(batch_size * [""])
|
343 |
+
if isinstance(prompts, tuple):
|
344 |
+
prompts = list(prompts)
|
345 |
+
c = model.get_learned_conditioning(prompts)
|
346 |
+
shape = [opt.C, opt.H // opt.f, opt.W // opt.f]
|
347 |
+
samples, _ = sampler.sample(S=opt.steps,
|
348 |
+
conditioning=c,
|
349 |
+
batch_size=opt.n_samples,
|
350 |
+
shape=shape,
|
351 |
+
verbose=False,
|
352 |
+
unconditional_guidance_scale=opt.scale,
|
353 |
+
unconditional_conditioning=uc,
|
354 |
+
eta=opt.ddim_eta,
|
355 |
+
x_T=start_code)
|
356 |
+
|
357 |
+
x_samples = model.decode_first_stage(samples)
|
358 |
+
x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)
|
359 |
+
|
360 |
+
for x_sample in x_samples:
|
361 |
+
x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
|
362 |
+
img = Image.fromarray(x_sample.astype(np.uint8))
|
363 |
+
img = put_watermark(img, wm_encoder)
|
364 |
+
img.save(os.path.join(sample_path, f"{base_count:05}.png"))
|
365 |
+
base_count += 1
|
366 |
+
sample_count += 1
|
367 |
+
|
368 |
+
all_samples.append(x_samples)
|
369 |
+
|
370 |
+
# additionally, save as grid
|
371 |
+
grid = torch.stack(all_samples, 0)
|
372 |
+
grid = rearrange(grid, 'n b c h w -> (n b) c h w')
|
373 |
+
grid = make_grid(grid, nrow=n_rows)
|
374 |
+
|
375 |
+
# to image
|
376 |
+
grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy()
|
377 |
+
grid = Image.fromarray(grid.astype(np.uint8))
|
378 |
+
grid = put_watermark(grid, wm_encoder)
|
379 |
+
grid.save(os.path.join(outpath, f'grid-{grid_count:04}.png'))
|
380 |
+
grid_count += 1
|
381 |
+
|
382 |
+
print(f"Your samples are ready and waiting for you here: \n{outpath} \n"
|
383 |
+
f" \nEnjoy.")
|
384 |
+
|
385 |
+
|
386 |
+
if __name__ == "__main__":
|
387 |
+
opt = parse_args()
|
388 |
+
main(opt)
|
repositories/stable-diffusion-stability-ai/setup.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from setuptools import setup, find_packages
|
2 |
+
|
3 |
+
setup(
|
4 |
+
name='stable-diffusion',
|
5 |
+
version='0.0.1',
|
6 |
+
description='',
|
7 |
+
packages=find_packages(),
|
8 |
+
install_requires=[
|
9 |
+
'torch',
|
10 |
+
'numpy',
|
11 |
+
'tqdm',
|
12 |
+
],
|
13 |
+
)
|
requirements-test.txt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
pytest-base-url~=2.0
|
2 |
+
pytest-cov~=4.0
|
3 |
+
pytest~=7.3
|
requirements.txt
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
GitPython
|
2 |
+
Pillow
|
3 |
+
accelerate
|
4 |
+
|
5 |
+
basicsr
|
6 |
+
blendmodes
|
7 |
+
clean-fid
|
8 |
+
einops
|
9 |
+
gfpgan
|
10 |
+
gradio==3.32.0
|
11 |
+
inflection
|
12 |
+
jsonmerge
|
13 |
+
kornia
|
14 |
+
lark
|
15 |
+
numpy
|
16 |
+
omegaconf
|
17 |
+
open-clip-torch
|
18 |
+
|
19 |
+
piexif
|
20 |
+
psutil
|
21 |
+
pytorch_lightning
|
22 |
+
realesrgan
|
23 |
+
requests
|
24 |
+
resize-right
|
25 |
+
|
26 |
+
safetensors
|
27 |
+
scikit-image>=0.19
|
28 |
+
timm
|
29 |
+
tomesd
|
30 |
+
torch
|
31 |
+
torchdiffeq
|
32 |
+
torchsde
|
33 |
+
transformers==4.25.1
|
requirements_versions.txt
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
GitPython==3.1.30
|
2 |
+
Pillow==9.5.0
|
3 |
+
accelerate==0.18.0
|
4 |
+
basicsr==1.4.2
|
5 |
+
blendmodes==2022
|
6 |
+
clean-fid==0.1.35
|
7 |
+
einops==0.4.1
|
8 |
+
fastapi==0.94.0
|
9 |
+
gfpgan==1.3.8
|
10 |
+
gradio==3.32.0
|
11 |
+
httpcore==0.15
|
12 |
+
inflection==0.5.1
|
13 |
+
jsonmerge==1.8.0
|
14 |
+
kornia==0.6.7
|
15 |
+
lark==1.1.2
|
16 |
+
numpy==1.23.5
|
17 |
+
omegaconf==2.2.3
|
18 |
+
open-clip-torch==2.20.0
|
19 |
+
piexif==1.1.3
|
20 |
+
psutil==5.9.5
|
21 |
+
pytorch_lightning==1.9.4
|
22 |
+
realesrgan==0.3.0
|
23 |
+
resize-right==0.0.2
|
24 |
+
safetensors==0.3.1
|
25 |
+
scikit-image==0.20.0
|
26 |
+
timm==0.6.7
|
27 |
+
tomesd==0.1.2
|
28 |
+
torch
|
29 |
+
torchdiffeq==0.2.3
|
30 |
+
torchsde==0.2.5
|
31 |
+
transformers==4.25.1
|
screenshot.png
ADDED
![]() |
script.js
ADDED
@@ -0,0 +1,163 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
function gradioApp() {
|
2 |
+
const elems = document.getElementsByTagName('gradio-app');
|
3 |
+
const elem = elems.length == 0 ? document : elems[0];
|
4 |
+
|
5 |
+
if (elem !== document) {
|
6 |
+
elem.getElementById = function(id) {
|
7 |
+
return document.getElementById(id);
|
8 |
+
};
|
9 |
+
}
|
10 |
+
return elem.shadowRoot ? elem.shadowRoot : elem;
|
11 |
+
}
|
12 |
+
|
13 |
+
/**
|
14 |
+
* Get the currently selected top-level UI tab button (e.g. the button that says "Extras").
|
15 |
+
*/
|
16 |
+
function get_uiCurrentTab() {
|
17 |
+
return gradioApp().querySelector('#tabs > .tab-nav > button.selected');
|
18 |
+
}
|
19 |
+
|
20 |
+
/**
|
21 |
+
* Get the first currently visible top-level UI tab content (e.g. the div hosting the "txt2img" UI).
|
22 |
+
*/
|
23 |
+
function get_uiCurrentTabContent() {
|
24 |
+
return gradioApp().querySelector('#tabs > .tabitem[id^=tab_]:not([style*="display: none"])');
|
25 |
+
}
|
26 |
+
|
27 |
+
var uiUpdateCallbacks = [];
|
28 |
+
var uiAfterUpdateCallbacks = [];
|
29 |
+
var uiLoadedCallbacks = [];
|
30 |
+
var uiTabChangeCallbacks = [];
|
31 |
+
var optionsChangedCallbacks = [];
|
32 |
+
var uiAfterUpdateTimeout = null;
|
33 |
+
var uiCurrentTab = null;
|
34 |
+
|
35 |
+
/**
|
36 |
+
* Register callback to be called at each UI update.
|
37 |
+
* The callback receives an array of MutationRecords as an argument.
|
38 |
+
*/
|
39 |
+
function onUiUpdate(callback) {
|
40 |
+
uiUpdateCallbacks.push(callback);
|
41 |
+
}
|
42 |
+
|
43 |
+
/**
|
44 |
+
* Register callback to be called soon after UI updates.
|
45 |
+
* The callback receives no arguments.
|
46 |
+
*
|
47 |
+
* This is preferred over `onUiUpdate` if you don't need
|
48 |
+
* access to the MutationRecords, as your function will
|
49 |
+
* not be called quite as often.
|
50 |
+
*/
|
51 |
+
function onAfterUiUpdate(callback) {
|
52 |
+
uiAfterUpdateCallbacks.push(callback);
|
53 |
+
}
|
54 |
+
|
55 |
+
/**
|
56 |
+
* Register callback to be called when the UI is loaded.
|
57 |
+
* The callback receives no arguments.
|
58 |
+
*/
|
59 |
+
function onUiLoaded(callback) {
|
60 |
+
uiLoadedCallbacks.push(callback);
|
61 |
+
}
|
62 |
+
|
63 |
+
/**
|
64 |
+
* Register callback to be called when the UI tab is changed.
|
65 |
+
* The callback receives no arguments.
|
66 |
+
*/
|
67 |
+
function onUiTabChange(callback) {
|
68 |
+
uiTabChangeCallbacks.push(callback);
|
69 |
+
}
|
70 |
+
|
71 |
+
/**
|
72 |
+
* Register callback to be called when the options are changed.
|
73 |
+
* The callback receives no arguments.
|
74 |
+
* @param callback
|
75 |
+
*/
|
76 |
+
function onOptionsChanged(callback) {
|
77 |
+
optionsChangedCallbacks.push(callback);
|
78 |
+
}
|
79 |
+
|
80 |
+
function executeCallbacks(queue, arg) {
|
81 |
+
for (const callback of queue) {
|
82 |
+
try {
|
83 |
+
callback(arg);
|
84 |
+
} catch (e) {
|
85 |
+
console.error("error running callback", callback, ":", e);
|
86 |
+
}
|
87 |
+
}
|
88 |
+
}
|
89 |
+
|
90 |
+
/**
|
91 |
+
* Schedule the execution of the callbacks registered with onAfterUiUpdate.
|
92 |
+
* The callbacks are executed after a short while, unless another call to this function
|
93 |
+
* is made before that time. IOW, the callbacks are executed only once, even
|
94 |
+
* when there are multiple mutations observed.
|
95 |
+
*/
|
96 |
+
function scheduleAfterUiUpdateCallbacks() {
|
97 |
+
clearTimeout(uiAfterUpdateTimeout);
|
98 |
+
uiAfterUpdateTimeout = setTimeout(function() {
|
99 |
+
executeCallbacks(uiAfterUpdateCallbacks);
|
100 |
+
}, 200);
|
101 |
+
}
|
102 |
+
|
103 |
+
var executedOnLoaded = false;
|
104 |
+
|
105 |
+
document.addEventListener("DOMContentLoaded", function() {
|
106 |
+
var mutationObserver = new MutationObserver(function(m) {
|
107 |
+
if (!executedOnLoaded && gradioApp().querySelector('#txt2img_prompt')) {
|
108 |
+
executedOnLoaded = true;
|
109 |
+
executeCallbacks(uiLoadedCallbacks);
|
110 |
+
}
|
111 |
+
|
112 |
+
executeCallbacks(uiUpdateCallbacks, m);
|
113 |
+
scheduleAfterUiUpdateCallbacks();
|
114 |
+
const newTab = get_uiCurrentTab();
|
115 |
+
if (newTab && (newTab !== uiCurrentTab)) {
|
116 |
+
uiCurrentTab = newTab;
|
117 |
+
executeCallbacks(uiTabChangeCallbacks);
|
118 |
+
}
|
119 |
+
});
|
120 |
+
mutationObserver.observe(gradioApp(), {childList: true, subtree: true});
|
121 |
+
});
|
122 |
+
|
123 |
+
/**
|
124 |
+
* Add a ctrl+enter as a shortcut to start a generation
|
125 |
+
*/
|
126 |
+
document.addEventListener('keydown', function(e) {
|
127 |
+
var handled = false;
|
128 |
+
if (e.key !== undefined) {
|
129 |
+
if ((e.key == "Enter" && (e.metaKey || e.ctrlKey || e.altKey))) handled = true;
|
130 |
+
} else if (e.keyCode !== undefined) {
|
131 |
+
if ((e.keyCode == 13 && (e.metaKey || e.ctrlKey || e.altKey))) handled = true;
|
132 |
+
}
|
133 |
+
if (handled) {
|
134 |
+
var button = get_uiCurrentTabContent().querySelector('button[id$=_generate]');
|
135 |
+
if (button) {
|
136 |
+
button.click();
|
137 |
+
}
|
138 |
+
e.preventDefault();
|
139 |
+
}
|
140 |
+
});
|
141 |
+
|
142 |
+
/**
|
143 |
+
* checks that a UI element is not in another hidden element or tab content
|
144 |
+
*/
|
145 |
+
function uiElementIsVisible(el) {
|
146 |
+
if (el === document) {
|
147 |
+
return true;
|
148 |
+
}
|
149 |
+
|
150 |
+
const computedStyle = getComputedStyle(el);
|
151 |
+
const isVisible = computedStyle.display !== 'none';
|
152 |
+
|
153 |
+
if (!isVisible) return false;
|
154 |
+
return uiElementIsVisible(el.parentNode);
|
155 |
+
}
|
156 |
+
|
157 |
+
function uiElementInSight(el) {
|
158 |
+
const clRect = el.getBoundingClientRect();
|
159 |
+
const windowHeight = window.innerHeight;
|
160 |
+
const isOnScreen = clRect.bottom > 0 && clRect.top < windowHeight;
|
161 |
+
|
162 |
+
return isOnScreen;
|
163 |
+
}
|
scripts/__pycache__/custom_code.cpython-310.pyc
ADDED
Binary file (2.73 kB). View file
|
|
scripts/__pycache__/img2imgalt.cpython-310.pyc
ADDED
Binary file (6.37 kB). View file
|
|
scripts/__pycache__/loopback.cpython-310.pyc
ADDED
Binary file (3.53 kB). View file
|
|
scripts/__pycache__/outpainting_mk_2.cpython-310.pyc
ADDED
Binary file (8.31 kB). View file
|
|
scripts/__pycache__/poor_mans_outpainting.cpython-310.pyc
ADDED
Binary file (4.12 kB). View file
|
|
scripts/__pycache__/postprocessing_codeformer.cpython-310.pyc
ADDED
Binary file (1.61 kB). View file
|
|
scripts/__pycache__/postprocessing_gfpgan.cpython-310.pyc
ADDED
Binary file (1.39 kB). View file
|
|
scripts/__pycache__/postprocessing_upscale.cpython-310.pyc
ADDED
Binary file (6.44 kB). View file
|
|
scripts/__pycache__/prompt_matrix.cpython-310.pyc
ADDED
Binary file (4.2 kB). View file
|
|