ARPG / pipeline.py
hp-l33's picture
Update pipeline.py
0d85cab verified
from diffusers import DiffusionPipeline
import torch
import random
import numpy as np
import importlib.util
import sys
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
import os
from torchvision.utils import save_image, make_grid
from PIL import Image
from safetensors.torch import load_file
from .vq_model import VQ_models
from .arpg import ARPG_models
# inheriting from DiffusionPipeline for HF
class ARPGModel(DiffusionPipeline):
def __init__(self):
super().__init__()
@torch.no_grad()
def __call__(self, *args, **kwargs):
"""
This method downloads the model and VAE components,
then executes the forward pass based on the user's input.
"""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# init the mar model architecture
model_type = kwargs.get("model_type", "ARPG-XXL")
# download the pretrained model and set diffloss parameters
if model_type == "ARPG-L":
model_path = "arpg_300m.pt"
elif model_type == "ARPG-XL":
model_path = "arpg_700m.pt"
elif model_type == "ARPG-XXL":
model_path = "arpg_1b.pt"
else:
raise NotImplementedError
# download and load the model weights (.safetensors or .pth)
model_checkpoint_path = hf_hub_download(
repo_id=kwargs.get("repo_id", "hp-l33/ARPG"),
filename=kwargs.get("model_filename", model_path)
)
model_fn = ARPG_models[model_type]
model = model_fn(
num_classes=1000,
vocab_size=16384
).cuda()
state_dict = torch.load(model_checkpoint_path)['state_dict']
model.load_state_dict(state_dict)
model.eval()
# download and load the vae
vae_checkpoint_path = hf_hub_download(
repo_id=kwargs.get("repo_id", "FoundationVision/LlamaGen"),
filename=kwargs.get("vae_filename", "vq_ds16_c2i.pt")
)
vae = VQ_models['VQ-16']()
vae_state_dict = torch.load(vae_checkpoint_path)['model']
vae.load_state_dict(vae_state_dict)
vae = vae.to(device).eval()
# set up user-specified or default values for generation
seed = kwargs.get("seed", 6)
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
num_steps = kwargs.get("num_steps", 64)
cfg_scale = kwargs.get("cfg_scale", 4)
cfg_schedule = kwargs.get("cfg_schedule", "constant")
sample_schedule = kwargs.get("sample_schedule", "arccos")
temperature = kwargs.get("temperature", 1.0)
top_k = kwargs.get("top_k", 600)
class_labels = kwargs.get("class_labels", [207, 360, 388, 113, 355, 980, 323, 979])
# generate the tokens and images
with torch.cuda.amp.autocast():
sampled_tokens = model.generate(
condition=torch.Tensor(class_labels).long().cuda(),
num_iter=num_steps,
guidance_scale=cfg_scale,
cfg_schedule=cfg_schedule,
sample_schedule=sample_schedule,
temperature=temperature,
top_k=top_k,
)
sampled_images = vae.decode_code(sampled_tokens, shape=(len(class_labels), 8, 16, 16))
output_dir = kwargs.get("output_dir", "./")
os.makedirs(output_dir, exist_ok=True)
# save the images
image_path = os.path.join(output_dir, "sampled_image.png")
samples_per_row = kwargs.get("samples_per_row", 4)
ndarr = make_grid(
torch.clamp(127.5 * sampled_images + 128.0, 0, 255),
nrow=int(samples_per_row)
).permute(1, 2, 0).to("cpu", dtype=torch.uint8).numpy()
Image.fromarray(ndarr).save(image_path)
# return as a pil image
image = Image.open(image_path)
return image