Aduc-sdr commited on
Commit
00c0ed2
·
verified ·
1 Parent(s): c5f2555

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +35 -47
app.py CHANGED
@@ -11,23 +11,10 @@
11
  # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
  # // See the License for the specific language governing permissions and
13
  # // limitations under the License.
14
-
15
-
16
- import torch
17
- import torch.distributed as dist
18
  import os
19
- import gc
20
- import logging
21
  import sys
22
- import subprocess
23
- from pathlib import Path
24
- from urllib.parse import urlparse
25
- from torch.hub import download_url_to_file
26
- import gradio as gr
27
- import mediapy
28
- from einops import rearrange
29
- import shutil
30
- from omegaconf import OmegaConf
31
 
32
  # --- ETAPA 1: Preparação do Ambiente ---
33
  # Clonar o repositório para garantir que todas as pastas de código (data, common, etc.) existam.
@@ -35,7 +22,6 @@ from omegaconf import OmegaConf
35
  repo_dir_name = "SeedVR2-3B"
36
  if not os.path.exists(repo_dir_name):
37
  print(f"Clonando o repositório {repo_dir_name} para obter todo o código-fonte...")
38
- # Usamos --depth 1 para um clone mais rápido, já que não precisamos do histórico
39
  subprocess.run(f"git clone --depth 1 https://huggingface.co/spaces/ByteDance-Seed/{repo_dir_name}", shell=True, check=True)
40
 
41
  # --- ETAPA 2: Configuração dos Caminhos ---
@@ -47,14 +33,38 @@ sys.path.insert(0, os.path.abspath('.'))
47
  print(f"Diretório atual adicionado ao sys.path para importações.")
48
 
49
  # --- ETAPA 3: Instalação de Dependências e Download de Modelos ---
50
- # Agora que estamos no diretório correto, podemos prosseguir.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
 
 
 
52
  import torch
53
  from pathlib import Path
54
  from urllib.parse import urlparse
55
  from torch.hub import download_url_to_file, get_dir
56
 
57
- # Função de download do original
58
  def load_file_from_url(url, model_dir=None, progress=True, file_name=None):
59
  if model_dir is None:
60
  hub_dir = get_dir()
@@ -62,15 +72,13 @@ def load_file_from_url(url, model_dir=None, progress=True, file_name=None):
62
  os.makedirs(model_dir, exist_ok=True)
63
  parts = urlparse(url)
64
  filename = os.path.basename(parts.path)
65
- if file_name is not None:
66
- filename = file_name
67
  cached_file = os.path.abspath(os.path.join(model_dir, filename))
68
  if not os.path.exists(cached_file):
69
  print(f'Baixando: "{url}" para {cached_file}\n')
70
  download_url_to_file(url, cached_file, hash_prefix=None, progress=progress)
71
  return cached_file
72
 
