Aduc-sdr commited on
Commit
f85ca57
·
verified ·
1 Parent(s): 09b1b8e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +87 -301
app.py CHANGED
@@ -14,73 +14,65 @@
14
  import spaces
15
  import subprocess
16
  import os
17
- import sys # <-- ADICIONADO PARA MANIPULAR O CAMINHO DO PYTHON
18
 
19
- # Clone the repository to ensure all files are available
20
- # Make sure git-lfs is installed
 
 
21
  subprocess.run("git lfs install", shell=True, check=True)
22
- # Clone the repository only if it doesn't exist
23
  if not os.path.exists("SeedVR2-3B"):
 
24
  subprocess.run("git clone https://huggingface.co/spaces/ByteDance-Seed/SeedVR2-3B", shell=True, check=True)
25
 
26
- # Define the repository directory
27
- repo_dir = 'SeedVR2-3B'
28
- # Change the current working directory to the cloned repository
29
- os.chdir(repo_dir)
30
- # Add the repository directory to the Python path to allow imports
31
- sys.path.insert(0, os.path.abspath('.')) # <-- CORREÇÃO PRINCIPAL AQUI
 
 
32
 
33
  import torch
34
  import mediapy
35
  from einops import rearrange
36
  from omegaconf import OmegaConf
37
- print(os.getcwd())
38
  import datetime
39
  from tqdm import tqdm
40
  import gc
41
-
42
- from data.image.transforms.divisible_crop import DivisibleCrop
43
- from data.image.transforms.na_resize import NaResize
44
- from data.video.transforms.rearrange import Rearrange
45
- if os.path.exists("./projects/video_diffusion_sr/color_fix.py"):
46
- from projects.video_diffusion_sr.color_fix import wavelet_reconstruction
47
- use_colorfix=True
48
- else:
49
- use_colorfix = False
50
- print('Note!!!!!! Color fix is not avaliable!')
51
- from torchvision.transforms import Compose, Lambda, Normalize
52
- from torchvision.io.video import read_video
53
- import argparse
54
  from PIL import Image
55
-
56
- from common.distributed import (
57
- get_device,
58
- init_torch,
59
- )
60
-
61
- from common.distributed.advanced import (
62
- get_data_parallel_rank,
63
- get_data_parallel_world_size,
64
- get_sequence_parallel_rank,
65
- get_sequence_parallel_world_size,
66
- init_sequence_parallel,
67
- )
68
-
69
- from projects.video_diffusion_sr.infer import VideoDiffusionInfer
70
- from common.config import load_config
71
- from common.distributed.ops import sync_data
72
- from common.seed import set_seed
73
- from common.partition import partition_by_groups, partition_by_size
74
-
75
  import gradio as gr
76
  from pathlib import Path
77
- from urllib.parse import urlparse
78
- from torch.hub import download_url_to_file, get_dir
79
  import shlex
80
  import uuid
81
  import mimetypes
82
  import torchvision.transforms as T
 
 
83
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
  os.environ["MASTER_ADDR"] = "127.0.0.1"
85
  os.environ["MASTER_PORT"] = "12355"
86
  os.environ["RANK"] = str(0)
@@ -92,33 +84,34 @@ subprocess.run(
92
  shell=True,
93
  )
94
 
95
- # Install apex from the local wheel file
96
- if os.path.exists("apex-0.1-cp310-cp310-linux_x86_64.whl"):
97
- subprocess.run(shlex.split("pip install apex-0.1-cp310-cp310-linux_x86_64.whl"))
98
- print(f"✅ setup completed Apex")
99
 
 
100
 
101
  def configure_sequence_parallel(sp_size):
102
  if sp_size > 1:
103
  init_sequence_parallel(sp_size)
104
 
105
-
106
  def configure_runner(sp_size):
107
- config_path = os.path.join('./configs_3b', 'main.yaml')
 
 
108
  config = load_config(config_path)
109
  runner = VideoDiffusionInfer(config)
110
  OmegaConf.set_readonly(runner.config, False)
111
 
112
  init_torch(cudnn_benchmark=False, timeout=datetime.timedelta(seconds=3600))
