Spaces:
Build error
Build error
import torch | |
from enum import Enum | |
import gc | |
import numpy as np | |
import jax.numpy as jnp | |
import jax | |
from PIL import Image | |
from typing import List | |
from flax.training.common_utils import shard | |
from flax.jax_utils import replicate | |
from flax import jax_utils | |
import einops | |
from transformers import CLIPTokenizer, CLIPFeatureExtractor, FlaxCLIPTextModel | |
from diffusers import ( | |
FlaxDDIMScheduler, | |
FlaxAutoencoderKL, | |
FlaxStableDiffusionControlNetPipeline, | |
StableDiffusionPipeline, | |
FlaxUNet2DConditionModel as VanillaFlaxUNet2DConditionModel, | |
) | |
from text_to_animation.models.unet_2d_condition_flax import ( | |
FlaxUNet2DConditionModel | |
) | |
from diffusers import FlaxControlNetModel | |
from text_to_animation.pipelines.text_to_video_pipeline_flax import ( | |
FlaxTextToVideoPipeline, | |
) | |
import utils.utils as utils | |
import utils.gradio_utils as gradio_utils | |
import os | |
on_huggingspace = os.environ.get("SPACE_AUTHOR_NAME") == "PAIR" | |
unshard = lambda x: einops.rearrange(x, "d b ... -> (d b) ...") | |
class ModelType(Enum): | |
Text2Video = 1 | |
ControlNetPose = 2 | |
StableDiffusion = 3 | |
def replicate_devices(array): | |
return jnp.expand_dims(array, 0).repeat(jax.device_count(), 0) | |
class ControlAnimationModel: | |
def __init__(self, dtype, **kwargs): | |
self.dtype = dtype | |
self.rng = jax.random.PRNGKey(0) | |
self.pipe = None | |
self.model_type = None | |
self.states = {} | |
self.model_name = "" | |
def set_model( | |
self, | |
model_id: str, | |
**kwargs, | |
): | |
if hasattr(self, "pipe") and self.pipe is not None: | |
del self.pipe | |
self.pipe = None | |
gc.collect() | |
controlnet, controlnet_params = FlaxControlNetModel.from_pretrained( | |
"fusing/stable-diffusion-v1-5-controlnet-openpose", | |
from_pt=True, | |
dtype=jnp.float16, | |
) | |
scheduler, scheduler_state = FlaxDDIMScheduler.from_pretrained( | |
model_id, subfolder="scheduler", from_pt=True | |
) | |
tokenizer = CLIPTokenizer.from_pretrained(model_id, subfolder="tokenizer") | |
feature_extractor = CLIPFeatureExtractor.from_pretrained( | |
model_id, subfolder="feature_extractor" | |
) | |
unet, unet_params = FlaxUNet2DConditionModel.from_pretrained( | |
model_id, subfolder="unet", from_pt=True, dtype=self.dtype | |
) | |
unet_vanilla = VanillaFlaxUNet2DConditionModel.from_config( | |
model_id, subfolder="unet", from_pt=True, dtype=self.dtype | |
) | |
vae, vae_params = FlaxAutoencoderKL.from_pretrained( | |
model_id, subfolder="vae", from_pt=True, dtype=self.dtype | |
) | |
text_encoder = FlaxCLIPTextModel.from_pretrained( | |
model_id, subfolder="text_encoder", from_pt=True, dtype=self.dtype | |
) | |
self.pipe = FlaxTextToVideoPipeline( | |
vae=vae, | |
text_encoder=text_encoder, | |
tokenizer=tokenizer, | |
unet=unet, | |
unet_vanilla=unet_vanilla, | |
controlnet=controlnet, | |
scheduler=scheduler, | |
safety_checker=None, | |
feature_extractor=feature_extractor, | |
) | |
self.params = { | |
"unet": unet_params, | |
"vae": vae_params, | |
"scheduler": scheduler_state, | |
"controlnet": controlnet_params, | |
"text_encoder": text_encoder.params, | |
} | |
self.p_params = jax_utils.replicate(self.params) | |
self.model_name = model_id | |
def generate_initial_frames( | |
self, | |
prompt: str, | |
video_path: str, | |
n_prompt: str = "", | |
seed: int = 0, | |
num_imgs: int = 4, | |
resolution: int = 512, | |
model_id: str = "runwayml/stable-diffusion-v1-5", | |
) -> List[Image.Image]: | |
self.set_model(model_id=model_id) | |
video_path = gradio_utils.motion_to_video_path(video_path) | |
added_prompt = "high quality, best quality, HD, clay stop-motion, claymation, HQ, masterpiece, art, smooth" | |
prompts = added_prompt + ", " + prompt | |
added_n_prompt = "longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer difits, cropped, worst quality, low quality, deformed body, bloated, ugly" | |
negative_prompts = added_n_prompt + ", " + n_prompt | |
video, fps = utils.prepare_video( | |
video_path, resolution, None, self.dtype, False, output_fps=4 | |
) | |
control = utils.pre_process_pose(video, apply_pose_detect=False) | |
# seeds = [seed for seed in jax.random.randint(self.rng, [num_imgs], 0, 65536)] | |
prngs = [jax.random.PRNGKey(seed)] * num_imgs | |
images = self.pipe.generate_starting_frames( | |
params=self.p_params, | |
prngs=prngs, | |
controlnet_image=control, | |
prompt=prompts, | |
neg_prompt=negative_prompts, | |
) | |
images = [np.array(images[i]) for i in range(images.shape[0])] | |
return video, images | |
def generate_video_from_frame(self, controlnet_video, prompt, n_prompt, seed): | |
# generate a video using the seed provided | |
prng_seed = jax.random.PRNGKey(seed) | |
len_vid = controlnet_video.shape[0] | |
# print(f"Generating video from prompt {'<aardman> style '+ prompt}, with {controlnet_video.shape[0]} frames and prng seed {seed}") | |
added_prompt = "high quality, best quality, HD, clay stop-motion, claymation, HQ, masterpiece, art, smooth" | |
prompts = added_prompt + ", " + prompt | |
added_n_prompt = "longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer difits, cropped, worst quality, low quality, deformed body, bloated, ugly" | |
negative_prompts = added_n_prompt + ", " + n_prompt | |
# prompt_ids = self.pipe.prepare_text_inputs(["aardman style "+ prompt]*len_vid) | |
# n_prompt_ids = self.pipe.prepare_text_inputs([neg_prompt]*len_vid) | |
prompt_ids = self.pipe.prepare_text_inputs([prompts]*len_vid) | |
n_prompt_ids = self.pipe.prepare_text_inputs([negative_prompts]*len_vid) | |
prng = replicate_devices(prng_seed) #jax.random.split(prng, jax.device_count()) | |
image = replicate_devices(controlnet_video) | |
prompt_ids = replicate_devices(prompt_ids) | |
n_prompt_ids = replicate_devices(n_prompt_ids) | |
motion_field_strength_x = replicate_devices(jnp.array(3)) | |
motion_field_strength_y = replicate_devices(jnp.array(4)) | |
smooth_bg_strength = replicate_devices(jnp.array(0.8)) | |
vid = (self.pipe(image=image, | |
prompt_ids=prompt_ids, | |
neg_prompt_ids=n_prompt_ids, | |
params=self.p_params, | |
prng_seed=prng, | |
jit = True, | |
smooth_bg_strength=smooth_bg_strength, | |
motion_field_strength_x=motion_field_strength_x, | |
motion_field_strength_y=motion_field_strength_y, | |
).images)[0] | |
return utils.create_gif(np.array(vid), 4, path=None, watermark=None) | |