Aduc-sdr commited on
Commit
ea7dfbd
·
verified ·
1 Parent(s): 1d6758a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +50 -43
app.py CHANGED
@@ -16,23 +16,28 @@ import subprocess
16
  import os
17
  import sys
18
 
19
- # --- Setup: Clone repository and Change Working Directory ---
20
- # This is the most robust way to ensure all relative paths work correctly.
21
 
22
- # 1. Clone the repository with all its files
23
- subprocess.run("git lfs install", shell=True, check=True)
24
  repo_dir_name = "SeedVR2-3B"
25
  if not os.path.exists(repo_dir_name):
26
- print(f"Cloning {repo_dir_name} repository...")
27
  subprocess.run(f"git clone https://huggingface.co/spaces/ByteDance-Seed/{repo_dir_name}", shell=True, check=True)
28
 
29
- # 2. Change the current working directory to the repository's root
30
- # CORREÇÃO PRINCIPAL: Isso resolve todos os problemas de caminho relativo.
31
  os.chdir(repo_dir_name)
32
- print(f"Changed working directory to: {os.getcwd()}")
33
 
34
- # --- Main Application Code ---
35
- # Now that we are inside the repo, all imports and file loads will work naturally.
 
 
 
 
 
 
36
 
37
  import torch
38
  import mediapy
@@ -51,7 +56,7 @@ import torchvision.transforms as T
51
  from torchvision.transforms import Compose, Lambda, Normalize
52
  from torchvision.io.video import read_video
53
 
54
- # Imports from the repository (will now work directly)
55
  from data.image.transforms.divisible_crop import DivisibleCrop
56
  from data.image.transforms.na_resize import NaResize
57
  from data.video.transforms.rearrange import Rearrange
@@ -63,21 +68,21 @@ from common.partition import partition_by_size
63
  from projects.video_diffusion_sr.infer import VideoDiffusionInfer
64
  from common.distributed.ops import sync_data
65
 
66
- # Check for color_fix utility (using relative path)
67
  if os.path.exists("projects/video_diffusion_sr/color_fix.py"):
68
  from projects.video_diffusion_sr.color_fix import wavelet_reconstruction
69
  use_colorfix = True
70
  else:
71
  use_colorfix = False
72
- print('Note!!!!!! Color fix is not available!')
73
 
74
- # --- Environment and Dependencies Setup ---
75
  os.environ["MASTER_ADDR"] = "127.0.0.1"
76
  os.environ["MASTER_PORT"] = "12355"
77
  os.environ["RANK"] = str(0)
78
  os.environ["WORLD_SIZE"] = str(1)
79
 
80
- # Use sys.executable to ensure we use the correct pip
81
  python_executable = sys.executable
