Spaces:
Runtime error
Runtime error
| from einops import rearrange | |
| import math | |
| from typing import List, Optional, Union | |
| import time | |
| import torch | |
| import torch.nn.functional as F | |
| from diffusers.utils.torch_utils import randn_tensor | |
| from diffusers.models.embeddings import get_2d_rotary_pos_embed | |
| class PixelFlowPipeline: | |
| def __init__( | |
| self, | |
| scheduler, | |
| transformer, | |
| text_encoder=None, | |
| tokenizer=None, | |
| max_token_length=512, | |
| ): | |
| super().__init__() | |
| self.class_cond = text_encoder is None or tokenizer is None | |
| self.scheduler = scheduler | |
| self.transformer = transformer | |
| self.patch_size = transformer.patch_size | |
| self.head_dim = transformer.attention_head_dim | |
| self.num_stages = scheduler.num_stages | |
| self.text_encoder = text_encoder | |
| self.tokenizer = tokenizer | |
| self.max_token_length = max_token_length | |
| def encode_prompt( | |
| self, | |
| prompt: Union[str, List[str]], | |
| device: Optional[torch.device] = None, | |
| num_images_per_prompt: int = 1, | |
| do_classifier_free_guidance: bool = True, | |
| negative_prompt: Union[str, List[str]] = "", | |
| prompt_embeds: Optional[torch.FloatTensor] = None, | |
| negative_prompt_embeds: Optional[torch.FloatTensor] = None, | |
| prompt_attention_mask: Optional[torch.FloatTensor] = None, | |
| negative_prompt_attention_mask: Optional[torch.FloatTensor] = None, | |
| use_attention_mask: bool = False, | |
| max_length: int = 512, | |
| ): | |
| # Determine the batch size and normalize prompt input to a list | |
| if prompt is not None: | |
| if isinstance(prompt, str): | |
| prompt = [prompt] | |
| batch_size = len(prompt) | |
| else: | |
| batch_size = prompt_embeds.shape[0] | |
| # Process prompt embeddings if not provided | |
| if prompt_embeds is None: | |
| text_inputs = self.tokenizer( | |
| prompt, | |
| padding="max_length", | |
| max_length=max_length, | |
| truncation=True, | |
| add_special_tokens=True, | |
| return_tensors="pt", | |
| ) | |
| text_input_ids = text_inputs.input_ids.to(device) | |
| prompt_attention_mask = text_inputs.attention_mask.to(device) | |
| prompt_embeds = self.text_encoder( | |
| text_input_ids, | |
| attention_mask=prompt_attention_mask if use_attention_mask else None | |
| )[0] | |
| # Determine dtype from available encoder | |
| if self.text_encoder is not None: | |
| dtype = self.text_encoder.dtype | |
| elif self.transformer is not None: | |
| dtype = self.transformer.dtype | |
| else: | |
| dtype = None | |
| # Move prompt embeddings to desired dtype and device | |
| prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) | |
| bs_embed, seq_len, _ = prompt_embeds.shape | |
| prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) | |
| prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) | |
| prompt_attention_mask = prompt_attention_mask.view(bs_embed, -1).repeat(num_images_per_prompt, 1) | |
| # Handle classifier-free guidance for negative prompts | |
| if do_classifier_free_guidance and negative_prompt_embeds is None: | |
| # Normalize negative prompt to list and validate length | |
| if isinstance(negative_prompt, str): | |
| uncond_tokens = [negative_prompt] * batch_size | |
| elif isinstance(negative_prompt, list): | |
| if len(negative_prompt) != batch_size: | |
| raise ValueError(f"The negative prompt list must have the same length as the prompt list, but got {len(negative_prompt)} and {batch_size}") | |
| uncond_tokens = negative_prompt | |
| else: | |
| raise ValueError(f"Negative prompt must be a string or a list of strings, but got {type(negative_prompt)}") | |
| # Tokenize and encode negative prompts | |
| uncond_inputs = self.tokenizer( | |
| uncond_tokens, | |
| padding="max_length", | |
| max_length=prompt_embeds.shape[1], | |
| truncation=True, | |
| return_attention_mask=True, | |
| add_special_tokens=True, | |
| return_tensors="pt", | |
| ) | |
| negative_input_ids = uncond_inputs.input_ids.to(device) | |
| negative_prompt_attention_mask = uncond_inputs.attention_mask.to(device) | |
| negative_prompt_embeds = self.text_encoder( | |
| negative_input_ids, | |
| attention_mask=negative_prompt_attention_mask if use_attention_mask else None | |
| )[0] | |
| if do_classifier_free_guidance: | |
| # Duplicate negative prompt embeddings and attention mask for each generation | |
| seq_len_neg = negative_prompt_embeds.shape[1] | |
| negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device) | |
| negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) | |
| negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len_neg, -1) | |
| negative_prompt_attention_mask = negative_prompt_attention_mask.view(bs_embed, -1).repeat(num_images_per_prompt, 1) | |
| else: | |
| negative_prompt_embeds = None | |
| negative_prompt_attention_mask = None | |
| # Concatenate negative and positive embeddings and their masks | |
| prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) | |
| prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0) | |
| return prompt_embeds, prompt_attention_mask | |
| def sample_block_noise(self, bs, ch, height, width, eps=1e-6): | |
| gamma = self.scheduler.gamma | |
| dist = torch.distributions.multivariate_normal.MultivariateNormal(torch.zeros(4), torch.eye(4) * (1 - gamma) + torch.ones(4, 4) * gamma + eps * torch.eye(4)) | |
| block_number = bs * ch * (height // 2) * (width // 2) | |
| noise = torch.stack([dist.sample() for _ in range(block_number)]) # [block number, 4] | |
| noise = rearrange(noise, '(b c h w) (p q) -> b c (h p) (w q)',b=bs,c=ch,h=height//2,w=width//2,p=2,q=2) | |
| return noise | |
| def __call__( | |
| self, | |
| prompt, | |
| height, | |
| width, | |
| num_inference_steps=30, | |
| guidance_scale=4.0, | |
| num_images_per_prompt=1, | |
| device=None, | |
| shift=1.0, | |
| use_ode_dopri5=False, | |
| ): | |
| if isinstance(num_inference_steps, int): | |
| num_inference_steps = [num_inference_steps] * self.num_stages | |
| if use_ode_dopri5: | |
| assert self.class_cond, "ODE (dopri5) sampling is only supported for class-conditional models now" | |
| from pixelflow.solver_ode_wrapper import ODE | |
| sample_fn = ODE(t0=0, t1=1, sampler_type="dopri5", num_steps=num_inference_steps[0], atol=1e-06, rtol=0.001).sample | |
| else: | |
| # default Euler | |
| sample_fn = None | |
| self._guidance_scale = guidance_scale | |
| batch_size = len(prompt) | |
| if self.class_cond: | |
| prompt_embeds = torch.tensor(prompt, dtype=torch.int32).to(device) | |
| negative_prompt_embeds = 1000 * torch.ones_like(prompt_embeds) | |
| if self.do_classifier_free_guidance: | |
| prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) | |
| else: | |
| prompt_embeds, prompt_attention_mask = self.encode_prompt( | |
| prompt, | |
| device, | |
| num_images_per_prompt, | |
| guidance_scale > 1, | |
| "", | |
| prompt_embeds=None, | |
| negative_prompt_embeds=None, | |
| use_attention_mask=True, | |
| max_length=self.max_token_length, | |
| ) | |
| init_factor = 2 ** (self.num_stages - 1) | |
| height, width = height // init_factor, width // init_factor | |
| shape = (batch_size * num_images_per_prompt, 3, height, width) | |
| latents = randn_tensor(shape, device=device, dtype=torch.float32) | |
| for stage_idx in range(self.num_stages): | |
| stage_start = time.time() | |
| # Set the number of inference steps for the current stage | |
| self.scheduler.set_timesteps(num_inference_steps[stage_idx], stage_idx, device=device, shift=shift) | |
| Timesteps = self.scheduler.Timesteps | |
| if stage_idx > 0: | |
| height, width = height * 2, width * 2 | |
| latents = F.interpolate(latents, size=(height, width), mode='nearest') | |
| original_start_t = self.scheduler.original_start_t[stage_idx] | |
| gamma = self.scheduler.gamma | |
| alpha = 1 / (math.sqrt(1 - (1 / gamma)) * (1 - original_start_t) + original_start_t) | |
| beta = alpha * (1 - original_start_t) / math.sqrt(- gamma) | |
| # bs, ch, height, width = latents.shape | |
| noise = self.sample_block_noise(*latents.shape) | |
| noise = noise.to(device=device, dtype=latents.dtype) | |
| latents = alpha * latents + beta * noise | |
| size_tensor = torch.tensor([latents.shape[-1] // self.patch_size], dtype=torch.int32, device=device) | |
| pos_embed = get_2d_rotary_pos_embed( | |
| embed_dim=self.head_dim, | |
| crops_coords=((0, 0), (latents.shape[-1] // self.patch_size, latents.shape[-1] // self.patch_size)), | |
| grid_size=(latents.shape[-1] // self.patch_size, latents.shape[-1] // self.patch_size), | |
| ) | |
| rope_pos = torch.stack(pos_embed, -1) | |
| if sample_fn is not None: | |
| # dopri5 | |
| model_kwargs = dict(class_labels=prompt_embeds, cfg_scale=self.guidance_scale(None, stage_idx), latent_size=size_tensor, pos_embed=rope_pos) | |
| if stage_idx == 0: | |
| latents = torch.cat([latents] * 2) | |
| stage_T_start = self.scheduler.Timesteps_per_stage[stage_idx][0].item() | |
| stage_T_end = self.scheduler.Timesteps_per_stage[stage_idx][-1].item() | |
| latents = sample_fn(latents, self.transformer.c2i_forward_cfg_torchdiffq, stage_T_start, stage_T_end, **model_kwargs)[-1] | |
| if stage_idx == self.num_stages - 1: | |
| latents = latents[:latents.shape[0] // 2] | |
| else: | |
| # euler | |
| for T in Timesteps: | |
| latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents | |
| timestep = T.expand(latent_model_input.shape[0]).to(latent_model_input.dtype) | |
| if self.class_cond: | |
| noise_pred = self.transformer(latent_model_input, timestep=timestep, class_labels=prompt_embeds, latent_size=size_tensor, pos_embed=rope_pos) | |
| else: | |
| encoder_hidden_states = prompt_embeds | |
| encoder_attention_mask = prompt_attention_mask | |
| noise_pred = self.transformer( | |
| latent_model_input, | |
| encoder_hidden_states=encoder_hidden_states, | |
| encoder_attention_mask=encoder_attention_mask, | |
| timestep=timestep, | |
| latent_size=size_tensor, | |
| pos_embed=rope_pos, | |
| ) | |
| if self.do_classifier_free_guidance: | |
| noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) | |
| noise_pred = noise_pred_uncond + self.guidance_scale(T, stage_idx) * (noise_pred_text - noise_pred_uncond) | |
| latents = self.scheduler.step(model_output=noise_pred, sample=latents) | |
| stage_end = time.time() | |
| samples = (latents / 2 + 0.5).clamp(0, 1) | |
| samples = samples.cpu().permute(0, 2, 3, 1).float().numpy() | |
| return samples | |
| def device(self): | |
| return next(self.transformer.parameters()).device | |
| def dtype(self): | |
| return next(self.transformer.parameters()).dtype | |
| def guidance_scale(self, step=None, stage_idx=None): | |
| if not self.class_cond: | |
| return self._guidance_scale | |
| scale_dict = {0: 0, 1: 1/6, 2: 2/3, 3: 1} | |
| return (self._guidance_scale - 1) * scale_dict[stage_idx] + 1 | |
| def do_classifier_free_guidance(self): | |
| return self._guidance_scale > 0 | |