from diffusers import StableDiffusionPipeline, UNet2DConditionModel import torch import copy import time ORIGINAL_CHECKPOINT_ID = "CompVis/stable-diffusion-v1-4" COMPRESSED_UNET_ID = "nota-ai/bk-sdm-small" DEVICE='cuda' # DEVICE='cpu' class SdmCompressionDemo: def __init__(self, device) -> None: self.device = device self.torch_dtype = torch.float16 if 'cuda' in self.device else torch.float32 self.pipe_original = StableDiffusionPipeline.from_pretrained(ORIGINAL_CHECKPOINT_ID, torch_dtype=self.torch_dtype) self.pipe_compressed = copy.deepcopy(self.pipe_original) self.pipe_compressed.unet = UNet2DConditionModel.from_pretrained(COMPRESSED_UNET_ID, subfolder="unet", torch_dtype=self.torch_dtype) if 'cuda' in self.device: self.pipe_original = self.pipe_original.to(self.device) self.pipe_compressed = self.pipe_compressed.to(self.device) self.device_msg = 'Tested on GPU.' if 'cuda' in self.device else 'Tested on CPU.' def _count_params(self, model): return sum(p.numel() for p in model.parameters()) def get_sdm_params(self, pipe): params_unet = self._count_params(pipe.unet) params_text_enc = self._count_params(pipe.text_encoder) params_image_dec = self._count_params(pipe.vae.decoder) params_total = params_unet + params_text_enc + params_image_dec return f"Total {(params_total/1e6):.1f}M (U-Net {(params_unet/1e6):.1f}M)" def generate_image(self, pipe, text, negative, guidance_scale, steps, seed): generator = torch.Generator(self.device).manual_seed(seed) start = time.time() result = pipe(text, negative_prompt = negative, generator = generator, guidance_scale = guidance_scale, num_inference_steps = steps) test_time = time.time() - start image = result.images[0] nsfw_detected = result.nsfw_content_detected[0] print(f"text {text} | Processed time: {test_time} sec | nsfw_flag {nsfw_detected}") print(f"negative {negative} | guidance_scale {guidance_scale} | steps {steps} ") print("===========") return image, nsfw_detected, format(test_time, ".2f") def error_msg(self, nsfw_detected): if nsfw_detected: return self.device_msg+" Black images are returned when potential harmful content is detected. Try different prompts or seeds." else: return self.device_msg def check_invalid_input(self, text): if text == '': return True def infer_original_model(self, text, negative, guidance_scale, steps, seed): print(f"=== ORIG model --- seed {seed}") if self.check_invalid_input(text): print('hello') return None, "Please enter the input prompt.", None output_image, nsfw_detected, test_time = self.generate_image(self.pipe_original, text, negative, guidance_scale, steps, seed) return output_image, self.error_msg(nsfw_detected), test_time def infer_compressed_model(self, text, negative, guidance_scale, steps, seed): print(f"=== COMPRESSED model --- seed {seed}") if self.check_invalid_input(text): return None, "Please enter the input prompt.", None output_image, nsfw_detected, test_time = self.generate_image(self.pipe_compressed, text, negative, guidance_scale, steps, seed) return output_image, self.error_msg(nsfw_detected), test_time def get_example_list(self): return [ 'a tropical bird sitting on a branch of a tree', 'many decorative umbrellas hanging up', 'an orange cat staring off with pretty eyes', 'beautiful woman face with fancy makeup', 'a decorated living room with a stylish feel', 'a black vase holding a bouquet of roses', 'very elegant bedroom featuring natural wood', 'buffet-style food including cake and cheese', 'a tall castle sitting under a cloudy sky', 'closeup of a brown bear sitting in a grassy area', 'a large basket with many fresh vegetables', 'house being built with lots of wood', 'a close up of a pizza with several toppings', 'a golden vase with many different flows', 'a statue of a lion face attached to brick wall', 'something that looks particularly interesting', 'table filled with a variety of different dishes', 'a cinematic view of a large snowy peak', 'a grand city in the year 2100, hyper realistic', 'a blue eyed baby girl looking at the camera', ]