File size: 3,469 Bytes
f961e67
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
from tqdm import tqdm
import torch
from transformers import CLIPModel, CLIPTokenizer
from inverse_stable_diffusion import InversableStableDiffusionPipeline
from diffusers import DPMSolverMultistepScheduler, DDIMScheduler
import os
import gradio as gr
from image_utils import *
from watermark import *


# Initialize the parameter:
model_path = 'stabilityai/stable-diffusion-2-1-base'
channel_copy = 1
hw_copy = 8
fpr = 0.000001
user_number = 1000000
guidance_scale = 7.5
num_inference_steps = 50
image_length = 512


# """ ---------------------- Initialization ---------------------- """
device = 'cuda' if torch.cuda.is_available() else 'cpu'
scheduler = DPMSolverMultistepScheduler.from_pretrained(model_path, subfolder='scheduler')
pipe = InversableStableDiffusionPipeline.from_pretrained(
        model_path,
        scheduler=scheduler,
        torch_dtype=torch.float16,
        revision='fp16',
)
pipe.safety_checker = None
pipe = pipe.to(device)

#a simple implement watermark
watermark = Gaussian_Shading(channel_copy, hw_copy, fpr, user_number)

# assume at the detection time, the original prompt is unknown
tester_prompt = ''
text_embeddings = pipe.get_text_embedding(tester_prompt)

#generate with watermark
def generate_with_watermark(seed, prompt, guidance_scale=7.5, num_inference_steps=50):
    set_random_seed(seed)

    init_latents_w, key, wk = watermark.create_watermark_and_return_w()
    watermark_list = []
    torch.save(key, 'key.pt')
    if not os.path.exists('watermark.pt'):
        torch.save(wk, 'watermark.pt')
    else:
        watermark_list = torch.load('watermark.pt')
    if not isinstance(watermark_list, list):
        watermark_list = [watermark_list]
    watermark_list.append(wk)
    torch.save(watermark_list, 'watermark.pt')

    outputs = pipe(
        prompt,
        num_images_per_prompt=1,
        guidance_scale=guidance_scale,
        num_inference_steps=num_inference_steps,
        height=image_length,
        width=image_length,
        latents=init_latents_w,
    )
    image_w = outputs.images[0]
    # From original
    outputs_original = pipe(
        prompt,
        num_images_per_prompt=1,
        guidance_scale=guidance_scale,
        num_inference_steps=num_inference_steps,
        height=image_length,
        width=image_length
    )
    image_original = outputs_original.images[0]
    
    # save file, download and remove
    image_path = 'output_image.png'
    if os.path.exists(image_path):
        os.remove(image_path)

    image_w.save('output_image.png', format='PNG')
    return image_original, image_w, 'output_image.png'

# reverse img
def reverse_watermark(image, *args, **kwargs):
    image_attacked = image_distortion(image, *args, **kwargs)
    image_w_distortion = transform_img(image_attacked).unsqueeze(0).to(text_embeddings.dtype).to(device)
    image_latents_w = pipe.get_image_latents(image_w_distortion, sample=False)
    reversed_latents_w = pipe.forward_diffusion(
        latents=image_latents_w,
        text_embeddings=text_embeddings,
        guidance_scale=1,
        num_inference_steps=50,
    )
    try:
        bit, accuracy =  watermark.eval_watermark(reversed_latents_w)
    except FileNotFoundError:
        raise gr.Error("Database is empty. Please generate Image first!", duration=8)
    if accuracy > 0.7:
        output = 'This Image have watermark'
    else:
        output = "This Image doesn't have watermark"
    return output, bit, accuracy, image_attacked