113
  configure_sequence_parallel(sp_size)
114
- runner.configure_dit_model(device="cuda", checkpoint='./ckpts/seedvr2_ema_3b.pth')
115
  runner.configure_vae_model()
116
- # Set memory limit.
117
  if hasattr(runner.vae, "set_memory_limit"):
118
  runner.vae.set_memory_limit(**runner.config.vae.memory_limit)
119
  return runner
120
 
121
-
122
  def generation_step(runner, text_embeds_dict, cond_latents):
123
  def _move_to_cuda(x):
124
  return [i.to(torch.device("cuda")) for i in x]
@@ -127,68 +120,44 @@ def generation_step(runner, text_embeds_dict, cond_latents):
127
  aug_noises = [torch.randn_like(latent) for latent in cond_latents]
128
  print(f"Generating with noise shape: {noises[0].size()}.")
129
  noises, aug_noises, cond_latents = sync_data((noises, aug_noises, cond_latents), 0)
130
- noises, aug_noises, cond_latents = list(
131
- map(lambda x: _move_to_cuda(x), (noises, aug_noises, cond_latents))
132
- )
133
  cond_noise_scale = 0.1
134
 
135
  def _add_noise(x, aug_noise):
136
- t = (
137
- torch.tensor([1000.0], device=torch.device("cuda"))
138
- * cond_noise_scale
139
- )
140
  shape = torch.tensor(x.shape[1:], device=torch.device("cuda"))[None]
141
  t = runner.timestep_transform(t, shape)
142
- print(
143
- f"Timestep shifting from"
144
- f" {1000.0 * cond_noise_scale} to {t}."
145
- )
146
  x = runner.schedule.forward(x, aug_noise, t)
147
  return x
148
 
149
  conditions = [
150
- runner.get_condition(
151
- noise,
152
- task="sr",
153
- latent_blur=_add_noise(latent_blur, aug_noise),
154
- )
155
  for noise, aug_noise, latent_blur in zip(noises, aug_noises, cond_latents)
156
  ]
157
 
158
  with torch.no_grad(), torch.autocast("cuda", torch.bfloat16, enabled=True):
159
  video_tensors = runner.inference(
160
- noises=noises,
161
- conditions=conditions,
162
- dit_offload=False,
163
- **text_embeds_dict,
164
  )
165
 
166
- samples = [
167
- (
168
- rearrange(video[:, None], "c t h w -> t c h w")
169
- if video.ndim == 3
170
- else rearrange(video, "c t h w -> t c h w")
171
- )
172
- for video in video_tensors
173
- ]
174
  del video_tensors
175
-
176
  return samples
177
 
178
-
179
- def generation_loop(video_path='./test_videos', seed=666, fps_out=12, batch_size=1, cfg_scale=1.0, cfg_rescale=0.0, sample_steps=1, res_h=1280, res_w=720, sp_size=1):
 
 
 
180
  runner = configure_runner(1)
181
 
182
  def _extract_text_embeds():
183
- # Text encoder forward.
184
  positive_prompts_embeds = []
185
- for texts_pos in tqdm(original_videos_local):
186
- text_pos_embeds = torch.load('pos_emb.pt')
187
- text_neg_embeds = torch.load('neg_emb.pt')
188
-
189
- positive_prompts_embeds.append(
190
- {"texts_pos": [text_pos_embeds], "texts_neg": [text_neg_embeds]}
191
- )
192
  gc.collect()
193
  torch.cuda.empty_cache()
194
  return positive_prompts_embeds
