Spaces:
Paused
Paused
Update mimicmotion/pipelines/pipeline_mimicmotion.py
Browse files
mimicmotion/pipelines/pipeline_mimicmotion.py
CHANGED
|
@@ -16,11 +16,12 @@ from diffusers.schedulers import EulerDiscreteScheduler
|
|
| 16 |
from diffusers.utils import BaseOutput, logging
|
| 17 |
from diffusers.utils.torch_utils import is_compiled_module, randn_tensor
|
| 18 |
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
|
|
|
|
| 19 |
|
| 20 |
from ..modules.pose_net import PoseNet
|
| 21 |
|
| 22 |
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 23 |
-
|
| 24 |
|
| 25 |
def _append_dims(x, target_dims):
|
| 26 |
"""Appends dimensions to the end of a tensor until it has target_dims dimensions."""
|
|
@@ -221,29 +222,37 @@ class MimicMotionPipeline(DiffusionPipeline):
|
|
| 221 |
decode_chunk_size: int = 8):
|
| 222 |
# [batch, frames, channels, height, width] -> [batch*frames, channels, height, width]
|
| 223 |
latents = latents.flatten(0, 1)
|
| 224 |
-
|
| 225 |
latents = 1 / self.vae.config.scaling_factor * latents
|
| 226 |
-
|
| 227 |
forward_vae_fn = self.vae._orig_mod.forward if is_compiled_module(self.vae) else self.vae.forward
|
| 228 |
accepts_num_frames = "num_frames" in set(inspect.signature(forward_vae_fn).parameters.keys())
|
| 229 |
-
|
| 230 |
-
#
|
| 231 |
-
|
| 232 |
-
for i in range(0, latents.shape[0], decode_chunk_size):
|
| 233 |
-
num_frames_in = latents[i: i + decode_chunk_size].shape[0]
|
| 234 |
decode_kwargs = {}
|
| 235 |
if accepts_num_frames:
|
| 236 |
-
|
| 237 |
-
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
|
| 241 |
-
frames =
|
| 242 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 243 |
# [batch*frames, channels, height, width] -> [batch, channels, frames, height, width]
|
|
|
|
| 244 |
frames = frames.reshape(-1, num_frames, *frames.shape[1:]).permute(0, 2, 1, 3, 4)
|
| 245 |
-
|
| 246 |
-
#
|
| 247 |
frames = frames.float()
|
| 248 |
return frames
|
| 249 |
|
|
|
|
| 16 |
from diffusers.utils import BaseOutput, logging
|
| 17 |
from diffusers.utils.torch_utils import is_compiled_module, randn_tensor
|
| 18 |
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
|
| 19 |
+
import threading
|
| 20 |
|
| 21 |
from ..modules.pose_net import PoseNet
|
| 22 |
|
| 23 |
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 24 |
+
import concurrent.futures
|
| 25 |
|
| 26 |
def _append_dims(x, target_dims):
|
| 27 |
"""Appends dimensions to the end of a tensor until it has target_dims dimensions."""
|
|
|
|
| 222 |
decode_chunk_size: int = 8):
|
| 223 |
# [batch, frames, channels, height, width] -> [batch*frames, channels, height, width]
|
| 224 |
latents = latents.flatten(0, 1)
|
|
|
|
| 225 |
latents = 1 / self.vae.config.scaling_factor * latents
|
| 226 |
+
|
| 227 |
forward_vae_fn = self.vae._orig_mod.forward if is_compiled_module(self.vae) else self.vae.forward
|
| 228 |
accepts_num_frames = "num_frames" in set(inspect.signature(forward_vae_fn).parameters.keys())
|
| 229 |
+
|
| 230 |
+
# Função auxiliar para processar um chunk de frames
|
| 231 |
+
def process_chunk(start, end, frames_list):
|
|
|
|
|
|
|
| 232 |
decode_kwargs = {}
|
| 233 |
if accepts_num_frames:
|
| 234 |
+
decode_kwargs["num_frames"] = end - start
|
| 235 |
+
frame = self.vae.decode(latents[start:end], **decode_kwargs).sample
|
| 236 |
+
frames_list.append(frame.cpu())
|
| 237 |
+
|
| 238 |
+
threads = []
|
| 239 |
+
frames = []
|
| 240 |
+
|
| 241 |
+
# Dividindo o trabalho em chunks e criando threads para processá-los
|
| 242 |
+
for i in range(0, latents.shape[0], decode_chunk_size):
|
| 243 |
+
t = threading.Thread(target=process_chunk, args=(i, i + decode_chunk_size, frames))
|
| 244 |
+
threads.append(t)
|
| 245 |
+
t.start()
|
| 246 |
+
|
| 247 |
+
# Aguardando todas as threads terminarem
|
| 248 |
+
for t in threads:
|
| 249 |
+
t.join()
|
| 250 |
+
|
| 251 |
# [batch*frames, channels, height, width] -> [batch, channels, frames, height, width]
|
| 252 |
+
frames = torch.cat(frames, dim=0)
|
| 253 |
frames = frames.reshape(-1, num_frames, *frames.shape[1:]).permute(0, 2, 1, 3, 4)
|
| 254 |
+
|
| 255 |
+
# Cast para float32 para compatibilidade com bfloat16
|
| 256 |
frames = frames.float()
|
| 257 |
return frames
|
| 258 |
|