from tqdm import tqdm import torch from transformers import CLIPModel, CLIPTokenizer from inverse_stable_diffusion import InversableStableDiffusionPipeline from diffusers import DPMSolverMultistepScheduler, DDIMScheduler import os import gradio as gr from image_utils import * from watermark import * # Initialize the parameter: model_path = 'stabilityai/stable-diffusion-2-1-base' channel_copy = 1 hw_copy = 8 fpr = 0.000001 user_number = 1000000 guidance_scale = 7.5 num_inference_steps = 50 image_length = 512 # """ ---------------------- Initialization ---------------------- """ device = 'cuda' if torch.cuda.is_available() else 'cpu' scheduler = DPMSolverMultistepScheduler.from_pretrained(model_path, subfolder='scheduler') pipe = InversableStableDiffusionPipeline.from_pretrained( model_path, scheduler=scheduler, torch_dtype=torch.float16, revision='fp16', ) pipe.safety_checker = None pipe = pipe.to(device) #a simple implement watermark watermark = Gaussian_Shading(channel_copy, hw_copy, fpr, user_number) # assume at the detection time, the original prompt is unknown tester_prompt = '' text_embeddings = pipe.get_text_embedding(tester_prompt) #generate with watermark def generate_with_watermark(seed, prompt, guidance_scale=7.5, num_inference_steps=50): set_random_seed(seed) init_latents_w, key, wk = watermark.create_watermark_and_return_w() watermark_list = [] torch.save(key, 'key.pt') if not os.path.exists('watermark.pt'): torch.save(wk, 'watermark.pt') else: watermark_list = torch.load('watermark.pt') if not isinstance(watermark_list, list): watermark_list = [watermark_list] watermark_list.append(wk) torch.save(watermark_list, 'watermark.pt') outputs = pipe( prompt, num_images_per_prompt=1, guidance_scale=guidance_scale, num_inference_steps=num_inference_steps, height=image_length, width=image_length, latents=init_latents_w, ) image_w = outputs.images[0] # From original outputs_original = pipe( prompt, num_images_per_prompt=1, guidance_scale=guidance_scale, num_inference_steps=num_inference_steps, height=image_length, width=image_length ) image_original = outputs_original.images[0] # save file, download and remove image_path = 'output_image.png' if os.path.exists(image_path): os.remove(image_path) image_w.save('output_image.png', format='PNG') return image_original, image_w, 'output_image.png' # reverse img def reverse_watermark(image, *args, **kwargs): image_attacked = image_distortion(image, *args, **kwargs) image_w_distortion = transform_img(image_attacked).unsqueeze(0).to(text_embeddings.dtype).to(device) image_latents_w = pipe.get_image_latents(image_w_distortion, sample=False) reversed_latents_w = pipe.forward_diffusion( latents=image_latents_w, text_embeddings=text_embeddings, guidance_scale=1, num_inference_steps=50, ) try: bit, accuracy = watermark.eval_watermark(reversed_latents_w) except FileNotFoundError: raise gr.Error("Database is empty. Please generate Image first!", duration=8) if accuracy > 0.7: output = 'This Image have watermark' else: output = "This Image doesn't have watermark" return output, bit, accuracy, image_attacked