@@ -198,239 +167,56 @@ def generation_loop(video_path='./test_videos', seed=666, fps_out=12, batch_size
198
  videos = videos[:, :121]
199
  t = videos.size(1)
200
  if t <= 4 * sp_size:
201
- print(f"Cut input video size: {videos.size()}")
202
- padding = [videos[:, -1].unsqueeze(1)] * (4 * sp_size - t + 1)
203
- padding = torch.cat(padding, dim=1)
204
- videos = torch.cat([videos, padding], dim=1)
205
  return videos
206
  if (t - 1) % (4 * sp_size) == 0:
207
  return videos
208
  else:
209
- padding = [videos[:, -1].unsqueeze(1)] * (
210
- 4 * sp_size - ((t - 1) % (4 * sp_size))
211
- )
212
- padding = torch.cat(padding, dim=1)
213
  videos = torch.cat([videos, padding], dim=1)
214
  assert (videos.size(1) - 1) % (4 * sp_size) == 0
215
  return videos
216
 
217
- # classifier-free guidance
218
  runner.config.diffusion.cfg.scale = cfg_scale
219
  runner.config.diffusion.cfg.rescale = cfg_rescale
220
- # sampling steps
221
  runner.config.diffusion.timesteps.sampling.steps = sample_steps
222
  runner.configure_diffusion()
223
 
224
- # set random seed
225
- seed = seed % (2**32) # avoid over range
226
  set_seed(seed, same_across_ranks=True)
227
- os.makedirs('output/', exist_ok=True)
 
228
 
229
- # get test prompts
230
  original_videos = [os.path.basename(video_path)]
231
-
232
- # divide the prompts into different groups
233
- original_videos_group = original_videos
234
- # store prompt mapping
235
- original_videos_local = original_videos_group
236
- original_videos_local = partition_by_size(original_videos_local, batch_size)
237
-
238
- # pre-extract the text embeddings
239
  positive_prompts_embeds = _extract_text_embeds()
240
 
241
- video_transform = Compose(
242
- [
243
- NaResize(
244
- resolution=(
245
- res_h * res_w
246
- )
247
- ** 0.5,
248
- mode="area",
249
- # Upsample image, model only trained for high res.
250
- downsample_only=False,
251
- ),
252
- Lambda(lambda x: torch.clamp(x, 0.0, 1.0)),
253
- DivisibleCrop((16, 16)),
254
- Normalize(0.5, 0.5),
255
- Rearrange("t c h w -> c t h w"),
256
- ]
257
- )
258
 
259
- # generation loop
260
  for videos, text_embeds in tqdm(zip(original_videos_local, positive_prompts_embeds)):
261
- # read condition latents
262
  cond_latents = []
263
- for video in videos:
264
  media_type, _ = mimetypes.guess_type(video_path)
265
  is_image = media_type and media_type.startswith("image")
266
  is_video = media_type and media_type.startswith("video")
 
267
  if is_video:
268
- video = (
269
- read_video(
270
- video_path, output_format="TCHW"
271
- )[0]
272
- / 255.0
273
- )
274
  if video.size(0) > 121:
275
  video = video[:121]
276
  print(f"Read video size: {video.size()}")
277
- output_dir = 'output/' + str(uuid.uuid4()) + '.mp4'
278
- else:
279
  img = Image.open(video_path).convert("RGB")
280
- img_tensor = T.ToTensor()(img).unsqueeze(0) # (1, C, H, W)
281
- video = img_tensor.permute(0, 1, 2, 3) # (T=1, C, H, W)
282
- print(f"Read Image size: {video.size()}")
283
- output_dir = 'output/' + str(uuid.uuid4()) + '.png'
284
- cond_latents.append(video_transform(video.to(torch.device("cuda"))))
285
-
286
- ori_lengths = [video.size(1) for video in cond_latents]
287
- input_videos = cond_latents
288
- if is_video:
289
- cond_latents = [cut_videos(video, sp_size) for video in cond_latents]
290
-
291
- print(f"Encoding videos: {list(map(lambda x: x.size(), cond_latents))}")
292
- cond_latents = runner.vae_encode(cond_latents)
293
-
294
- for i, emb in enumerate(text_embeds["texts_pos"]):
295
- text_embeds["texts_pos"][i] = emb.to(torch.device("cuda"))
296
- for i, emb in enumerate(text_embeds["texts_neg"]):
297
- text_embeds["texts_neg"][i] = emb.to(torch.device("cuda"))
298
-
299
- samples = generation_step(runner, text_embeds, cond_latents=cond_latents)
300
- del cond_latents
301
-
302
- # dump samples to the output directory
303
- for path, input, sample, ori_length in zip(
304
- videos, input_videos, samples, ori_lengths
305
- ):
306
- if ori_length < sample.shape[0]:
307
- sample = sample[:ori_length]
308
- # color fix
309
- input = (
310
- rearrange(input[:, None], "c t h w -> t c h w")
311
- if input.ndim == 3
312
- else rearrange(input, "c t h w -> t c h w")
313
- )
314
- if use_colorfix:
315
- sample = wavelet_reconstruction(
316
- sample.to("cpu"), input[: sample.size(0)].to("cpu")
317
- )
318
- else:
319
- sample = sample.to("cpu")
320
- sample = (
321
- rearrange(sample[:, None], "t c h w -> t h w c")
322
- if sample.ndim == 3
323
- else rearrange(sample, "t c h w -> t h w c")
324
- )
325
- sample = sample.clip(-1, 1).mul_(0.5).add_(0.5).mul_(255).round()
326
- sample = sample.to(torch.uint8).numpy()
327
-
328
- if is_image:
329
- mediapy.write_image(output_dir, sample[0])
330
- else:
331
- mediapy.write_video(
332
- output_dir, sample, fps=fps_out
333
- )
334
-
335
- gc.collect()
336
- torch.cuda.empty_cache()
337
- if is_image:
338
- return output_dir, None, output_dir
339
- else:
340
- return None, output_dir, output_dir
341
-
342
-
343
- with gr.Blocks(title="SeedVR2: One-Step Video Restoration via Diffusion Adversarial Post-Training") as demo:
344
- # Top logo and title
345
- gr.HTML("""
346
- <div style='text-align:center; margin-bottom: 10px;'>
347
- <img src='assets/seedvr_logo.png' style='height:40px;' alt='SeedVR logo'/>
348
- </div>
349
- <p><b>Official Gradio demo</b> for
350
- <a href='https://github.com/ByteDance-Seed/SeedVR' target='_blank'>
351
- <b>SeedVR2: One-Step Video Restoration via Diffusion Adversarial Post-Training</b></a>.<br>
352
- 🔥 <b>SeedVR2</b> is a one-step image and video restoration algorithm for real-world and AIGC content.
353
- </p>
354
- """)
355
-
356
- # Interface
357
- with gr.Row():
358
- input_video = gr.File(label="Upload image or video", type="filepath")
359
- seed = gr.Number(label="Seeds", value=666)
360
- fps = gr.Number(label="fps", value=24)
361
-
362
- with gr.Row():
363
- output_video = gr.Video(label="Output_Video")
364
- output_image = gr.Image(label="Output_Image")
365
- download_link = gr.File(label="Download the output")
366
-
367
- run_button = gr.Button("Run")
368
- run_button.click(fn=generation_loop, inputs=[input_video, seed, fps], outputs=[output_image, output_video, download_link])
369
-
370
- # Examples
371
- gr.Examples(
372
- examples=[
373
- ["01.mp4", 4, 24],
374
- ["02.mp4", 4, 24],
375
- ["03.mp4", 4, 24],
376
- ],
377
- inputs=[input_video, seed, fps]
378
- )
379
-
380
- # Article/Footer
381
- gr.HTML("""
382
- <hr>
383
- <p>If you find SeedVR helpful, please ⭐ the
384
- <a href='https://github.com/ByteDance-Seed/SeedVR' target='_blank'>GitHub repository</a>:</p>
385
-
386
- <a href="https://github.com/ByteDance-Seed/SeedVR" target="_blank">
387
- <img src="https://img.shields.io/github/stars/ByteDance-Seed/SeedVR?style=social" alt="GitHub Stars">
388
- </a>
389
-
390
- <h4>Notice</h4>
391
- <p>This demo supports up to <b>720p and 121 frames for videos or 2k images</b>.
392
- For other use cases (image restoration beyond 2K, video resolutions beyond 720p, etc), check the <a href='https://github.com/ByteDance-Seed/SeedVR' target='_blank'>GitHub repo</a>.</p>
393
-
394
- <h4>Limitations</h4>
395
- <p>May fail on heavy degradations or small-motion AIGC clips, causing oversharpening or poor restoration.</p>
396
-
397
- <h4>Citation</h4>
398
- <pre style="font-size: 12px;">
399
- @article{wang2025seedvr2,
400
- title={SeedVR2: One-Step Video Restoration via Diffusion Adversarial Post-Training},
401
- author={Wang, Jianyi and Lin, Shanchuan and Lin, Zhijie and Ren, Yuxi and Wei, Meng and Yue, Zongsheng and Zhou, Shangchen and Chen, Hao and Zhao, Yang and Yang, Ceyuan and Xiao, Xuefeng and Loy, Chen Change and Jiang, Lu},
402
- booktitle={arXiv preprint arXiv:2506.05301},
403
- year={2025}
404
- }
405
-
406
- @inproceedings{wang2025seedvr,
407
- title={SeedVR: Seeding Infinity in Diffusion Transformer Towards Generic Video Restoration},
408
- author={Wang, Jianyi and Lin, Zhijie and Wei, Meng and Zhao, Yang and Yang, Ceyuan and Loy, Chen Change and Jiang, Lu},
409
- booktitle={CVPR},
410
- year={2025}
411
- }
412
- </pre>
413
-
414
- <h4>License</h4>
415
- <p>Licensed under the
416
- <a href="http://www.apache.org/licenses/LICENSE-2.0" target="_blank">Apache 2.0 License</a>.</p>
417
-
418
- <h4>Contact</h4>
419
- <p>Email: <b>[email protected]</b></p>
420
-
421
- <p>
422
- <a href="https://twitter.com/Iceclearwjy">
423
- <img src="https://img.shields.io/twitter/follow/Iceclearwjy?label=%40Iceclearwjy&style=social" alt="Twitter Follow">
424
- </a>
425
- <a href="https://github.com/IceClear">
426
- <img src="https://img.shields.io/github/followers/IceClear?style=social" alt="GitHub Follow">
427
- </a>
428
- </p>
429
-
430
- <p style="text-align:center;">
431
- <img src="https://visitor-badge.laobi.icu/badge?page_id=ByteDance-Seed/SeedVR" alt="visitors">
432
- </p>
433
- """)
434
-
435
- demo.queue()
436
- demo.launch()
 
14
  import spaces
15
  import subprocess
16
  import os
17
+ import sys
18
 
19
+ # --- Setup: Clone repository and add it to Python Path ---
20
+ # This section ensures all necessary code and model files are available.
21
+
22
+ # 1. Clone the repository with all its files
23
  subprocess.run("git lfs install", shell=True, check=True)
 
24
  if not os.path.exists("SeedVR2-3B"):
25
+ print("Cloning SeedVR2-3B repository...")
26
  subprocess.run("git clone https://huggingface.co/spaces/ByteDance-Seed/SeedVR2-3B", shell=True, check=True)
27
 
28
+ # 2. Add the cloned repository's directory to Python's module search path
29
+ repo_dir = "SeedVR2-3B"
30
+ # This allows us to import modules like 'data', 'common', etc., from the cloned repo.
31
+ sys.path.insert(0, os.path.abspath(repo_dir))
32
+ print(f"Repository directory '{os.path.abspath(repo_dir)}' added to Python path.")
33
+
34
+ # --- Main Application Code ---
35
+ # All file paths will now be relative to the cloned repository directory.
36
 
37
  import torch
38
  import mediapy
39
  from einops import rearrange
40
  from omegaconf import OmegaConf
 
41
  import datetime
42
  from tqdm import tqdm
43
  import gc
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  from PIL import Image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  import gradio as gr
46
  from pathlib import Path
 
 
47
  import shlex
48
  import uuid
49
  import mimetypes
50
  import torchvision.transforms as T
51
+ from torchvision.transforms import Compose, Lambda, Normalize
52
+ from torchvision.io.video import read_video
53
 
54
+ # Imports from the cloned repository
55
+ from data.image.transforms.divisible_crop import DivisibleCrop
56
+ from data.image.transforms.na_resize import NaResize
57
+ from data.video.transforms.rearrange import Rearrange
58
+ from common.config import load_config
59
+ from common.distributed import init_torch
60
+ from common.distributed.advanced import init_sequence_parallel
61
+ from common.seed import set_seed
62
+ from common.partition import partition_by_size
63
+ from projects.video_diffusion_sr.infer import VideoDiffusionInfer
64
+ from common.distributed.ops import sync_data
65
+
66
+ # Check for color_fix utility
67
+ color_fix_path = os.path.join(repo_dir, "projects/video_diffusion_sr/color_fix.py")
68
+ if os.path.exists(color_fix_path):
69
+ from projects.video_diffusion_sr.color_fix import wavelet_reconstruction
70
+ use_colorfix = True
71
+ else:
72
+ use_colorfix = False
73
+ print('Note!!!!!! Color fix is not available!')
74
+
75
+ # --- Environment and Dependencies Setup ---
76
  os.environ["MASTER_ADDR"] = "127.0.0.1"
77
  os.environ["MASTER_PORT"] = "12355"
78
  os.environ["RANK"] = str(0)
 
84
  shell=True,
85
  )
