Spaces:
Paused
Paused
Update app.py
Browse files
app.py
CHANGED
@@ -16,23 +16,28 @@ import subprocess
|
|
16 |
import os
|
17 |
import sys
|
18 |
|
19 |
-
# --- Setup: Clone repository and
|
20 |
-
#
|
21 |
|
22 |
-
# 1. Clone
|
23 |
-
subprocess.run("git lfs install", shell=True, check=True)
|
24 |
repo_dir_name = "SeedVR2-3B"
|
25 |
if not os.path.exists(repo_dir_name):
|
26 |
-
print(f"
|
27 |
subprocess.run(f"git clone https://huggingface.co/spaces/ByteDance-Seed/{repo_dir_name}", shell=True, check=True)
|
28 |
|
29 |
-
# 2.
|
30 |
-
#
|
31 |
os.chdir(repo_dir_name)
|
32 |
-
print(f"
|
33 |
|
34 |
-
#
|
35 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
36 |
|
37 |
import torch
|
38 |
import mediapy
|
@@ -51,7 +56,7 @@ import torchvision.transforms as T
|
|
51 |
from torchvision.transforms import Compose, Lambda, Normalize
|
52 |
from torchvision.io.video import read_video
|
53 |
|
54 |
-
#
|
55 |
from data.image.transforms.divisible_crop import DivisibleCrop
|
56 |
from data.image.transforms.na_resize import NaResize
|
57 |
from data.video.transforms.rearrange import Rearrange
|
@@ -63,21 +68,21 @@ from common.partition import partition_by_size
|
|
63 |
from projects.video_diffusion_sr.infer import VideoDiffusionInfer
|
64 |
from common.distributed.ops import sync_data
|
65 |
|
66 |
-
#
|
67 |
if os.path.exists("projects/video_diffusion_sr/color_fix.py"):
|
68 |
from projects.video_diffusion_sr.color_fix import wavelet_reconstruction
|
69 |
use_colorfix = True
|
70 |
else:
|
71 |
use_colorfix = False
|
72 |
-
print('
|
73 |
|
74 |
-
# ---
|
75 |
os.environ["MASTER_ADDR"] = "127.0.0.1"
|
76 |
os.environ["MASTER_PORT"] = "12355"
|
77 |
os.environ["RANK"] = str(0)
|
78 |
os.environ["WORLD_SIZE"] = str(1)
|
79 |
|
80 |
-
# Use sys.executable
|
81 |
python_executable = sys.executable
|
82 |
subprocess.run(
|
83 |
[python_executable, "-m", "pip", "install", "flash-attn", "--no-build-isolation"],
|
@@ -88,20 +93,20 @@ subprocess.run(
|
|
88 |
apex_wheel_path = "apex-0.1-cp310-cp310-linux_x86_64.whl"
|
89 |
if os.path.exists(apex_wheel_path):
|
90 |
subprocess.run([python_executable, "-m", "pip", "install", apex_wheel_path], check=True)
|
91 |
-
print("✅ Apex
|
92 |
|
93 |
-
# ---
|
94 |
|
95 |
def configure_sequence_parallel(sp_size):
|
96 |
if sp_size > 1:
|
97 |
init_sequence_parallel(sp_size)
|
98 |
|
99 |
def configure_runner(sp_size):
|
100 |
-
#
|
101 |
config_path = 'configs_3b/main.yaml'
|
102 |
checkpoint_path = 'ckpts/seedvr2_ema_3b.pth'
|
103 |
|
104 |
-
config = load_config(config_path) #
|
105 |
runner = VideoDiffusionInfer(config)
|
106 |
OmegaConf.set_readonly(runner.config, False)
|
107 |
|
@@ -120,7 +125,7 @@ def generation_step(runner, text_embeds_dict, cond_latents):
|
|
120 |
|
121 |
noises = [torch.randn_like(latent) for latent in cond_latents]
|
122 |
aug_noises = [torch.randn_like(latent) for latent in cond_latents]
|
123 |
-
print(f"
|
124 |
noises, aug_noises, cond_latents = sync_data((noises, aug_noises, cond_latents), 0)
|
125 |
noises, aug_noises, cond_latents = list(map(_move_to_cuda, (noises, aug_noises, cond_latents)))
|
126 |
cond_noise_scale = 0.1
|
@@ -129,7 +134,7 @@ def generation_step(runner, text_embeds_dict, cond_latents):
|
|
129 |
t = torch.tensor([1000.0], device=torch.device("cuda")) * cond_noise_scale
|
130 |
shape = torch.tensor(x.shape[1:], device=torch.device("cuda"))[None]
|
131 |
t = runner.timestep_transform(t, shape)
|
132 |
-
print(f"Timestep
|
133 |
x = runner.schedule.forward(x, aug_noise, t)
|
134 |
return x
|
135 |
|
@@ -157,7 +162,7 @@ def generation_loop(video_path, seed=666, fps_out=24, batch_size=1, cfg_scale=1.
|
|
157 |
def _extract_text_embeds():
|
158 |
positive_prompts_embeds = []
|
159 |
for _ in original_videos_local:
|
160 |
-
#
|
161 |
text_pos_embeds = torch.load('pos_emb.pt')
|
162 |
text_neg_embeds = torch.load('neg_emb.pt')
|
163 |
positive_prompts_embeds.append({"texts_pos": [text_pos_embeds], "texts_neg": [text_neg_embeds]})
|
@@ -218,16 +223,16 @@ def generation_loop(video_path, seed=666, fps_out=24, batch_size=1, cfg_scale=1.
|
|
218 |
video = video / 255.0
|
219 |
if video.size(0) > 121:
|
220 |
video = video[:121]
|
221 |
-
print(f"
|
222 |
output_dir = os.path.join(output_base_dir, f"{uuid.uuid4()}.mp4")
|
223 |
elif is_image:
|
224 |
img = Image.open(video_path).convert("RGB")
|
225 |
img_tensor = T.ToTensor()(img).unsqueeze(0)
|
226 |
video = img_tensor
|
227 |
-
print(f"
|
228 |
output_dir = os.path.join(output_base_dir, f"{uuid.uuid4()}.png")
|
229 |
else:
|
230 |
-
raise ValueError("
|
231 |
|
232 |
cond_latents.append(video_transform(video.to(torch.device("cuda"))))
|
233 |
|
@@ -236,7 +241,7 @@ def generation_loop(video_path, seed=666, fps_out=24, batch_size=1, cfg_scale=1.
|
|
236 |
if is_video:
|
237 |
cond_latents = [cut_videos(v, sp_size) for v in cond_latents]
|
238 |
|
239 |
-
print(f"
|
240 |
cond_latents = runner.vae_encode(cond_latents)
|
241 |
|
242 |
for i, emb in enumerate(text_embeds["texts_pos"]):
|
@@ -273,41 +278,43 @@ def generation_loop(video_path, seed=666, fps_out=24, batch_size=1, cfg_scale=1.
|
|
273 |
else:
|
274 |
return None, output_dir, output_dir
|
275 |
|
276 |
-
# --- Gradio
|
277 |
|
278 |
-
with gr.Blocks(title="SeedVR2:
|
279 |
-
# Use
|
280 |
logo_path = os.path.abspath("assets/seedvr_logo.png")
|
281 |
gr.HTML(f"""
|
282 |
-
|
283 |
-
|
284 |
-
|
|
|
|
|
285 |
""")
|
286 |
|
287 |
with gr.Row():
|
288 |
-
input_file = gr.File(label="
|
289 |
with gr.Column():
|
290 |
seed = gr.Number(label="Seed", value=666)
|
291 |
-
fps = gr.Number(label="
|
292 |
|
293 |
-
run_button = gr.Button("
|
294 |
|
295 |
with gr.Row():
|
296 |
-
output_image = gr.Image(label="
|
297 |
-
output_video = gr.Video(label="
|
298 |
|
299 |
-
download_link = gr.File(label="
|
300 |
|
301 |
run_button.click(fn=generation_loop, inputs=[input_file, seed, fps], outputs=[output_image, output_video, download_link])
|
302 |
|
303 |
gr.HTML("""
|
304 |
<hr>
|
305 |
-
<p>
|
306 |
<a href="https://github.com/ByteDance-Seed/SeedVR" target="_blank"><img src="https://img.shields.io/github/stars/ByteDance-Seed/SeedVR?style=social" alt="GitHub Stars"></a></p>
|
307 |
-
<h4>
|
308 |
-
<p>
|
309 |
-
<h4>
|
310 |
-
<p>
|
311 |
""")
|
312 |
|
313 |
demo.queue().launch(share=True)
|
|
|
16 |
import os
|
17 |
import sys
|
18 |
|
19 |
+
# --- Setup: Clone repository, Change Directory, and Update Python Path ---
|
20 |
+
# Esta é a abordagem definitiva para corrigir todos os problemas de caminho.
|
21 |
|
22 |
+
# 1. Clone o repositório
|
|
|
23 |
repo_dir_name = "SeedVR2-3B"
|
24 |
if not os.path.exists(repo_dir_name):
|
25 |
+
print(f"Clonando o repositório {repo_dir_name}...")
|
26 |
subprocess.run(f"git clone https://huggingface.co/spaces/ByteDance-Seed/{repo_dir_name}", shell=True, check=True)
|
27 |
|
28 |
+
# 2. Mude o diretório de trabalho atual para a raiz do repositório.
|
29 |
+
# Isso corrige o acesso a arquivos relativos (ex: carregar config.yaml).
|
30 |
os.chdir(repo_dir_name)
|
31 |
+
print(f"Diretório de trabalho alterado para: {os.getcwd()}")
|
32 |
|
33 |
+
# 3. Adicione explicitamente o novo diretório de trabalho ao caminho do sistema do Python.
|
34 |
+
# Isso corrige as importações de módulos (ex: `from data...`).
|
35 |
+
sys.path.insert(0, os.path.abspath('.'))
|
36 |
+
print(f"Diretório atual adicionado ao sys.path: {os.path.abspath('.')}")
|
37 |
+
|
38 |
+
|
39 |
+
# --- Código Principal da Aplicação ---
|
40 |
+
# Agora, todas as importações e cargas de arquivos devem funcionar corretamente.
|
41 |
|
42 |
import torch
|
43 |
import mediapy
|
|
|
56 |
from torchvision.transforms import Compose, Lambda, Normalize
|
57 |
from torchvision.io.video import read_video
|
58 |
|
59 |
+
# Importações do repositório (agora funcionarão)
|
60 |
from data.image.transforms.divisible_crop import DivisibleCrop
|
61 |
from data.image.transforms.na_resize import NaResize
|
62 |
from data.video.transforms.rearrange import Rearrange
|
|
|
68 |
from projects.video_diffusion_sr.infer import VideoDiffusionInfer
|
69 |
from common.distributed.ops import sync_data
|
70 |
|
71 |
+
# Verifica o utilitário color_fix (usando caminho relativo)
|
72 |
if os.path.exists("projects/video_diffusion_sr/color_fix.py"):
|
73 |
from projects.video_diffusion_sr.color_fix import wavelet_reconstruction
|
74 |
use_colorfix = True
|
75 |
else:
|
76 |
use_colorfix = False
|
77 |
+
print('Atenção!!!!!! A correção de cor não está disponível!')
|
78 |
|
79 |
+
# --- Configuração de Ambiente e Dependências ---
|
80 |
os.environ["MASTER_ADDR"] = "127.0.0.1"
|
81 |
os.environ["MASTER_PORT"] = "12355"
|
82 |
os.environ["RANK"] = str(0)
|
83 |
os.environ["WORLD_SIZE"] = str(1)
|
84 |
|
85 |
+
# Use sys.executable para garantir que estamos usando o pip correto
|
86 |
python_executable = sys.executable
|
87 |
subprocess.run(
|
88 |
[python_executable, "-m", "pip", "install", "flash-attn", "--no-build-isolation"],
|
|
|
93 |
apex_wheel_path = "apex-0.1-cp310-cp310-linux_x86_64.whl"
|
94 |
if os.path.exists(apex_wheel_path):
|
95 |
subprocess.run([python_executable, "-m", "pip", "install", apex_wheel_path], check=True)
|
96 |
+
print("✅ Configuração do Apex concluída.")
|
97 |
|
98 |
+
# --- Funções Principais ---
|
99 |
|
100 |
def configure_sequence_parallel(sp_size):
|
101 |
if sp_size > 1:
|
102 |
init_sequence_parallel(sp_size)
|
103 |
|
104 |
def configure_runner(sp_size):
|
105 |
+
# Os caminhos agora são simples e relativos à raiz do repositório
|
106 |
config_path = 'configs_3b/main.yaml'
|
107 |
checkpoint_path = 'ckpts/seedvr2_ema_3b.pth'
|
108 |
|
109 |
+
config = load_config(config_path) # Isto agora funcionará corretamente
|
110 |
runner = VideoDiffusionInfer(config)
|
111 |
OmegaConf.set_readonly(runner.config, False)
|
112 |
|
|
|
125 |
|
126 |
noises = [torch.randn_like(latent) for latent in cond_latents]
|
127 |
aug_noises = [torch.randn_like(latent) for latent in cond_latents]
|
128 |
+
print(f"Gerando com o formato de ruído: {noises[0].size()}.")
|
129 |
noises, aug_noises, cond_latents = sync_data((noises, aug_noises, cond_latents), 0)
|
130 |
noises, aug_noises, cond_latents = list(map(_move_to_cuda, (noises, aug_noises, cond_latents)))
|
131 |
cond_noise_scale = 0.1
|
|
|
134 |
t = torch.tensor([1000.0], device=torch.device("cuda")) * cond_noise_scale
|
135 |
shape = torch.tensor(x.shape[1:], device=torch.device("cuda"))[None]
|
136 |
t = runner.timestep_transform(t, shape)
|
137 |
+
print(f"Deslocamento de Timestep de {1000.0 * cond_noise_scale} para {t}.")
|
138 |
x = runner.schedule.forward(x, aug_noise, t)
|
139 |
return x
|
140 |
|
|
|
162 |
def _extract_text_embeds():
|
163 |
positive_prompts_embeds = []
|
164 |
for _ in original_videos_local:
|
165 |
+
# Os caminhos agora são simples
|
166 |
text_pos_embeds = torch.load('pos_emb.pt')
|
167 |
text_neg_embeds = torch.load('neg_emb.pt')
|
168 |
positive_prompts_embeds.append({"texts_pos": [text_pos_embeds], "texts_neg": [text_neg_embeds]})
|
|
|
223 |
video = video / 255.0
|
224 |
if video.size(0) > 121:
|
225 |
video = video[:121]
|
226 |
+
print(f"Tamanho do vídeo lido: {video.size()}")
|
227 |
output_dir = os.path.join(output_base_dir, f"{uuid.uuid4()}.mp4")
|
228 |
elif is_image:
|
229 |
img = Image.open(video_path).convert("RGB")
|
230 |
img_tensor = T.ToTensor()(img).unsqueeze(0)
|
231 |
video = img_tensor
|
232 |
+
print(f"Tamanho da imagem lida: {video.size()}")
|
233 |
output_dir = os.path.join(output_base_dir, f"{uuid.uuid4()}.png")
|
234 |
else:
|
235 |
+
raise ValueError("Tipo de arquivo não suportado")
|
236 |
|
237 |
cond_latents.append(video_transform(video.to(torch.device("cuda"))))
|
238 |
|
|
|
241 |
if is_video:
|
242 |
cond_latents = [cut_videos(v, sp_size) for v in cond_latents]
|
243 |
|
244 |
+
print(f"Codificando vídeos: {[v.size() for v in cond_latents]}")
|
245 |
cond_latents = runner.vae_encode(cond_latents)
|
246 |
|
247 |
for i, emb in enumerate(text_embeds["texts_pos"]):
|
|
|
278 |
else:
|
279 |
return None, output_dir, output_dir
|
280 |
|
281 |
+
# --- UI do Gradio ---
|
282 |
|
283 |
+
with gr.Blocks(title="SeedVR2: Restauração de Vídeo em Um Passo") as demo:
|
284 |
+
# Use um caminho absoluto para o arquivo de logo do Gradio para segurança
|
285 |
logo_path = os.path.abspath("assets/seedvr_logo.png")
|
286 |
gr.HTML(f"""
|
287 |
+
<div style='text-align:center; margin-bottom: 10px;'>
|
288 |
+
<img src='file/{logo_path}' style='height:40px;' alt='SeedVR logo'/>
|
289 |
+
</div>
|
290 |
+
<p><b>Demonstração oficial do Gradio</b> para <a href='https://github.com/ByteDance-Seed/SeedVR' target='_blank'><b>SeedVR2: One-Step Video Restoration via Diffusion Adversarial Post-Training</b></a>.<br>
|
291 |
+
🔥 <b>SeedVR2</b> é um algoritmo de restauração de imagem e vídeo em um passo para conteúdo do mundo real e AIGC.</p>
|
292 |
""")
|
293 |
|
294 |
with gr.Row():
|
295 |
+
input_file = gr.File(label="Carregar imagem ou vídeo", type="filepath")
|
296 |
with gr.Column():
|
297 |
seed = gr.Number(label="Seed", value=666)
|
298 |
+
fps = gr.Number(label="FPS de Saída (para vídeo)", value=24)
|
299 |
|
300 |
+
run_button = gr.Button("Executar")
|
301 |
|
302 |
with gr.Row():
|
303 |
+
output_image = gr.Image(label="Imagem de Saída")
|
304 |
+
output_video = gr.Video(label="Vídeo de Saída")
|
305 |
|
306 |
+
download_link = gr.File(label="Baixar o resultado")
|
307 |
|
308 |
run_button.click(fn=generation_loop, inputs=[input_file, seed, fps], outputs=[output_image, output_video, download_link])
|
309 |
|
310 |
gr.HTML("""
|
311 |
<hr>
|
312 |
+
<p>Se você achou o SeedVR útil, por favor ⭐ o <a href='https://github.com/ByteDance-Seed/SeedVR' target='_blank'>repositório no GitHub</a>:
|
313 |
<a href="https://github.com/ByteDance-Seed/SeedVR" target="_blank"><img src="https://img.shields.io/github/stars/ByteDance-Seed/SeedVR?style=social" alt="GitHub Stars"></a></p>
|
314 |
+
<h4>Aviso</h4>
|
315 |
+
<p>Esta demonstração suporta até <b>720p e 121 frames para vídeos ou imagens 2k</b>. Para outros casos de uso, verifique o <a href='https://github.com/ByteDance-Seed/SeedVR' target='_blank'>repositório no GitHub</a>.</p>
|
316 |
+
<h4>Limitações</h4>
|
317 |
+
<p>Pode falhar em degradações pesadas ou em clipes AIGC com pouco movimento, causando excesso de nitidez ou restauração inadequada.</p>
|
318 |
""")
|
319 |
|
320 |
demo.queue().launch(share=True)
|