Spaces:
Paused
Paused
Update app.py
Browse files
app.py
CHANGED
|
@@ -16,23 +16,28 @@ import subprocess
|
|
| 16 |
import os
|
| 17 |
import sys
|
| 18 |
|
| 19 |
-
# --- Setup: Clone repository and
|
| 20 |
-
#
|
| 21 |
|
| 22 |
-
# 1. Clone
|
| 23 |
-
subprocess.run("git lfs install", shell=True, check=True)
|
| 24 |
repo_dir_name = "SeedVR2-3B"
|
| 25 |
if not os.path.exists(repo_dir_name):
|
| 26 |
-
print(f"
|
| 27 |
subprocess.run(f"git clone https://huggingface.co/spaces/ByteDance-Seed/{repo_dir_name}", shell=True, check=True)
|
| 28 |
|
| 29 |
-
# 2.
|
| 30 |
-
#
|
| 31 |
os.chdir(repo_dir_name)
|
| 32 |
-
print(f"
|
| 33 |
|
| 34 |
-
#
|
| 35 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
|
| 37 |
import torch
|
| 38 |
import mediapy
|
|
@@ -51,7 +56,7 @@ import torchvision.transforms as T
|
|
| 51 |
from torchvision.transforms import Compose, Lambda, Normalize
|
| 52 |
from torchvision.io.video import read_video
|
| 53 |
|
| 54 |
-
#
|
| 55 |
from data.image.transforms.divisible_crop import DivisibleCrop
|
| 56 |
from data.image.transforms.na_resize import NaResize
|
| 57 |
from data.video.transforms.rearrange import Rearrange
|
|
@@ -63,21 +68,21 @@ from common.partition import partition_by_size
|
|
| 63 |
from projects.video_diffusion_sr.infer import VideoDiffusionInfer
|
| 64 |
from common.distributed.ops import sync_data
|
| 65 |
|
| 66 |
-
#
|
| 67 |
if os.path.exists("projects/video_diffusion_sr/color_fix.py"):
|
| 68 |
from projects.video_diffusion_sr.color_fix import wavelet_reconstruction
|
| 69 |
use_colorfix = True
|
| 70 |
else:
|
| 71 |
use_colorfix = False
|
| 72 |
-
print('
|
| 73 |
|
| 74 |
-
# ---
|
| 75 |
os.environ["MASTER_ADDR"] = "127.0.0.1"
|
| 76 |
os.environ["MASTER_PORT"] = "12355"
|
| 77 |
os.environ["RANK"] = str(0)
|
| 78 |
os.environ["WORLD_SIZE"] = str(1)
|
| 79 |
|
| 80 |
-
# Use sys.executable
|
| 81 |
python_executable = sys.executable
|
| 82 |
subprocess.run(
|
| 83 |
[python_executable, "-m", "pip", "install", "flash-attn", "--no-build-isolation"],
|
|
@@ -88,20 +93,20 @@ subprocess.run(
|
|
| 88 |
apex_wheel_path = "apex-0.1-cp310-cp310-linux_x86_64.whl"
|
| 89 |
if os.path.exists(apex_wheel_path):
|
| 90 |
subprocess.run([python_executable, "-m", "pip", "install", apex_wheel_path], check=True)
|
| 91 |
-
print("✅ Apex
|
| 92 |
|
| 93 |
-
# ---
|
| 94 |
|
| 95 |
def configure_sequence_parallel(sp_size):
|
| 96 |
if sp_size > 1:
|
| 97 |
init_sequence_parallel(sp_size)
|
| 98 |
|
| 99 |
def configure_runner(sp_size):
|
| 100 |
-
#
|
| 101 |
config_path = 'configs_3b/main.yaml'
|
| 102 |
checkpoint_path = 'ckpts/seedvr2_ema_3b.pth'
|
| 103 |
|
| 104 |
-
config = load_config(config_path) #
|
| 105 |
runner = VideoDiffusionInfer(config)
|
| 106 |
OmegaConf.set_readonly(runner.config, False)
|
| 107 |
|
|
@@ -120,7 +125,7 @@ def generation_step(runner, text_embeds_dict, cond_latents):
|
|
| 120 |
|
| 121 |
noises = [torch.randn_like(latent) for latent in cond_latents]
|
| 122 |
aug_noises = [torch.randn_like(latent) for latent in cond_latents]
|
| 123 |
-
print(f"
|
| 124 |
noises, aug_noises, cond_latents = sync_data((noises, aug_noises, cond_latents), 0)
|
| 125 |
noises, aug_noises, cond_latents = list(map(_move_to_cuda, (noises, aug_noises, cond_latents)))
|
| 126 |
cond_noise_scale = 0.1
|
|
@@ -129,7 +134,7 @@ def generation_step(runner, text_embeds_dict, cond_latents):
|
|
| 129 |
t = torch.tensor([1000.0], device=torch.device("cuda")) * cond_noise_scale
|
| 130 |
shape = torch.tensor(x.shape[1:], device=torch.device("cuda"))[None]
|
| 131 |
t = runner.timestep_transform(t, shape)
|
| 132 |
-
print(f"Timestep
|
| 133 |
x = runner.schedule.forward(x, aug_noise, t)
|
| 134 |
return x
|
| 135 |
|
|
@@ -157,7 +162,7 @@ def generation_loop(video_path, seed=666, fps_out=24, batch_size=1, cfg_scale=1.
|
|
| 157 |
def _extract_text_embeds():
|
| 158 |
positive_prompts_embeds = []
|
| 159 |
for _ in original_videos_local:
|
| 160 |
-
#
|
| 161 |
text_pos_embeds = torch.load('pos_emb.pt')
|
| 162 |
text_neg_embeds = torch.load('neg_emb.pt')
|
| 163 |
positive_prompts_embeds.append({"texts_pos": [text_pos_embeds], "texts_neg": [text_neg_embeds]})
|
|
@@ -218,16 +223,16 @@ def generation_loop(video_path, seed=666, fps_out=24, batch_size=1, cfg_scale=1.
|
|
| 218 |
video = video / 255.0
|
| 219 |
if video.size(0) > 121:
|
| 220 |
video = video[:121]
|
| 221 |
-
print(f"
|
| 222 |
output_dir = os.path.join(output_base_dir, f"{uuid.uuid4()}.mp4")
|
| 223 |
elif is_image:
|
| 224 |
img = Image.open(video_path).convert("RGB")
|
| 225 |
img_tensor = T.ToTensor()(img).unsqueeze(0)
|
| 226 |
video = img_tensor
|
| 227 |
-
print(f"
|
| 228 |
output_dir = os.path.join(output_base_dir, f"{uuid.uuid4()}.png")
|
| 229 |
else:
|
| 230 |
-
raise ValueError("
|
| 231 |
|
| 232 |
cond_latents.append(video_transform(video.to(torch.device("cuda"))))
|
| 233 |
|
|
@@ -236,7 +241,7 @@ def generation_loop(video_path, seed=666, fps_out=24, batch_size=1, cfg_scale=1.
|
|
| 236 |
if is_video:
|
| 237 |
cond_latents = [cut_videos(v, sp_size) for v in cond_latents]
|
| 238 |
|
| 239 |
-
print(f"
|
| 240 |
cond_latents = runner.vae_encode(cond_latents)
|
| 241 |
|
| 242 |
for i, emb in enumerate(text_embeds["texts_pos"]):
|
|
@@ -273,41 +278,43 @@ def generation_loop(video_path, seed=666, fps_out=24, batch_size=1, cfg_scale=1.
|
|
| 273 |
else:
|
| 274 |
return None, output_dir, output_dir
|
| 275 |
|
| 276 |
-
# --- Gradio
|
| 277 |
|
| 278 |
-
with gr.Blocks(title="SeedVR2:
|
| 279 |
-
# Use
|
| 280 |
logo_path = os.path.abspath("assets/seedvr_logo.png")
|
| 281 |
gr.HTML(f"""
|
| 282 |
-
|
| 283 |
-
|
| 284 |
-
|
|
|
|
|
|
|
| 285 |
""")
|
| 286 |
|
| 287 |
with gr.Row():
|
| 288 |
-
input_file = gr.File(label="
|
| 289 |
with gr.Column():
|
| 290 |
seed = gr.Number(label="Seed", value=666)
|
| 291 |
-
fps = gr.Number(label="
|
| 292 |
|
| 293 |
-
run_button = gr.Button("
|
| 294 |
|
| 295 |
with gr.Row():
|
| 296 |
-
output_image = gr.Image(label="
|
| 297 |
-
output_video = gr.Video(label="
|
| 298 |
|
| 299 |
-
download_link = gr.File(label="
|
| 300 |
|
| 301 |
run_button.click(fn=generation_loop, inputs=[input_file, seed, fps], outputs=[output_image, output_video, download_link])
|
| 302 |
|
| 303 |
gr.HTML("""
|
| 304 |
<hr>
|
| 305 |
-
<p>
|
| 306 |
<a href="https://github.com/ByteDance-Seed/SeedVR" target="_blank"><img src="https://img.shields.io/github/stars/ByteDance-Seed/SeedVR?style=social" alt="GitHub Stars"></a></p>
|
| 307 |
-
<h4>
|
| 308 |
-
<p>
|
| 309 |
-
<h4>
|
| 310 |
-
<p>
|
| 311 |
""")
|
| 312 |
|
| 313 |
demo.queue().launch(share=True)
|
|
|
|
| 16 |
import os
|
| 17 |
import sys
|
| 18 |
|
| 19 |
+
# --- Setup: Clone repository, Change Directory, and Update Python Path ---
|
| 20 |
+
# Esta é a abordagem definitiva para corrigir todos os problemas de caminho.
|
| 21 |
|
| 22 |
+
# 1. Clone o repositório
|
|
|
|
| 23 |
repo_dir_name = "SeedVR2-3B"
|
| 24 |
if not os.path.exists(repo_dir_name):
|
| 25 |
+
print(f"Clonando o repositório {repo_dir_name}...")
|
| 26 |
subprocess.run(f"git clone https://huggingface.co/spaces/ByteDance-Seed/{repo_dir_name}", shell=True, check=True)
|
| 27 |
|
| 28 |
+
# 2. Mude o diretório de trabalho atual para a raiz do repositório.
|
| 29 |
+
# Isso corrige o acesso a arquivos relativos (ex: carregar config.yaml).
|
| 30 |
os.chdir(repo_dir_name)
|
| 31 |
+
print(f"Diretório de trabalho alterado para: {os.getcwd()}")
|
| 32 |
|
| 33 |
+
# 3. Adicione explicitamente o novo diretório de trabalho ao caminho do sistema do Python.
|
| 34 |
+
# Isso corrige as importações de módulos (ex: `from data...`).
|
| 35 |
+
sys.path.insert(0, os.path.abspath('.'))
|
| 36 |
+
print(f"Diretório atual adicionado ao sys.path: {os.path.abspath('.')}")
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
# --- Código Principal da Aplicação ---
|
| 40 |
+
# Agora, todas as importações e cargas de arquivos devem funcionar corretamente.
|
| 41 |
|
| 42 |
import torch
|
| 43 |
import mediapy
|
|
|
|
| 56 |
from torchvision.transforms import Compose, Lambda, Normalize
|
| 57 |
from torchvision.io.video import read_video
|
| 58 |
|
| 59 |
+
# Importações do repositório (agora funcionarão)
|
| 60 |
from data.image.transforms.divisible_crop import DivisibleCrop
|
| 61 |
from data.image.transforms.na_resize import NaResize
|
| 62 |
from data.video.transforms.rearrange import Rearrange
|
|
|
|
| 68 |
from projects.video_diffusion_sr.infer import VideoDiffusionInfer
|
| 69 |
from common.distributed.ops import sync_data
|
| 70 |
|
| 71 |
+
# Verifica o utilitário color_fix (usando caminho relativo)
|
| 72 |
if os.path.exists("projects/video_diffusion_sr/color_fix.py"):
|
| 73 |
from projects.video_diffusion_sr.color_fix import wavelet_reconstruction
|
| 74 |
use_colorfix = True
|
| 75 |
else:
|
| 76 |
use_colorfix = False
|
| 77 |
+
print('Atenção!!!!!! A correção de cor não está disponível!')
|
| 78 |
|
| 79 |
+
# --- Configuração de Ambiente e Dependências ---
|
| 80 |
os.environ["MASTER_ADDR"] = "127.0.0.1"
|
| 81 |
os.environ["MASTER_PORT"] = "12355"
|
| 82 |
os.environ["RANK"] = str(0)
|
| 83 |
os.environ["WORLD_SIZE"] = str(1)
|
| 84 |
|
| 85 |
+
# Use sys.executable para garantir que estamos usando o pip correto
|
| 86 |
python_executable = sys.executable
|
| 87 |
subprocess.run(
|
| 88 |
[python_executable, "-m", "pip", "install", "flash-attn", "--no-build-isolation"],
|
|
|
|
| 93 |
apex_wheel_path = "apex-0.1-cp310-cp310-linux_x86_64.whl"
|
| 94 |
if os.path.exists(apex_wheel_path):
|
| 95 |
subprocess.run([python_executable, "-m", "pip", "install", apex_wheel_path], check=True)
|
| 96 |
+
print("✅ Configuração do Apex concluída.")
|
| 97 |
|
| 98 |
+
# --- Funções Principais ---
|
| 99 |
|
| 100 |
def configure_sequence_parallel(sp_size):
|
| 101 |
if sp_size > 1:
|
| 102 |
init_sequence_parallel(sp_size)
|
| 103 |
|
| 104 |
def configure_runner(sp_size):
|
| 105 |
+
# Os caminhos agora são simples e relativos à raiz do repositório
|
| 106 |
config_path = 'configs_3b/main.yaml'
|
| 107 |
checkpoint_path = 'ckpts/seedvr2_ema_3b.pth'
|
| 108 |
|
| 109 |
+
config = load_config(config_path) # Isto agora funcionará corretamente
|
| 110 |
runner = VideoDiffusionInfer(config)
|
| 111 |
OmegaConf.set_readonly(runner.config, False)
|
| 112 |
|
|
|
|
| 125 |
|
| 126 |
noises = [torch.randn_like(latent) for latent in cond_latents]
|
| 127 |
aug_noises = [torch.randn_like(latent) for latent in cond_latents]
|
| 128 |
+
print(f"Gerando com o formato de ruído: {noises[0].size()}.")
|
| 129 |
noises, aug_noises, cond_latents = sync_data((noises, aug_noises, cond_latents), 0)
|
| 130 |
noises, aug_noises, cond_latents = list(map(_move_to_cuda, (noises, aug_noises, cond_latents)))
|
| 131 |
cond_noise_scale = 0.1
|
|
|
|
| 134 |
t = torch.tensor([1000.0], device=torch.device("cuda")) * cond_noise_scale
|
| 135 |
shape = torch.tensor(x.shape[1:], device=torch.device("cuda"))[None]
|
| 136 |
t = runner.timestep_transform(t, shape)
|
| 137 |
+
print(f"Deslocamento de Timestep de {1000.0 * cond_noise_scale} para {t}.")
|
| 138 |
x = runner.schedule.forward(x, aug_noise, t)
|
| 139 |
return x
|
| 140 |
|
|
|
|
| 162 |
def _extract_text_embeds():
|
| 163 |
positive_prompts_embeds = []
|
| 164 |
for _ in original_videos_local:
|
| 165 |
+
# Os caminhos agora são simples
|
| 166 |
text_pos_embeds = torch.load('pos_emb.pt')
|
| 167 |
text_neg_embeds = torch.load('neg_emb.pt')
|
| 168 |
positive_prompts_embeds.append({"texts_pos": [text_pos_embeds], "texts_neg": [text_neg_embeds]})
|
|
|
|
| 223 |
video = video / 255.0
|
| 224 |
if video.size(0) > 121:
|
| 225 |
video = video[:121]
|
| 226 |
+
print(f"Tamanho do vídeo lido: {video.size()}")
|
| 227 |
output_dir = os.path.join(output_base_dir, f"{uuid.uuid4()}.mp4")
|
| 228 |
elif is_image:
|
| 229 |
img = Image.open(video_path).convert("RGB")
|
| 230 |
img_tensor = T.ToTensor()(img).unsqueeze(0)
|
| 231 |
video = img_tensor
|
| 232 |
+
print(f"Tamanho da imagem lida: {video.size()}")
|
| 233 |
output_dir = os.path.join(output_base_dir, f"{uuid.uuid4()}.png")
|
| 234 |
else:
|
| 235 |
+
raise ValueError("Tipo de arquivo não suportado")
|
| 236 |
|
| 237 |
cond_latents.append(video_transform(video.to(torch.device("cuda"))))
|
| 238 |
|
|
|
|
| 241 |
if is_video:
|
| 242 |
cond_latents = [cut_videos(v, sp_size) for v in cond_latents]
|
| 243 |
|
| 244 |
+
print(f"Codificando vídeos: {[v.size() for v in cond_latents]}")
|
| 245 |
cond_latents = runner.vae_encode(cond_latents)
|
| 246 |
|
| 247 |
for i, emb in enumerate(text_embeds["texts_pos"]):
|
|
|
|
| 278 |
else:
|
| 279 |
return None, output_dir, output_dir
|
| 280 |
|
| 281 |
+
# --- UI do Gradio ---
|
| 282 |
|
| 283 |
+
with gr.Blocks(title="SeedVR2: Restauração de Vídeo em Um Passo") as demo:
|
| 284 |
+
# Use um caminho absoluto para o arquivo de logo do Gradio para segurança
|
| 285 |
logo_path = os.path.abspath("assets/seedvr_logo.png")
|
| 286 |
gr.HTML(f"""
|
| 287 |
+
<div style='text-align:center; margin-bottom: 10px;'>
|
| 288 |
+
<img src='file/{logo_path}' style='height:40px;' alt='SeedVR logo'/>
|
| 289 |
+
</div>
|
| 290 |
+
<p><b>Demonstração oficial do Gradio</b> para <a href='https://github.com/ByteDance-Seed/SeedVR' target='_blank'><b>SeedVR2: One-Step Video Restoration via Diffusion Adversarial Post-Training</b></a>.<br>
|
| 291 |
+
🔥 <b>SeedVR2</b> é um algoritmo de restauração de imagem e vídeo em um passo para conteúdo do mundo real e AIGC.</p>
|
| 292 |
""")
|
| 293 |
|
| 294 |
with gr.Row():
|
| 295 |
+
input_file = gr.File(label="Carregar imagem ou vídeo", type="filepath")
|
| 296 |
with gr.Column():
|
| 297 |
seed = gr.Number(label="Seed", value=666)
|
| 298 |
+
fps = gr.Number(label="FPS de Saída (para vídeo)", value=24)
|
| 299 |
|
| 300 |
+
run_button = gr.Button("Executar")
|
| 301 |
|
| 302 |
with gr.Row():
|
| 303 |
+
output_image = gr.Image(label="Imagem de Saída")
|
| 304 |
+
output_video = gr.Video(label="Vídeo de Saída")
|
| 305 |
|
| 306 |
+
download_link = gr.File(label="Baixar o resultado")
|
| 307 |
|
| 308 |
run_button.click(fn=generation_loop, inputs=[input_file, seed, fps], outputs=[output_image, output_video, download_link])
|
| 309 |
|
| 310 |
gr.HTML("""
|
| 311 |
<hr>
|
| 312 |
+
<p>Se você achou o SeedVR útil, por favor ⭐ o <a href='https://github.com/ByteDance-Seed/SeedVR' target='_blank'>repositório no GitHub</a>:
|
| 313 |
<a href="https://github.com/ByteDance-Seed/SeedVR" target="_blank"><img src="https://img.shields.io/github/stars/ByteDance-Seed/SeedVR?style=social" alt="GitHub Stars"></a></p>
|
| 314 |
+
<h4>Aviso</h4>
|
| 315 |
+
<p>Esta demonstração suporta até <b>720p e 121 frames para vídeos ou imagens 2k</b>. Para outros casos de uso, verifique o <a href='https://github.com/ByteDance-Seed/SeedVR' target='_blank'>repositório no GitHub</a>.</p>
|
| 316 |
+
<h4>Limitações</h4>
|
| 317 |
+
<p>Pode falhar em degradações pesadas ou em clipes AIGC com pouco movimento, causando excesso de nitidez ou restauração inadequada.</p>
|
| 318 |
""")
|
| 319 |
|
| 320 |
demo.queue().launch(share=True)
|