Aduc-sdr commited on
Commit
1d6758a
·
verified ·
1 Parent(s): 206f0da

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -29
app.py CHANGED
@@ -16,23 +16,23 @@ import subprocess
16
  import os
17
  import sys
18
 
19
- # --- Setup: Clone repository and add it to Python Path ---
20
- # This section ensures all necessary code and model files are available.
21
 
22
  # 1. Clone the repository with all its files
23
  subprocess.run("git lfs install", shell=True, check=True)
24
- if not os.path.exists("SeedVR2-3B"):
25
- print("Cloning SeedVR2-3B repository...")
26
- subprocess.run("git clone https://huggingface.co/spaces/ByteDance-Seed/SeedVR2-3B", shell=True, check=True)
 
27
 
28
- # 2. Add the cloned repository's directory to Python's module search path
29
- repo_dir = "SeedVR2-3B"
30
- # This allows us to import modules like 'data', 'common', etc., from the cloned repo.
31
- sys.path.insert(0, os.path.abspath(repo_dir))
32
- print(f"Repository directory '{os.path.abspath(repo_dir)}' added to Python path.")
33
 
34
  # --- Main Application Code ---
35
- # All file paths will now be relative to the cloned repository directory.
36
 
37
  import torch
38
  import mediapy
@@ -51,7 +51,7 @@ import torchvision.transforms as T
51
  from torchvision.transforms import Compose, Lambda, Normalize
52
  from torchvision.io.video import read_video
53
 
54
- # Imports from the cloned repository
55
  from data.image.transforms.divisible_crop import DivisibleCrop
56
  from data.image.transforms.na_resize import NaResize
57
  from data.video.transforms.rearrange import Rearrange
@@ -63,9 +63,8 @@ from common.partition import partition_by_size
63
  from projects.video_diffusion_sr.infer import VideoDiffusionInfer
64
  from common.distributed.ops import sync_data
65
 
66
- # Check for color_fix utility
67
- color_fix_path = os.path.join(repo_dir, "projects/video_diffusion_sr/color_fix.py")
68
- if os.path.exists(color_fix_path):
69
  from projects.video_diffusion_sr.color_fix import wavelet_reconstruction
70
  use_colorfix = True
71
  else:
@@ -78,7 +77,7 @@ os.environ["MASTER_PORT"] = "12355"
78
  os.environ["RANK"] = str(0)
79
  os.environ["WORLD_SIZE"] = str(1)
80
 
81
- # CORREÇÃO: Usar sys.executable para chamar o pip corretamente
82
  python_executable = sys.executable
