Spaces:
Runtime error
Runtime error
| import gc | |
| import spaces | |
| from safetensors.torch import load_file | |
| from autoregressive.models.gpt_t2i import GPT_models | |
| from tokenizer.tokenizer_image.vq_model import VQ_models | |
| from language.t5 import T5Embedder | |
| import torch | |
| import numpy as np | |
| import PIL | |
| from PIL import Image | |
| from condition.canny import CannyDetector | |
| import time | |
| from autoregressive.models.generate import generate | |
| from condition.midas.depth import MidasDetector | |
| # from controlnet_aux import ( | |
| # MidasDetector, | |
| # ) | |
| models = { | |
| "canny": "checkpoints/canny_MR.safetensors", | |
| "depth": "checkpoints/depth_MR.safetensors", | |
| } | |
| def resize_image_to_16_multiple(image, condition_type='canny'): | |
| if isinstance(image, np.ndarray): | |
| image = Image.fromarray(image) | |
| # image = Image.open(image_path) | |
| width, height = image.size | |
| if condition_type == 'depth': # The depth model requires a side length that is a multiple of 32 | |
| new_width = (width + 31) // 32 * 32 | |
| new_height = (height + 31) // 32 * 32 | |
| else: | |
| new_width = (width + 15) // 16 * 16 | |
| new_height = (height + 15) // 16 * 16 | |
| resized_image = image.resize((new_width, new_height)) | |
| return resized_image | |
| class Model: | |
| def __init__(self): | |
| self.device = torch.device( | |
| "cuda") | |
| self.base_model_id = "" | |
| self.task_name = "" | |
| self.vq_model = self.load_vq() | |
| self.t5_model = self.load_t5() | |
| self.gpt_model_canny = self.load_gpt(condition_type='canny') | |
| # self.gpt_model_depth = self.load_gpt(condition_type='depth') | |
| self.get_control_canny = CannyDetector() | |
| # self.get_control_depth = MidasDetector('cuda') | |
| # self.get_control_depth = MidasDetector.from_pretrained("lllyasviel/Annotators") | |
| def to(self, device): | |
| self.gpt_model_canny.to('cuda') | |
| # print(next(self.gpt_model_canny.adapter.parameters()).device) | |
| # print(self.gpt_model_canny.device) | |
| def load_vq(self): | |
| vq_model = VQ_models["VQ-16"](codebook_size=16384, | |
| codebook_embed_dim=8) | |
| # vq_model.to('cuda') | |
| vq_model.eval() | |
| checkpoint = torch.load(f"checkpoints/vq_ds16_t2i.pt", | |
| map_location="cpu") | |
| vq_model.load_state_dict(checkpoint["model"]) | |
| del checkpoint | |
| print("image tokenizer is loaded") | |
| return vq_model | |
| def load_gpt(self, condition_type='canny'): | |
| gpt_ckpt = models[condition_type] | |
| # precision = torch.bfloat16 | |
| precision = torch.float32 | |
| latent_size = 768 // 16 | |
| gpt_model = GPT_models["GPT-XL"]( | |
| block_size=latent_size**2, | |
| cls_token_num=120, | |
| model_type='t2i', | |
| condition_type=condition_type, | |
| ).to(device='cpu', dtype=precision) | |
| model_weight = load_file(gpt_ckpt) | |
| print("prev:", model_weight['adapter.model.embeddings.patch_embeddings.projection.weight']) | |
| gpt_model.load_state_dict(model_weight, strict=True) | |
| gpt_model.eval() | |
| print("loaded:", gpt_model.adapter.model.embeddings.patch_embeddings.projection.weight) | |
| print("gpt model is loaded") | |
| return gpt_model | |
| def load_t5(self): | |
| # precision = torch.bfloat16 | |
| precision = torch.float32 | |
| t5_model = T5Embedder( | |
| device=self.device, | |
| local_cache=True, | |
| cache_dir='checkpoints/flan-t5-xl', | |
| dir_or_name='flan-t5-xl', | |
| torch_dtype=precision, | |
| model_max_length=120, | |
| ) | |
| return t5_model | |
| def process_canny( | |
| self, | |
| image: np.ndarray, | |
| prompt: str, | |
| cfg_scale: float, | |
| temperature: float, | |
| top_k: int, | |
| top_p: int, | |
| seed: int, | |
| low_threshold: int, | |
| high_threshold: int, | |
| ) -> list[PIL.Image.Image]: | |
| print(image) | |
| image = resize_image_to_16_multiple(image, 'canny') | |
| W, H = image.size | |
| print(W, H) | |
| # self.gpt_model_depth.to('cpu') | |
| self.t5_model.model.to('cuda').to(torch.bfloat16) | |
| self.gpt_model_canny.to('cuda').to(torch.bfloat16) | |
| self.vq_model.to('cuda') | |
| # print("after cuda", self.gpt_model_canny.adapter.model.embeddings.patch_embeddings.projection.weight) | |
| condition_img = self.get_control_canny(np.array(image), low_threshold, | |
| high_threshold) | |
| condition_img = torch.from_numpy(condition_img[None, None, | |
| ...]).repeat( | |
| 2, 3, 1, 1) | |
| condition_img = condition_img.to(self.device) | |
| condition_img = 2 * (condition_img / 255 - 0.5) | |
| prompts = [prompt] * 2 | |
| caption_embs, emb_masks = self.t5_model.get_text_embeddings(prompts) | |
| print(f"processing left-padding...") | |
| new_emb_masks = torch.flip(emb_masks, dims=[-1]) | |
| new_caption_embs = [] | |
| for idx, (caption_emb, | |
| emb_mask) in enumerate(zip(caption_embs, emb_masks)): | |
| valid_num = int(emb_mask.sum().item()) | |
| print(f' prompt {idx} token len: {valid_num}') | |
| new_caption_emb = torch.cat( | |
| [caption_emb[valid_num:], caption_emb[:valid_num]]) | |
| new_caption_embs.append(new_caption_emb) | |
| new_caption_embs = torch.stack(new_caption_embs) | |
| c_indices = new_caption_embs * new_emb_masks[:, :, None] | |
| c_emb_masks = new_emb_masks | |
| qzshape = [len(c_indices), 8, H // 16, W // 16] | |
| t1 = time.time() | |
| print(caption_embs.device) | |
| index_sample = generate( | |
| self.gpt_model_canny, | |
| c_indices, | |
| (H // 16) * (W // 16), | |
| c_emb_masks, | |
| condition=condition_img, | |
| cfg_scale=cfg_scale, | |
| temperature=temperature, | |
| top_k=top_k, | |
| top_p=top_p, | |
| sample_logits=True, | |
| ) | |
| sampling_time = time.time() - t1 | |
| print(f"Full sampling takes about {sampling_time:.2f} seconds.") | |
| t2 = time.time() | |
| print(index_sample.shape) | |
| samples = self.vq_model.decode_code( | |
| index_sample, qzshape) # output value is between [-1, 1] | |
| decoder_time = time.time() - t2 | |
| print(f"decoder takes about {decoder_time:.2f} seconds.") | |
| samples = torch.cat((condition_img[0:1], samples), dim=0) | |
| samples = 255 * (samples * 0.5 + 0.5) | |
| samples = [image] + [ | |
| Image.fromarray( | |
| sample.permute(1, 2, 0).cpu().detach().numpy().clip( | |
| 0, 255).astype(np.uint8)) for sample in samples | |
| ] | |
| del condition_img | |
| torch.cuda.empty_cache() | |
| return samples | |
| def process_depth( | |
| self, | |
| image: np.ndarray, | |
| prompt: str, | |
| cfg_scale: float, | |
| temperature: float, | |
| top_k: int, | |
| top_p: int, | |
| seed: int, | |
| ) -> list[PIL.Image.Image]: | |
| image = resize_image_to_16_multiple(image, 'depth') | |
| W, H = image.size | |
| print(W, H) | |
| self.gpt_model_canny.to('cpu') | |
| self.t5_model.model.to(self.device) | |
| self.gpt_model_depth.to(self.device) | |
| self.get_control_depth.model.to(self.device) | |
| self.vq_model.to(self.device) | |
| image_tensor = torch.from_numpy(np.array(image)).to(self.device) | |
| # condition_img = torch.from_numpy( | |
| # self.get_control_depth(image_tensor)).unsqueeze(0) | |
| # condition_img = condition_img.unsqueeze(0).repeat(2, 3, 1, 1) | |
| # condition_img = condition_img.to(self.device) | |
| # condition_img = 2 * (condition_img / 255 - 0.5) | |
| condition_img = 2 * (image_tensor / 255 - 0.5) | |
| print(condition_img.shape) | |
| condition_img = condition_img.permute(2,0,1).unsqueeze(0).repeat(2, 1, 1, 1) | |
| # control_image = self.get_control_depth( | |
| # image=image, | |
| # image_resolution=512, | |
| # detect_resolution=512, | |
| # ) | |
| prompts = [prompt] * 2 | |
| caption_embs, emb_masks = self.t5_model.get_text_embeddings(prompts) | |
| print(f"processing left-padding...") | |
| new_emb_masks = torch.flip(emb_masks, dims=[-1]) | |
| new_caption_embs = [] | |
| for idx, (caption_emb, | |
| emb_mask) in enumerate(zip(caption_embs, emb_masks)): | |
| valid_num = int(emb_mask.sum().item()) | |
| print(f' prompt {idx} token len: {valid_num}') | |
| new_caption_emb = torch.cat( | |
| [caption_emb[valid_num:], caption_emb[:valid_num]]) | |
| new_caption_embs.append(new_caption_emb) | |
| new_caption_embs = torch.stack(new_caption_embs) | |
| c_indices = new_caption_embs * new_emb_masks[:, :, None] | |
| c_emb_masks = new_emb_masks | |
| qzshape = [len(c_indices), 8, H // 16, W // 16] | |
| t1 = time.time() | |
| index_sample = generate( | |
| self.gpt_model_depth, | |
| c_indices, | |
| (H // 16) * (W // 16), | |
| c_emb_masks, | |
| condition=condition_img, | |
| cfg_scale=cfg_scale, | |
| temperature=temperature, | |
| top_k=top_k, | |
| top_p=top_p, | |
| sample_logits=True, | |
| ) | |
| sampling_time = time.time() - t1 | |
| print(f"Full sampling takes about {sampling_time:.2f} seconds.") | |
| t2 = time.time() | |
| print(index_sample.shape) | |
| samples = self.vq_model.decode_code(index_sample, qzshape) | |
| decoder_time = time.time() - t2 | |
| print(f"decoder takes about {decoder_time:.2f} seconds.") | |
| condition_img = condition_img.cpu() | |
| samples = samples.cpu() | |
| samples = torch.cat((condition_img[0:1], samples), dim=0) | |
| samples = 255 * (samples * 0.5 + 0.5) | |
| samples = [image] + [ | |
| Image.fromarray( | |
| sample.permute(1, 2, 0).cpu().detach().numpy().clip(0, 255).astype(np.uint8)) | |
| for sample in samples | |
| ] | |
| del image_tensor | |
| del condition_img | |
| torch.cuda.empty_cache() | |
| return samples | |