86
 
87
+ apex_wheel_path = os.path.join(repo_dir, "apex-0.1-cp310-cp310-linux_x86_64.whl")
88
+ if os.path.exists(apex_wheel_path):
89
+ subprocess.run(shlex.split(f"pip install {apex_wheel_path}"))
90
+ print("✅ Apex setup completed.")
91
 
92
+ # --- Core Functions ---
93
 
94
  def configure_sequence_parallel(sp_size):
95
  if sp_size > 1:
96
  init_sequence_parallel(sp_size)
97
 
 
98
  def configure_runner(sp_size):
99
+ config_path = os.path.join(repo_dir, 'configs_3b', 'main.yaml')
100
+ checkpoint_path = os.path.join(repo_dir, 'ckpts', 'seedvr2_ema_3b.pth')
101
+
102
  config = load_config(config_path)
103
  runner = VideoDiffusionInfer(config)
104
  OmegaConf.set_readonly(runner.config, False)
105
 
106
  init_torch(cudnn_benchmark=False, timeout=datetime.timedelta(seconds=3600))
107
  configure_sequence_parallel(sp_size)
108
+ runner.configure_dit_model(device="cuda", checkpoint=checkpoint_path)
109
  runner.configure_vae_model()
110
+
111
  if hasattr(runner.vae, "set_memory_limit"):
