Spaces:
Running
on
Zero
Running
on
Zero
alex
commited on
Commit
·
542f3d9
1
Parent(s):
1886860
progress bar fixed
Browse files- OmniAvatar/wan_video.py +6 -2
- app.py +46 -42
OmniAvatar/wan_video.py
CHANGED
@@ -223,7 +223,7 @@ class WanVideoPipeline(BasePipeline):
|
|
223 |
tile_stride=(15, 26),
|
224 |
tea_cache_l1_thresh=None,
|
225 |
tea_cache_model_id="",
|
226 |
-
progress_bar_cmd=
|
227 |
return_latent=False,
|
228 |
):
|
229 |
tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
|
@@ -249,7 +249,7 @@ class WanVideoPipeline(BasePipeline):
|
|
249 |
|
250 |
# Denoise
|
251 |
self.load_models_to_device(["dit"])
|
252 |
-
for progress_id, timestep in enumerate(
|
253 |
if fixed_frame > 0: # new
|
254 |
latents[:, :, :fixed_frame] = lat[:, :, :fixed_frame]
|
255 |
timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
|
@@ -273,6 +273,10 @@ class WanVideoPipeline(BasePipeline):
|
|
273 |
noise_pred = noise_pred_posi
|
274 |
# Scheduler
|
275 |
latents = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], latents)
|
|
|
|
|
|
|
|
|
276 |
|
277 |
if fixed_frame > 0: # new
|
278 |
latents[:, :, :fixed_frame] = lat[:, :, :fixed_frame]
|
|
|
223 |
tile_stride=(15, 26),
|
224 |
tea_cache_l1_thresh=None,
|
225 |
tea_cache_model_id="",
|
226 |
+
progress_bar_cmd=None,
|
227 |
return_latent=False,
|
228 |
):
|
229 |
tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
|
|
|
249 |
|
250 |
# Denoise
|
251 |
self.load_models_to_device(["dit"])
|
252 |
+
for progress_id, timestep in enumerate(tqdm(self.scheduler.timesteps) if progress_bar_cmd is None else self.scheduler.timesteps ):
|
253 |
if fixed_frame > 0: # new
|
254 |
latents[:, :, :fixed_frame] = lat[:, :, :fixed_frame]
|
255 |
timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
|
|
|
273 |
noise_pred = noise_pred_posi
|
274 |
# Scheduler
|
275 |
latents = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], latents)
|
276 |
+
|
277 |
+
if progress_bar_cmd is not None:
|
278 |
+
progress_bar_cmd.update(1)
|
279 |
+
|
280 |
|
281 |
if fixed_frame > 0: # new
|
282 |
latents[:, :, :fixed_frame] = lat[:, :, :fixed_frame]
|
app.py
CHANGED
@@ -11,6 +11,7 @@ import librosa
|
|
11 |
import numpy as np
|
12 |
import uuid
|
13 |
import shutil
|
|
|
14 |
|
15 |
import importlib, site, sys
|
16 |
from huggingface_hub import hf_hub_download, snapshot_download
|
@@ -443,51 +444,54 @@ class WanInferencePipeline(nn.Module):
|
|
443 |
msk[:, :, 1:] = 1
|
444 |
image_emb["y"] = torch.cat([image_cat, msk], dim=1)
|
445 |
|
446 |
-
|
447 |
-
|
448 |
-
|
449 |
-
|
450 |
-
|
451 |
-
|
452 |
-
overlap = fixed_frame
|
453 |
-
image_emb["y"][:, -1:, :prefix_lat_frame] = 0 # 第一次推理是mask只有1,往后都是mask overlap
|
454 |
-
prefix_overlap = (3 + overlap) // 4
|
455 |
-
if audio_embeddings is not None:
|
456 |
if t == 0:
|
457 |
-
|
458 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
459 |
]
|
|
|
|
|
|
|
|
|
|
|
460 |
else:
|
461 |
-
|
462 |
-
|
463 |
-
|
464 |
-
|
465 |
-
|
466 |
-
|
467 |
-
|
468 |
-
|
469 |
-
|
470 |
-
|
471 |
-
|
472 |
-
|
473 |
-
|
474 |
-
img_lat =
|
475 |
-
|
476 |
-
|
477 |
-
|
478 |
-
|
479 |
-
|
480 |
-
|
481 |
-
tea_cache_l1_thresh=self.args.tea_cache_l1_thresh,tea_cache_model_id="Wan2.1-T2V-14B")
|
482 |
-
|
483 |
-
torch.cuda.empty_cache()
|
484 |
-
img_lat = None
|
485 |
-
image = (frames[:, -fixed_frame:].clip(0, 1) * 2.0 - 1.0).permute(0, 2, 1, 3, 4).contiguous()
|
486 |
-
|
487 |
-
if t == 0:
|
488 |
-
video.append(frames)
|
489 |
-
else:
|
490 |
-
video.append(frames[:, overlap:])
|
491 |
video = torch.cat(video, dim=1)
|
492 |
video = video[:, :ori_audio_len + 1]
|
493 |
|
|
|
11 |
import numpy as np
|
12 |
import uuid
|
13 |
import shutil
|
14 |
+
from tqdm import tqdm
|
15 |
|
16 |
import importlib, site, sys
|
17 |
from huggingface_hub import hf_hub_download, snapshot_download
|
|
|
444 |
msk[:, :, 1:] = 1
|
445 |
image_emb["y"] = torch.cat([image_cat, msk], dim=1)
|
446 |
|
447 |
+
total_iterations = times * num_steps
|
448 |
+
|
449 |
+
with tqdm(total=total_iterations) as pbar:
|
450 |
+
for t in range(times):
|
451 |
+
print(f"[{t+1}/{times}]")
|
452 |
+
audio_emb = {}
|
|
|
|
|
|
|
|
|
453 |
if t == 0:
|
454 |
+
overlap = first_fixed_frame
|
455 |
+
else:
|
456 |
+
overlap = fixed_frame
|
457 |
+
image_emb["y"][:, -1:, :prefix_lat_frame] = 0 # 第一次推理是mask只有1,往后都是mask overlap
|
458 |
+
prefix_overlap = (3 + overlap) // 4
|
459 |
+
if audio_embeddings is not None:
|
460 |
+
if t == 0:
|
461 |
+
audio_tensor = audio_embeddings[
|
462 |
+
:min(L - overlap, audio_embeddings.shape[0])
|
463 |
+
]
|
464 |
+
else:
|
465 |
+
audio_start = L - first_fixed_frame + (t - 1) * (L - overlap)
|
466 |
+
audio_tensor = audio_embeddings[
|
467 |
+
audio_start: min(audio_start + L - overlap, audio_embeddings.shape[0])
|
468 |
]
|
469 |
+
|
470 |
+
audio_tensor = torch.cat([audio_prefix, audio_tensor], dim=0)
|
471 |
+
audio_prefix = audio_tensor[-fixed_frame:]
|
472 |
+
audio_tensor = audio_tensor.unsqueeze(0).to(device=self.device, dtype=self.dtype)
|
473 |
+
audio_emb["audio_emb"] = audio_tensor
|
474 |
else:
|
475 |
+
audio_prefix = None
|
476 |
+
if image is not None and img_lat is None:
|
477 |
+
self.pipe.load_models_to_device(['vae'])
|
478 |
+
img_lat = self.pipe.encode_video(image.to(dtype=self.dtype)).to(self.device, dtype=self.dtype)
|
479 |
+
assert img_lat.shape[2] == prefix_overlap
|
480 |
+
img_lat = torch.cat([img_lat, torch.zeros_like(img_lat[:, :, :1].repeat(1, 1, T - prefix_overlap, 1, 1), dtype=self.dtype)], dim=2)
|
481 |
+
frames, _, latents = self.pipe.log_video(img_lat, prompt, prefix_overlap, image_emb, audio_emb,
|
482 |
+
negative_prompt, num_inference_steps=num_steps,
|
483 |
+
cfg_scale=guidance_scale, audio_cfg_scale=audio_scale if audio_scale is not None else guidance_scale,
|
484 |
+
return_latent=True,
|
485 |
+
tea_cache_l1_thresh=self.args.tea_cache_l1_thresh,tea_cache_model_id="Wan2.1-T2V-14B", progress_bar_cmd=pbar)
|
486 |
+
|
487 |
+
torch.cuda.empty_cache()
|
488 |
+
img_lat = None
|
489 |
+
image = (frames[:, -fixed_frame:].clip(0, 1) * 2.0 - 1.0).permute(0, 2, 1, 3, 4).contiguous()
|
490 |
+
|
491 |
+
if t == 0:
|
492 |
+
video.append(frames)
|
493 |
+
else:
|
494 |
+
video.append(frames[:, overlap:])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
495 |
video = torch.cat(video, dim=1)
|
496 |
video = video[:, :ori_audio_len + 1]
|
497 |
|