Gaussian-Shading-watermark / run_gaussian_shading.py
TruongScotl's picture
Upload 7 files
f961e67 verified
from tqdm import tqdm
import torch
from transformers import CLIPModel, CLIPTokenizer
from inverse_stable_diffusion import InversableStableDiffusionPipeline
from diffusers import DPMSolverMultistepScheduler, DDIMScheduler
import os
import gradio as gr
from image_utils import *
from watermark import *
# Initialize the parameter:
model_path = 'stabilityai/stable-diffusion-2-1-base'
channel_copy = 1
hw_copy = 8
fpr = 0.000001
user_number = 1000000
guidance_scale = 7.5
num_inference_steps = 50
image_length = 512
# """ ---------------------- Initialization ---------------------- """
device = 'cuda' if torch.cuda.is_available() else 'cpu'
scheduler = DPMSolverMultistepScheduler.from_pretrained(model_path, subfolder='scheduler')
pipe = InversableStableDiffusionPipeline.from_pretrained(
model_path,
scheduler=scheduler,
torch_dtype=torch.float16,
revision='fp16',
)
pipe.safety_checker = None
pipe = pipe.to(device)
#a simple implement watermark
watermark = Gaussian_Shading(channel_copy, hw_copy, fpr, user_number)
# assume at the detection time, the original prompt is unknown
tester_prompt = ''
text_embeddings = pipe.get_text_embedding(tester_prompt)
#generate with watermark
def generate_with_watermark(seed, prompt, guidance_scale=7.5, num_inference_steps=50):
set_random_seed(seed)
init_latents_w, key, wk = watermark.create_watermark_and_return_w()
watermark_list = []
torch.save(key, 'key.pt')
if not os.path.exists('watermark.pt'):
torch.save(wk, 'watermark.pt')
else:
watermark_list = torch.load('watermark.pt')
if not isinstance(watermark_list, list):
watermark_list = [watermark_list]
watermark_list.append(wk)
torch.save(watermark_list, 'watermark.pt')
outputs = pipe(
prompt,
num_images_per_prompt=1,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
height=image_length,
width=image_length,
latents=init_latents_w,
)
image_w = outputs.images[0]
# From original
outputs_original = pipe(
prompt,
num_images_per_prompt=1,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
height=image_length,
width=image_length
)
image_original = outputs_original.images[0]
# save file, download and remove
image_path = 'output_image.png'
if os.path.exists(image_path):
os.remove(image_path)
image_w.save('output_image.png', format='PNG')
return image_original, image_w, 'output_image.png'
# reverse img
def reverse_watermark(image, *args, **kwargs):
image_attacked = image_distortion(image, *args, **kwargs)
image_w_distortion = transform_img(image_attacked).unsqueeze(0).to(text_embeddings.dtype).to(device)
image_latents_w = pipe.get_image_latents(image_w_distortion, sample=False)
reversed_latents_w = pipe.forward_diffusion(
latents=image_latents_w,
text_embeddings=text_embeddings,
guidance_scale=1,
num_inference_steps=50,
)
try:
bit, accuracy = watermark.eval_watermark(reversed_latents_w)
except FileNotFoundError:
raise gr.Error("Database is empty. Please generate Image first!", duration=8)
if accuracy > 0.7:
output = 'This Image have watermark'
else:
output = "This Image doesn't have watermark"
return output, bit, accuracy, image_attacked