|
import gc |
|
import spaces |
|
from safetensors.torch import load_file |
|
from autoregressive.models.gpt_t2i import GPT_models |
|
from tokenizer.tokenizer_image.vq_model import VQ_models |
|
from language.t5 import T5Embedder |
|
import torch |
|
import numpy as np |
|
import PIL |
|
from PIL import Image |
|
from condition.canny import CannyDetector |
|
import time |
|
from autoregressive.models.generate import generate |
|
from condition.midas.depth import MidasDetector |
|
|
|
|
|
|
|
|
|
|
|
models = { |
|
"canny": "checkpoints/canny_MR.safetensors", |
|
"depth": "checkpoints/depth_MR.safetensors", |
|
} |
|
|
|
|
|
def resize_image_to_16_multiple(image, condition_type='canny'): |
|
if isinstance(image, np.ndarray): |
|
image = Image.fromarray(image) |
|
|
|
width, height = image.size |
|
|
|
if condition_type == 'depth': |
|
new_width = (width + 31) // 32 * 32 |
|
new_height = (height + 31) // 32 * 32 |
|
else: |
|
new_width = (width + 15) // 16 * 16 |
|
new_height = (height + 15) // 16 * 16 |
|
|
|
resized_image = image.resize((new_width, new_height)) |
|
return resized_image |
|
|
|
|
|
class Model: |
|
|
|
def __init__(self): |
|
self.device = torch.device( |
|
"cuda") |
|
self.base_model_id = "" |
|
self.task_name = "" |
|
self.vq_model = self.load_vq() |
|
self.t5_model = self.load_t5() |
|
self.gpt_model_canny = self.load_gpt(condition_type='canny') |
|
|
|
self.get_control_canny = CannyDetector() |
|
|
|
|
|
|
|
def to(self, device): |
|
self.gpt_model_canny.to('cuda') |
|
|
|
|
|
|
|
def load_vq(self): |
|
vq_model = VQ_models["VQ-16"](codebook_size=16384, |
|
codebook_embed_dim=8) |
|
|
|
vq_model.eval() |
|
checkpoint = torch.load(f"checkpoints/vq_ds16_t2i.pt", |
|
map_location="cpu") |
|
vq_model.load_state_dict(checkpoint["model"]) |
|
del checkpoint |
|
print("image tokenizer is loaded") |
|
return vq_model |
|
|
|
def load_gpt(self, condition_type='canny'): |
|
gpt_ckpt = models[condition_type] |
|
|
|
precision = torch.float32 |
|
latent_size = 768 // 16 |
|
gpt_model = GPT_models["GPT-XL"]( |
|
block_size=latent_size**2, |
|
cls_token_num=120, |
|
model_type='t2i', |
|
condition_type=condition_type, |
|
).to(device='cpu', dtype=precision) |
|
|
|
model_weight = load_file(gpt_ckpt) |
|
print("prev:", model_weight['adapter.model.embeddings.patch_embeddings.projection.weight']) |
|
gpt_model.load_state_dict(model_weight, strict=True) |
|
gpt_model.eval() |
|
print("loaded:", gpt_model.adapter.model.embeddings.patch_embeddings.projection.weight) |
|
print("gpt model is loaded") |
|
return gpt_model |
|
|
|
def load_t5(self): |
|
|
|
precision = torch.float32 |
|
t5_model = T5Embedder( |
|
device=self.device, |
|
local_cache=True, |
|
cache_dir='checkpoints/flan-t5-xl', |
|
dir_or_name='flan-t5-xl', |
|
torch_dtype=precision, |
|
model_max_length=120, |
|
) |
|
return t5_model |
|
|
|
@torch.no_grad() |
|
@spaces.GPU(enable_queue=True) |
|
def process_canny( |
|
self, |
|
image: np.ndarray, |
|
prompt: str, |
|
cfg_scale: float, |
|
temperature: float, |
|
top_k: int, |
|
top_p: int, |
|
seed: int, |
|
low_threshold: int, |
|
high_threshold: int, |
|
) -> list[PIL.Image.Image]: |
|
print(image) |
|
image = resize_image_to_16_multiple(image, 'canny') |
|
W, H = image.size |
|
print(W, H) |
|
|
|
self.t5_model.model.to('cuda').to(torch.bfloat16) |
|
self.gpt_model_canny.to('cuda').to(torch.bfloat16) |
|
self.vq_model.to('cuda') |
|
|
|
|
|
condition_img = self.get_control_canny(np.array(image), low_threshold, |
|
high_threshold) |
|
condition_img = torch.from_numpy(condition_img[None, None, |
|
...]).repeat( |
|
2, 3, 1, 1) |
|
condition_img = condition_img.to(self.device) |
|
condition_img = 2 * (condition_img / 255 - 0.5) |
|
prompts = [prompt] * 2 |
|
caption_embs, emb_masks = self.t5_model.get_text_embeddings(prompts) |
|
|
|
print(f"processing left-padding...") |
|
new_emb_masks = torch.flip(emb_masks, dims=[-1]) |
|
new_caption_embs = [] |
|
for idx, (caption_emb, |
|
emb_mask) in enumerate(zip(caption_embs, emb_masks)): |
|
valid_num = int(emb_mask.sum().item()) |
|
print(f' prompt {idx} token len: {valid_num}') |
|
new_caption_emb = torch.cat( |
|
[caption_emb[valid_num:], caption_emb[:valid_num]]) |
|
new_caption_embs.append(new_caption_emb) |
|
new_caption_embs = torch.stack(new_caption_embs) |
|
c_indices = new_caption_embs * new_emb_masks[:, :, None] |
|
c_emb_masks = new_emb_masks |
|
qzshape = [len(c_indices), 8, H // 16, W // 16] |
|
t1 = time.time() |
|
print(caption_embs.device) |
|
index_sample = generate( |
|
self.gpt_model_canny, |
|
c_indices, |
|
(H // 16) * (W // 16), |
|
c_emb_masks, |
|
condition=condition_img, |
|
cfg_scale=cfg_scale, |
|
temperature=temperature, |
|
top_k=top_k, |
|
top_p=top_p, |
|
sample_logits=True, |
|
) |
|
sampling_time = time.time() - t1 |
|
print(f"Full sampling takes about {sampling_time:.2f} seconds.") |
|
|
|
t2 = time.time() |
|
print(index_sample.shape) |
|
samples = self.vq_model.decode_code( |
|
index_sample, qzshape) |
|
decoder_time = time.time() - t2 |
|
print(f"decoder takes about {decoder_time:.2f} seconds.") |
|
|
|
samples = torch.cat((condition_img[0:1], samples), dim=0) |
|
samples = 255 * (samples * 0.5 + 0.5) |
|
samples = [image] + [ |
|
Image.fromarray( |
|
sample.permute(1, 2, 0).cpu().detach().numpy().clip( |
|
0, 255).astype(np.uint8)) for sample in samples |
|
] |
|
del condition_img |
|
torch.cuda.empty_cache() |
|
return samples |
|
|
|
@torch.no_grad() |
|
@spaces.GPU(enable_queue=True) |
|
def process_depth( |
|
self, |
|
image: np.ndarray, |
|
prompt: str, |
|
cfg_scale: float, |
|
temperature: float, |
|
top_k: int, |
|
top_p: int, |
|
seed: int, |
|
) -> list[PIL.Image.Image]: |
|
image = resize_image_to_16_multiple(image, 'depth') |
|
W, H = image.size |
|
print(W, H) |
|
self.gpt_model_canny.to('cpu') |
|
self.t5_model.model.to(self.device) |
|
self.gpt_model_depth.to(self.device) |
|
self.get_control_depth.model.to(self.device) |
|
self.vq_model.to(self.device) |
|
image_tensor = torch.from_numpy(np.array(image)).to(self.device) |
|
|
|
|
|
|
|
|
|
|
|
condition_img = 2 * (image_tensor / 255 - 0.5) |
|
print(condition_img.shape) |
|
condition_img = condition_img.permute(2,0,1).unsqueeze(0).repeat(2, 1, 1, 1) |
|
|
|
|
|
|
|
|
|
|
|
|
|
prompts = [prompt] * 2 |
|
caption_embs, emb_masks = self.t5_model.get_text_embeddings(prompts) |
|
|
|
print(f"processing left-padding...") |
|
new_emb_masks = torch.flip(emb_masks, dims=[-1]) |
|
new_caption_embs = [] |
|
for idx, (caption_emb, |
|
emb_mask) in enumerate(zip(caption_embs, emb_masks)): |
|
valid_num = int(emb_mask.sum().item()) |
|
print(f' prompt {idx} token len: {valid_num}') |
|
new_caption_emb = torch.cat( |
|
[caption_emb[valid_num:], caption_emb[:valid_num]]) |
|
new_caption_embs.append(new_caption_emb) |
|
new_caption_embs = torch.stack(new_caption_embs) |
|
|
|
c_indices = new_caption_embs * new_emb_masks[:, :, None] |
|
c_emb_masks = new_emb_masks |
|
qzshape = [len(c_indices), 8, H // 16, W // 16] |
|
t1 = time.time() |
|
index_sample = generate( |
|
self.gpt_model_depth, |
|
c_indices, |
|
(H // 16) * (W // 16), |
|
c_emb_masks, |
|
condition=condition_img, |
|
cfg_scale=cfg_scale, |
|
temperature=temperature, |
|
top_k=top_k, |
|
top_p=top_p, |
|
sample_logits=True, |
|
) |
|
sampling_time = time.time() - t1 |
|
print(f"Full sampling takes about {sampling_time:.2f} seconds.") |
|
|
|
t2 = time.time() |
|
print(index_sample.shape) |
|
samples = self.vq_model.decode_code(index_sample, qzshape) |
|
decoder_time = time.time() - t2 |
|
print(f"decoder takes about {decoder_time:.2f} seconds.") |
|
condition_img = condition_img.cpu() |
|
samples = samples.cpu() |
|
samples = torch.cat((condition_img[0:1], samples), dim=0) |
|
samples = 255 * (samples * 0.5 + 0.5) |
|
samples = [image] + [ |
|
Image.fromarray( |
|
sample.permute(1, 2, 0).cpu().detach().numpy().clip(0, 255).astype(np.uint8)) |
|
for sample in samples |
|
] |
|
del image_tensor |
|
del condition_img |
|
torch.cuda.empty_cache() |
|
return samples |
|
|