82
  subprocess.run(
83
  [python_executable, "-m", "pip", "install", "flash-attn", "--no-build-isolation"],
@@ -88,20 +93,20 @@ subprocess.run(
88
  apex_wheel_path = "apex-0.1-cp310-cp310-linux_x86_64.whl"
89
  if os.path.exists(apex_wheel_path):
90
  subprocess.run([python_executable, "-m", "pip", "install", apex_wheel_path], check=True)
91
- print("✅ Apex setup completed.")
92
 
93
- # --- Core Functions ---
94
 
95
  def configure_sequence_parallel(sp_size):
96
  if sp_size > 1:
97
  init_sequence_parallel(sp_size)
98
 
99
  def configure_runner(sp_size):
100
- # Paths are now simple and relative to the repo root
101
  config_path = 'configs_3b/main.yaml'
102
  checkpoint_path = 'ckpts/seedvr2_ema_3b.pth'
103
 
104
- config = load_config(config_path) # This will now work correctly
105
  runner = VideoDiffusionInfer(config)
106
  OmegaConf.set_readonly(runner.config, False)
107
 
@@ -120,7 +125,7 @@ def generation_step(runner, text_embeds_dict, cond_latents):
120
 
121
  noises = [torch.randn_like(latent) for latent in cond_latents]
122
  aug_noises = [torch.randn_like(latent) for latent in cond_latents]
123
- print(f"Generating with noise shape: {noises[0].size()}.")
124
  noises, aug_noises, cond_latents = sync_data((noises, aug_noises, cond_latents), 0)
125
  noises, aug_noises, cond_latents = list(map(_move_to_cuda, (noises, aug_noises, cond_latents)))
126
  cond_noise_scale = 0.1
@@ -129,7 +134,7 @@ def generation_step(runner, text_embeds_dict, cond_latents):
129
  t = torch.tensor([1000.0], device=torch.device("cuda")) * cond_noise_scale
130
  shape = torch.tensor(x.shape[1:], device=torch.device("cuda"))[None]
131
  t = runner.timestep_transform(t, shape)
132
- print(f"Timestep shifting from {1000.0 * cond_noise_scale} to {t}.")
133
  x = runner.schedule.forward(x, aug_noise, t)
134
  return x
135
 
@@ -157,7 +162,7 @@ def generation_loop(video_path, seed=666, fps_out=24, batch_size=1, cfg_scale=1.
157
  def _extract_text_embeds():
158
  positive_prompts_embeds = []
159
  for _ in original_videos_local:
160
- # Paths are now simple
161
  text_pos_embeds = torch.load('pos_emb.pt')
162
  text_neg_embeds = torch.load('neg_emb.pt')
163
  positive_prompts_embeds.append({"texts_pos": [text_pos_embeds], "texts_neg": [text_neg_embeds]})
@@ -218,16 +223,16 @@ def generation_loop(video_path, seed=666, fps_out=24, batch_size=1, cfg_scale=1.
218
  video = video / 255.0
219
  if video.size(0) > 121:
220
  video = video[:121]
221
- print(f"Read video size: {video.size()}")
222
  output_dir = os.path.join(output_base_dir, f"{uuid.uuid4()}.mp4")
223
  elif is_image:
224
  img = Image.open(video_path).convert("RGB")
225
  img_tensor = T.ToTensor()(img).unsqueeze(0)
226
  video = img_tensor
227
- print(f"Read Image size: {video.size()}")
228
  output_dir = os.path.join(output_base_dir, f"{uuid.uuid4()}.png")
229
  else:
230
- raise ValueError("Unsupported file type")
231
 
232
  cond_latents.append(video_transform(video.to(torch.device("cuda"))))
233
 
@@ -236,7 +241,7 @@ def generation_loop(video_path, seed=666, fps_out=24, batch_size=1, cfg_scale=1.
236
  if is_video:
237
  cond_latents = [cut_videos(v, sp_size) for v in cond_latents]
238
 
239
- print(f"Encoding videos: {[v.size() for v in cond_latents]}")
240
  cond_latents = runner.vae_encode(cond_latents)
241
 
242
  for i, emb in enumerate(text_embeds["texts_pos"]):
@@ -273,41 +278,43 @@ def generation_loop(video_path, seed=666, fps_out=24, batch_size=1, cfg_scale=1.
273
  else:
274
  return None, output_dir, output_dir
275
 
276
- # --- Gradio UI ---
277
 
278
- with gr.Blocks(title="SeedVR2: One-Step Video Restoration") as demo:
279
- # Use an absolute path for the Gradio file source to be safe
280
  logo_path = os.path.abspath("assets/seedvr_logo.png")
281
  gr.HTML(f"""
282
-
283
- <a href='https://github.com/ByteDance-Seed/SeedVR' target='_blank'><b>SeedVR2: One-Step Video Restoration via Diffusion Adversarial Post-Training</b></a>.<br>
284
- 🔥 <b>SeedVR2</b> is a one-step image and video restoration algorithm for real-world and AIGC content.
 
 
285
  """)
286
 
287
  with gr.Row():
288
- input_file = gr.File(label="Upload image or video", type="filepath")
289
  with gr.Column():
290
  seed = gr.Number(label="Seed", value=666)
291
- fps = gr.Number(label="Output FPS (for video)", value=24)
292
 
293
- run_button = gr.Button("Run")
294
 
295
  with gr.Row():
296
- output_image = gr.Image(label="Output Image")
297
- output_video = gr.Video(label="Output Video")
298
 
299
- download_link = gr.File(label="Download the output")
300
 
301
  run_button.click(fn=generation_loop, inputs=[input_file, seed, fps], outputs=[output_image, output_video, download_link])
302
 
303
  gr.HTML("""
304
  <hr>
305
- <p>If you find SeedVR helpful, pleasethe <a href='https://github.com/ByteDance-Seed/SeedVR' target='_blank'>GitHub repository</a>:
306
  <a href="https://github.com/ByteDance-Seed/SeedVR" target="_blank"><img src="https://img.shields.io/github/stars/ByteDance-Seed/SeedVR?style=social" alt="GitHub Stars"></a></p>
307
- <h4>Notice</h4>
308
- <p>This demo supports up to <b>720p and 121 frames for videos or 2k images</b>. For other use cases, check the <a href='https://github.com/ByteDance-Seed/SeedVR' target='_blank'>GitHub repo</a>.</p>
309
- <h4>Limitations</h4>
310
- <p>May fail on heavy degradations or small-motion AIGC clips, causing oversharpening or poor restoration.</p>
311
  """)
312
 
313
  demo.queue().launch(share=True)
 
16
  import os
17
  import sys
18
 
19
+ # --- Setup: Clone repository, Change Directory, and Update Python Path ---
20
+ # Esta é a abordagem definitiva para corrigir todos os problemas de caminho.
21
 
22
+ # 1. Clone o repositório
 
23
  repo_dir_name = "SeedVR2-3B"
24
  if not os.path.exists(repo_dir_name):
25
+ print(f"Clonando o repositório {repo_dir_name}...")
26
  subprocess.run(f"git clone https://huggingface.co/spaces/ByteDance-Seed/{repo_dir_name}", shell=True, check=True)
27
 
28
+ # 2. Mude o diretório de trabalho atual para a raiz do repositório.
29
+ # Isso corrige o acesso a arquivos relativos (ex: carregar config.yaml).
30
  os.chdir(repo_dir_name)
31
+ print(f"Diretório de trabalho alterado para: {os.getcwd()}")
32
 
33
+ # 3. Adicione explicitamente o novo diretório de trabalho ao caminho do sistema do Python.
34
+ # Isso corrige as importações de módulos (ex: `from data...`).
35
+ sys.path.insert(0, os.path.abspath('.'))
36
+ print(f"Diretório atual adicionado ao sys.path: {os.path.abspath('.')}")
37
+
38
+
39
+ # --- Código Principal da Aplicação ---
40
+ # Agora, todas as importações e cargas de arquivos devem funcionar corretamente.
41
 
42
  import torch
43
  import mediapy
 
56
  from torchvision.transforms import Compose, Lambda, Normalize
57
  from torchvision.io.video import read_video
58
 
59
+ # Importações do repositório (agora funcionarão)
60
  from data.image.transforms.divisible_crop import DivisibleCrop
61
  from data.image.transforms.na_resize import NaResize
62
  from data.video.transforms.rearrange import Rearrange
 
68
  from projects.video_diffusion_sr.infer import VideoDiffusionInfer
69
  from common.distributed.ops import sync_data
70
 
71
+ # Verifica o utilitário color_fix (usando caminho relativo)
72
  if os.path.exists("projects/video_diffusion_sr/color_fix.py"):
73
  from projects.video_diffusion_sr.color_fix import wavelet_reconstruction
74
  use_colorfix = True
75
  else:
76
  use_colorfix = False
77
+ print('Atenção!!!!!! A correção de cor não está disponível!')
78
 
79
+ # --- Configuração de Ambiente e Dependências ---
80
  os.environ["MASTER_ADDR"] = "127.0.0.1"
81
  os.environ["MASTER_PORT"] = "12355"
82
  os.environ["RANK"] = str(0)
83
  os.environ["WORLD_SIZE"] = str(1)
84
 
85
+ # Use sys.executable para garantir que estamos usando o pip correto
86
  python_executable = sys.executable
87
  subprocess.run(
88
  [python_executable, "-m", "pip", "install", "flash-attn", "--no-build-isolation"],
 
93
  apex_wheel_path = "apex-0.1-cp310-cp310-linux_x86_64.whl"
94
  if os.path.exists(apex_wheel_path):
95
  subprocess.run([python_executable, "-m", "pip", "install", apex_wheel_path], check=True)
96
+ print("✅ Configuração do Apex concluída.")
97
 
98
+ # --- Funções Principais ---
99
 
100
  def configure_sequence_parallel(sp_size):
101
  if sp_size > 1:
102
  init_sequence_parallel(sp_size)
103
 
104
  def configure_runner(sp_size):
105
+ # Os caminhos agora são simples e relativos à raiz do repositório
106
  config_path = 'configs_3b/main.yaml'
107
  checkpoint_path = 'ckpts/seedvr2_ema_3b.pth'
108
 
109
+ config = load_config(config_path) # Isto agora funcionará corretamente
110
  runner = VideoDiffusionInfer(config)
111
  OmegaConf.set_readonly(runner.config, False)
112
 
 
125
 
126
  noises = [torch.randn_like(latent) for latent in cond_latents]
127
  aug_noises = [torch.randn_like(latent) for latent in cond_latents]
128
+ print(f"Gerando com o formato de ruído: {noises[0].size()}.")
129
  noises, aug_noises, cond_latents = sync_data((noises, aug_noises, cond_latents), 0)
130
  noises, aug_noises, cond_latents = list(map(_move_to_cuda, (noises, aug_noises, cond_latents)))
131
  cond_noise_scale = 0.1
 
134
  t = torch.tensor([1000.0], device=torch.device("cuda")) * cond_noise_scale
135
  shape = torch.tensor(x.shape[1:], device=torch.device("cuda"))[None]
136
  t = runner.timestep_transform(t, shape)
137
+ print(f"Deslocamento de Timestep de {1000.0 * cond_noise_scale} para {t}.")
138
  x = runner.schedule.forward(x, aug_noise, t)
139
  return x
140
 
 
162
  def _extract_text_embeds():
163
  positive_prompts_embeds = []
164
  for _ in original_videos_local:
165
+ # Os caminhos agora são simples
166
  text_pos_embeds = torch.load('pos_emb.pt')
167
  text_neg_embeds = torch.load('neg_emb.pt')
168
  positive_prompts_embeds.append({"texts_pos": [text_pos_embeds], "texts_neg": [text_neg_embeds]})
 
223
  video = video / 255.0
224
  if video.size(0) > 121:
225
  video = video[:121]
226
+ print(f"Tamanho do vídeo lido: {video.size()}")
227
  output_dir = os.path.join(output_base_dir, f"{uuid.uuid4()}.mp4")
228
  elif is_image:
229
  img = Image.open(video_path).convert("RGB")
230
  img_tensor = T.ToTensor()(img).unsqueeze(0)
231
  video = img_tensor
232
+ print(f"Tamanho da imagem lida: {video.size()}")
233
  output_dir = os.path.join(output_base_dir, f"{uuid.uuid4()}.png")
234
  else:
235
+ raise ValueError("Tipo de arquivo não suportado")
236
 
237
  cond_latents.append(video_transform(video.to(torch.device("cuda"))))
238
 
 
241
  if is_video:
242
  cond_latents = [cut_videos(v, sp_size) for v in cond_latents]
243
 
244
+ print(f"Codificando vídeos: {[v.size() for v in cond_latents]}")
245
  cond_latents = runner.vae_encode(cond_latents)
246
 
247
  for i, emb in enumerate(text_embeds["texts_pos"]):
 
278
  else:
279
  return None, output_dir, output_dir
280
 
281
+ # --- UI do Gradio ---
282
 
283
+ with gr.Blocks(title="SeedVR2: Restauração de Vídeo em Um Passo") as demo:
284
+ # Use um caminho absoluto para o arquivo de logo do Gradio para segurança
285
  logo_path = os.path.abspath("assets/seedvr_logo.png")
286
  gr.HTML(f"""
287
+ <div style='text-align:center; margin-bottom: 10px;'>
288
+ <img src='file/{logo_path}' style='height:40px;' alt='SeedVR logo'/>
289
+ </div>
290
+ <p><b>Demonstração oficial do Gradio</b> para <a href='https://github.com/ByteDance-Seed/SeedVR' target='_blank'><b>SeedVR2: One-Step Video Restoration via Diffusion Adversarial Post-Training</b></a>.<br>
291
+ 🔥 <b>SeedVR2</b> é um algoritmo de restauração de imagem e vídeo em um passo para conteúdo do mundo real e AIGC.</p>
292
  """)
293
 
294
  with gr.Row():
295
+ input_file = gr.File(label="Carregar imagem ou vídeo", type="filepath")
296
  with gr.Column():
297
  seed = gr.Number(label="Seed", value=666)
298
+ fps = gr.Number(label="FPS de Saída (para vídeo)", value=24)
299
 
300
+ run_button = gr.Button("Executar")
301
 
302
  with gr.Row():
303
+ output_image = gr.Image(label="Imagem de Saída")
304
+ output_video = gr.Video(label="Vídeo de Saída")
305
 
306
+ download_link = gr.File(label="Baixar o resultado")
307
 
308
  run_button.click(fn=generation_loop, inputs=[input_file, seed, fps], outputs=[output_image, output_video, download_link])
309
 
310
  gr.HTML("""
311
  <hr>
312
+ <p>Se você achou o SeedVR útil, por favor o <a href='https://github.com/ByteDance-Seed/SeedVR' target='_blank'>repositório no GitHub</a>:
313
  <a href="https://github.com/ByteDance-Seed/SeedVR" target="_blank"><img src="https://img.shields.io/github/stars/ByteDance-Seed/SeedVR?style=social" alt="GitHub Stars"></a></p>
314
+ <h4>Aviso</h4>
315
+ <p>Esta demonstração suporta até <b>720p e 121 frames para vídeos ou imagens 2k</b>. Para outros casos de uso, verifique o <a href='https://github.com/ByteDance-Seed/SeedVR' target='_blank'>repositório no GitHub</a>.</p>
316
+ <h4>Limitações</h4>
317
+ <p>Pode falhar em degradações pesadas ou em clipes AIGC com pouco movimento, causando excesso de nitidez ou restauração inadequada.</p>
318
  """)
319
 
320
  demo.queue().launch(share=True)