alex commited on
Commit
542f3d9
·
1 Parent(s): 1886860

progress bar fixed

Browse files
Files changed (2) hide show
  1. OmniAvatar/wan_video.py +6 -2
  2. app.py +46 -42
OmniAvatar/wan_video.py CHANGED
@@ -223,7 +223,7 @@ class WanVideoPipeline(BasePipeline):
223
  tile_stride=(15, 26),
224
  tea_cache_l1_thresh=None,
225
  tea_cache_model_id="",
226
- progress_bar_cmd=tqdm,
227
  return_latent=False,
228
  ):
229
  tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
@@ -249,7 +249,7 @@ class WanVideoPipeline(BasePipeline):
249
 
250
  # Denoise
251
  self.load_models_to_device(["dit"])
252
- for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
253
  if fixed_frame > 0: # new
254
  latents[:, :, :fixed_frame] = lat[:, :, :fixed_frame]
255
  timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
@@ -273,6 +273,10 @@ class WanVideoPipeline(BasePipeline):
273
  noise_pred = noise_pred_posi
274
  # Scheduler
275
  latents = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], latents)
 
 
 
 
276
 
277
  if fixed_frame > 0: # new
278
  latents[:, :, :fixed_frame] = lat[:, :, :fixed_frame]
 
223
  tile_stride=(15, 26),
224
  tea_cache_l1_thresh=None,
225
  tea_cache_model_id="",
226
+ progress_bar_cmd=None,
227
  return_latent=False,
228
  ):
229
  tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
 
249
 
250
  # Denoise
251
  self.load_models_to_device(["dit"])
252
+ for progress_id, timestep in enumerate(tqdm(self.scheduler.timesteps) if progress_bar_cmd is None else self.scheduler.timesteps ):
253
  if fixed_frame > 0: # new
254
  latents[:, :, :fixed_frame] = lat[:, :, :fixed_frame]
255
  timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
 
273
  noise_pred = noise_pred_posi
274
  # Scheduler
275
  latents = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], latents)
276
+
277
+ if progress_bar_cmd is not None:
278
+ progress_bar_cmd.update(1)
279
+
280
 
281
  if fixed_frame > 0: # new
282
  latents[:, :, :fixed_frame] = lat[:, :, :fixed_frame]
app.py CHANGED
@@ -11,6 +11,7 @@ import librosa
11
  import numpy as np
12
  import uuid
13
  import shutil
 
14
 
15
  import importlib, site, sys
16
  from huggingface_hub import hf_hub_download, snapshot_download
@@ -443,51 +444,54 @@ class WanInferencePipeline(nn.Module):
443
  msk[:, :, 1:] = 1
444
  image_emb["y"] = torch.cat([image_cat, msk], dim=1)
445
 
446
- for t in range(times):
447
- print(f"[{t+1}/{times}]")
448
- audio_emb = {}
449
- if t == 0:
450
- overlap = first_fixed_frame
451
- else:
452
- overlap = fixed_frame
453
- image_emb["y"][:, -1:, :prefix_lat_frame] = 0 # 第一次推理是mask只有1,往后都是mask overlap
454
- prefix_overlap = (3 + overlap) // 4
455
- if audio_embeddings is not None:
456
  if t == 0:
457
- audio_tensor = audio_embeddings[
458
- :min(L - overlap, audio_embeddings.shape[0])
 
 
 
 
 
 
 
 
 
 
 
 
459
  ]
 
 
 
 
 
460
  else:
