chatQwenne / generate_consistent.py
K00B404's picture
Update generate_consistent.py
50e876e verified
# filename: ip_adapter_multi_mode.py
import torch
from diffusers import (
StableDiffusionPipeline,
StableDiffusionImg2ImgPipeline,
StableDiffusionInpaintPipelineLegacy,
DDIMScheduler,
AutoencoderKL,
)
from PIL import Image
from ip_adapter import IPAdapter
class IPAdapterRunner:
def __init__(
self,
base_model_path="runwayml/stable-diffusion-v1-5",
vae_model_path="stabilityai/sd-vae-ft-mse",
image_encoder_path="models/image_encoder/",
ip_ckpt="models/ip-adapter_sd15.bin",
device="cuda"
):
self.base_model_path = base_model_path
self.vae_model_path = vae_model_path
self.image_encoder_path = image_encoder_path
self.ip_ckpt = ip_ckpt
self.device = device
self.vae = self._load_vae()
self.scheduler = self._create_scheduler()
self.pipe = None
self.ip_model = None
def _create_scheduler(self):
return DDIMScheduler(
num_train_timesteps=1000,
beta_start=0.00085,
beta_end=0.012,
beta_schedule="scaled_linear",
clip_sample=False,
set_alpha_to_one=False,
steps_offset=1,
)
def _load_vae(self):
return AutoencoderKL.from_pretrained(self.vae_model_path).to(dtype=torch.float16)
def _clear_previous_pipe(self):
if self.pipe:
del self.pipe
del self.ip_model
torch.cuda.empty_cache()
def _load_pipeline(self, mode):
self._clear_previous_pipe()
if mode == "text2img":
self.pipe = StableDiffusionPipeline.from_pretrained(
self.base_model_path,
torch_dtype=torch.float16,
scheduler=self.scheduler,
vae=self.vae,
feature_extractor=None,
safety_checker=None,
)
elif mode == "img2img":
self.pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
self.base_model_path,
torch_dtype=torch.float16,
scheduler=self.scheduler,
vae=self.vae,
feature_extractor=None,
safety_checker=None,
)
elif mode == "inpaint":
self.pipe = StableDiffusionInpaintPipelineLegacy.from_pretrained(
self.base_model_path,
torch_dtype=torch.float16,
scheduler=self.scheduler,
vae=self.vae,
feature_extractor=None,
safety_checker=None,
)
else:
raise ValueError(f"Unsupported mode: {mode}")
self.ip_model = IPAdapter(self.pipe, self.image_encoder_path, self.ip_ckpt, self.device)
def generate_text2img(self, pil_image, num_samples=4, num_inference_steps=50, seed=42):
self._load_pipeline("text2img")
pil_image = pil_image.resize((256, 256))
return self.ip_model.generate(
pil_image=pil_image,
num_samples=num_samples,
num_inference_steps=num_inference_steps,
seed=seed,
)
def generate_img2img(self, pil_image, reference_image, strength=0.6, num_samples=4, num_inference_steps=50, seed=42):
self._load_pipeline("img2img")
return self.ip_model.generate(
pil_image=pil_image,
image=reference_image,
strength=strength,
num_samples=num_samples,
num_inference_steps=num_inference_steps,
seed=seed,
)
def generate_inpaint(self, pil_image, image, mask_image, strength=0.7, num_samples=4, num_inference_steps=50, seed=42):
self._load_pipeline("inpaint")
return self.ip_model.generate(
pil_image=pil_image,
image=image,
mask_image=mask_image,
strength=strength,
num_samples=num_samples,
num_inference_steps=num_inference_steps,
seed=seed,
)
@staticmethod
def image_grid(imgs, rows, cols):
assert len(imgs) == rows * cols
w, h = imgs[0].size
grid = Image.new('RGB', size=(cols * w, rows * h))
for i, img in enumerate(imgs):
grid.paste(img, box=(i % cols * w, i // cols * h))
return grid