112
  runner.vae.set_memory_limit(**runner.config.vae.memory_limit)
113
  return runner
114
 
 
115
  def generation_step(runner, text_embeds_dict, cond_latents):
116
  def _move_to_cuda(x):
117
  return [i.to(torch.device("cuda")) for i in x]
 
120
  aug_noises = [torch.randn_like(latent) for latent in cond_latents]
121
  print(f"Generating with noise shape: {noises[0].size()}.")
122
  noises, aug_noises, cond_latents = sync_data((noises, aug_noises, cond_latents), 0)
123
+ noises, aug_noises, cond_latents = list(map(_move_to_cuda, (noises, aug_noises, cond_latents)))
 
 
124
  cond_noise_scale = 0.1
125
 
126
  def _add_noise(x, aug_noise):
127
+ t = torch.tensor([1000.0], device=torch.device("cuda")) * cond_noise_scale
 
 
 
128
  shape = torch.tensor(x.shape[1:], device=torch.device("cuda"))[None]
129
  t = runner.timestep_transform(t, shape)
130
+ print(f"Timestep shifting from {1000.0 * cond_noise_scale} to {t}.")
 
 
 
131
  x = runner.schedule.forward(x, aug_noise, t)
132
  return x
133
 
134
  conditions = [
135
+ runner.get_condition(noise, task="sr", latent_blur=_add_noise(latent_blur, aug_noise))
 
 
 
 
136
  for noise, aug_noise, latent_blur in zip(noises, aug_noises, cond_latents)
137
  ]
