Spaces:
Running
Running
| # This file is adapted from gradio_*.py in https://github.com/lllyasviel/ControlNet/tree/f4748e3630d8141d7765e2bd9b1e348f47847707 | |
| # The original license file is LICENSE.ControlNet in this repo. | |
| from __future__ import annotations | |
| import pathlib | |
| import sys | |
| import cv2 | |
| import numpy as np | |
| import PIL.Image | |
| import torch | |
| from diffusers import (ControlNetModel, DiffusionPipeline, | |
| StableDiffusionControlNetPipeline, | |
| UniPCMultistepScheduler) | |
| repo_dir = pathlib.Path(__file__).parent | |
| submodule_dir = repo_dir / 'ControlNet' | |
| sys.path.append(submodule_dir.as_posix()) | |
| from annotator.canny import apply_canny | |
| from annotator.hed import apply_hed, nms | |
| from annotator.midas import apply_midas | |
| from annotator.mlsd import apply_mlsd | |
| from annotator.openpose import apply_openpose | |
| from annotator.uniformer import apply_uniformer | |
| from annotator.util import HWC3, resize_image | |
| from share import * | |
| CONTROLNET_MODEL_IDS = { | |
| 'canny': 'lllyasviel/sd-controlnet-canny', | |
| 'hough': 'lllyasviel/sd-controlnet-mlsd', | |
| 'hed': 'lllyasviel/sd-controlnet-hed', | |
| 'scribble': 'lllyasviel/sd-controlnet-scribble', | |
| 'pose': 'lllyasviel/sd-controlnet-openpose', | |
| 'seg': 'lllyasviel/sd-controlnet-seg', | |
| 'depth': 'lllyasviel/sd-controlnet-depth', | |
| 'normal': 'lllyasviel/sd-controlnet-normal', | |
| } | |
| class Model: | |
| def __init__(self, | |
| base_model_id: str = 'runwayml/stable-diffusion-v1-5', | |
| task_name: str = 'canny'): | |
| self.base_model_id = '' | |
| self.task_name = '' | |
| self.pipe = self.load_pipe(base_model_id, task_name) | |
| def load_pipe(self, base_model_id: str, task_name) -> DiffusionPipeline: | |
| if base_model_id == self.base_model_id and task_name == self.task_name: | |
| return self.pipe | |
| model_id = CONTROLNET_MODEL_IDS[task_name] | |
| controlnet = ControlNetModel.from_pretrained(model_id, | |
| torch_dtype=torch.float16) | |
| pipe = StableDiffusionControlNetPipeline.from_pretrained( | |
| base_model_id, | |
| safety_checker=None, | |
| controlnet=controlnet, | |
| torch_dtype=torch.float16) | |
| pipe.scheduler = UniPCMultistepScheduler.from_config( | |
| pipe.scheduler.config) | |
| pipe.enable_xformers_memory_efficient_attention() | |
| pipe.enable_model_cpu_offload() | |
| self.base_model_id = base_model_id | |
| self.task_name = task_name | |
| return pipe | |
| def set_base_model(self, base_model_id: str) -> str: | |
| self.pipe = self.load_pipe(base_model_id, self.task_name) | |
| return self.base_model_id | |
| def load_controlnet_weight(self, task_name: str) -> None: | |
| if task_name == self.task_name: | |
| return | |
| model_id = CONTROLNET_MODEL_IDS[task_name] | |
| controlnet = ControlNetModel.from_pretrained(model_id, | |
| torch_dtype=torch.float16) | |
| from accelerate import cpu_offload_with_hook | |
| cpu_offload_with_hook(controlnet, torch.device('cuda:0')) | |
| self.pipe.controlnet = controlnet | |
| self.task_name = task_name | |
| def get_prompt(self, prompt: str, additional_prompt: str) -> str: | |
| if not prompt: | |
| prompt = additional_prompt | |
| else: | |
| prompt = f'{prompt}, {additional_prompt}' | |
| return prompt | |
| def run_pipe( | |
| self, | |
| prompt: str, | |
| negative_prompt: str, | |
| control_image: PIL.Image.Image, | |
| num_images: int, | |
| num_steps: int, | |
| guidance_scale: float, | |
| seed: int, | |
| ): | |
| generator = torch.Generator().manual_seed(seed) | |
| return self.pipe(prompt=prompt, | |
| negative_prompt=negative_prompt, | |
| guidance_scale=guidance_scale, | |
| num_images_per_prompt=num_images, | |
| num_inference_steps=num_steps, | |
| generator=generator, | |
| image=control_image) | |
| def process( | |
| self, | |
| task_name: str, | |
| prompt: str, | |
| additional_prompt: str, | |
| negative_prompt: str, | |
| control_image: PIL.Image.Image, | |
| vis_control_image: PIL.Image.Image, | |
| num_samples: int, | |
| num_steps: int, | |
| guidance_scale: float, | |
| seed: int, | |
| ): | |
| self.load_controlnet_weight(task_name) | |
| results = self.run_pipe( | |
| prompt=self.get_prompt(prompt, additional_prompt), | |
| negative_prompt=negative_prompt, | |
| control_image=control_image, | |
| num_images=num_samples, | |
| num_steps=num_steps, | |
| guidance_scale=guidance_scale, | |
| seed=seed, | |
| ) | |
| return [vis_control_image] + results.images | |
| def preprocess_canny( | |
| self, | |
| input_image: np.ndarray, | |
| image_resolution: int, | |
| low_threshold: int, | |
| high_threshold: int, | |
| ) -> tuple[PIL.Image.Image, PIL.Image.Image]: | |
| image = resize_image(HWC3(input_image), image_resolution) | |
| control_image = apply_canny(image, low_threshold, high_threshold) | |
| control_image = HWC3(control_image) | |
| vis_control_image = 255 - control_image | |
| return PIL.Image.fromarray(control_image), PIL.Image.fromarray( | |
| vis_control_image) | |
| def process_canny( | |
| self, | |
| input_image: np.ndarray, | |
| prompt: str, | |
| additional_prompt: str, | |
| negative_prompt: str, | |
| num_samples: int, | |
| image_resolution: int, | |
| num_steps: int, | |
| guidance_scale: float, | |
| seed: int, | |
| low_threshold: int, | |
| high_threshold: int, | |
| ) -> list[PIL.Image.Image]: | |
| control_image, vis_control_image = self.preprocess_canny( | |
| input_image=input_image, | |
| image_resolution=image_resolution, | |
| low_threshold=low_threshold, | |
| high_threshold=high_threshold, | |
| ) | |
| return self.process( | |
| task_name='canny', | |
| prompt=prompt, | |
| additional_prompt=additional_prompt, | |
| negative_prompt=negative_prompt, | |
| control_image=control_image, | |
| vis_control_image=vis_control_image, | |
| num_samples=num_samples, | |
| num_steps=num_steps, | |
| guidance_scale=guidance_scale, | |
| seed=seed, | |
| ) | |
| def preprocess_hough( | |
| self, | |
| input_image: np.ndarray, | |
| image_resolution: int, | |
| detect_resolution: int, | |
| value_threshold: float, | |
| distance_threshold: float, | |
| ) -> tuple[PIL.Image.Image, PIL.Image.Image]: | |
| input_image = HWC3(input_image) | |
| control_image = apply_mlsd( | |
| resize_image(input_image, detect_resolution), value_threshold, | |
| distance_threshold) | |
| control_image = HWC3(control_image) | |
| image = resize_image(input_image, image_resolution) | |
| H, W = image.shape[:2] | |
| control_image = cv2.resize(control_image, (W, H), | |
| interpolation=cv2.INTER_NEAREST) | |
| vis_control_image = 255 - cv2.dilate( | |
| control_image, np.ones(shape=(3, 3), dtype=np.uint8), iterations=1) | |
| return PIL.Image.fromarray(control_image), PIL.Image.fromarray( | |
| vis_control_image) | |
| def process_hough( | |
| self, | |
| input_image: np.ndarray, | |
| prompt: str, | |
| additional_prompt: str, | |
| negative_prompt: str, | |
| num_samples: int, | |
| image_resolution: int, | |
| detect_resolution: int, | |
| num_steps: int, | |
| guidance_scale: float, | |
| seed: int, | |
| value_threshold: float, | |
| distance_threshold: float, | |
| ) -> list[PIL.Image.Image]: | |
| control_image, vis_control_image = self.preprocess_hough( | |
| input_image=input_image, | |
| image_resolution=image_resolution, | |
| detect_resolution=detect_resolution, | |
| value_threshold=value_threshold, | |
| distance_threshold=distance_threshold, | |
| ) | |
| return self.process( | |
| task_name='hough', | |
| prompt=prompt, | |
| additional_prompt=additional_prompt, | |
| negative_prompt=negative_prompt, | |
| control_image=control_image, | |
| vis_control_image=vis_control_image, | |
| num_samples=num_samples, | |
| num_steps=num_steps, | |
| guidance_scale=guidance_scale, | |
| seed=seed, | |
| ) | |
| def preprocess_hed( | |
| self, | |
| input_image: np.ndarray, | |
| image_resolution: int, | |
| detect_resolution: int, | |
| ) -> tuple[PIL.Image.Image, PIL.Image.Image]: | |
| input_image = HWC3(input_image) | |
| control_image = apply_hed(resize_image(input_image, detect_resolution)) | |
| control_image = HWC3(control_image) | |
| image = resize_image(input_image, image_resolution) | |
| H, W = image.shape[:2] | |
| control_image = cv2.resize(control_image, (W, H), | |
| interpolation=cv2.INTER_LINEAR) | |
| return PIL.Image.fromarray(control_image), PIL.Image.fromarray( | |
| control_image) | |
| def process_hed( | |
| self, | |
| input_image: np.ndarray, | |
| prompt: str, | |
| additional_prompt: str, | |
| negative_prompt: str, | |
| num_samples: int, | |
| image_resolution: int, | |
| detect_resolution: int, | |
| num_steps: int, | |
| guidance_scale: float, | |
| seed: int, | |
| ) -> list[PIL.Image.Image]: | |
| control_image, vis_control_image = self.preprocess_hed( | |
| input_image=input_image, | |
| image_resolution=image_resolution, | |
| detect_resolution=detect_resolution, | |
| ) | |
| return self.process( | |
| task_name='hed', | |
| prompt=prompt, | |
| additional_prompt=additional_prompt, | |
| negative_prompt=negative_prompt, | |
| control_image=control_image, | |
| vis_control_image=vis_control_image, | |
| num_samples=num_samples, | |
| num_steps=num_steps, | |
| guidance_scale=guidance_scale, | |
| seed=seed, | |
| ) | |
| def preprocess_scribble( | |
| self, | |
| input_image: np.ndarray, | |
| image_resolution: int, | |
| ) -> tuple[PIL.Image.Image, PIL.Image.Image]: | |
| image = resize_image(HWC3(input_image), image_resolution) | |
| control_image = np.zeros_like(image, dtype=np.uint8) | |
| control_image[np.min(image, axis=2) < 127] = 255 | |
| vis_control_image = 255 - control_image | |
| return PIL.Image.fromarray(control_image), PIL.Image.fromarray( | |
| vis_control_image) | |
| def process_scribble( | |
| self, | |
| input_image: np.ndarray, | |
| prompt: str, | |
| additional_prompt: str, | |
| negative_prompt: str, | |
| num_samples: int, | |
| image_resolution: int, | |
| num_steps: int, | |
| guidance_scale: float, | |
| seed: int, | |
| ) -> list[PIL.Image.Image]: | |
| control_image, vis_control_image = self.preprocess_scribble( | |
| input_image=input_image, | |
| image_resolution=image_resolution, | |
| ) | |
| return self.process( | |
| task_name='scribble', | |
| prompt=prompt, | |
| additional_prompt=additional_prompt, | |
| negative_prompt=negative_prompt, | |
| control_image=control_image, | |
| vis_control_image=vis_control_image, | |
| num_samples=num_samples, | |
| num_steps=num_steps, | |
| guidance_scale=guidance_scale, | |
| seed=seed, | |
| ) | |
| def preprocess_scribble_interactive( | |
| self, | |
| input_image: np.ndarray, | |
| image_resolution: int, | |
| ) -> tuple[PIL.Image.Image, PIL.Image.Image]: | |
| image = resize_image(HWC3(input_image['mask'][:, :, 0]), | |
| image_resolution) | |
| control_image = np.zeros_like(image, dtype=np.uint8) | |
| control_image[np.min(image, axis=2) > 127] = 255 | |
| vis_control_image = 255 - control_image | |
| return PIL.Image.fromarray(control_image), PIL.Image.fromarray( | |
| vis_control_image) | |
| def process_scribble_interactive( | |
| self, | |
| input_image: np.ndarray, | |
| prompt: str, | |
| additional_prompt: str, | |
| negative_prompt: str, | |
| num_samples: int, | |
| image_resolution: int, | |
| num_steps: int, | |
| guidance_scale: float, | |
| seed: int, | |
| ) -> list[PIL.Image.Image]: | |
| control_image, vis_control_image = self.preprocess_scribble_interactive( | |
| input_image=input_image, | |
| image_resolution=image_resolution, | |
| ) | |
| return self.process( | |
| task_name='scribble', | |
| prompt=prompt, | |
| additional_prompt=additional_prompt, | |
| negative_prompt=negative_prompt, | |
| control_image=control_image, | |
| vis_control_image=vis_control_image, | |
| num_samples=num_samples, | |
| num_steps=num_steps, | |
| guidance_scale=guidance_scale, | |
| seed=seed, | |
| ) | |
| def preprocess_fake_scribble( | |
| self, | |
| input_image: np.ndarray, | |
| image_resolution: int, | |
| detect_resolution: int, | |
| ) -> tuple[PIL.Image.Image, PIL.Image.Image]: | |
| input_image = HWC3(input_image) | |
| control_image = apply_hed(resize_image(input_image, detect_resolution)) | |
| control_image = HWC3(control_image) | |
| image = resize_image(input_image, image_resolution) | |
| H, W = image.shape[:2] | |
| control_image = cv2.resize(control_image, (W, H), | |
| interpolation=cv2.INTER_LINEAR) | |
| control_image = nms(control_image, 127, 3.0) | |
| control_image = cv2.GaussianBlur(control_image, (0, 0), 3.0) | |
| control_image[control_image > 4] = 255 | |
| control_image[control_image < 255] = 0 | |
| vis_control_image = 255 - control_image | |
| return PIL.Image.fromarray(control_image), PIL.Image.fromarray( | |
| vis_control_image) | |
| def process_fake_scribble( | |
| self, | |
| input_image: np.ndarray, | |
| prompt: str, | |
| additional_prompt: str, | |
| negative_prompt: str, | |
| num_samples: int, | |
| image_resolution: int, | |
| detect_resolution: int, | |
| num_steps: int, | |
| guidance_scale: float, | |
| seed: int, | |
| ) -> list[PIL.Image.Image]: | |
| control_image, vis_control_image = self.preprocess_fake_scribble( | |
| input_image=input_image, | |
| image_resolution=image_resolution, | |
| detect_resolution=detect_resolution, | |
| ) | |
| return self.process( | |
| task_name='scribble', | |
| prompt=prompt, | |
| additional_prompt=additional_prompt, | |
| negative_prompt=negative_prompt, | |
| control_image=control_image, | |
| vis_control_image=vis_control_image, | |
| num_samples=num_samples, | |
| num_steps=num_steps, | |
| guidance_scale=guidance_scale, | |
| seed=seed, | |
| ) | |
| def preprocess_pose( | |
| self, | |
| input_image: np.ndarray, | |
| image_resolution: int, | |
| detect_resolution: int, | |
| ) -> tuple[PIL.Image.Image, PIL.Image.Image]: | |
| input_image = HWC3(input_image) | |
| control_image, _ = apply_openpose( | |
| resize_image(input_image, detect_resolution)) | |
| control_image = HWC3(control_image) | |
| image = resize_image(input_image, image_resolution) | |
| H, W = image.shape[:2] | |
| control_image = cv2.resize(control_image, (W, H), | |
| interpolation=cv2.INTER_NEAREST) | |
| return PIL.Image.fromarray(control_image), PIL.Image.fromarray( | |
| control_image) | |
| def process_pose( | |
| self, | |
| input_image: np.ndarray, | |
| prompt: str, | |
| additional_prompt: str, | |
| negative_prompt: str, | |
| num_samples: int, | |
| image_resolution: int, | |
| detect_resolution: int, | |
| num_steps: int, | |
| guidance_scale: float, | |
| seed: int, | |
| ) -> list[PIL.Image.Image]: | |
| control_image, vis_control_image = self.preprocess_pose( | |
| input_image=input_image, | |
| image_resolution=image_resolution, | |
| detect_resolution=detect_resolution, | |
| ) | |
| return self.process( | |
| task_name='pose', | |
| prompt=prompt, | |
| additional_prompt=additional_prompt, | |
| negative_prompt=negative_prompt, | |
| control_image=control_image, | |
| vis_control_image=vis_control_image, | |
| num_samples=num_samples, | |
| num_steps=num_steps, | |
| guidance_scale=guidance_scale, | |
| seed=seed, | |
| ) | |
| def preprocess_seg( | |
| self, | |
| input_image: np.ndarray, | |
| image_resolution: int, | |
| detect_resolution: int, | |
| ) -> tuple[PIL.Image.Image, PIL.Image.Image]: | |
| input_image = HWC3(input_image) | |
| control_image = apply_uniformer( | |
| resize_image(input_image, detect_resolution)) | |
| image = resize_image(input_image, image_resolution) | |
| H, W = image.shape[:2] | |
| control_image = cv2.resize(control_image, (W, H), | |
| interpolation=cv2.INTER_NEAREST) | |
| return PIL.Image.fromarray(control_image), PIL.Image.fromarray( | |
| control_image) | |
| def process_seg( | |
| self, | |
| input_image: np.ndarray, | |
| prompt: str, | |
| additional_prompt: str, | |
| negative_prompt: str, | |
| num_samples: int, | |
| image_resolution: int, | |
| detect_resolution: int, | |
| num_steps: int, | |
| guidance_scale: float, | |
| seed: int, | |
| ) -> list[PIL.Image.Image]: | |
| control_image, vis_control_image = self.preprocess_seg( | |
| input_image=input_image, | |
| image_resolution=image_resolution, | |
| detect_resolution=detect_resolution, | |
| ) | |
| return self.process( | |
| task_name='seg', | |
| prompt=prompt, | |
| additional_prompt=additional_prompt, | |
| negative_prompt=negative_prompt, | |
| control_image=control_image, | |
| vis_control_image=vis_control_image, | |
| num_samples=num_samples, | |
| num_steps=num_steps, | |
| guidance_scale=guidance_scale, | |
| seed=seed, | |
| ) | |
| def preprocess_depth( | |
| self, | |
| input_image: np.ndarray, | |
| image_resolution: int, | |
| detect_resolution: int, | |
| ) -> tuple[PIL.Image.Image, PIL.Image.Image]: | |
| input_image = HWC3(input_image) | |
| control_image, _ = apply_midas( | |
| resize_image(input_image, detect_resolution)) | |
| control_image = HWC3(control_image) | |
| return PIL.Image.fromarray(control_image), PIL.Image.fromarray( | |
| control_image) | |
| def process_depth( | |
| self, | |
| input_image: np.ndarray, | |
| prompt: str, | |
| additional_prompt: str, | |
| negative_prompt: str, | |
| num_samples: int, | |
| image_resolution: int, | |
| detect_resolution: int, | |
| num_steps: int, | |
| guidance_scale: float, | |
| seed: int, | |
| ) -> list[PIL.Image.Image]: | |
| control_image, vis_control_image = self.preprocess_depth( | |
| input_image=input_image, | |
| image_resolution=image_resolution, | |
| detect_resolution=detect_resolution, | |
| ) | |
| return self.process( | |
| task_name='depth', | |
| prompt=prompt, | |
| additional_prompt=additional_prompt, | |
| negative_prompt=negative_prompt, | |
| control_image=control_image, | |
| vis_control_image=vis_control_image, | |
| num_samples=num_samples, | |
| num_steps=num_steps, | |
| guidance_scale=guidance_scale, | |
| seed=seed, | |
| ) | |
| def preprocess_normal( | |
| self, | |
| input_image: np.ndarray, | |
| image_resolution: int, | |
| detect_resolution: int, | |
| bg_threshold, | |
| ) -> tuple[PIL.Image.Image, PIL.Image.Image]: | |
| input_image = HWC3(input_image) | |
| _, control_image = apply_midas(resize_image(input_image, | |
| detect_resolution), | |
| bg_th=bg_threshold) | |
| control_image = HWC3(control_image) | |
| image = resize_image(input_image, image_resolution) | |
| H, W = image.shape[:2] | |
| control_image = cv2.resize(control_image, (W, H), | |
| interpolation=cv2.INTER_LINEAR) | |
| return PIL.Image.fromarray(control_image), PIL.Image.fromarray( | |
| control_image) | |
| def process_normal( | |
| self, | |
| input_image: np.ndarray, | |
| prompt: str, | |
| additional_prompt: str, | |
| negative_prompt: str, | |
| num_samples: int, | |
| image_resolution: int, | |
| detect_resolution: int, | |
| num_steps: int, | |
| guidance_scale: float, | |
| seed: int, | |
| bg_threshold, | |
| ) -> list[PIL.Image.Image]: | |
| control_image, vis_control_image = self.preprocess_normal( | |
| input_image=input_image, | |
| image_resolution=image_resolution, | |
| detect_resolution=detect_resolution, | |
| bg_threshold=bg_threshold, | |
| ) | |
| return self.process( | |
| task_name='normal', | |
| prompt=prompt, | |
| additional_prompt=additional_prompt, | |
| negative_prompt=negative_prompt, | |
| control_image=control_image, | |
| vis_control_image=vis_control_image, | |
| num_samples=num_samples, | |
| num_steps=num_steps, | |
| guidance_scale=guidance_scale, | |
| seed=seed, | |
| ) | |