# filename: ip_adapter_multi_mode.py import torch from diffusers import ( StableDiffusionPipeline, StableDiffusionImg2ImgPipeline, StableDiffusionInpaintPipelineLegacy, DDIMScheduler, AutoencoderKL, ) from PIL import Image from ip_adapter import IPAdapter class IPAdapterRunner: def __init__( self, base_model_path="runwayml/stable-diffusion-v1-5", vae_model_path="stabilityai/sd-vae-ft-mse", image_encoder_path="models/image_encoder/", ip_ckpt="models/ip-adapter_sd15.bin", device="cuda" ): self.base_model_path = base_model_path self.vae_model_path = vae_model_path self.image_encoder_path = image_encoder_path self.ip_ckpt = ip_ckpt self.device = device self.vae = self._load_vae() self.scheduler = self._create_scheduler() self.pipe = None self.ip_model = None def _create_scheduler(self): return DDIMScheduler( num_train_timesteps=1000, beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False, steps_offset=1, ) def _load_vae(self): return AutoencoderKL.from_pretrained(self.vae_model_path).to(dtype=torch.float16) def _clear_previous_pipe(self): if self.pipe: del self.pipe del self.ip_model torch.cuda.empty_cache() def _load_pipeline(self, mode): self._clear_previous_pipe() if mode == "text2img": self.pipe = StableDiffusionPipeline.from_pretrained( self.base_model_path, torch_dtype=torch.float16, scheduler=self.scheduler, vae=self.vae, feature_extractor=None, safety_checker=None, ) elif mode == "img2img": self.pipe = StableDiffusionImg2ImgPipeline.from_pretrained( self.base_model_path, torch_dtype=torch.float16, scheduler=self.scheduler, vae=self.vae, feature_extractor=None, safety_checker=None, ) elif mode == "inpaint": self.pipe = StableDiffusionInpaintPipelineLegacy.from_pretrained( self.base_model_path, torch_dtype=torch.float16, scheduler=self.scheduler, vae=self.vae, feature_extractor=None, safety_checker=None, ) else: raise ValueError(f"Unsupported mode: {mode}") self.ip_model = IPAdapter(self.pipe, self.image_encoder_path, self.ip_ckpt, self.device) def generate_text2img(self, pil_image, num_samples=4, num_inference_steps=50, seed=42): self._load_pipeline("text2img") pil_image = pil_image.resize((256, 256)) return self.ip_model.generate( pil_image=pil_image, num_samples=num_samples, num_inference_steps=num_inference_steps, seed=seed, ) def generate_img2img(self, pil_image, reference_image, strength=0.6, num_samples=4, num_inference_steps=50, seed=42): self._load_pipeline("img2img") return self.ip_model.generate( pil_image=pil_image, image=reference_image, strength=strength, num_samples=num_samples, num_inference_steps=num_inference_steps, seed=seed, ) def generate_inpaint(self, pil_image, image, mask_image, strength=0.7, num_samples=4, num_inference_steps=50, seed=42): self._load_pipeline("inpaint") return self.ip_model.generate( pil_image=pil_image, image=image, mask_image=mask_image, strength=strength, num_samples=num_samples, num_inference_steps=num_inference_steps, seed=seed, ) @staticmethod def image_grid(imgs, rows, cols): assert len(imgs) == rows * cols w, h = imgs[0].size grid = Image.new('RGB', size=(cols * w, rows * h)) for i, img in enumerate(imgs): grid.paste(img, box=(i % cols * w, i // cols * h)) return grid