|
from diffusers import DiffusionPipeline |
|
import torch |
|
import random |
|
import numpy as np |
|
import importlib.util |
|
import sys |
|
from huggingface_hub import hf_hub_download |
|
from safetensors.torch import load_file |
|
import os |
|
from torchvision.utils import save_image, make_grid |
|
from PIL import Image |
|
from safetensors.torch import load_file |
|
from .vq_model import VQ_models |
|
from .arpg import ARPG_models |
|
|
|
|
|
class ARPGModel(DiffusionPipeline): |
|
|
|
def __init__(self): |
|
super().__init__() |
|
|
|
@torch.no_grad() |
|
def __call__(self, *args, **kwargs): |
|
""" |
|
This method downloads the model and VAE components, |
|
then executes the forward pass based on the user's input. |
|
""" |
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
model_type = kwargs.get("model_type", "ARPG-XXL") |
|
|
|
|
|
if model_type == "ARPG-L": |
|
model_path = "arpg_300m.pt" |
|
elif model_type == "ARPG-XL": |
|
model_path = "arpg_700m.pt" |
|
elif model_type == "ARPG-XXL": |
|
model_path = "arpg_1b.pt" |
|
else: |
|
raise NotImplementedError |
|
|
|
|
|
model_checkpoint_path = hf_hub_download( |
|
repo_id=kwargs.get("repo_id", "hp-l33/ARPG"), |
|
filename=kwargs.get("model_filename", model_path) |
|
) |
|
|
|
model_fn = ARPG_models[model_type] |
|
model = model_fn( |
|
num_classes=1000, |
|
vocab_size=16384 |
|
).cuda() |
|
|
|
state_dict = torch.load(model_checkpoint_path)['state_dict'] |
|
model.load_state_dict(state_dict) |
|
model.eval() |
|
|
|
|
|
vae_checkpoint_path = hf_hub_download( |
|
repo_id=kwargs.get("repo_id", "FoundationVision/LlamaGen"), |
|
filename=kwargs.get("vae_filename", "vq_ds16_c2i.pt") |
|
) |
|
|
|
vae = VQ_models['VQ-16']() |
|
|
|
vae_state_dict = torch.load(vae_checkpoint_path)['model'] |
|
vae.load_state_dict(vae_state_dict) |
|
vae = vae.to(device).eval() |
|
|
|
|
|
seed = kwargs.get("seed", 6) |
|
torch.manual_seed(seed) |
|
np.random.seed(seed) |
|
random.seed(seed) |
|
|
|
num_steps = kwargs.get("num_steps", 64) |
|
cfg_scale = kwargs.get("cfg_scale", 4) |
|
cfg_schedule = kwargs.get("cfg_schedule", "constant") |
|
sample_schedule = kwargs.get("sample_schedule", "arccos") |
|
temperature = kwargs.get("temperature", 1.0) |
|
top_k = kwargs.get("top_k", 600) |
|
class_labels = kwargs.get("class_labels", [207, 360, 388, 113, 355, 980, 323, 979]) |
|
|
|
|
|
with torch.cuda.amp.autocast(): |
|
sampled_tokens = model.generate( |
|
condition=torch.Tensor(class_labels).long().cuda(), |
|
num_iter=num_steps, |
|
guidance_scale=cfg_scale, |
|
cfg_schedule=cfg_schedule, |
|
sample_schedule=sample_schedule, |
|
temperature=temperature, |
|
top_k=top_k, |
|
) |
|
sampled_images = vae.decode_code(sampled_tokens, shape=(len(class_labels), 8, 16, 16)) |
|
|
|
output_dir = kwargs.get("output_dir", "./") |
|
os.makedirs(output_dir, exist_ok=True) |
|
|
|
|
|
image_path = os.path.join(output_dir, "sampled_image.png") |
|
samples_per_row = kwargs.get("samples_per_row", 4) |
|
|
|
ndarr = make_grid( |
|
torch.clamp(127.5 * sampled_images + 128.0, 0, 255), |
|
nrow=int(samples_per_row) |
|
).permute(1, 2, 0).to("cpu", dtype=torch.uint8).numpy() |
|
|
|
Image.fromarray(ndarr).save(image_path) |
|
|
|
|
|
image = Image.open(image_path) |
|
|
|
return image |
|
|