73
- # URLs dos modelos
74
  pretrain_model_url = {
75
  'vae': 'https://huggingface.co/ByteDance-Seed/SeedVR2-3B/resolve/main/ema_vae.pth',
76
  'dit': 'https://huggingface.co/ByteDance-Seed/SeedVR2-3B/resolve/main/seedvr2_ema_3b.pth',
@@ -78,9 +86,7 @@ pretrain_model_url = {
78
  'neg_emb': 'https://huggingface.co/ByteDance-Seed/SeedVR2-3B/resolve/main/neg_emb.pt',
79
  }
80
 
81
- # Criar diretório de checkpoints e baixar modelos
82
- ckpt_dir = Path('./ckpts')
83
- ckpt_dir.mkdir(exist_ok=True)
84
  for key, url in pretrain_model_url.items():
85
  filename = os.path.basename(url)
86
  model_dir = './ckpts' if key in ['vae', 'dit'] else '.'
@@ -88,29 +94,10 @@ for key, url in pretrain_model_url.items():
88
  if not os.path.exists(target_path):
89
  load_file_from_url(url=url, model_dir=model_dir, progress=True, file_name=filename)
90
 
91
- # Baixar vídeos de exemplo
92
  torch.hub.download_url_to_file('https://huggingface.co/datasets/Iceclear/SeedVR_VideoDemos/resolve/main/seedvr_videos_crf23/aigc1k/23_1_lq.mp4', '01.mp4')
93
  torch.hub.download_url_to_file('https://huggingface.co/datasets/Iceclear/SeedVR_VideoDemos/resolve/main/seedvr_videos_crf23/aigc1k/28_1_lq.mp4', '02.mp4')
94
  torch.hub.download_url_to_file('https://huggingface.co/datasets/Iceclear/SeedVR_VideoDemos/resolve/main/seedvr_videos_crf23/aigc1k/2_1_lq.mp4', '03.mp4')
95
-
96
- # --- REFINAMENTO: Compilar dependências do zero para a GPU L40S (Ada Lovelace) ---
97
- python_executable = sys.executable
98
-
99
- print("Instalando flash-attn compilando do zero...")
100
- # Força a reinstalação a partir do zero para garantir que seja compilado para a GPU atual
101
- subprocess.run([python_executable, "-m", "pip", "install", "--force-reinstall", "--no-cache-dir", "flash-attn"], check=True)
102
-
103
- print("Clonando e compilando o Apex do zero...")
104
- if not os.path.exists("apex"):
105
- subprocess.run("git clone https://github.com/NVIDIA/apex", shell=True, check=True)
106
-
107
- # Instala o Apex a partir da fonte clonada, o que força a compilação para a GPU L40S
108
- # As flags --cpp_ext e --cuda_ext são essenciais para a compilação
109
- subprocess.run(
110
- [python_executable, "-m", "pip", "install", "-v", "--disable-pip-version-check", "--no-cache-dir", "--global-option=--cpp_ext", "--global-option=--cuda_ext", "./apex"],
111
- check=True
112
- )
113
- print("✅ Configuração do Apex concluída.")
114
 
115
  # --- ETAPA 4: Execução do Código Principal da Aplicação ---
116
  import mediapy
@@ -142,7 +129,6 @@ os.environ["MASTER_ADDR"] = "127.0.0.1"
142
  os.environ["MASTER_PORT"] = "12355"
143
  os.environ["RANK"] = str(0)
144
  os.environ["WORLD_SIZE"] = str(1)
145
- # Adiciona uma variável de ambiente que pode ajudar o PyTorch a debugar erros de CUDA
146
  os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
147
 
148
  if os.path.exists("projects/video_diffusion_sr/color_fix.py"):
@@ -186,7 +172,7 @@ def generation_step(runner, text_embeds_dict, cond_latents):
186
  video_tensors = runner.inference(noises=noises, conditions=conditions, dit_offload=False, **text_embeds_dict)
187
  return [rearrange(video, "c t h w -> t c h w") for video in video_tensors]
188
 
189
- @spaces.GPU
190
  def generation_loop(video_path, seed=666, fps_out=24, batch_size=1, cfg_scale=1.0, cfg_rescale=0.0, sample_steps=1, res_h=1280, res_w=720, sp_size=1):
191
  if video_path is None: return None, None, None
192
  runner = configure_runner(1)
@@ -228,7 +214,9 @@ def generation_loop(video_path, seed=666, fps_out=24, batch_size=1, cfg_scale=1.
228
 
229
  with gr.Blocks(title="SeedVR2: Restauração de Vídeo em Um Passo") as demo:
230
  gr.HTML(f"""
231
-
 
 
232
  <p><b>Demonstração oficial do Gradio</b> para
233
  <a href='https://github.com/ByteDance-Seed/SeedVR' target='_blank'>
234
  <b>SeedVR2: One-Step Video Restoration via Diffusion Adversarial Post-Training</b></a>.<br>
 
11
  # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
  # // See the License for the specific language governing permissions and
13
  # // limitations under the License.
14
+ import spaces
15
+ import subprocess
 
 
16
  import os
 
 
17
  import sys
 
 
 
 
 
 
 
 
 
18
 
19
  # --- ETAPA 1: Preparação do Ambiente ---
20
  # Clonar o repositório para garantir que todas as pastas de código (data, common, etc.) existam.
 
22
  repo_dir_name = "SeedVR2-3B"
23
  if not os.path.exists(repo_dir_name):
24
  print(f"Clonando o repositório {repo_dir_name} para obter todo o código-fonte...")
 
25
  subprocess.run(f"git clone --depth 1 https://huggingface.co/spaces/ByteDance-Seed/{repo_dir_name}", shell=True, check=True)
26
 
27
  # --- ETAPA 2: Configuração dos Caminhos ---
 
33
  print(f"Diretório atual adicionado ao sys.path para importações.")
34
 
35
  # --- ETAPA 3: Instalação de Dependências e Download de Modelos ---
36
+ # ORDEM CORRETA: Instalar requisitos, DEPOIS compilar apex/flash-attn, DEPOIS baixar modelos.
37
+
38
+ python_executable = sys.executable
39
+
40
+ # **CORREÇÃO CRÍTICA: Instalar requisitos PRIMEIRO para ter o PyTorch disponível**
41
+ print("Instalando dependências a partir do requirements.txt (isso inclui o PyTorch)...")
42
+ subprocess.run([python_executable, "-m", "pip", "install", "-r", "requirements.txt"], check=True)
43
+ print("✅ Dependências básicas instaladas.")
44
+
45
+
46
+ # **Compilar dependências otimizadas para a GPU L40S (Ada Lovelace)**
47
+ print("Instalando flash-attn compilando do zero...")
48
+ subprocess.run([python_executable, "-m", "pip", "install", "--force-reinstall", "--no-cache-dir", "flash-attn"], check=True)
49
+
50
+ print("Clonando e compilando o Apex do zero (isso pode demorar um pouco)...")
51
+ if not os.path.exists("apex"):
52
+ subprocess.run("git clone https://github.com/NVIDIA/apex", shell=True, check=True)
53
+
54
+ # Instala o Apex a partir da fonte clonada, o que força a compilação para a GPU L40S
55
+ subprocess.run(
56
+ [python_executable, "-m", "pip", "install", "-v", "--disable-pip-version-check", "--no-cache-dir", "--global-option=--cpp_ext", "--global-option=--cuda_ext", "./apex"],
57
+ check=True
58
+ )
59
+ print("✅ Configuração do Apex concluída.")
60
 
61
+
62
+ # **Download dos modelos e dados**
63
  import torch
64
  from pathlib import Path
65
  from urllib.parse import urlparse
66
  from torch.hub import download_url_to_file, get_dir
67
 
 
68
  def load_file_from_url(url, model_dir=None, progress=True, file_name=None):
69
  if model_dir is None:
70
  hub_dir = get_dir()
 
72
  os.makedirs(model_dir, exist_ok=True)
73
  parts = urlparse(url)
74
  filename = os.path.basename(parts.path)
75
+ if file_name is not None: filename = file_name
 
76
  cached_file = os.path.abspath(os.path.join(model_dir, filename))
77
  if not os.path.exists(cached_file):
78
  print(f'Baixando: "{url}" para {cached_file}\n')
79
  download_url_to_file(url, cached_file, hash_prefix=None, progress=progress)
80
  return cached_file
81
 
 
82
  pretrain_model_url = {
83
  'vae': 'https://huggingface.co/ByteDance-Seed/SeedVR2-3B/resolve/main/ema_vae.pth',
84
  'dit': 'https://huggingface.co/ByteDance-Seed/SeedVR2-3B/resolve/main/seedvr2_ema_3b.pth',
 
86
  'neg_emb': 'https://huggingface.co/ByteDance-Seed/SeedVR2-3B/resolve/main/neg_emb.pt',
87
  }
88
 
89
+ ckpt_dir = Path('./ckpts'); ckpt_dir.mkdir(exist_ok=True)
 
 
90
  for key, url in pretrain_model_url.items():
91
  filename = os.path.basename(url)
92
  model_dir = './ckpts' if key in ['vae', 'dit'] else '.'
 
94
  if not os.path.exists(target_path):
95
  load_file_from_url(url=url, model_dir=model_dir, progress=True, file_name=filename)
96
 
 
97
  torch.hub.download_url_to_file('https://huggingface.co/datasets/Iceclear/SeedVR_VideoDemos/resolve/main/seedvr_videos_crf23/aigc1k/23_1_lq.mp4', '01.mp4')
98
  torch.hub.download_url_to_file('https://huggingface.co/datasets/Iceclear/SeedVR_VideoDemos/resolve/main/seedvr_videos_crf23/aigc1k/28_1_lq.mp4', '02.mp4')
99
  torch.hub.download_url_to_file('https://huggingface.co/datasets/Iceclear/SeedVR_VideoDemos/resolve/main/seedvr_videos_crf23/aigc1k/2_1_lq.mp4', '03.mp4')
100
+ print("✅ Modelos e dados de exemplo baixados.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
 
102
  # --- ETAPA 4: Execução do Código Principal da Aplicação ---
103
  import mediapy
 
129
  os.environ["MASTER_PORT"] = "12355"
130
  os.environ["RANK"] = str(0)
131
  os.environ["WORLD_SIZE"] = str(1)
 
132
  os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
133
 
134
  if os.path.exists("projects/video_diffusion_sr/color_fix.py"):
 
172
  video_tensors = runner.inference(noises=noises, conditions=conditions, dit_offload=False, **text_embeds_dict)
173
  return [rearrange(video, "c t h w -> t c h w") for video in video_tensors]
174
 
175
+
176
  def generation_loop(video_path, seed=666, fps_out=24, batch_size=1, cfg_scale=1.0, cfg_rescale=0.0, sample_steps=1, res_h=1280, res_w=720, sp_size=1):
177
  if video_path is None: return None, None, None
178
  runner = configure_runner(1)
 
214
 
215
  with gr.Blocks(title="SeedVR2: Restauração de Vídeo em Um Passo") as demo:
216
  gr.HTML(f"""
217
+ <div style='text-align:center; margin-bottom: 10px;'>
218
+ <img src='file/{os.path.abspath("assets/seedvr_logo.png")}' style='height:40px;' alt='SeedVR logo'/>
219
+ </div>
220
  <p><b>Demonstração oficial do Gradio</b> para
221
  <a href='https://github.com/ByteDance-Seed/SeedVR' target='_blank'>
222
  <b>SeedVR2: One-Step Video Restoration via Diffusion Adversarial Post-Training</b></a>.<br>