83
  subprocess.run(
84
  [python_executable, "-m", "pip", "install", "flash-attn", "--no-build-isolation"],
@@ -86,9 +85,8 @@ subprocess.run(
86
  check=True
87
  )
88
 
89
- apex_wheel_path = os.path.join(repo_dir, "apex-0.1-cp310-cp310-linux_x86_64.whl")
90
  if os.path.exists(apex_wheel_path):
91
- # CORREÇÃO: Usar sys.executable aqui também
92
  subprocess.run([python_executable, "-m", "pip", "install", apex_wheel_path], check=True)
93
  print("✅ Apex setup completed.")
94
 
@@ -99,10 +97,11 @@ def configure_sequence_parallel(sp_size):
99
  init_sequence_parallel(sp_size)
100
 
101
  def configure_runner(sp_size):
102
- config_path = os.path.join(repo_dir, 'configs_3b', 'main.yaml')
103
- checkpoint_path = os.path.join(repo_dir, 'ckpts', 'seedvr2_ema_3b.pth')
 
104
 
105
- config = load_config(config_path)
106
  runner = VideoDiffusionInfer(config)
107
  OmegaConf.set_readonly(runner.config, False)
108
 
@@ -158,8 +157,9 @@ def generation_loop(video_path, seed=666, fps_out=24, batch_size=1, cfg_scale=1.
158
  def _extract_text_embeds():
159
  positive_prompts_embeds = []
160
  for _ in original_videos_local:
161
- text_pos_embeds = torch.load(os.path.join(repo_dir, 'pos_emb.pt'))
162
- text_neg_embeds = torch.load(os.path.join(repo_dir, 'neg_emb.pt'))
 
163
  positive_prompts_embeds.append({"texts_pos": [text_pos_embeds], "texts_neg": [text_neg_embeds]})
164
  gc.collect()
165
  torch.cuda.empty_cache()
@@ -276,13 +276,12 @@ def generation_loop(video_path, seed=666, fps_out=24, batch_size=1, cfg_scale=1.
276
  # --- Gradio UI ---
277
 
278
  with gr.Blocks(title="SeedVR2: One-Step Video Restoration") as demo:
279
- logo_path = os.path.join(repo_dir, "assets/seedvr_logo.png")
 
280
  gr.HTML(f"""
281
- <div style='text-align:center; margin-bottom: 10px;'>
282
- <img src='file/{logo_path}' style='height:40px;' alt='SeedVR logo'/>
283
- </div>
284
- <p><b>Official Gradio demo</b> for <a href='https://github.com/ByteDance-Seed/SeedVR' target='_blank'><b>SeedVR2: One-Step Video Restoration via Diffusion Adversarial Post-Training</b></a>.<br>
285
- 🔥 <b>SeedVR2</b> is a one-step image and video restoration algorithm for real-world and AIGC content.</p>
286
  """)
287
 
288
  with gr.Row():
 
16
  import os
17
  import sys
18
 
19
+ # --- Setup: Clone repository and Change Working Directory ---
20
+ # This is the most robust way to ensure all relative paths work correctly.
21
 
22
  # 1. Clone the repository with all its files
23
  subprocess.run("git lfs install", shell=True, check=True)
24
+ repo_dir_name = "SeedVR2-3B"
25
+ if not os.path.exists(repo_dir_name):
26
+ print(f"Cloning {repo_dir_name} repository...")
27
+ subprocess.run(f"git clone https://huggingface.co/spaces/ByteDance-Seed/{repo_dir_name}", shell=True, check=True)
28
 
29
+ # 2. Change the current working directory to the repository's root
30
+ # CORREÇÃO PRINCIPAL: Isso resolve todos os problemas de caminho relativo.
31
+ os.chdir(repo_dir_name)
32
+ print(f"Changed working directory to: {os.getcwd()}")
 
33
 
34
  # --- Main Application Code ---
35
+ # Now that we are inside the repo, all imports and file loads will work naturally.
36
 
37
  import torch
38
  import mediapy
 
51
  from torchvision.transforms import Compose, Lambda, Normalize
52
  from torchvision.io.video import read_video
53
 
54
+ # Imports from the repository (will now work directly)
55
  from data.image.transforms.divisible_crop import DivisibleCrop
56
  from data.image.transforms.na_resize import NaResize
57
  from data.video.transforms.rearrange import Rearrange
 
63
  from projects.video_diffusion_sr.infer import VideoDiffusionInfer
64
  from common.distributed.ops import sync_data
65
 
66
+ # Check for color_fix utility (using relative path)
67
+ if os.path.exists("projects/video_diffusion_sr/color_fix.py"):
 
68
  from projects.video_diffusion_sr.color_fix import wavelet_reconstruction
69
  use_colorfix = True
70
  else:
 
77
  os.environ["RANK"] = str(0)
78
  os.environ["WORLD_SIZE"] = str(1)
79
 
80
+ # Use sys.executable to ensure we use the correct pip
81
  python_executable = sys.executable
82
  subprocess.run(
83
  [python_executable, "-m", "pip", "install", "flash-attn", "--no-build-isolation"],
 
85
  check=True
86
  )
87
 
88
+ apex_wheel_path = "apex-0.1-cp310-cp310-linux_x86_64.whl"
89
  if os.path.exists(apex_wheel_path):
 
90
  subprocess.run([python_executable, "-m", "pip", "install", apex_wheel_path], check=True)
91
  print("✅ Apex setup completed.")
92
 
 
97
  init_sequence_parallel(sp_size)
98
 
99
  def configure_runner(sp_size):
100
+ # Paths are now simple and relative to the repo root
101
+ config_path = 'configs_3b/main.yaml'
102
+ checkpoint_path = 'ckpts/seedvr2_ema_3b.pth'
103
 
104
+ config = load_config(config_path) # This will now work correctly
105
  runner = VideoDiffusionInfer(config)
106
  OmegaConf.set_readonly(runner.config, False)
107
 
 
157
  def _extract_text_embeds():
158
  positive_prompts_embeds = []
159
  for _ in original_videos_local:
160
+ # Paths are now simple
161
+ text_pos_embeds = torch.load('pos_emb.pt')
162
+ text_neg_embeds = torch.load('neg_emb.pt')
163
  positive_prompts_embeds.append({"texts_pos": [text_pos_embeds], "texts_neg": [text_neg_embeds]})
164
  gc.collect()
165
  torch.cuda.empty_cache()
 
276
  # --- Gradio UI ---
277
 
278
  with gr.Blocks(title="SeedVR2: One-Step Video Restoration") as demo:
279
+ # Use an absolute path for the Gradio file source to be safe
280
+ logo_path = os.path.abspath("assets/seedvr_logo.png")
281
  gr.HTML(f"""
282
+
283
+ <a href='https://github.com/ByteDance-Seed/SeedVR' target='_blank'><b>SeedVR2: One-Step Video Restoration via Diffusion Adversarial Post-Training</b></a>.<br>
284
+ 🔥 <b>SeedVR2</b> is a one-step image and video restoration algorithm for real-world and AIGC content.
 
 
285
  """)
286
 
287
  with gr.Row():