138
 
139
  with torch.no_grad(), torch.autocast("cuda", torch.bfloat16, enabled=True):
140
  video_tensors = runner.inference(
141
+ noises=noises, conditions=conditions, dit_offload=False, **text_embeds_dict
 
 
 
142
  )
143
 
144
+ samples = [rearrange(video, "c t h w -> t c h w") for video in video_tensors]
 
 
 
 
 
 
 
145
  del video_tensors
 
146
  return samples
147
 
148
+ @spaces.GPU
149
+ def generation_loop(video_path, seed=666, fps_out=24, batch_size=1, cfg_scale=1.0, cfg_rescale=0.0, sample_steps=1, res_h=1280, res_w=720, sp_size=1):
150
+ if video_path is None:
151
+ return None, None, None
152
+
153
  runner = configure_runner(1)
154
 
155
  def _extract_text_embeds():
 
156
  positive_prompts_embeds = []
157
+ for _ in original_videos_local:
158
+ text_pos_embeds = torch.load(os.path.join(repo_dir, 'pos_emb.pt'))
159
+ text_neg_embeds = torch.load(os.path.join(repo_dir, 'neg_emb.pt'))
160
+ positive_prompts_embeds.append({"texts_pos": [text_pos_embeds], "texts_neg": [text_neg_embeds]})
 
 
 
