TruongScotl commited on
Commit
f961e67
·
verified ·
1 Parent(s): 0c4598f

Upload 7 files

Browse files
app.py ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import gradio as gr
3
+ from run_gaussian_shading import *
4
+
5
+ examples = [
6
+ "A photo of a cat",
7
+ "A pizza with pineapple on it",
8
+ "A photo of dog",
9
+ ]
10
+
11
+ css = """
12
+ #col-container {
13
+ margin: 0 auto;
14
+ max-width: 700px;
15
+ }
16
+ """
17
+
18
+ MAX_SEED = np.iinfo(np.int32).max
19
+
20
+
21
+ #---------------------------------------------------------------------------------------------------
22
+
23
+
24
+ with gr.Blocks(css=css) as demo:
25
+
26
+ # ---------------------------------- Add Watermark -----------------------------------------
27
+
28
+ with gr.Tab("Add watermark"):
29
+ with gr.Column(elem_id="col-container"):
30
+ gr.Markdown(" # Text-to-Image Watermark")
31
+ with gr.Accordion("Instruction", open=False):
32
+ gr.Markdown("""
33
+ # Embedding Watermark
34
+ ## 1. Generate watermarked image
35
+ * Enter your prompt in the text box.
36
+ * Click **Run** to generate an image with a random binary watermark.
37
+
38
+ ## 2. Save Image
39
+ Click **Download** to save the watermarked image in PNG format
40
+
41
+
42
+ ## 3. Advanced Settings
43
+ - **Seed**: Generates different images with different seed.
44
+ - **Guidance Scale**: Higher values give the model more freedom in image creation.
45
+ - **Num Inference Steps**: More steps enhance image detail and quality but increase computational cost.
46
+ Source code: [Gaussian Shading](https://github.com/bsmhmmlf/Gaussian-Shading)""")
47
+ with gr.Row():
48
+ prompt = gr.Text(
49
+ label="Prompt",
50
+ show_label=False,
51
+ max_lines=1,
52
+ placeholder="Enter your prompt",
53
+ container=False,
54
+ )
55
+ run_button = gr.Button("Run", scale=0, variant="primary")
56
+ download_button = gr.DownloadButton(visible=True)
57
+ with gr.Row():
58
+ result_original = gr.Image(label="Image without watermark", show_label=True)
59
+ result = gr.Image(label="Watermarked Image", show_label=True)
60
+
61
+ with gr.Accordion("Advanced Settings", open=False):
62
+ seed = gr.Slider(
63
+ label="Seed",
64
+ minimum=0,
65
+ maximum=MAX_SEED,
66
+ step=1,
67
+ value=0,
68
+ )
69
+ with gr.Row():
70
+ guidance_scale = gr.Slider(
71
+ label="Guidance scale",
72
+ minimum=1.5,
73
+ maximum=10,
74
+ step=0.1,
75
+ value=7.5,
76
+ )
77
+ num_inference_steps = gr.Slider(
78
+ label="Num inference steps",
79
+ minimum=10,
80
+ maximum=100,
81
+ step=1,
82
+ value=50,
83
+ )
84
+ gr.Examples(examples=examples, inputs=[prompt])
85
+
86
+ # ---------------------------------- Extract Watermark -----------------------------------------
87
+ with gr.Tab("Extract watermark"):
88
+ with gr.Column(elem_id="col-container"):
89
+ gr.Markdown(" # Watermark Extraction")
90
+ with gr.Accordion("Instruction", open=False):
91
+ gr.Markdown("""
92
+ # Extracting Watermark
93
+ **Note**: Ensure you create an image first to add the watermark to the database.
94
+ ## 1. Upload Image
95
+ - Upload the image to the Image box.
96
+ - Click the **Extract** button to extract the watermark.
97
+ ## 2. Advanced Settings
98
+ These settings are **optional** and can be used to simulate real-world attacks to erase the watermark:
99
+ Click the **Attack** button to generate a distorted image.
100
+ * **Seed**: Initialize the random number generator, ensuring reproducibility of the attack
101
+ * **Random crop ratio**: determines the proportion of the image to be randomly cropped. A lower ratio means more of the image will be cropped.
102
+ * **Random drop ratio**: specifies the fraction of pixels to be randomly dropped. A higher ratio increases the number of dropped pixels.
103
+ * **Resize ratio**: determines how much the image will be resized. A lower ratio means the image will be reduced more significantly.
104
+ * **Gaussian blur R**: the radius of the Gaussian blur applied to the image. A larger radius results in a more blurred image.
105
+ * **Gaussian Std**: standard deviation of the Gaussian distribution used for blurring. A higher value results in a stronger blur effect.
106
+ * **Sp prob**: the probability of each pixel being replaced with either black or white noise. A higher probability increases the amount of noise added to the image.
107
+ ## Output Explanation
108
+ - **Output watermark**: The binary bit embedding in the image.
109
+ - **Accuracy bit**: The number of binary bits extracted that match the binary watermark in the database.
110
+ """)
111
+ with gr.Row():
112
+ input_image = gr.Image(type='pil')
113
+ extract_button = gr.Button("Extract", scale=0, variant="primary")
114
+
115
+ with gr.Accordion("Advanced Settings", open=False):
116
+ with gr.Row():
117
+ seed = gr.Slider(
118
+ label="Seed",
119
+ minimum=0,
120
+ maximum=MAX_SEED,
121
+ step=1,
122
+ value=0,
123
+ )
124
+ attack_button = gr.Button("Attack!", scale=0, variant="primary")
125
+ with gr.Row():
126
+ random_crop_ratio = gr.Slider(
127
+ label="Random crop ratio",
128
+ minimum=0.5,
129
+ maximum=1,
130
+ step=0.1,
131
+ value=1,
132
+ )
133
+ random_drop_ratio = gr.Slider(
134
+ label="Random drop ratio",
135
+ minimum=0,
136
+ maximum=1,
137
+ step=0.1,
138
+ value=0,
139
+ )
140
+ with gr.Row():
141
+ resize_ratio = gr.Slider(
142
+ label="Resize ratio",
143
+ minimum=0.2,
144
+ maximum=1,
145
+ step=0.1,
146
+ value=1,
147
+ )
148
+ gaussian_blur_r = gr.Slider(
149
+ label="Gaussian blur r",
150
+ minimum=0,
151
+ maximum=1,
152
+ step=0.1,
153
+ value=0,
154
+ )
155
+ with gr.Row():
156
+ gaussian_std = gr.Slider(
157
+ label="Gaussian std",
158
+ minimum=0,
159
+ maximum=0.01,
160
+ step=0.0001,
161
+ value=0,
162
+ )
163
+ sp_prob = gr.Slider(
164
+ label="Sp prob",
165
+ minimum=0,
166
+ maximum=0.1,
167
+ step=0.001,
168
+ value=0,
169
+ )
170
+ attack_image = gr.Image(label="Attacked Image")
171
+ output = gr.Textbox(label="Output")
172
+ with gr.Accordion("More Details", open=False):
173
+ result_extract = gr.Textbox(label="Bit watermark")
174
+ accuracy_bit = gr.Textbox(label="Accuracy bit")
175
+
176
+ # ----------------------------- Embedding watermark -------------------------
177
+ gr.on(
178
+ triggers=[run_button.click, prompt.submit],
179
+ fn=generate_with_watermark,
180
+ inputs=[
181
+ seed,
182
+ prompt,
183
+ guidance_scale,
184
+ num_inference_steps
185
+ ],
186
+ outputs=[result_original, result, download_button],
187
+ )
188
+
189
+ # ----------------------------- Extract watermark -------------------------
190
+ gr.on(
191
+ triggers=[extract_button.click, attack_button.click],
192
+ fn=reverse_watermark,
193
+ inputs=[
194
+ input_image,
195
+ seed,
196
+ random_crop_ratio,
197
+ random_drop_ratio,
198
+ resize_ratio,
199
+ gaussian_blur_r,
200
+ gaussian_std,
201
+ sp_prob,
202
+ ],
203
+ outputs=[output, result_extract, accuracy_bit, attack_image],
204
+ )
205
+ demo.launch(share=True)
image_utils.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from torchvision import transforms
4
+ from PIL import Image, ImageFilter
5
+ import random
6
+
7
+
8
+ def set_random_seed(seed=0):
9
+ torch.manual_seed(seed + 0)
10
+ torch.cuda.manual_seed(seed + 1)
11
+ torch.cuda.manual_seed_all(seed + 2)
12
+ np.random.seed(seed + 3)
13
+ torch.cuda.manual_seed_all(seed + 4)
14
+ random.seed(seed + 5)
15
+
16
+
17
+ def transform_img(image, target_size=512):
18
+ tform = transforms.Compose(
19
+ [
20
+ transforms.Resize(target_size),
21
+ transforms.CenterCrop(target_size),
22
+ transforms.ToTensor(),
23
+ ]
24
+ )
25
+ image = tform(image)
26
+ return 2.0 * image - 1.0
27
+
28
+
29
+ def latents_to_imgs(pipe, latents):
30
+ x = pipe.decode_image(latents)
31
+ x = pipe.torch_to_numpy(x)
32
+ x = pipe.numpy_to_pil(x)
33
+ return x
34
+
35
+ def image_distortion(img,
36
+ seed: int = 42,
37
+ random_crop_ratio: float = None,
38
+ random_drop_ratio: float = None,
39
+ resize_ratio: float = None,
40
+ gaussian_blur_r: int = None, #
41
+ gaussian_std: float = None,
42
+ sp_prob: float = None):
43
+
44
+ if random_crop_ratio is not None:
45
+ set_random_seed(seed)
46
+ width, height, c = np.array(img).shape
47
+ img = np.array(img)
48
+ new_width = int(width * random_crop_ratio)
49
+ new_height = int(height * random_crop_ratio)
50
+ start_x = np.random.randint(0, width - new_width + 1)
51
+ start_y = np.random.randint(0, height - new_height + 1)
52
+ end_x = start_x + new_width
53
+ end_y = start_y + new_height
54
+ padded_image = np.zeros_like(img)
55
+ padded_image[start_y:end_y, start_x:end_x] = img[start_y:end_y, start_x:end_x]
56
+ img = Image.fromarray(padded_image)
57
+
58
+ if random_drop_ratio is not None:
59
+ set_random_seed(seed)
60
+ width, height, c = np.array(img).shape
61
+ img = np.array(img)
62
+ new_width = int(width * random_drop_ratio)
63
+ new_height = int(height * random_drop_ratio)
64
+ start_x = np.random.randint(0, width - new_width + 1)
65
+ start_y = np.random.randint(0, height - new_height + 1)
66
+ padded_image = np.zeros_like(img[start_y:start_y + new_height, start_x:start_x + new_width])
67
+ img[start_y:start_y + new_height, start_x:start_x + new_width] = padded_image
68
+ img = Image.fromarray(img)
69
+
70
+ if resize_ratio is not None:
71
+ img_shape = np.array(img).shape
72
+ resize_size = int(img_shape[0] * resize_ratio)
73
+ img = transforms.Resize(size=resize_size)(img)
74
+ img = transforms.Resize(size=img_shape[0])(img)
75
+
76
+ if gaussian_blur_r is not None:
77
+ img = img.filter(ImageFilter.GaussianBlur(radius=gaussian_blur_r))
78
+
79
+ if gaussian_std is not None:
80
+ img_shape = np.array(img).shape
81
+ g_noise = np.random.normal(0, gaussian_std, img_shape) * 255
82
+ g_noise = g_noise.astype(np.uint8)
83
+ img = Image.fromarray(np.clip(np.array(img) + g_noise, 0, 255))
84
+
85
+ if sp_prob is not None:
86
+ c,h,w = np.array(img).shape
87
+ prob_zero = sp_prob / 2
88
+ prob_one = 1 - prob_zero
89
+ rdn = np.random.rand(c,h,w)
90
+ img = np.where(rdn > prob_one, np.zeros_like(img), img)
91
+ img = np.where(rdn < prob_zero, np.ones_like(img)*255, img)
92
+ img = Image.fromarray(img)
93
+
94
+ return img
95
+
96
+ def measure_similarity(images, prompt, model, clip_preprocess, tokenizer, device):
97
+ with torch.no_grad():
98
+ img_batch = [clip_preprocess(i).unsqueeze(0) for i in images]
99
+ img_batch = torch.concatenate(img_batch).to(device)
100
+ image_features = model.encode_image(img_batch)
101
+
102
+ text = tokenizer([prompt]).to(device)
103
+ text_features = model.encode_text(text)
104
+
105
+ image_features /= image_features.norm(dim=-1, keepdim=True)
106
+ text_features /= text_features.norm(dim=-1, keepdim=True)
107
+
108
+ return (image_features @ text_features.T).mean(-1)
inverse_stable_diffusion.py ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+ from typing import Callable, List, Optional, Union, Tuple
3
+
4
+ import torch
5
+ from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
6
+
7
+ from diffusers.models import AutoencoderKL, UNet2DConditionModel
8
+ # from diffusers import StableDiffusionPipeline
9
+ from diffusers.pipelines.stable_diffusion.safety_checker import \
10
+ StableDiffusionSafetyChecker
11
+ from diffusers.schedulers import DDIMScheduler,PNDMScheduler, LMSDiscreteScheduler
12
+
13
+ from modified_stable_diffusion import ModifiedStableDiffusionPipeline
14
+ from torchvision.transforms import ToPILImage
15
+ import matplotlib.pyplot as plt
16
+
17
+
18
+
19
+ ### credit to: https://github.com/cccntu/efficient-prompt-to-prompt
20
+ def backward_ddim(x_t, alpha_t, alpha_tm1, eps_xt):
21
+ """ from noise to image"""
22
+ return (
23
+ alpha_tm1**0.5
24
+ * (
25
+ (alpha_t**-0.5 - alpha_tm1**-0.5) * x_t
26
+ + ((1 / alpha_tm1 - 1) ** 0.5 - (1 / alpha_t - 1) ** 0.5) * eps_xt
27
+ )
28
+ + x_t
29
+ )
30
+
31
+ def forward_ddim(x_t, alpha_t, alpha_tp1, eps_xt):
32
+ """ from image to noise, it's the same as backward_ddim"""
33
+ return backward_ddim(x_t, alpha_t, alpha_tp1, eps_xt)
34
+
35
+
36
+ class InversableStableDiffusionPipeline(ModifiedStableDiffusionPipeline):
37
+ def __init__(self,
38
+ vae,
39
+ text_encoder,
40
+ tokenizer,
41
+ unet,
42
+ scheduler,
43
+ safety_checker,
44
+ feature_extractor,
45
+ requires_safety_checker: bool = False,
46
+ ):
47
+
48
+ super(InversableStableDiffusionPipeline, self).__init__(vae,
49
+ text_encoder,
50
+ tokenizer,
51
+ unet,
52
+ scheduler,
53
+ safety_checker,
54
+ feature_extractor,
55
+ requires_safety_checker)
56
+
57
+ self.forward_diffusion = partial(self.backward_diffusion, reverse_process=True)
58
+ self.count = 0
59
+
60
+ def get_random_latents(self, latents=None, height=512, width=512, generator=None):
61
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
62
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
63
+
64
+ batch_size = 1
65
+ device = self._execution_device
66
+
67
+ num_channels_latents = self.unet.in_channels
68
+
69
+ latents = self.prepare_latents(
70
+ batch_size,
71
+ num_channels_latents,
72
+ height,
73
+ width,
74
+ self.text_encoder.dtype,
75
+ device,
76
+ generator,
77
+ latents,
78
+ )
79
+
80
+ return latents
81
+
82
+ @torch.inference_mode()
83
+ def get_text_embedding(self, prompt):
84
+ text_input_ids = self.tokenizer(
85
+ prompt,
86
+ padding="max_length",
87
+ truncation=True,
88
+ max_length=self.tokenizer.model_max_length,
89
+ return_tensors="pt",
90
+ ).input_ids
91
+ text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0]
92
+ return text_embeddings
93
+
94
+ @torch.inference_mode()
95
+ def get_image_latents(self, image, sample=True, rng_generator=None):
96
+ encoding_dist = self.vae.encode(image).latent_dist
97
+ if sample:
98
+ encoding = encoding_dist.sample(generator=rng_generator)
99
+ else:
100
+ encoding = encoding_dist.mode()
101
+ latents = encoding * 0.18215
102
+ return latents
103
+
104
+
105
+ @torch.inference_mode()
106
+ def backward_diffusion(
107
+ self,
108
+ use_old_emb_i=25,
109
+ text_embeddings=None,
110
+ old_text_embeddings=None,
111
+ new_text_embeddings=None,
112
+ latents: Optional[torch.FloatTensor] = None,
113
+ num_inference_steps: int = 50,
114
+ guidance_scale: float = 7.5,
115
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
116
+ callback_steps: Optional[int] = 1,
117
+ reverse_process: True = False,
118
+ **kwargs,
119
+ ):
120
+ """ Generate image from text prompt and latents
121
+ """
122
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
123
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
124
+ # corresponds to doing no classifier free guidance.
125
+ do_classifier_free_guidance = guidance_scale > 1.0
126
+ # set timesteps
127
+ self.scheduler.set_timesteps(num_inference_steps)
128
+ # Some schedulers like PNDM have timesteps as arrays
129
+ # It's more optimized to move all timesteps to correct device beforehand
130
+ timesteps_tensor = self.scheduler.timesteps.to(self.device)
131
+ # scale the initial noise by the standard deviation required by the scheduler
132
+ latents = latents * self.scheduler.init_noise_sigma
133
+
134
+ if old_text_embeddings is not None and new_text_embeddings is not None:
135
+ prompt_to_prompt = True
136
+ else:
137
+ prompt_to_prompt = False
138
+
139
+
140
+ for i, t in enumerate(self.progress_bar(timesteps_tensor if not reverse_process else reversed(timesteps_tensor))):
141
+ if prompt_to_prompt:
142
+ if i < use_old_emb_i:
143
+ text_embeddings = old_text_embeddings
144
+ else:
145
+ text_embeddings = new_text_embeddings
146
+
147
+ # expand the latents if we are doing classifier free guidance
148
+ latent_model_input = (
149
+ torch.cat([latents] * 2) if do_classifier_free_guidance else latents
150
+ )
151
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
152
+
153
+ # predict the noise residual
154
+ noise_pred = self.unet(
155
+ latent_model_input, t, encoder_hidden_states=text_embeddings
156
+ ).sample
157
+
158
+ # perform guidance
159
+ if do_classifier_free_guidance:
160
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
161
+ noise_pred = noise_pred_uncond + guidance_scale * (
162
+ noise_pred_text - noise_pred_uncond
163
+ )
164
+
165
+ prev_timestep = (
166
+ t
167
+ - self.scheduler.config.num_train_timesteps
168
+ // self.scheduler.num_inference_steps
169
+ )
170
+ # call the callback, if provided
171
+ if callback is not None and i % callback_steps == 0:
172
+ callback(i, t, latents)
173
+
174
+ # ddim
175
+ alpha_prod_t = self.scheduler.alphas_cumprod[t]
176
+ alpha_prod_t_prev = (
177
+ self.scheduler.alphas_cumprod[prev_timestep]
178
+ if prev_timestep >= 0
179
+ else self.scheduler.final_alpha_cumprod
180
+ )
181
+ if reverse_process:
182
+ alpha_prod_t, alpha_prod_t_prev = alpha_prod_t_prev, alpha_prod_t
183
+ latents = backward_ddim(
184
+ x_t=latents,
185
+ alpha_t=alpha_prod_t,
186
+ alpha_tm1=alpha_prod_t_prev,
187
+ eps_xt=noise_pred,
188
+ )
189
+ return latents
190
+
191
+
192
+ @torch.inference_mode()
193
+ def decode_image(self, latents: torch.FloatTensor, **kwargs):
194
+ scaled_latents = 1 / 0.18215 * latents
195
+ image = [
196
+ self.vae.decode(scaled_latents[i : i + 1]).sample for i in range(len(latents))
197
+ ]
198
+ image = torch.cat(image, dim=0)
199
+ return image
200
+
201
+ @torch.inference_mode()
202
+ def torch_to_numpy(self, image):
203
+ image = (image / 2 + 0.5).clamp(0, 1)
204
+ image = image.cpu().permute(0, 2, 3, 1).numpy()
205
+ return image
modified_stable_diffusion.py ADDED
@@ -0,0 +1,235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from typing import Callable, List, Optional, Union, Any, Dict
3
+ import copy
4
+ import numpy as np
5
+ import PIL
6
+
7
+ import torch
8
+ from diffusers import StableDiffusionPipeline, DDIMScheduler, UNet2DConditionModel
9
+ from diffusers.utils import logging, BaseOutput
10
+ from torchvision.transforms import ToPILImage
11
+
12
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
13
+
14
+
15
+
16
+
17
+
18
+ class ModifiedStableDiffusionPipelineOutput(BaseOutput):
19
+ images: Union[List[PIL.Image.Image], np.ndarray]
20
+ nsfw_content_detected: Optional[List[bool]]
21
+ init_latents: Optional[torch.FloatTensor]
22
+
23
+
24
+ class ModifiedStableDiffusionPipeline(StableDiffusionPipeline):
25
+ def __init__(self,
26
+ vae,
27
+ text_encoder,
28
+ tokenizer,
29
+ unet,
30
+ scheduler,
31
+ safety_checker,
32
+ feature_extractor,
33
+ requires_safety_checker: bool = False,
34
+ ):
35
+ super(ModifiedStableDiffusionPipeline, self).__init__(vae,
36
+ text_encoder,
37
+ tokenizer,
38
+ unet,
39
+ scheduler,
40
+ safety_checker,
41
+ feature_extractor,
42
+ requires_safety_checker)
43
+
44
+ @torch.no_grad()
45
+ def __call__(
46
+ self,
47
+ prompt: Union[str, List[str]],
48
+ height: Optional[int] = None,
49
+ width: Optional[int] = None,
50
+ num_inference_steps: int = 50,
51
+ guidance_scale: float = 7.5,
52
+ negative_prompt: Optional[Union[str, List[str]]] = None,
53
+ num_images_per_prompt: Optional[int] = 1,
54
+ eta: float = 0.0,
55
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
56
+ latents: Optional[torch.FloatTensor] = None,
57
+ output_type: Optional[str] = "pil",
58
+ return_dict: bool = True,
59
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
60
+ callback_steps: Optional[int] = 1,
61
+ watermarking_gamma: float = None,
62
+ watermarking_delta: float = None,
63
+ watermarking_mask: Optional[torch.BoolTensor] = None,
64
+ ):
65
+ r"""
66
+ Function invoked when calling the pipeline for generation.
67
+
68
+ Args:
69
+ prompt (`str` or `List[str]`):
70
+ The prompt or prompts to guide the image generation.
71
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
72
+ The height in pixels of the generated image.
73
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
74
+ The width in pixels of the generated image.
75
+ num_inference_steps (`int`, *optional*, defaults to 50):
76
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
77
+ expense of slower inference.
78
+ guidance_scale (`float`, *optional*, defaults to 7.5):
79
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
80
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
81
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
82
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
83
+ usually at the expense of lower image quality.
84
+ negative_prompt (`str` or `List[str]`, *optional*):
85
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
86
+ if `guidance_scale` is less than `1`).
87
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
88
+ The number of images to generate per prompt.
89
+ eta (`float`, *optional*, defaults to 0.0):
90
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
91
+ [`schedulers.DDIMScheduler`], will be ignored for others.
92
+ generator (`torch.Generator`, *optional*):
93
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
94
+ to make generation deterministic.
95
+ latents (`torch.FloatTensor`, *optional*):
96
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
97
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
98
+ tensor will ge generated by sampling using the supplied random `generator`.
99
+ output_type (`str`, *optional*, defaults to `"pil"`):
100
+ The output format of the generate image. Choose between
101
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
102
+ return_dict (`bool`, *optional*, defaults to `True`):
103
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
104
+ plain tuple.
105
+ callback (`Callable`, *optional*):
106
+ A function that will be called every `callback_steps` steps during inference. The function will be
107
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
108
+ callback_steps (`int`, *optional*, defaults to 1):
109
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
110
+ called at every step.
111
+
112
+ Returns:
113
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
114
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
115
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
116
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
117
+ (nsfw) content, according to the `safety_checker`.
118
+ """
119
+ # 0. Default height and width to unet
120
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
121
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
122
+ self.count = 0
123
+
124
+
125
+ # 1. Check inputs. Raise error if not correct
126
+ self.check_inputs(prompt, height, width, callback_steps)
127
+
128
+ # 2. Define call parameters
129
+ batch_size = 1 if isinstance(prompt, str) else len(prompt)
130
+ device = self._execution_device
131
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
132
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
133
+ # corresponds to doing no classifier free guidance.
134
+ do_classifier_free_guidance = guidance_scale > 1.0
135
+
136
+ # 3. Encode input prompt
137
+ text_embeddings = self._encode_prompt(
138
+ prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt
139
+ )
140
+
141
+ # 4. Prepare timesteps
142
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
143
+ timesteps = self.scheduler.timesteps
144
+
145
+ # 5. Prepare latent variables
146
+ num_channels_latents = self.unet.in_channels
147
+ latents = self.prepare_latents(
148
+ batch_size * num_images_per_prompt,
149
+ num_channels_latents,
150
+ height,
151
+ width,
152
+ text_embeddings.dtype,
153
+ device,
154
+ generator,
155
+ latents,
156
+ )
157
+
158
+ init_latents = copy.deepcopy(latents)
159
+
160
+ # watermarking mask
161
+ if watermarking_gamma is not None:
162
+ watermarking_mask = torch.rand(latents.shape, device=device) < watermarking_gamma
163
+
164
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
165
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
166
+
167
+ # 7. Denoising loop
168
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
169
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
170
+ for i, t in enumerate(timesteps):
171
+ # add watermark
172
+ if watermarking_mask is not None:
173
+ # latents[watermarking_mask] += watermarking_delta
174
+ latents[watermarking_mask] += watermarking_delta * torch.sign(latents[watermarking_mask])
175
+
176
+ # expand the latents if we are doing classifier free guidance
177
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
178
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
179
+
180
+ # predict the noise residual
181
+ noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
182
+
183
+ # perform guidance
184
+ if do_classifier_free_guidance:
185
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
186
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
187
+
188
+ # compute the previous noisy sample x_t -> x_t-1
189
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
190
+
191
+ # call the callback, if provided
192
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
193
+ progress_bar.update()
194
+ if callback is not None and i % callback_steps == 0:
195
+ callback(i, t, latents)
196
+
197
+ # 8. Post-processing
198
+ image = self.decode_latents(latents)
199
+ # 9. Run safety checker
200
+ image, has_nsfw_concept = self.run_safety_checker(image, device, text_embeddings.dtype)
201
+
202
+ # 10. Convert to PIL
203
+ if output_type == "pil":
204
+ image = self.numpy_to_pil(image)
205
+
206
+ if not return_dict:
207
+ return (image, has_nsfw_concept)
208
+
209
+ return ModifiedStableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept, init_latents=init_latents)
210
+
211
+
212
+ @torch.inference_mode()
213
+ def decode_image(self, latents: torch.FloatTensor, **kwargs):
214
+ scaled_latents = 1 / 0.18215 * latents
215
+ image = [
216
+ self.vae.decode(scaled_latents[i : i + 1]).sample for i in range(len(latents))
217
+ ]
218
+ image = torch.cat(image, dim=0)
219
+ return image
220
+
221
+ @torch.inference_mode()
222
+ def torch_to_numpy(self, image):
223
+ image = (image / 2 + 0.5).clamp(0, 1)
224
+ image = image.cpu().permute(0, 2, 3, 1).numpy()
225
+ return image
226
+
227
+ @torch.inference_mode()
228
+ def get_image_latents(self, image, sample=True, rng_generator=None):
229
+ encoding_dist = self.vae.encode(image).latent_dist
230
+ if sample:
231
+ encoding = encoding_dist.sample(generator=rng_generator)
232
+ else:
233
+ encoding = encoding_dist.mode()
234
+ latents = encoding * 0.18215
235
+ return latents
requirements.txt ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ albumentations
2
+ diffusers
3
+ einops
4
+ huggingface_hub
5
+ natsort
6
+ pillow
7
+ PyYAML
8
+ regex
9
+ requests
10
+ timm
11
+ torch
12
+ torchvision
13
+ tqdm
14
+ transformers
run_gaussian_shading.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from tqdm import tqdm
2
+ import torch
3
+ from transformers import CLIPModel, CLIPTokenizer
4
+ from inverse_stable_diffusion import InversableStableDiffusionPipeline
5
+ from diffusers import DPMSolverMultistepScheduler, DDIMScheduler
6
+ import os
7
+ import gradio as gr
8
+ from image_utils import *
9
+ from watermark import *
10
+
11
+
12
+ # Initialize the parameter:
13
+ model_path = 'stabilityai/stable-diffusion-2-1-base'
14
+ channel_copy = 1
15
+ hw_copy = 8
16
+ fpr = 0.000001
17
+ user_number = 1000000
18
+ guidance_scale = 7.5
19
+ num_inference_steps = 50
20
+ image_length = 512
21
+
22
+
23
+ # """ ---------------------- Initialization ---------------------- """
24
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
25
+ scheduler = DPMSolverMultistepScheduler.from_pretrained(model_path, subfolder='scheduler')
26
+ pipe = InversableStableDiffusionPipeline.from_pretrained(
27
+ model_path,
28
+ scheduler=scheduler,
29
+ torch_dtype=torch.float16,
30
+ revision='fp16',
31
+ )
32
+ pipe.safety_checker = None
33
+ pipe = pipe.to(device)
34
+
35
+ #a simple implement watermark
36
+ watermark = Gaussian_Shading(channel_copy, hw_copy, fpr, user_number)
37
+
38
+ # assume at the detection time, the original prompt is unknown
39
+ tester_prompt = ''
40
+ text_embeddings = pipe.get_text_embedding(tester_prompt)
41
+
42
+ #generate with watermark
43
+ def generate_with_watermark(seed, prompt, guidance_scale=7.5, num_inference_steps=50):
44
+ set_random_seed(seed)
45
+
46
+ init_latents_w, key, wk = watermark.create_watermark_and_return_w()
47
+ watermark_list = []
48
+ torch.save(key, 'key.pt')
49
+ if not os.path.exists('watermark.pt'):
50
+ torch.save(wk, 'watermark.pt')
51
+ else:
52
+ watermark_list = torch.load('watermark.pt')
53
+ if not isinstance(watermark_list, list):
54
+ watermark_list = [watermark_list]
55
+ watermark_list.append(wk)
56
+ torch.save(watermark_list, 'watermark.pt')
57
+
58
+ outputs = pipe(
59
+ prompt,
60
+ num_images_per_prompt=1,
61
+ guidance_scale=guidance_scale,
62
+ num_inference_steps=num_inference_steps,
63
+ height=image_length,
64
+ width=image_length,
65
+ latents=init_latents_w,
66
+ )
67
+ image_w = outputs.images[0]
68
+ # From original
69
+ outputs_original = pipe(
70
+ prompt,
71
+ num_images_per_prompt=1,
72
+ guidance_scale=guidance_scale,
73
+ num_inference_steps=num_inference_steps,
74
+ height=image_length,
75
+ width=image_length
76
+ )
77
+ image_original = outputs_original.images[0]
78
+
79
+ # save file, download and remove
80
+ image_path = 'output_image.png'
81
+ if os.path.exists(image_path):
82
+ os.remove(image_path)
83
+
84
+ image_w.save('output_image.png', format='PNG')
85
+ return image_original, image_w, 'output_image.png'
86
+
87
+ # reverse img
88
+ def reverse_watermark(image, *args, **kwargs):
89
+ image_attacked = image_distortion(image, *args, **kwargs)
90
+ image_w_distortion = transform_img(image_attacked).unsqueeze(0).to(text_embeddings.dtype).to(device)
91
+ image_latents_w = pipe.get_image_latents(image_w_distortion, sample=False)
92
+ reversed_latents_w = pipe.forward_diffusion(
93
+ latents=image_latents_w,
94
+ text_embeddings=text_embeddings,
95
+ guidance_scale=1,
96
+ num_inference_steps=50,
97
+ )
98
+ try:
99
+ bit, accuracy = watermark.eval_watermark(reversed_latents_w)
100
+ except FileNotFoundError:
101
+ raise gr.Error("Database is empty. Please generate Image first!", duration=8)
102
+ if accuracy > 0.7:
103
+ output = 'This Image have watermark'
104
+ else:
105
+ output = "This Image doesn't have watermark"
106
+ return output, bit, accuracy, image_attacked
watermark.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from scipy.stats import norm,truncnorm
3
+ from functools import reduce
4
+ from scipy.special import betainc
5
+ import numpy as np
6
+
7
+
8
+ class Gaussian_Shading:
9
+ def __init__(self, ch_factor, hw_factor, fpr, user_number):
10
+ self.ch = ch_factor
11
+ self.hw = hw_factor
12
+ self.key = None
13
+ self.watermark = None
14
+ self.latentlength = 4 * 64 * 64
15
+ self.marklength = self.latentlength//(self.ch * self.hw * self.hw)
16
+
17
+ self.threshold = 1 if self.hw == 1 and self.ch == 1 else self.ch * self.hw * self.hw // 2
18
+ self.tp_onebit_count = 0
19
+ self.tp_bits_count = 0
20
+ self.tau_onebit = None
21
+ self.tau_bits = None
22
+
23
+ for i in range(self.marklength):
24
+ fpr_onebit = betainc(i+1, self.marklength-i, 0.5)
25
+ fpr_bits = betainc(i+1, self.marklength-i, 0.5) * user_number
26
+ if fpr_onebit <= fpr and self.tau_onebit is None:
27
+ self.tau_onebit = i / self.marklength
28
+ if fpr_bits <= fpr and self.tau_bits is None:
29
+ self.tau_bits = i / self.marklength
30
+
31
+ def truncSampling(self, message):
32
+ z = np.zeros(self.latentlength)
33
+ denominator = 2.0
34
+ ppf = [norm.ppf(j / denominator) for j in range(int(denominator) + 1)]
35
+ for i in range(self.latentlength):
36
+ dec_mes = reduce(lambda a, b: 2 * a + b, message[i : i + 1])
37
+ dec_mes = int(dec_mes)
38
+ z[i] = truncnorm.rvs(ppf[dec_mes], ppf[dec_mes + 1])
39
+ z = torch.from_numpy(z).reshape(1, 4, 64, 64).half()
40
+ return z.cuda()
41
+
42
+ def create_watermark_and_return_w(self):
43
+ rng_state = torch.get_rng_state()
44
+ torch.manual_seed(42)
45
+ self.key = torch.randint(0, 2, [1, 4, 64, 64]).cuda()
46
+ torch.set_rng_state(rng_state)
47
+
48
+ self.watermark = torch.randint(0, 2, [1, 4 // self.ch, 64 // self.hw, 64 // self.hw]).cuda()
49
+ sd = self.watermark.repeat(1,self.ch,self.hw,self.hw)
50
+ m = ((sd + self.key) % 2).flatten().cpu().numpy()
51
+ w = self.truncSampling(m)
52
+ return w, self.key, self.watermark
53
+
54
+ def diffusion_inverse(self,watermark_sd):
55
+ ch_stride = 4 // self.ch
56
+ hw_stride = 64 // self.hw
57
+ ch_list = [ch_stride] * self.ch
58
+ hw_list = [hw_stride] * self.hw
59
+ split_dim1 = torch.cat(torch.split(watermark_sd, tuple(ch_list), dim=1), dim=0)
60
+ split_dim2 = torch.cat(torch.split(split_dim1, tuple(hw_list), dim=2), dim=0)
61
+ split_dim3 = torch.cat(torch.split(split_dim2, tuple(hw_list), dim=3), dim=0)
62
+ vote = torch.sum(split_dim3, dim=0).clone()
63
+ vote[vote <= self.threshold] = 0
64
+ vote[vote > self.threshold] = 1
65
+ return vote
66
+
67
+ def sequence_binary_watermark(self, watermark):
68
+ ls = watermark.view(-1).tolist()
69
+ sequence = ''.join(str(i) for i in ls)
70
+ return sequence
71
+
72
+ def eval_watermark(self, reversed_m):
73
+ key = torch.load('key.pt')
74
+ reversed_m = (reversed_m > 0).int()
75
+ # reversed_sd = (reversed_m + self.key) % 2
76
+ reversed_sd = (reversed_m + key) % 2
77
+ reversed_watermark = self.diffusion_inverse(reversed_sd)
78
+ print(f"The extracted watermark is {self.sequence_binary_watermark(reversed_watermark)}")
79
+
80
+ watermark = torch.load('watermark.pt')
81
+ ls_accurate = []
82
+ for i in watermark:
83
+ ls_accurate.append((reversed_watermark == i).float().mean().item())
84
+
85
+ correct = max(ls_accurate)
86
+ if correct >= self.tau_onebit:
87
+ self.tp_onebit_count = self.tp_onebit_count+1
88
+ if correct >= self.tau_bits:
89
+ self.tp_bits_count = self.tp_bits_count + 1
90
+ return self.sequence_binary_watermark(reversed_watermark), correct
91
+
92
+ def get_tpr(self):
93
+ return self.tp_onebit_count, self.tp_bits_count
94
+
95
+