Spaces:
Paused
Paused
Update app.py
Browse files
app.py
CHANGED
@@ -11,23 +11,10 @@
|
|
11 |
# // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
# // See the License for the specific language governing permissions and
|
13 |
# // limitations under the License.
|
14 |
-
|
15 |
-
|
16 |
-
import torch
|
17 |
-
import torch.distributed as dist
|
18 |
import os
|
19 |
-
import gc
|
20 |
-
import logging
|
21 |
import sys
|
22 |
-
import subprocess
|
23 |
-
from pathlib import Path
|
24 |
-
from urllib.parse import urlparse
|
25 |
-
from torch.hub import download_url_to_file
|
26 |
-
import gradio as gr
|
27 |
-
import mediapy
|
28 |
-
from einops import rearrange
|
29 |
-
import shutil
|
30 |
-
from omegaconf import OmegaConf
|
31 |
|
32 |
# --- ETAPA 1: Preparação do Ambiente ---
|
33 |
# Clonar o repositório para garantir que todas as pastas de código (data, common, etc.) existam.
|
@@ -35,7 +22,6 @@ from omegaconf import OmegaConf
|
|
35 |
repo_dir_name = "SeedVR2-3B"
|
36 |
if not os.path.exists(repo_dir_name):
|
37 |
print(f"Clonando o repositório {repo_dir_name} para obter todo o código-fonte...")
|
38 |
-
# Usamos --depth 1 para um clone mais rápido, já que não precisamos do histórico
|
39 |
subprocess.run(f"git clone --depth 1 https://huggingface.co/spaces/ByteDance-Seed/{repo_dir_name}", shell=True, check=True)
|
40 |
|
41 |
# --- ETAPA 2: Configuração dos Caminhos ---
|
@@ -47,14 +33,38 @@ sys.path.insert(0, os.path.abspath('.'))
|
|
47 |
print(f"Diretório atual adicionado ao sys.path para importações.")
|
48 |
|
49 |
# --- ETAPA 3: Instalação de Dependências e Download de Modelos ---
|
50 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
51 |
|
|
|
|
|
52 |
import torch
|
53 |
from pathlib import Path
|
54 |
from urllib.parse import urlparse
|
55 |
from torch.hub import download_url_to_file, get_dir
|
56 |
|
57 |
-
# Função de download do original
|
58 |
def load_file_from_url(url, model_dir=None, progress=True, file_name=None):
|
59 |
if model_dir is None:
|
60 |
hub_dir = get_dir()
|
@@ -62,15 +72,13 @@ def load_file_from_url(url, model_dir=None, progress=True, file_name=None):
|
|
62 |
os.makedirs(model_dir, exist_ok=True)
|
63 |
parts = urlparse(url)
|
64 |
filename = os.path.basename(parts.path)
|
65 |
-
if file_name is not None:
|
66 |
-
filename = file_name
|
67 |
cached_file = os.path.abspath(os.path.join(model_dir, filename))
|
68 |
if not os.path.exists(cached_file):
|
69 |
print(f'Baixando: "{url}" para {cached_file}\n')
|
70 |
download_url_to_file(url, cached_file, hash_prefix=None, progress=progress)
|
71 |
return cached_file
|
72 |
|
73 |
-
# URLs dos modelos
|
74 |
pretrain_model_url = {
|
75 |
'vae': 'https://huggingface.co/ByteDance-Seed/SeedVR2-3B/resolve/main/ema_vae.pth',
|
76 |
'dit': 'https://huggingface.co/ByteDance-Seed/SeedVR2-3B/resolve/main/seedvr2_ema_3b.pth',
|
@@ -78,9 +86,7 @@ pretrain_model_url = {
|
|
78 |
'neg_emb': 'https://huggingface.co/ByteDance-Seed/SeedVR2-3B/resolve/main/neg_emb.pt',
|
79 |
}
|
80 |
|
81 |
-
|
82 |
-
ckpt_dir = Path('./ckpts')
|
83 |
-
ckpt_dir.mkdir(exist_ok=True)
|
84 |
for key, url in pretrain_model_url.items():
|
85 |
filename = os.path.basename(url)
|
86 |
model_dir = './ckpts' if key in ['vae', 'dit'] else '.'
|
@@ -88,29 +94,10 @@ for key, url in pretrain_model_url.items():
|
|
88 |
if not os.path.exists(target_path):
|
89 |
load_file_from_url(url=url, model_dir=model_dir, progress=True, file_name=filename)
|
90 |
|
91 |
-
# Baixar vídeos de exemplo
|
92 |
torch.hub.download_url_to_file('https://huggingface.co/datasets/Iceclear/SeedVR_VideoDemos/resolve/main/seedvr_videos_crf23/aigc1k/23_1_lq.mp4', '01.mp4')
|
93 |
torch.hub.download_url_to_file('https://huggingface.co/datasets/Iceclear/SeedVR_VideoDemos/resolve/main/seedvr_videos_crf23/aigc1k/28_1_lq.mp4', '02.mp4')
|
94 |
torch.hub.download_url_to_file('https://huggingface.co/datasets/Iceclear/SeedVR_VideoDemos/resolve/main/seedvr_videos_crf23/aigc1k/2_1_lq.mp4', '03.mp4')
|
95 |
-
|
96 |
-
# --- REFINAMENTO: Compilar dependências do zero para a GPU L40S (Ada Lovelace) ---
|
97 |
-
python_executable = sys.executable
|
98 |
-
|
99 |
-
print("Instalando flash-attn compilando do zero...")
|
100 |
-
# Força a reinstalação a partir do zero para garantir que seja compilado para a GPU atual
|
101 |
-
subprocess.run([python_executable, "-m", "pip", "install", "--force-reinstall", "--no-cache-dir", "flash-attn"], check=True)
|
102 |
-
|
103 |
-
print("Clonando e compilando o Apex do zero...")
|
104 |
-
if not os.path.exists("apex"):
|
105 |
-
subprocess.run("git clone https://github.com/NVIDIA/apex", shell=True, check=True)
|
106 |
-
|
107 |
-
# Instala o Apex a partir da fonte clonada, o que força a compilação para a GPU L40S
|
108 |
-
# As flags --cpp_ext e --cuda_ext são essenciais para a compilação
|
109 |
-
subprocess.run(
|
110 |
-
[python_executable, "-m", "pip", "install", "-v", "--disable-pip-version-check", "--no-cache-dir", "--global-option=--cpp_ext", "--global-option=--cuda_ext", "./apex"],
|
111 |
-
check=True
|
112 |
-
)
|
113 |
-
print("✅ Configuração do Apex concluída.")
|
114 |
|
115 |
# --- ETAPA 4: Execução do Código Principal da Aplicação ---
|
116 |
import mediapy
|
@@ -142,7 +129,6 @@ os.environ["MASTER_ADDR"] = "127.0.0.1"
|
|
142 |
os.environ["MASTER_PORT"] = "12355"
|
143 |
os.environ["RANK"] = str(0)
|
144 |
os.environ["WORLD_SIZE"] = str(1)
|
145 |
-
# Adiciona uma variável de ambiente que pode ajudar o PyTorch a debugar erros de CUDA
|
146 |
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
|
147 |
|
148 |
if os.path.exists("projects/video_diffusion_sr/color_fix.py"):
|
@@ -186,7 +172,7 @@ def generation_step(runner, text_embeds_dict, cond_latents):
|
|
186 |
video_tensors = runner.inference(noises=noises, conditions=conditions, dit_offload=False, **text_embeds_dict)
|
187 |
return [rearrange(video, "c t h w -> t c h w") for video in video_tensors]
|
188 |
|
189 |
-
|
190 |
def generation_loop(video_path, seed=666, fps_out=24, batch_size=1, cfg_scale=1.0, cfg_rescale=0.0, sample_steps=1, res_h=1280, res_w=720, sp_size=1):
|
191 |
if video_path is None: return None, None, None
|
192 |
runner = configure_runner(1)
|
@@ -228,7 +214,9 @@ def generation_loop(video_path, seed=666, fps_out=24, batch_size=1, cfg_scale=1.
|
|
228 |
|
229 |
with gr.Blocks(title="SeedVR2: Restauração de Vídeo em Um Passo") as demo:
|
230 |
gr.HTML(f"""
|
231 |
-
|
|
|
|
|
232 |
<p><b>Demonstração oficial do Gradio</b> para
|
233 |
<a href='https://github.com/ByteDance-Seed/SeedVR' target='_blank'>
|
234 |
<b>SeedVR2: One-Step Video Restoration via Diffusion Adversarial Post-Training</b></a>.<br>
|
|
|
11 |
# // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
# // See the License for the specific language governing permissions and
|
13 |
# // limitations under the License.
|
14 |
+
import spaces
|
15 |
+
import subprocess
|
|
|
|
|
16 |
import os
|
|
|
|
|
17 |
import sys
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
|
19 |
# --- ETAPA 1: Preparação do Ambiente ---
|
20 |
# Clonar o repositório para garantir que todas as pastas de código (data, common, etc.) existam.
|
|
|
22 |
repo_dir_name = "SeedVR2-3B"
|
23 |
if not os.path.exists(repo_dir_name):
|
24 |
print(f"Clonando o repositório {repo_dir_name} para obter todo o código-fonte...")
|
|
|
25 |
subprocess.run(f"git clone --depth 1 https://huggingface.co/spaces/ByteDance-Seed/{repo_dir_name}", shell=True, check=True)
|
26 |
|
27 |
# --- ETAPA 2: Configuração dos Caminhos ---
|
|
|
33 |
print(f"Diretório atual adicionado ao sys.path para importações.")
|
34 |
|
35 |
# --- ETAPA 3: Instalação de Dependências e Download de Modelos ---
|
36 |
+
# ORDEM CORRETA: Instalar requisitos, DEPOIS compilar apex/flash-attn, DEPOIS baixar modelos.
|
37 |
+
|
38 |
+
python_executable = sys.executable
|
39 |
+
|
40 |
+
# **CORREÇÃO CRÍTICA: Instalar requisitos PRIMEIRO para ter o PyTorch disponível**
|
41 |
+
print("Instalando dependências a partir do requirements.txt (isso inclui o PyTorch)...")
|
42 |
+
subprocess.run([python_executable, "-m", "pip", "install", "-r", "requirements.txt"], check=True)
|
43 |
+
print("✅ Dependências básicas instaladas.")
|
44 |
+
|
45 |
+
|
46 |
+
# **Compilar dependências otimizadas para a GPU L40S (Ada Lovelace)**
|
47 |
+
print("Instalando flash-attn compilando do zero...")
|
48 |
+
subprocess.run([python_executable, "-m", "pip", "install", "--force-reinstall", "--no-cache-dir", "flash-attn"], check=True)
|
49 |
+
|
50 |
+
print("Clonando e compilando o Apex do zero (isso pode demorar um pouco)...")
|
51 |
+
if not os.path.exists("apex"):
|
52 |
+
subprocess.run("git clone https://github.com/NVIDIA/apex", shell=True, check=True)
|
53 |
+
|
54 |
+
# Instala o Apex a partir da fonte clonada, o que força a compilação para a GPU L40S
|
55 |
+
subprocess.run(
|
56 |
+
[python_executable, "-m", "pip", "install", "-v", "--disable-pip-version-check", "--no-cache-dir", "--global-option=--cpp_ext", "--global-option=--cuda_ext", "./apex"],
|
57 |
+
check=True
|
58 |
+
)
|
59 |
+
print("✅ Configuração do Apex concluída.")
|
60 |
|
61 |
+
|
62 |
+
# **Download dos modelos e dados**
|
63 |
import torch
|
64 |
from pathlib import Path
|
65 |
from urllib.parse import urlparse
|
66 |
from torch.hub import download_url_to_file, get_dir
|
67 |
|
|
|
68 |
def load_file_from_url(url, model_dir=None, progress=True, file_name=None):
|
69 |
if model_dir is None:
|
70 |
hub_dir = get_dir()
|
|
|
72 |
os.makedirs(model_dir, exist_ok=True)
|
73 |
parts = urlparse(url)
|
74 |
filename = os.path.basename(parts.path)
|
75 |
+
if file_name is not None: filename = file_name
|
|
|
76 |
cached_file = os.path.abspath(os.path.join(model_dir, filename))
|
77 |
if not os.path.exists(cached_file):
|
78 |
print(f'Baixando: "{url}" para {cached_file}\n')
|
79 |
download_url_to_file(url, cached_file, hash_prefix=None, progress=progress)
|
80 |
return cached_file
|
81 |
|
|
|
82 |
pretrain_model_url = {
|
83 |
'vae': 'https://huggingface.co/ByteDance-Seed/SeedVR2-3B/resolve/main/ema_vae.pth',
|
84 |
'dit': 'https://huggingface.co/ByteDance-Seed/SeedVR2-3B/resolve/main/seedvr2_ema_3b.pth',
|
|
|
86 |
'neg_emb': 'https://huggingface.co/ByteDance-Seed/SeedVR2-3B/resolve/main/neg_emb.pt',
|
87 |
}
|
88 |
|
89 |
+
ckpt_dir = Path('./ckpts'); ckpt_dir.mkdir(exist_ok=True)
|
|
|
|
|
90 |
for key, url in pretrain_model_url.items():
|
91 |
filename = os.path.basename(url)
|
92 |
model_dir = './ckpts' if key in ['vae', 'dit'] else '.'
|
|
|
94 |
if not os.path.exists(target_path):
|
95 |
load_file_from_url(url=url, model_dir=model_dir, progress=True, file_name=filename)
|
96 |
|
|
|
97 |
torch.hub.download_url_to_file('https://huggingface.co/datasets/Iceclear/SeedVR_VideoDemos/resolve/main/seedvr_videos_crf23/aigc1k/23_1_lq.mp4', '01.mp4')
|
98 |
torch.hub.download_url_to_file('https://huggingface.co/datasets/Iceclear/SeedVR_VideoDemos/resolve/main/seedvr_videos_crf23/aigc1k/28_1_lq.mp4', '02.mp4')
|
99 |
torch.hub.download_url_to_file('https://huggingface.co/datasets/Iceclear/SeedVR_VideoDemos/resolve/main/seedvr_videos_crf23/aigc1k/2_1_lq.mp4', '03.mp4')
|
100 |
+
print("✅ Modelos e dados de exemplo baixados.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
101 |
|
102 |
# --- ETAPA 4: Execução do Código Principal da Aplicação ---
|
103 |
import mediapy
|
|
|
129 |
os.environ["MASTER_PORT"] = "12355"
|
130 |
os.environ["RANK"] = str(0)
|
131 |
os.environ["WORLD_SIZE"] = str(1)
|
|
|
132 |
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
|
133 |
|
134 |
if os.path.exists("projects/video_diffusion_sr/color_fix.py"):
|
|
|
172 |
video_tensors = runner.inference(noises=noises, conditions=conditions, dit_offload=False, **text_embeds_dict)
|
173 |
return [rearrange(video, "c t h w -> t c h w") for video in video_tensors]
|
174 |
|
175 |
+
|
176 |
def generation_loop(video_path, seed=666, fps_out=24, batch_size=1, cfg_scale=1.0, cfg_rescale=0.0, sample_steps=1, res_h=1280, res_w=720, sp_size=1):
|
177 |
if video_path is None: return None, None, None
|
178 |
runner = configure_runner(1)
|
|
|
214 |
|
215 |
with gr.Blocks(title="SeedVR2: Restauração de Vídeo em Um Passo") as demo:
|
216 |
gr.HTML(f"""
|
217 |
+
<div style='text-align:center; margin-bottom: 10px;'>
|
218 |
+
<img src='file/{os.path.abspath("assets/seedvr_logo.png")}' style='height:40px;' alt='SeedVR logo'/>
|
219 |
+
</div>
|
220 |
<p><b>Demonstração oficial do Gradio</b> para
|
221 |
<a href='https://github.com/ByteDance-Seed/SeedVR' target='_blank'>
|
222 |
<b>SeedVR2: One-Step Video Restoration via Diffusion Adversarial Post-Training</b></a>.<br>
|