161
  gc.collect()
162
  torch.cuda.empty_cache()
163
  return positive_prompts_embeds
 
167
  videos = videos[:, :121]
168
  t = videos.size(1)
169
  if t <= 4 * sp_size:
170
+ padding_needed = 4 * sp_size - t + 1
171
+ if padding_needed > 0:
172
+ padding = torch.cat([videos[:, -1].unsqueeze(1)] * padding_needed, dim=1)
173
+ videos = torch.cat([videos, padding], dim=1)
174
  return videos
175
  if (t - 1) % (4 * sp_size) == 0:
176
  return videos
177
  else:
178
+ padding_needed = 4 * sp_size - ((t - 1) % (4 * sp_size))
179
+ padding = torch.cat([videos[:, -1].unsqueeze(1)] * padding_needed, dim=1)
 
 
180
  videos = torch.cat([videos, padding], dim=1)
181
  assert (videos.size(1) - 1) % (4 * sp_size) == 0
182
  return videos
183
 
 
184
  runner.config.diffusion.cfg.scale = cfg_scale
185
  runner.config.diffusion.cfg.rescale = cfg_rescale
 
186
  runner.config.diffusion.timesteps.sampling.steps = sample_steps
187
  runner.configure_diffusion()
188
 
189
+ seed = int(seed) % (2**32)
 
190
  set_seed(seed, same_across_ranks=True)
191
+ output_base_dir = "output"
192
+ os.makedirs(output_base_dir, exist_ok=True)
193
 
 
194
  original_videos = [os.path.basename(video_path)]
195
+ original_videos_local = partition_by_size(original_videos, batch_size)
 
 
 
 
 
 
 
196
  positive_prompts_embeds = _extract_text_embeds()
197
 
198
+ video_transform = Compose([
199
+ NaResize(resolution=(res_h * res_w) ** 0.5, mode="area", downsample_only=False),
200
+ Lambda(lambda x: torch.clamp(x, 0.0, 1.0)),
201
+ DivisibleCrop((16, 16)),
202
+ Normalize(0.5, 0.5),
203
+ Rearrange("t c h w -> c t h w"),
204
+ ])
 
 
 
 
 
 
 
 
 
 
205
 
 
206
  for videos, text_embeds in tqdm(zip(original_videos_local, positive_prompts_embeds)):
 
207
  cond_latents = []
208
+ for _ in videos:
209
  media_type, _ = mimetypes.guess_type(video_path)
210
  is_image = media_type and media_type.startswith("image")
211
  is_video = media_type and media_type.startswith("video")
212
+
213
  if is_video:
214
+ video, _, _ = read_video(video_path, output_format="TCHW")
215
+ video = video / 255.0
 
 
 
 
216
  if video.size(0) > 121:
217
  video = video[:121]
218
  print(f"Read video size: {video.size()}")
219
+ output_dir = os.path.join(output_base_dir, f"{uuid.uuid4()}.mp4")
220
+ elif is_image:
221
  img = Image.open(video_path).convert("RGB")
222
+ img_tensor = T.ToTensor()(img).uns