Upload 7 files
Browse files- app.py +205 -0
- image_utils.py +108 -0
- inverse_stable_diffusion.py +205 -0
- modified_stable_diffusion.py +235 -0
- requirements.txt +14 -0
- run_gaussian_shading.py +106 -0
- watermark.py +95 -0
app.py
ADDED
@@ -0,0 +1,205 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import gradio as gr
|
3 |
+
from run_gaussian_shading import *
|
4 |
+
|
5 |
+
examples = [
|
6 |
+
"A photo of a cat",
|
7 |
+
"A pizza with pineapple on it",
|
8 |
+
"A photo of dog",
|
9 |
+
]
|
10 |
+
|
11 |
+
css = """
|
12 |
+
#col-container {
|
13 |
+
margin: 0 auto;
|
14 |
+
max-width: 700px;
|
15 |
+
}
|
16 |
+
"""
|
17 |
+
|
18 |
+
MAX_SEED = np.iinfo(np.int32).max
|
19 |
+
|
20 |
+
|
21 |
+
#---------------------------------------------------------------------------------------------------
|
22 |
+
|
23 |
+
|
24 |
+
with gr.Blocks(css=css) as demo:
|
25 |
+
|
26 |
+
# ---------------------------------- Add Watermark -----------------------------------------
|
27 |
+
|
28 |
+
with gr.Tab("Add watermark"):
|
29 |
+
with gr.Column(elem_id="col-container"):
|
30 |
+
gr.Markdown(" # Text-to-Image Watermark")
|
31 |
+
with gr.Accordion("Instruction", open=False):
|
32 |
+
gr.Markdown("""
|
33 |
+
# Embedding Watermark
|
34 |
+
## 1. Generate watermarked image
|
35 |
+
* Enter your prompt in the text box.
|
36 |
+
* Click **Run** to generate an image with a random binary watermark.
|
37 |
+
|
38 |
+
## 2. Save Image
|
39 |
+
Click **Download** to save the watermarked image in PNG format
|
40 |
+
|
41 |
+
|
42 |
+
## 3. Advanced Settings
|
43 |
+
- **Seed**: Generates different images with different seed.
|
44 |
+
- **Guidance Scale**: Higher values give the model more freedom in image creation.
|
45 |
+
- **Num Inference Steps**: More steps enhance image detail and quality but increase computational cost.
|
46 |
+
Source code: [Gaussian Shading](https://github.com/bsmhmmlf/Gaussian-Shading)""")
|
47 |
+
with gr.Row():
|
48 |
+
prompt = gr.Text(
|
49 |
+
label="Prompt",
|
50 |
+
show_label=False,
|
51 |
+
max_lines=1,
|
52 |
+
placeholder="Enter your prompt",
|
53 |
+
container=False,
|
54 |
+
)
|
55 |
+
run_button = gr.Button("Run", scale=0, variant="primary")
|
56 |
+
download_button = gr.DownloadButton(visible=True)
|
57 |
+
with gr.Row():
|
58 |
+
result_original = gr.Image(label="Image without watermark", show_label=True)
|
59 |
+
result = gr.Image(label="Watermarked Image", show_label=True)
|
60 |
+
|
61 |
+
with gr.Accordion("Advanced Settings", open=False):
|
62 |
+
seed = gr.Slider(
|
63 |
+
label="Seed",
|
64 |
+
minimum=0,
|
65 |
+
maximum=MAX_SEED,
|
66 |
+
step=1,
|
67 |
+
value=0,
|
68 |
+
)
|
69 |
+
with gr.Row():
|
70 |
+
guidance_scale = gr.Slider(
|
71 |
+
label="Guidance scale",
|
72 |
+
minimum=1.5,
|
73 |
+
maximum=10,
|
74 |
+
step=0.1,
|
75 |
+
value=7.5,
|
76 |
+
)
|
77 |
+
num_inference_steps = gr.Slider(
|
78 |
+
label="Num inference steps",
|
79 |
+
minimum=10,
|
80 |
+
maximum=100,
|
81 |
+
step=1,
|
82 |
+
value=50,
|
83 |
+
)
|
84 |
+
gr.Examples(examples=examples, inputs=[prompt])
|
85 |
+
|
86 |
+
# ---------------------------------- Extract Watermark -----------------------------------------
|
87 |
+
with gr.Tab("Extract watermark"):
|
88 |
+
with gr.Column(elem_id="col-container"):
|
89 |
+
gr.Markdown(" # Watermark Extraction")
|
90 |
+
with gr.Accordion("Instruction", open=False):
|
91 |
+
gr.Markdown("""
|
92 |
+
# Extracting Watermark
|
93 |
+
**Note**: Ensure you create an image first to add the watermark to the database.
|
94 |
+
## 1. Upload Image
|
95 |
+
- Upload the image to the Image box.
|
96 |
+
- Click the **Extract** button to extract the watermark.
|
97 |
+
## 2. Advanced Settings
|
98 |
+
These settings are **optional** and can be used to simulate real-world attacks to erase the watermark:
|
99 |
+
Click the **Attack** button to generate a distorted image.
|
100 |
+
* **Seed**: Initialize the random number generator, ensuring reproducibility of the attack
|
101 |
+
* **Random crop ratio**: determines the proportion of the image to be randomly cropped. A lower ratio means more of the image will be cropped.
|
102 |
+
* **Random drop ratio**: specifies the fraction of pixels to be randomly dropped. A higher ratio increases the number of dropped pixels.
|
103 |
+
* **Resize ratio**: determines how much the image will be resized. A lower ratio means the image will be reduced more significantly.
|
104 |
+
* **Gaussian blur R**: the radius of the Gaussian blur applied to the image. A larger radius results in a more blurred image.
|
105 |
+
* **Gaussian Std**: standard deviation of the Gaussian distribution used for blurring. A higher value results in a stronger blur effect.
|
106 |
+
* **Sp prob**: the probability of each pixel being replaced with either black or white noise. A higher probability increases the amount of noise added to the image.
|
107 |
+
## Output Explanation
|
108 |
+
- **Output watermark**: The binary bit embedding in the image.
|
109 |
+
- **Accuracy bit**: The number of binary bits extracted that match the binary watermark in the database.
|
110 |
+
""")
|
111 |
+
with gr.Row():
|
112 |
+
input_image = gr.Image(type='pil')
|
113 |
+
extract_button = gr.Button("Extract", scale=0, variant="primary")
|
114 |
+
|
115 |
+
with gr.Accordion("Advanced Settings", open=False):
|
116 |
+
with gr.Row():
|
117 |
+
seed = gr.Slider(
|
118 |
+
label="Seed",
|
119 |
+
minimum=0,
|
120 |
+
maximum=MAX_SEED,
|
121 |
+
step=1,
|
122 |
+
value=0,
|
123 |
+
)
|
124 |
+
attack_button = gr.Button("Attack!", scale=0, variant="primary")
|
125 |
+
with gr.Row():
|
126 |
+
random_crop_ratio = gr.Slider(
|
127 |
+
label="Random crop ratio",
|
128 |
+
minimum=0.5,
|
129 |
+
maximum=1,
|
130 |
+
step=0.1,
|
131 |
+
value=1,
|
132 |
+
)
|
133 |
+
random_drop_ratio = gr.Slider(
|
134 |
+
label="Random drop ratio",
|
135 |
+
minimum=0,
|
136 |
+
maximum=1,
|
137 |
+
step=0.1,
|
138 |
+
value=0,
|
139 |
+
)
|
140 |
+
with gr.Row():
|
141 |
+
resize_ratio = gr.Slider(
|
142 |
+
label="Resize ratio",
|
143 |
+
minimum=0.2,
|
144 |
+
maximum=1,
|
145 |
+
step=0.1,
|
146 |
+
value=1,
|
147 |
+
)
|
148 |
+
gaussian_blur_r = gr.Slider(
|
149 |
+
label="Gaussian blur r",
|
150 |
+
minimum=0,
|
151 |
+
maximum=1,
|
152 |
+
step=0.1,
|
153 |
+
value=0,
|
154 |
+
)
|
155 |
+
with gr.Row():
|
156 |
+
gaussian_std = gr.Slider(
|
157 |
+
label="Gaussian std",
|
158 |
+
minimum=0,
|
159 |
+
maximum=0.01,
|
160 |
+
step=0.0001,
|
161 |
+
value=0,
|
162 |
+
)
|
163 |
+
sp_prob = gr.Slider(
|
164 |
+
label="Sp prob",
|
165 |
+
minimum=0,
|
166 |
+
maximum=0.1,
|
167 |
+
step=0.001,
|
168 |
+
value=0,
|
169 |
+
)
|
170 |
+
attack_image = gr.Image(label="Attacked Image")
|
171 |
+
output = gr.Textbox(label="Output")
|
172 |
+
with gr.Accordion("More Details", open=False):
|
173 |
+
result_extract = gr.Textbox(label="Bit watermark")
|
174 |
+
accuracy_bit = gr.Textbox(label="Accuracy bit")
|
175 |
+
|
176 |
+
# ----------------------------- Embedding watermark -------------------------
|
177 |
+
gr.on(
|
178 |
+
triggers=[run_button.click, prompt.submit],
|
179 |
+
fn=generate_with_watermark,
|
180 |
+
inputs=[
|
181 |
+
seed,
|
182 |
+
prompt,
|
183 |
+
guidance_scale,
|
184 |
+
num_inference_steps
|
185 |
+
],
|
186 |
+
outputs=[result_original, result, download_button],
|
187 |
+
)
|
188 |
+
|
189 |
+
# ----------------------------- Extract watermark -------------------------
|
190 |
+
gr.on(
|
191 |
+
triggers=[extract_button.click, attack_button.click],
|
192 |
+
fn=reverse_watermark,
|
193 |
+
inputs=[
|
194 |
+
input_image,
|
195 |
+
seed,
|
196 |
+
random_crop_ratio,
|
197 |
+
random_drop_ratio,
|
198 |
+
resize_ratio,
|
199 |
+
gaussian_blur_r,
|
200 |
+
gaussian_std,
|
201 |
+
sp_prob,
|
202 |
+
],
|
203 |
+
outputs=[output, result_extract, accuracy_bit, attack_image],
|
204 |
+
)
|
205 |
+
demo.launch(share=True)
|
image_utils.py
ADDED
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
from torchvision import transforms
|
4 |
+
from PIL import Image, ImageFilter
|
5 |
+
import random
|
6 |
+
|
7 |
+
|
8 |
+
def set_random_seed(seed=0):
|
9 |
+
torch.manual_seed(seed + 0)
|
10 |
+
torch.cuda.manual_seed(seed + 1)
|
11 |
+
torch.cuda.manual_seed_all(seed + 2)
|
12 |
+
np.random.seed(seed + 3)
|
13 |
+
torch.cuda.manual_seed_all(seed + 4)
|
14 |
+
random.seed(seed + 5)
|
15 |
+
|
16 |
+
|
17 |
+
def transform_img(image, target_size=512):
|
18 |
+
tform = transforms.Compose(
|
19 |
+
[
|
20 |
+
transforms.Resize(target_size),
|
21 |
+
transforms.CenterCrop(target_size),
|
22 |
+
transforms.ToTensor(),
|
23 |
+
]
|
24 |
+
)
|
25 |
+
image = tform(image)
|
26 |
+
return 2.0 * image - 1.0
|
27 |
+
|
28 |
+
|
29 |
+
def latents_to_imgs(pipe, latents):
|
30 |
+
x = pipe.decode_image(latents)
|
31 |
+
x = pipe.torch_to_numpy(x)
|
32 |
+
x = pipe.numpy_to_pil(x)
|
33 |
+
return x
|
34 |
+
|
35 |
+
def image_distortion(img,
|
36 |
+
seed: int = 42,
|
37 |
+
random_crop_ratio: float = None,
|
38 |
+
random_drop_ratio: float = None,
|
39 |
+
resize_ratio: float = None,
|
40 |
+
gaussian_blur_r: int = None, #
|
41 |
+
gaussian_std: float = None,
|
42 |
+
sp_prob: float = None):
|
43 |
+
|
44 |
+
if random_crop_ratio is not None:
|
45 |
+
set_random_seed(seed)
|
46 |
+
width, height, c = np.array(img).shape
|
47 |
+
img = np.array(img)
|
48 |
+
new_width = int(width * random_crop_ratio)
|
49 |
+
new_height = int(height * random_crop_ratio)
|
50 |
+
start_x = np.random.randint(0, width - new_width + 1)
|
51 |
+
start_y = np.random.randint(0, height - new_height + 1)
|
52 |
+
end_x = start_x + new_width
|
53 |
+
end_y = start_y + new_height
|
54 |
+
padded_image = np.zeros_like(img)
|
55 |
+
padded_image[start_y:end_y, start_x:end_x] = img[start_y:end_y, start_x:end_x]
|
56 |
+
img = Image.fromarray(padded_image)
|
57 |
+
|
58 |
+
if random_drop_ratio is not None:
|
59 |
+
set_random_seed(seed)
|
60 |
+
width, height, c = np.array(img).shape
|
61 |
+
img = np.array(img)
|
62 |
+
new_width = int(width * random_drop_ratio)
|
63 |
+
new_height = int(height * random_drop_ratio)
|
64 |
+
start_x = np.random.randint(0, width - new_width + 1)
|
65 |
+
start_y = np.random.randint(0, height - new_height + 1)
|
66 |
+
padded_image = np.zeros_like(img[start_y:start_y + new_height, start_x:start_x + new_width])
|
67 |
+
img[start_y:start_y + new_height, start_x:start_x + new_width] = padded_image
|
68 |
+
img = Image.fromarray(img)
|
69 |
+
|
70 |
+
if resize_ratio is not None:
|
71 |
+
img_shape = np.array(img).shape
|
72 |
+
resize_size = int(img_shape[0] * resize_ratio)
|
73 |
+
img = transforms.Resize(size=resize_size)(img)
|
74 |
+
img = transforms.Resize(size=img_shape[0])(img)
|
75 |
+
|
76 |
+
if gaussian_blur_r is not None:
|
77 |
+
img = img.filter(ImageFilter.GaussianBlur(radius=gaussian_blur_r))
|
78 |
+
|
79 |
+
if gaussian_std is not None:
|
80 |
+
img_shape = np.array(img).shape
|
81 |
+
g_noise = np.random.normal(0, gaussian_std, img_shape) * 255
|
82 |
+
g_noise = g_noise.astype(np.uint8)
|
83 |
+
img = Image.fromarray(np.clip(np.array(img) + g_noise, 0, 255))
|
84 |
+
|
85 |
+
if sp_prob is not None:
|
86 |
+
c,h,w = np.array(img).shape
|
87 |
+
prob_zero = sp_prob / 2
|
88 |
+
prob_one = 1 - prob_zero
|
89 |
+
rdn = np.random.rand(c,h,w)
|
90 |
+
img = np.where(rdn > prob_one, np.zeros_like(img), img)
|
91 |
+
img = np.where(rdn < prob_zero, np.ones_like(img)*255, img)
|
92 |
+
img = Image.fromarray(img)
|
93 |
+
|
94 |
+
return img
|
95 |
+
|
96 |
+
def measure_similarity(images, prompt, model, clip_preprocess, tokenizer, device):
|
97 |
+
with torch.no_grad():
|
98 |
+
img_batch = [clip_preprocess(i).unsqueeze(0) for i in images]
|
99 |
+
img_batch = torch.concatenate(img_batch).to(device)
|
100 |
+
image_features = model.encode_image(img_batch)
|
101 |
+
|
102 |
+
text = tokenizer([prompt]).to(device)
|
103 |
+
text_features = model.encode_text(text)
|
104 |
+
|
105 |
+
image_features /= image_features.norm(dim=-1, keepdim=True)
|
106 |
+
text_features /= text_features.norm(dim=-1, keepdim=True)
|
107 |
+
|
108 |
+
return (image_features @ text_features.T).mean(-1)
|
inverse_stable_diffusion.py
ADDED
@@ -0,0 +1,205 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from functools import partial
|
2 |
+
from typing import Callable, List, Optional, Union, Tuple
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
6 |
+
|
7 |
+
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
8 |
+
# from diffusers import StableDiffusionPipeline
|
9 |
+
from diffusers.pipelines.stable_diffusion.safety_checker import \
|
10 |
+
StableDiffusionSafetyChecker
|
11 |
+
from diffusers.schedulers import DDIMScheduler,PNDMScheduler, LMSDiscreteScheduler
|
12 |
+
|
13 |
+
from modified_stable_diffusion import ModifiedStableDiffusionPipeline
|
14 |
+
from torchvision.transforms import ToPILImage
|
15 |
+
import matplotlib.pyplot as plt
|
16 |
+
|
17 |
+
|
18 |
+
|
19 |
+
### credit to: https://github.com/cccntu/efficient-prompt-to-prompt
|
20 |
+
def backward_ddim(x_t, alpha_t, alpha_tm1, eps_xt):
|
21 |
+
""" from noise to image"""
|
22 |
+
return (
|
23 |
+
alpha_tm1**0.5
|
24 |
+
* (
|
25 |
+
(alpha_t**-0.5 - alpha_tm1**-0.5) * x_t
|
26 |
+
+ ((1 / alpha_tm1 - 1) ** 0.5 - (1 / alpha_t - 1) ** 0.5) * eps_xt
|
27 |
+
)
|
28 |
+
+ x_t
|
29 |
+
)
|
30 |
+
|
31 |
+
def forward_ddim(x_t, alpha_t, alpha_tp1, eps_xt):
|
32 |
+
""" from image to noise, it's the same as backward_ddim"""
|
33 |
+
return backward_ddim(x_t, alpha_t, alpha_tp1, eps_xt)
|
34 |
+
|
35 |
+
|
36 |
+
class InversableStableDiffusionPipeline(ModifiedStableDiffusionPipeline):
|
37 |
+
def __init__(self,
|
38 |
+
vae,
|
39 |
+
text_encoder,
|
40 |
+
tokenizer,
|
41 |
+
unet,
|
42 |
+
scheduler,
|
43 |
+
safety_checker,
|
44 |
+
feature_extractor,
|
45 |
+
requires_safety_checker: bool = False,
|
46 |
+
):
|
47 |
+
|
48 |
+
super(InversableStableDiffusionPipeline, self).__init__(vae,
|
49 |
+
text_encoder,
|
50 |
+
tokenizer,
|
51 |
+
unet,
|
52 |
+
scheduler,
|
53 |
+
safety_checker,
|
54 |
+
feature_extractor,
|
55 |
+
requires_safety_checker)
|
56 |
+
|
57 |
+
self.forward_diffusion = partial(self.backward_diffusion, reverse_process=True)
|
58 |
+
self.count = 0
|
59 |
+
|
60 |
+
def get_random_latents(self, latents=None, height=512, width=512, generator=None):
|
61 |
+
height = height or self.unet.config.sample_size * self.vae_scale_factor
|
62 |
+
width = width or self.unet.config.sample_size * self.vae_scale_factor
|
63 |
+
|
64 |
+
batch_size = 1
|
65 |
+
device = self._execution_device
|
66 |
+
|
67 |
+
num_channels_latents = self.unet.in_channels
|
68 |
+
|
69 |
+
latents = self.prepare_latents(
|
70 |
+
batch_size,
|
71 |
+
num_channels_latents,
|
72 |
+
height,
|
73 |
+
width,
|
74 |
+
self.text_encoder.dtype,
|
75 |
+
device,
|
76 |
+
generator,
|
77 |
+
latents,
|
78 |
+
)
|
79 |
+
|
80 |
+
return latents
|
81 |
+
|
82 |
+
@torch.inference_mode()
|
83 |
+
def get_text_embedding(self, prompt):
|
84 |
+
text_input_ids = self.tokenizer(
|
85 |
+
prompt,
|
86 |
+
padding="max_length",
|
87 |
+
truncation=True,
|
88 |
+
max_length=self.tokenizer.model_max_length,
|
89 |
+
return_tensors="pt",
|
90 |
+
).input_ids
|
91 |
+
text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0]
|
92 |
+
return text_embeddings
|
93 |
+
|
94 |
+
@torch.inference_mode()
|
95 |
+
def get_image_latents(self, image, sample=True, rng_generator=None):
|
96 |
+
encoding_dist = self.vae.encode(image).latent_dist
|
97 |
+
if sample:
|
98 |
+
encoding = encoding_dist.sample(generator=rng_generator)
|
99 |
+
else:
|
100 |
+
encoding = encoding_dist.mode()
|
101 |
+
latents = encoding * 0.18215
|
102 |
+
return latents
|
103 |
+
|
104 |
+
|
105 |
+
@torch.inference_mode()
|
106 |
+
def backward_diffusion(
|
107 |
+
self,
|
108 |
+
use_old_emb_i=25,
|
109 |
+
text_embeddings=None,
|
110 |
+
old_text_embeddings=None,
|
111 |
+
new_text_embeddings=None,
|
112 |
+
latents: Optional[torch.FloatTensor] = None,
|
113 |
+
num_inference_steps: int = 50,
|
114 |
+
guidance_scale: float = 7.5,
|
115 |
+
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
116 |
+
callback_steps: Optional[int] = 1,
|
117 |
+
reverse_process: True = False,
|
118 |
+
**kwargs,
|
119 |
+
):
|
120 |
+
""" Generate image from text prompt and latents
|
121 |
+
"""
|
122 |
+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
123 |
+
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
124 |
+
# corresponds to doing no classifier free guidance.
|
125 |
+
do_classifier_free_guidance = guidance_scale > 1.0
|
126 |
+
# set timesteps
|
127 |
+
self.scheduler.set_timesteps(num_inference_steps)
|
128 |
+
# Some schedulers like PNDM have timesteps as arrays
|
129 |
+
# It's more optimized to move all timesteps to correct device beforehand
|
130 |
+
timesteps_tensor = self.scheduler.timesteps.to(self.device)
|
131 |
+
# scale the initial noise by the standard deviation required by the scheduler
|
132 |
+
latents = latents * self.scheduler.init_noise_sigma
|
133 |
+
|
134 |
+
if old_text_embeddings is not None and new_text_embeddings is not None:
|
135 |
+
prompt_to_prompt = True
|
136 |
+
else:
|
137 |
+
prompt_to_prompt = False
|
138 |
+
|
139 |
+
|
140 |
+
for i, t in enumerate(self.progress_bar(timesteps_tensor if not reverse_process else reversed(timesteps_tensor))):
|
141 |
+
if prompt_to_prompt:
|
142 |
+
if i < use_old_emb_i:
|
143 |
+
text_embeddings = old_text_embeddings
|
144 |
+
else:
|
145 |
+
text_embeddings = new_text_embeddings
|
146 |
+
|
147 |
+
# expand the latents if we are doing classifier free guidance
|
148 |
+
latent_model_input = (
|
149 |
+
torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
150 |
+
)
|
151 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
152 |
+
|
153 |
+
# predict the noise residual
|
154 |
+
noise_pred = self.unet(
|
155 |
+
latent_model_input, t, encoder_hidden_states=text_embeddings
|
156 |
+
).sample
|
157 |
+
|
158 |
+
# perform guidance
|
159 |
+
if do_classifier_free_guidance:
|
160 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
161 |
+
noise_pred = noise_pred_uncond + guidance_scale * (
|
162 |
+
noise_pred_text - noise_pred_uncond
|
163 |
+
)
|
164 |
+
|
165 |
+
prev_timestep = (
|
166 |
+
t
|
167 |
+
- self.scheduler.config.num_train_timesteps
|
168 |
+
// self.scheduler.num_inference_steps
|
169 |
+
)
|
170 |
+
# call the callback, if provided
|
171 |
+
if callback is not None and i % callback_steps == 0:
|
172 |
+
callback(i, t, latents)
|
173 |
+
|
174 |
+
# ddim
|
175 |
+
alpha_prod_t = self.scheduler.alphas_cumprod[t]
|
176 |
+
alpha_prod_t_prev = (
|
177 |
+
self.scheduler.alphas_cumprod[prev_timestep]
|
178 |
+
if prev_timestep >= 0
|
179 |
+
else self.scheduler.final_alpha_cumprod
|
180 |
+
)
|
181 |
+
if reverse_process:
|
182 |
+
alpha_prod_t, alpha_prod_t_prev = alpha_prod_t_prev, alpha_prod_t
|
183 |
+
latents = backward_ddim(
|
184 |
+
x_t=latents,
|
185 |
+
alpha_t=alpha_prod_t,
|
186 |
+
alpha_tm1=alpha_prod_t_prev,
|
187 |
+
eps_xt=noise_pred,
|
188 |
+
)
|
189 |
+
return latents
|
190 |
+
|
191 |
+
|
192 |
+
@torch.inference_mode()
|
193 |
+
def decode_image(self, latents: torch.FloatTensor, **kwargs):
|
194 |
+
scaled_latents = 1 / 0.18215 * latents
|
195 |
+
image = [
|
196 |
+
self.vae.decode(scaled_latents[i : i + 1]).sample for i in range(len(latents))
|
197 |
+
]
|
198 |
+
image = torch.cat(image, dim=0)
|
199 |
+
return image
|
200 |
+
|
201 |
+
@torch.inference_mode()
|
202 |
+
def torch_to_numpy(self, image):
|
203 |
+
image = (image / 2 + 0.5).clamp(0, 1)
|
204 |
+
image = image.cpu().permute(0, 2, 3, 1).numpy()
|
205 |
+
return image
|
modified_stable_diffusion.py
ADDED
@@ -0,0 +1,235 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from typing import Callable, List, Optional, Union, Any, Dict
|
3 |
+
import copy
|
4 |
+
import numpy as np
|
5 |
+
import PIL
|
6 |
+
|
7 |
+
import torch
|
8 |
+
from diffusers import StableDiffusionPipeline, DDIMScheduler, UNet2DConditionModel
|
9 |
+
from diffusers.utils import logging, BaseOutput
|
10 |
+
from torchvision.transforms import ToPILImage
|
11 |
+
|
12 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
13 |
+
|
14 |
+
|
15 |
+
|
16 |
+
|
17 |
+
|
18 |
+
class ModifiedStableDiffusionPipelineOutput(BaseOutput):
|
19 |
+
images: Union[List[PIL.Image.Image], np.ndarray]
|
20 |
+
nsfw_content_detected: Optional[List[bool]]
|
21 |
+
init_latents: Optional[torch.FloatTensor]
|
22 |
+
|
23 |
+
|
24 |
+
class ModifiedStableDiffusionPipeline(StableDiffusionPipeline):
|
25 |
+
def __init__(self,
|
26 |
+
vae,
|
27 |
+
text_encoder,
|
28 |
+
tokenizer,
|
29 |
+
unet,
|
30 |
+
scheduler,
|
31 |
+
safety_checker,
|
32 |
+
feature_extractor,
|
33 |
+
requires_safety_checker: bool = False,
|
34 |
+
):
|
35 |
+
super(ModifiedStableDiffusionPipeline, self).__init__(vae,
|
36 |
+
text_encoder,
|
37 |
+
tokenizer,
|
38 |
+
unet,
|
39 |
+
scheduler,
|
40 |
+
safety_checker,
|
41 |
+
feature_extractor,
|
42 |
+
requires_safety_checker)
|
43 |
+
|
44 |
+
@torch.no_grad()
|
45 |
+
def __call__(
|
46 |
+
self,
|
47 |
+
prompt: Union[str, List[str]],
|
48 |
+
height: Optional[int] = None,
|
49 |
+
width: Optional[int] = None,
|
50 |
+
num_inference_steps: int = 50,
|
51 |
+
guidance_scale: float = 7.5,
|
52 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
53 |
+
num_images_per_prompt: Optional[int] = 1,
|
54 |
+
eta: float = 0.0,
|
55 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
56 |
+
latents: Optional[torch.FloatTensor] = None,
|
57 |
+
output_type: Optional[str] = "pil",
|
58 |
+
return_dict: bool = True,
|
59 |
+
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
60 |
+
callback_steps: Optional[int] = 1,
|
61 |
+
watermarking_gamma: float = None,
|
62 |
+
watermarking_delta: float = None,
|
63 |
+
watermarking_mask: Optional[torch.BoolTensor] = None,
|
64 |
+
):
|
65 |
+
r"""
|
66 |
+
Function invoked when calling the pipeline for generation.
|
67 |
+
|
68 |
+
Args:
|
69 |
+
prompt (`str` or `List[str]`):
|
70 |
+
The prompt or prompts to guide the image generation.
|
71 |
+
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
72 |
+
The height in pixels of the generated image.
|
73 |
+
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
74 |
+
The width in pixels of the generated image.
|
75 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
76 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
77 |
+
expense of slower inference.
|
78 |
+
guidance_scale (`float`, *optional*, defaults to 7.5):
|
79 |
+
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
80 |
+
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
81 |
+
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
82 |
+
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
83 |
+
usually at the expense of lower image quality.
|
84 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
85 |
+
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
|
86 |
+
if `guidance_scale` is less than `1`).
|
87 |
+
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
88 |
+
The number of images to generate per prompt.
|
89 |
+
eta (`float`, *optional*, defaults to 0.0):
|
90 |
+
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
91 |
+
[`schedulers.DDIMScheduler`], will be ignored for others.
|
92 |
+
generator (`torch.Generator`, *optional*):
|
93 |
+
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
94 |
+
to make generation deterministic.
|
95 |
+
latents (`torch.FloatTensor`, *optional*):
|
96 |
+
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
97 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
98 |
+
tensor will ge generated by sampling using the supplied random `generator`.
|
99 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
100 |
+
The output format of the generate image. Choose between
|
101 |
+
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
102 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
103 |
+
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
|
104 |
+
plain tuple.
|
105 |
+
callback (`Callable`, *optional*):
|
106 |
+
A function that will be called every `callback_steps` steps during inference. The function will be
|
107 |
+
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
|
108 |
+
callback_steps (`int`, *optional*, defaults to 1):
|
109 |
+
The frequency at which the `callback` function will be called. If not specified, the callback will be
|
110 |
+
called at every step.
|
111 |
+
|
112 |
+
Returns:
|
113 |
+
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
|
114 |
+
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
|
115 |
+
When returning a tuple, the first element is a list with the generated images, and the second element is a
|
116 |
+
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
|
117 |
+
(nsfw) content, according to the `safety_checker`.
|
118 |
+
"""
|
119 |
+
# 0. Default height and width to unet
|
120 |
+
height = height or self.unet.config.sample_size * self.vae_scale_factor
|
121 |
+
width = width or self.unet.config.sample_size * self.vae_scale_factor
|
122 |
+
self.count = 0
|
123 |
+
|
124 |
+
|
125 |
+
# 1. Check inputs. Raise error if not correct
|
126 |
+
self.check_inputs(prompt, height, width, callback_steps)
|
127 |
+
|
128 |
+
# 2. Define call parameters
|
129 |
+
batch_size = 1 if isinstance(prompt, str) else len(prompt)
|
130 |
+
device = self._execution_device
|
131 |
+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
132 |
+
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
133 |
+
# corresponds to doing no classifier free guidance.
|
134 |
+
do_classifier_free_guidance = guidance_scale > 1.0
|
135 |
+
|
136 |
+
# 3. Encode input prompt
|
137 |
+
text_embeddings = self._encode_prompt(
|
138 |
+
prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt
|
139 |
+
)
|
140 |
+
|
141 |
+
# 4. Prepare timesteps
|
142 |
+
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
143 |
+
timesteps = self.scheduler.timesteps
|
144 |
+
|
145 |
+
# 5. Prepare latent variables
|
146 |
+
num_channels_latents = self.unet.in_channels
|
147 |
+
latents = self.prepare_latents(
|
148 |
+
batch_size * num_images_per_prompt,
|
149 |
+
num_channels_latents,
|
150 |
+
height,
|
151 |
+
width,
|
152 |
+
text_embeddings.dtype,
|
153 |
+
device,
|
154 |
+
generator,
|
155 |
+
latents,
|
156 |
+
)
|
157 |
+
|
158 |
+
init_latents = copy.deepcopy(latents)
|
159 |
+
|
160 |
+
# watermarking mask
|
161 |
+
if watermarking_gamma is not None:
|
162 |
+
watermarking_mask = torch.rand(latents.shape, device=device) < watermarking_gamma
|
163 |
+
|
164 |
+
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
165 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
166 |
+
|
167 |
+
# 7. Denoising loop
|
168 |
+
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
169 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
170 |
+
for i, t in enumerate(timesteps):
|
171 |
+
# add watermark
|
172 |
+
if watermarking_mask is not None:
|
173 |
+
# latents[watermarking_mask] += watermarking_delta
|
174 |
+
latents[watermarking_mask] += watermarking_delta * torch.sign(latents[watermarking_mask])
|
175 |
+
|
176 |
+
# expand the latents if we are doing classifier free guidance
|
177 |
+
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
178 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
179 |
+
|
180 |
+
# predict the noise residual
|
181 |
+
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
|
182 |
+
|
183 |
+
# perform guidance
|
184 |
+
if do_classifier_free_guidance:
|
185 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
186 |
+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
187 |
+
|
188 |
+
# compute the previous noisy sample x_t -> x_t-1
|
189 |
+
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
190 |
+
|
191 |
+
# call the callback, if provided
|
192 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
193 |
+
progress_bar.update()
|
194 |
+
if callback is not None and i % callback_steps == 0:
|
195 |
+
callback(i, t, latents)
|
196 |
+
|
197 |
+
# 8. Post-processing
|
198 |
+
image = self.decode_latents(latents)
|
199 |
+
# 9. Run safety checker
|
200 |
+
image, has_nsfw_concept = self.run_safety_checker(image, device, text_embeddings.dtype)
|
201 |
+
|
202 |
+
# 10. Convert to PIL
|
203 |
+
if output_type == "pil":
|
204 |
+
image = self.numpy_to_pil(image)
|
205 |
+
|
206 |
+
if not return_dict:
|
207 |
+
return (image, has_nsfw_concept)
|
208 |
+
|
209 |
+
return ModifiedStableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept, init_latents=init_latents)
|
210 |
+
|
211 |
+
|
212 |
+
@torch.inference_mode()
|
213 |
+
def decode_image(self, latents: torch.FloatTensor, **kwargs):
|
214 |
+
scaled_latents = 1 / 0.18215 * latents
|
215 |
+
image = [
|
216 |
+
self.vae.decode(scaled_latents[i : i + 1]).sample for i in range(len(latents))
|
217 |
+
]
|
218 |
+
image = torch.cat(image, dim=0)
|
219 |
+
return image
|
220 |
+
|
221 |
+
@torch.inference_mode()
|
222 |
+
def torch_to_numpy(self, image):
|
223 |
+
image = (image / 2 + 0.5).clamp(0, 1)
|
224 |
+
image = image.cpu().permute(0, 2, 3, 1).numpy()
|
225 |
+
return image
|
226 |
+
|
227 |
+
@torch.inference_mode()
|
228 |
+
def get_image_latents(self, image, sample=True, rng_generator=None):
|
229 |
+
encoding_dist = self.vae.encode(image).latent_dist
|
230 |
+
if sample:
|
231 |
+
encoding = encoding_dist.sample(generator=rng_generator)
|
232 |
+
else:
|
233 |
+
encoding = encoding_dist.mode()
|
234 |
+
latents = encoding * 0.18215
|
235 |
+
return latents
|
requirements.txt
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
albumentations
|
2 |
+
diffusers
|
3 |
+
einops
|
4 |
+
huggingface_hub
|
5 |
+
natsort
|
6 |
+
pillow
|
7 |
+
PyYAML
|
8 |
+
regex
|
9 |
+
requests
|
10 |
+
timm
|
11 |
+
torch
|
12 |
+
torchvision
|
13 |
+
tqdm
|
14 |
+
transformers
|
run_gaussian_shading.py
ADDED
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from tqdm import tqdm
|
2 |
+
import torch
|
3 |
+
from transformers import CLIPModel, CLIPTokenizer
|
4 |
+
from inverse_stable_diffusion import InversableStableDiffusionPipeline
|
5 |
+
from diffusers import DPMSolverMultistepScheduler, DDIMScheduler
|
6 |
+
import os
|
7 |
+
import gradio as gr
|
8 |
+
from image_utils import *
|
9 |
+
from watermark import *
|
10 |
+
|
11 |
+
|
12 |
+
# Initialize the parameter:
|
13 |
+
model_path = 'stabilityai/stable-diffusion-2-1-base'
|
14 |
+
channel_copy = 1
|
15 |
+
hw_copy = 8
|
16 |
+
fpr = 0.000001
|
17 |
+
user_number = 1000000
|
18 |
+
guidance_scale = 7.5
|
19 |
+
num_inference_steps = 50
|
20 |
+
image_length = 512
|
21 |
+
|
22 |
+
|
23 |
+
# """ ---------------------- Initialization ---------------------- """
|
24 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
25 |
+
scheduler = DPMSolverMultistepScheduler.from_pretrained(model_path, subfolder='scheduler')
|
26 |
+
pipe = InversableStableDiffusionPipeline.from_pretrained(
|
27 |
+
model_path,
|
28 |
+
scheduler=scheduler,
|
29 |
+
torch_dtype=torch.float16,
|
30 |
+
revision='fp16',
|
31 |
+
)
|
32 |
+
pipe.safety_checker = None
|
33 |
+
pipe = pipe.to(device)
|
34 |
+
|
35 |
+
#a simple implement watermark
|
36 |
+
watermark = Gaussian_Shading(channel_copy, hw_copy, fpr, user_number)
|
37 |
+
|
38 |
+
# assume at the detection time, the original prompt is unknown
|
39 |
+
tester_prompt = ''
|
40 |
+
text_embeddings = pipe.get_text_embedding(tester_prompt)
|
41 |
+
|
42 |
+
#generate with watermark
|
43 |
+
def generate_with_watermark(seed, prompt, guidance_scale=7.5, num_inference_steps=50):
|
44 |
+
set_random_seed(seed)
|
45 |
+
|
46 |
+
init_latents_w, key, wk = watermark.create_watermark_and_return_w()
|
47 |
+
watermark_list = []
|
48 |
+
torch.save(key, 'key.pt')
|
49 |
+
if not os.path.exists('watermark.pt'):
|
50 |
+
torch.save(wk, 'watermark.pt')
|
51 |
+
else:
|
52 |
+
watermark_list = torch.load('watermark.pt')
|
53 |
+
if not isinstance(watermark_list, list):
|
54 |
+
watermark_list = [watermark_list]
|
55 |
+
watermark_list.append(wk)
|
56 |
+
torch.save(watermark_list, 'watermark.pt')
|
57 |
+
|
58 |
+
outputs = pipe(
|
59 |
+
prompt,
|
60 |
+
num_images_per_prompt=1,
|
61 |
+
guidance_scale=guidance_scale,
|
62 |
+
num_inference_steps=num_inference_steps,
|
63 |
+
height=image_length,
|
64 |
+
width=image_length,
|
65 |
+
latents=init_latents_w,
|
66 |
+
)
|
67 |
+
image_w = outputs.images[0]
|
68 |
+
# From original
|
69 |
+
outputs_original = pipe(
|
70 |
+
prompt,
|
71 |
+
num_images_per_prompt=1,
|
72 |
+
guidance_scale=guidance_scale,
|
73 |
+
num_inference_steps=num_inference_steps,
|
74 |
+
height=image_length,
|
75 |
+
width=image_length
|
76 |
+
)
|
77 |
+
image_original = outputs_original.images[0]
|
78 |
+
|
79 |
+
# save file, download and remove
|
80 |
+
image_path = 'output_image.png'
|
81 |
+
if os.path.exists(image_path):
|
82 |
+
os.remove(image_path)
|
83 |
+
|
84 |
+
image_w.save('output_image.png', format='PNG')
|
85 |
+
return image_original, image_w, 'output_image.png'
|
86 |
+
|
87 |
+
# reverse img
|
88 |
+
def reverse_watermark(image, *args, **kwargs):
|
89 |
+
image_attacked = image_distortion(image, *args, **kwargs)
|
90 |
+
image_w_distortion = transform_img(image_attacked).unsqueeze(0).to(text_embeddings.dtype).to(device)
|
91 |
+
image_latents_w = pipe.get_image_latents(image_w_distortion, sample=False)
|
92 |
+
reversed_latents_w = pipe.forward_diffusion(
|
93 |
+
latents=image_latents_w,
|
94 |
+
text_embeddings=text_embeddings,
|
95 |
+
guidance_scale=1,
|
96 |
+
num_inference_steps=50,
|
97 |
+
)
|
98 |
+
try:
|
99 |
+
bit, accuracy = watermark.eval_watermark(reversed_latents_w)
|
100 |
+
except FileNotFoundError:
|
101 |
+
raise gr.Error("Database is empty. Please generate Image first!", duration=8)
|
102 |
+
if accuracy > 0.7:
|
103 |
+
output = 'This Image have watermark'
|
104 |
+
else:
|
105 |
+
output = "This Image doesn't have watermark"
|
106 |
+
return output, bit, accuracy, image_attacked
|
watermark.py
ADDED
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from scipy.stats import norm,truncnorm
|
3 |
+
from functools import reduce
|
4 |
+
from scipy.special import betainc
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
|
8 |
+
class Gaussian_Shading:
|
9 |
+
def __init__(self, ch_factor, hw_factor, fpr, user_number):
|
10 |
+
self.ch = ch_factor
|
11 |
+
self.hw = hw_factor
|
12 |
+
self.key = None
|
13 |
+
self.watermark = None
|
14 |
+
self.latentlength = 4 * 64 * 64
|
15 |
+
self.marklength = self.latentlength//(self.ch * self.hw * self.hw)
|
16 |
+
|
17 |
+
self.threshold = 1 if self.hw == 1 and self.ch == 1 else self.ch * self.hw * self.hw // 2
|
18 |
+
self.tp_onebit_count = 0
|
19 |
+
self.tp_bits_count = 0
|
20 |
+
self.tau_onebit = None
|
21 |
+
self.tau_bits = None
|
22 |
+
|
23 |
+
for i in range(self.marklength):
|
24 |
+
fpr_onebit = betainc(i+1, self.marklength-i, 0.5)
|
25 |
+
fpr_bits = betainc(i+1, self.marklength-i, 0.5) * user_number
|
26 |
+
if fpr_onebit <= fpr and self.tau_onebit is None:
|
27 |
+
self.tau_onebit = i / self.marklength
|
28 |
+
if fpr_bits <= fpr and self.tau_bits is None:
|
29 |
+
self.tau_bits = i / self.marklength
|
30 |
+
|
31 |
+
def truncSampling(self, message):
|
32 |
+
z = np.zeros(self.latentlength)
|
33 |
+
denominator = 2.0
|
34 |
+
ppf = [norm.ppf(j / denominator) for j in range(int(denominator) + 1)]
|
35 |
+
for i in range(self.latentlength):
|
36 |
+
dec_mes = reduce(lambda a, b: 2 * a + b, message[i : i + 1])
|
37 |
+
dec_mes = int(dec_mes)
|
38 |
+
z[i] = truncnorm.rvs(ppf[dec_mes], ppf[dec_mes + 1])
|
39 |
+
z = torch.from_numpy(z).reshape(1, 4, 64, 64).half()
|
40 |
+
return z.cuda()
|
41 |
+
|
42 |
+
def create_watermark_and_return_w(self):
|
43 |
+
rng_state = torch.get_rng_state()
|
44 |
+
torch.manual_seed(42)
|
45 |
+
self.key = torch.randint(0, 2, [1, 4, 64, 64]).cuda()
|
46 |
+
torch.set_rng_state(rng_state)
|
47 |
+
|
48 |
+
self.watermark = torch.randint(0, 2, [1, 4 // self.ch, 64 // self.hw, 64 // self.hw]).cuda()
|
49 |
+
sd = self.watermark.repeat(1,self.ch,self.hw,self.hw)
|
50 |
+
m = ((sd + self.key) % 2).flatten().cpu().numpy()
|
51 |
+
w = self.truncSampling(m)
|
52 |
+
return w, self.key, self.watermark
|
53 |
+
|
54 |
+
def diffusion_inverse(self,watermark_sd):
|
55 |
+
ch_stride = 4 // self.ch
|
56 |
+
hw_stride = 64 // self.hw
|
57 |
+
ch_list = [ch_stride] * self.ch
|
58 |
+
hw_list = [hw_stride] * self.hw
|
59 |
+
split_dim1 = torch.cat(torch.split(watermark_sd, tuple(ch_list), dim=1), dim=0)
|
60 |
+
split_dim2 = torch.cat(torch.split(split_dim1, tuple(hw_list), dim=2), dim=0)
|
61 |
+
split_dim3 = torch.cat(torch.split(split_dim2, tuple(hw_list), dim=3), dim=0)
|
62 |
+
vote = torch.sum(split_dim3, dim=0).clone()
|
63 |
+
vote[vote <= self.threshold] = 0
|
64 |
+
vote[vote > self.threshold] = 1
|
65 |
+
return vote
|
66 |
+
|
67 |
+
def sequence_binary_watermark(self, watermark):
|
68 |
+
ls = watermark.view(-1).tolist()
|
69 |
+
sequence = ''.join(str(i) for i in ls)
|
70 |
+
return sequence
|
71 |
+
|
72 |
+
def eval_watermark(self, reversed_m):
|
73 |
+
key = torch.load('key.pt')
|
74 |
+
reversed_m = (reversed_m > 0).int()
|
75 |
+
# reversed_sd = (reversed_m + self.key) % 2
|
76 |
+
reversed_sd = (reversed_m + key) % 2
|
77 |
+
reversed_watermark = self.diffusion_inverse(reversed_sd)
|
78 |
+
print(f"The extracted watermark is {self.sequence_binary_watermark(reversed_watermark)}")
|
79 |
+
|
80 |
+
watermark = torch.load('watermark.pt')
|
81 |
+
ls_accurate = []
|
82 |
+
for i in watermark:
|
83 |
+
ls_accurate.append((reversed_watermark == i).float().mean().item())
|
84 |
+
|
85 |
+
correct = max(ls_accurate)
|
86 |
+
if correct >= self.tau_onebit:
|
87 |
+
self.tp_onebit_count = self.tp_onebit_count+1
|
88 |
+
if correct >= self.tau_bits:
|
89 |
+
self.tp_bits_count = self.tp_bits_count + 1
|
90 |
+
return self.sequence_binary_watermark(reversed_watermark), correct
|
91 |
+
|
92 |
+
def get_tpr(self):
|
93 |
+
return self.tp_onebit_count, self.tp_bits_count
|
94 |
+
|
95 |
+
|