461
- audio_start = L - first_fixed_frame + (t - 1) * (L - overlap)
462
- audio_tensor = audio_embeddings[
463
- audio_start: min(audio_start + L - overlap, audio_embeddings.shape[0])
464
- ]
465
-
466
- audio_tensor = torch.cat([audio_prefix, audio_tensor], dim=0)
467
- audio_prefix = audio_tensor[-fixed_frame:]
468
- audio_tensor = audio_tensor.unsqueeze(0).to(device=self.device, dtype=self.dtype)
469
- audio_emb["audio_emb"] = audio_tensor
470
- else:
471
- audio_prefix = None
472
- if image is not None and img_lat is None:
473
- self.pipe.load_models_to_device(['vae'])
474
- img_lat = self.pipe.encode_video(image.to(dtype=self.dtype)).to(self.device, dtype=self.dtype)
475
- assert img_lat.shape[2] == prefix_overlap
476
- img_lat = torch.cat([img_lat, torch.zeros_like(img_lat[:, :, :1].repeat(1, 1, T - prefix_overlap, 1, 1), dtype=self.dtype)], dim=2)
477
- frames, _, latents = self.pipe.log_video(img_lat, prompt, prefix_overlap, image_emb, audio_emb,
478
- negative_prompt, num_inference_steps=num_steps,
479
- cfg_scale=guidance_scale, audio_cfg_scale=audio_scale if audio_scale is not None else guidance_scale,
480
- return_latent=True,
481
- tea_cache_l1_thresh=self.args.tea_cache_l1_thresh,tea_cache_model_id="Wan2.1-T2V-14B")
482
-
483
- torch.cuda.empty_cache()
484
- img_lat = None
485
- image = (frames[:, -fixed_frame:].clip(0, 1) * 2.0 - 1.0).permute(0, 2, 1, 3, 4).contiguous()
486
-
487
- if t == 0:
488
- video.append(frames)
489
- else:
490
- video.append(frames[:, overlap:])
491
  video = torch.cat(video, dim=1)
492
  video = video[:, :ori_audio_len + 1]
493
 
 
11
  import numpy as np
12
  import uuid
13
  import shutil
14
+ from tqdm import tqdm
15
 
16
  import importlib, site, sys
17
  from huggingface_hub import hf_hub_download, snapshot_download
 
444
  msk[:, :, 1:] = 1
445
  image_emb["y"] = torch.cat([image_cat, msk], dim=1)
446
 
447
+ total_iterations = times * num_steps
448
+
449
+ with tqdm(total=total_iterations) as pbar:
450
+ for t in range(times):
451
+ print(f"[{t+1}/{times}]")
452
+ audio_emb = {}
 
 
 
 
453
  if t == 0:
454
+ overlap = first_fixed_frame
455
+ else:
456
+ overlap = fixed_frame
457
+ image_emb["y"][:, -1:, :prefix_lat_frame] = 0 # 第一次推理是mask只有1,往后都是mask overlap
458
+ prefix_overlap = (3 + overlap) // 4
459
+ if audio_embeddings is not None:
460
+ if t == 0:
461
+ audio_tensor = audio_embeddings[
462
+ :min(L - overlap, audio_embeddings.shape[0])
463
+ ]
464
+ else:
465
+ audio_start = L - first_fixed_frame + (t - 1) * (L - overlap)
466
+ audio_tensor = audio_embeddings[
467
+ audio_start: min(audio_start + L - overlap, audio_embeddings.shape[0])
468
  ]
469
+
470
+ audio_tensor = torch.cat([audio_prefix, audio_tensor], dim=0)
471
+ audio_prefix = audio_tensor[-fixed_frame:]
472
+ audio_tensor = audio_tensor.unsqueeze(0).to(device=self.device, dtype=self.dtype)
473
+ audio_emb["audio_emb"] = audio_tensor
474
  else:
475
+ audio_prefix = None
476
+ if image is not None and img_lat is None:
477
+ self.pipe.load_models_to_device(['vae'])
478
+ img_lat = self.pipe.encode_video(image.to(dtype=self.dtype)).to(self.device, dtype=self.dtype)
479
+ assert img_lat.shape[2] == prefix_overlap
480
+ img_lat = torch.cat([img_lat, torch.zeros_like(img_lat[:, :, :1].repeat(1, 1, T - prefix_overlap, 1, 1), dtype=self.dtype)], dim=2)
481
+ frames, _, latents = self.pipe.log_video(img_lat, prompt, prefix_overlap, image_emb, audio_emb,
482
+ negative_prompt, num_inference_steps=num_steps,
483
+ cfg_scale=guidance_scale, audio_cfg_scale=audio_scale if audio_scale is not None else guidance_scale,
484
+ return_latent=True,
485
+ tea_cache_l1_thresh=self.args.tea_cache_l1_thresh,tea_cache_model_id="Wan2.1-T2V-14B", progress_bar_cmd=pbar)
486
+
487
+ torch.cuda.empty_cache()
488
+ img_lat = None
489
+ image = (frames[:, -fixed_frame:].clip(0, 1) * 2.0 - 1.0).permute(0, 2, 1, 3, 4).contiguous()
490
+
491
+ if t == 0:
492
+ video.append(frames)
493
+ else:
494
+ video.append(frames[:, overlap:])
 
 
 
 
 
 
 
 
 
 
495
  video = torch.cat(video, dim=1)
496
  video = video[:, :ori_audio_len + 1]
497