Spaces:
Runtime error
Runtime error
| import logging | |
| import torch | |
| import torch.utils.checkpoint | |
| from diffusers.models import AutoencoderKLTemporalDecoder | |
| from diffusers.schedulers import EulerDiscreteScheduler | |
| from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection | |
| from ..modules.unet import UNetSpatioTemporalConditionModel | |
| from ..modules.pose_net import PoseNet | |
| from ..pipelines.pipeline_mimicmotion import MimicMotionPipeline | |
| logger = logging.getLogger(__name__) | |
| class MimicMotionModel(torch.nn.Module): | |
| def __init__(self, base_model_path): | |
| """construnct base model components and load pretrained svd model except pose-net | |
| Args: | |
| base_model_path (str): pretrained svd model path | |
| """ | |
| super().__init__() | |
| self.unet = UNetSpatioTemporalConditionModel.from_config( | |
| UNetSpatioTemporalConditionModel.load_config(base_model_path, subfolder="unet")) | |
| self.vae = AutoencoderKLTemporalDecoder.from_pretrained( | |
| base_model_path, subfolder="vae", torch_dtype=torch.float16, variant="fp16") | |
| self.image_encoder = CLIPVisionModelWithProjection.from_pretrained( | |
| base_model_path, subfolder="image_encoder", torch_dtype=torch.float16, variant="fp16") | |
| self.noise_scheduler = EulerDiscreteScheduler.from_pretrained( | |
| base_model_path, subfolder="scheduler") | |
| self.feature_extractor = CLIPImageProcessor.from_pretrained( | |
| base_model_path, subfolder="feature_extractor") | |
| # pose_net | |
| self.pose_net = PoseNet(noise_latent_channels=self.unet.config.block_out_channels[0]) | |
| def create_pipeline(infer_config, device): | |
| """create mimicmotion pipeline and load pretrained weight | |
| Args: | |
| infer_config (str): | |
| device (str or torch.device): "cpu" or "cuda:{device_id}" | |
| """ | |
| mimicmotion_models = MimicMotionModel(infer_config.base_model_path) | |
| mimicmotion_models.load_state_dict(torch.load(infer_config.ckpt_path, map_location="cpu"), strict=False) | |
| pipeline = MimicMotionPipeline( | |
| vae=mimicmotion_models.vae, | |
| image_encoder=mimicmotion_models.image_encoder, | |
| unet=mimicmotion_models.unet, | |
| scheduler=mimicmotion_models.noise_scheduler, | |
| feature_extractor=mimicmotion_models.feature_extractor, | |
| pose_net=mimicmotion_models.pose_net | |
| ) | |
| return pipeline | |