|
from __future__ import annotations |
|
|
|
import gc |
|
import pathlib |
|
import spaces |
|
|
|
import gradio as gr |
|
import PIL.Image |
|
import torch |
|
from diffusers import StableDiffusionXLPipeline |
|
from huggingface_hub import ModelCard |
|
|
|
from blora_utils import BLOCKS, filter_lora, scale_lora |
|
|
|
|
|
class InferencePipeline: |
|
def __init__(self, hf_token: str | None = None): |
|
self.hf_token = hf_token |
|
self.base_model_id = "stabilityai/stable-diffusion-xl-base-1.0" |
|
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
self.pipe = StableDiffusionXLPipeline.from_pretrained( |
|
self.base_model_id, |
|
torch_dtype=torch.float16, |
|
use_auth_token=self.hf_token) |
|
self.content_lora_model_id = None |
|
self.style_lora_model_id = None |
|
|
|
def clear(self) -> None: |
|
self.content_lora_model_id = None |
|
self.style_lora_model_id = None |
|
del self.pipe |
|
self.pipe = None |
|
torch.cuda.empty_cache() |
|
gc.collect() |
|
|
|
def load_b_lora_to_unet(self, content_lora_model_id: str, style_lora_model_id: str, content_alpha: float, |
|
style_alpha: float) -> None: |
|
try: |
|
|
|
if content_lora_model_id and content_lora_model_id != 'None': |
|
content_B_LoRA_sd, _ = self.pipe.lora_state_dict(content_lora_model_id, use_auth_token=self.hf_token) |
|
content_B_LoRA = filter_lora(content_B_LoRA_sd, BLOCKS['content']) |
|
content_B_LoRA = scale_lora(content_B_LoRA, content_alpha) |
|
else: |
|
content_B_LoRA = {} |
|
|
|
|
|
if style_lora_model_id and style_lora_model_id != 'None': |
|
style_B_LoRA_sd, _ = self.pipe.lora_state_dict(style_lora_model_id, use_auth_token=self.hf_token) |
|
style_B_LoRA = filter_lora(style_B_LoRA_sd, BLOCKS['style']) |
|
style_B_LoRA = scale_lora(style_B_LoRA, style_alpha) |
|
else: |
|
style_B_LoRA = {} |
|
|
|
|
|
res_lora = {**content_B_LoRA, **style_B_LoRA} |
|
|
|
|
|
self.pipe.load_lora_into_unet(res_lora, None, self.pipe.unet) |
|
except Exception as e: |
|
raise type(e)(f'failed to load_b_lora_to_unet, due to: {e}') |
|
|
|
@staticmethod |
|
def check_if_model_is_local(lora_model_id: str) -> bool: |
|
return pathlib.Path(lora_model_id).exists() |
|
|
|
@staticmethod |
|
def get_model_card(model_id: str, |
|
hf_token: str | None = None) -> ModelCard: |
|
if InferencePipeline.check_if_model_is_local(model_id): |
|
card_path = (pathlib.Path(model_id) / 'README.md').as_posix() |
|
else: |
|
card_path = model_id |
|
return ModelCard.load(card_path, token=hf_token) |
|
|
|
@staticmethod |
|
def get_base_model_info(lora_model_id: str, |
|
hf_token: str | None = None) -> str: |
|
card = InferencePipeline.get_model_card(lora_model_id, hf_token) |
|
return card.data.base_model |
|
|
|
def load_pipe(self, content_lora_model_id: str, style_lora_model_id: str, content_alpha: float, |
|
style_alpha: float) -> None: |
|
if content_lora_model_id == self.content_lora_model_id and style_lora_model_id == self.style_lora_model_id: |
|
return |
|
self.pipe.unload_lora_weights() |
|
|
|
self.load_b_lora_to_unet(content_lora_model_id, style_lora_model_id, content_alpha, style_alpha) |
|
|
|
self.content_lora_model_id = content_lora_model_id |
|
self.style_lora_model_id = style_lora_model_id |
|
|
|
@spaces.GPU |
|
def inference(self, |
|
prompt: str, |
|
seed: int, |
|
n_steps: int, |
|
guidance_scale: float, |
|
num_images_per_prompt: int = 1 |
|
) -> PIL.Image.Image: |
|
if not torch.cuda.is_available(): |
|
raise gr.Error('CUDA is not available.') |
|
self.pipe.to("cuda") |
|
generator = torch.Generator(device="cuda").manual_seed(seed) |
|
out = self.pipe( |
|
prompt, |
|
num_inference_steps=n_steps, |
|
guidance_scale=guidance_scale, |
|
generator=generator, |
|
num_images_per_prompt=num_images_per_prompt, |
|
) |
|
return out.images |
|
|
|
|
|
def run( |
|
self, |
|
content_lora_model_id: str, |
|
style_lora_model_id: str, |
|
prompt: str, |
|
content_alpha: float, |
|
style_alpha: float, |
|
seed: int, |
|
n_steps: int, |
|
guidance_scale: float, |
|
num_images_per_prompt: int = 1 |
|
) -> PIL.Image.Image: |
|
|
|
self.load_pipe(content_lora_model_id, style_lora_model_id, content_alpha, style_alpha) |
|
|
|
return self.inference( |
|
prompt=prompt, |
|
seed=seed, |
|
n_steps=n_steps, |
|
guidance_scale=guidance_scale, |
|
num_images_per_prompt=num_images_per_prompt, |
|
) |
|
|