Spaces:
Paused
Paused
Update app.py
Browse files
app.py
CHANGED
@@ -2,7 +2,7 @@
|
|
2 |
# //
|
3 |
# // Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
# // you may not use this file except in compliance with the License.
|
5 |
-
# // You may
|
6 |
# //
|
7 |
# // http://www.apache.org/licenses/LICENSE-2.0
|
8 |
# //
|
@@ -11,22 +11,12 @@
|
|
11 |
# // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
# // See the License for the specific language governing permissions and
|
13 |
# // limitations under the License.
|
14 |
-
import
|
|
|
15 |
import os
|
16 |
-
import gc
|
17 |
-
import logging
|
18 |
import sys
|
19 |
-
import subprocess
|
20 |
-
from pathlib import Path
|
21 |
-
from urllib.parse import urlparse
|
22 |
-
from torch.hub import download_url_to_file
|
23 |
-
import gradio as gr
|
24 |
-
import mediapy
|
25 |
-
from einops import rearrange
|
26 |
-
import shutil
|
27 |
-
from omegaconf import OmegaConf
|
28 |
|
29 |
-
# --- ETAPA 1: Clonar o Repositório
|
30 |
repo_name = "SeedVR"
|
31 |
if not os.path.exists(repo_name):
|
32 |
print(f"Clonando o repositório {repo_name} do GitHub...")
|
@@ -36,14 +26,22 @@ if not os.path.exists(repo_name):
|
|
36 |
os.chdir(repo_name)
|
37 |
print(f"Diretório de trabalho alterado para: {os.getcwd()}")
|
38 |
|
39 |
-
# Adicionar o diretório ao path do Python para que as importações funcionem
|
40 |
sys.path.insert(0, os.path.abspath('.'))
|
41 |
print(f"Diretório atual adicionado ao sys.path.")
|
42 |
|
43 |
-
# --- ETAPA 3: Instalar Dependências
|
44 |
python_executable = sys.executable
|
45 |
-
|
46 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
47 |
|
48 |
print("Instalando flash-attn...")
|
49 |
subprocess.run([python_executable, "-m", "pip", "install", "flash-attn==2.5.9.post1", "--no-build-isolation"], check=True)
|
@@ -52,7 +50,6 @@ from pathlib import Path
|
|
52 |
from urllib.parse import urlparse
|
53 |
from torch.hub import download_url_to_file, get_dir
|
54 |
|
55 |
-
# Função auxiliar para downloads
|
56 |
def load_file_from_url(url, model_dir='.', progress=True, file_name=None):
|
57 |
os.makedirs(model_dir, exist_ok=True)
|
58 |
if not file_name:
|
@@ -64,7 +61,6 @@ def load_file_from_url(url, model_dir='.', progress=True, file_name=None):
|
|
64 |
download_url_to_file(url, cached_file, hash_prefix=None, progress=progress)
|
65 |
return cached_file
|
66 |
|
67 |
-
# Baixar e instalar Apex pré-compilado (crucial para o ambiente do Spaces)
|
68 |
apex_url = 'https://huggingface.co/ByteDance-Seed/SeedVR2-3B/resolve/main/apex-0.1-cp310-cp310-linux_x86_64.whl'
|
69 |
apex_wheel_path = load_file_from_url(url=apex_url)
|
70 |
print("Instalando Apex a partir do wheel baixado...")
|
@@ -73,6 +69,8 @@ print("✅ Configuração do Apex concluída.")
|
|
73 |
|
74 |
# --- ETAPA 4: Baixar os Modelos Pré-treinados ---
|
75 |
print("Baixando modelos pré-treinados...")
|
|
|
|
|
76 |
pretrain_model_url = {
|
77 |
'vae': 'https://huggingface.co/ByteDance-Seed/SeedVR2-3B/resolve/main/ema_vae.pth',
|
78 |
'dit': 'https://huggingface.co/ByteDance-Seed/SeedVR2-3B/resolve/main/seedvr2_ema_3b.pth',
|
@@ -85,8 +83,8 @@ for key, url in pretrain_model_url.items():
|
|
85 |
model_dir = './ckpts' if key in ['vae', 'dit'] else '.'
|
86 |
load_file_from_url(url=url, model_dir=model_dir)
|
87 |
|
|
|
88 |
# --- ETAPA 5: Executar a Aplicação Principal ---
|
89 |
-
import torch
|
90 |
import mediapy
|
91 |
from einops import rearrange
|
92 |
from omegaconf import OmegaConf
|
@@ -112,16 +110,20 @@ from common.partition import partition_by_size
|
|
112 |
from projects.video_diffusion_sr.infer import VideoDiffusionInfer
|
113 |
from common.distributed.ops import sync_data
|
114 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
115 |
os.environ["MASTER_ADDR"] = "127.0.0.1"
|
116 |
os.environ["MASTER_PORT"] = "12355"
|
117 |
os.environ["RANK"] = str(0)
|
118 |
os.environ["WORLD_SIZE"] = str(1)
|
119 |
|
120 |
-
|
121 |
-
from projects.video_diffusion_sr.color_fix import wavelet_reconstruction
|
122 |
-
use_colorfix = True
|
123 |
-
else:
|
124 |
-
use_colorfix = False
|
125 |
|
126 |
def configure_runner():
|
127 |
config = load_config('configs_3b/main.yaml')
|
@@ -136,10 +138,9 @@ def configure_runner():
|
|
136 |
|
137 |
def generation_step(runner, text_embeds_dict, cond_latents):
|
138 |
def _move_to_cuda(x): return [i.to("cuda") for i in x]
|
139 |
-
noises = [torch.randn_like(
|
140 |
-
aug_noises = [torch.randn_like(latent) for latent in cond_latents]
|
141 |
noises, aug_noises, cond_latents = sync_data((noises, aug_noises, cond_latents), 0)
|
142 |
-
noises, aug_noises, cond_latents =
|
143 |
def _add_noise(x, aug_noise):
|
144 |
t = torch.tensor([100.0], device="cuda")
|
145 |
shape = torch.tensor(x.shape[1:], device="cuda")[None]
|
@@ -158,9 +159,10 @@ def generation_loop(video_path, seed=666, fps_out=24):
|
|
158 |
runner.configure_diffusion()
|
159 |
set_seed(int(seed))
|
160 |
os.makedirs("output", exist_ok=True)
|
161 |
-
|
162 |
media_type, _ = mimetypes.guess_type(video_path)
|
163 |
is_video = media_type and media_type.startswith("video")
|
|
|
164 |
if is_video:
|
165 |
video, _, _ = read_video(video_path, output_format="TCHW")
|
166 |
video = video[:121] / 255.0
|
@@ -168,12 +170,14 @@ def generation_loop(video_path, seed=666, fps_out=24):
|
|
168 |
else:
|
169 |
video = T.ToTensor()(Image.open(video_path).convert("RGB")).unsqueeze(0)
|
170 |
output_path = os.path.join("output", f"{uuid.uuid4()}.png")
|
171 |
-
|
|
|
172 |
ori_length = cond_latents[0].size(2)
|
173 |
cond_latents = runner.vae_encode(cond_latents)
|
174 |
samples = generation_step(runner, text_embeds, cond_latents)
|
175 |
sample = samples[0][:ori_length].cpu()
|
176 |
sample = rearrange(sample, "t c h w -> t h w c").clip(-1, 1).add(1).mul(127.5).byte().numpy()
|
|
|
177 |
if is_video:
|
178 |
mediapy.write_video(output_path, sample, fps=fps_out)
|
179 |
return None, output_path, output_path
|
@@ -182,7 +186,16 @@ def generation_loop(video_path, seed=666, fps_out=24):
|
|
182 |
return output_path, None, output_path
|
183 |
|
184 |
with gr.Blocks(title="SeedVR") as demo:
|
185 |
-
gr.HTML(f"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
186 |
with gr.Row():
|
187 |
input_file = gr.File(label="Carregar Imagem ou Vídeo")
|
188 |
with gr.Column():
|
@@ -193,5 +206,11 @@ with gr.Blocks(title="SeedVR") as demo:
|
|
193 |
output_video = gr.Video(label="Vídeo de Saída")
|
194 |
download_link = gr.File(label="Baixar Resultado")
|
195 |
run_button.click(fn=generation_loop, inputs=[input_file, seed, fps], outputs=[output_image, output_video, download_link])
|
|
|
|
|
|
|
|
|
|
|
|
|
196 |
|
197 |
demo.queue().launch(share=True)
|
|
|
2 |
# //
|
3 |
# // Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
# // you may not use this file except in compliance with the License.
|
5 |
+
# // You may obtain a copy of the License at
|
6 |
# //
|
7 |
# // http://www.apache.org/licenses/LICENSE-2.0
|
8 |
# //
|
|
|
11 |
# // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
# // See the License for the specific language governing permissions and
|
13 |
# // limitations under the License.
|
14 |
+
import spaces
|
15 |
+
import subprocess
|
16 |
import os
|
|
|
|
|
17 |
import sys
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
|
19 |
+
# --- ETAPA 1: Clonar o Repositório do GitHub ---
|
20 |
repo_name = "SeedVR"
|
21 |
if not os.path.exists(repo_name):
|
22 |
print(f"Clonando o repositório {repo_name} do GitHub...")
|
|
|
26 |
os.chdir(repo_name)
|
27 |
print(f"Diretório de trabalho alterado para: {os.getcwd()}")
|
28 |
|
|
|
29 |
sys.path.insert(0, os.path.abspath('.'))
|
30 |
print(f"Diretório atual adicionado ao sys.path.")
|
31 |
|
32 |
+
# --- ETAPA 3: Instalar Dependências Corretamente ---
|
33 |
python_executable = sys.executable
|
34 |
+
|
35 |
+
# CORREÇÃO CRÍTICA: Filtrar requirements.txt para evitar conflitos com torch/torchvision
|
36 |
+
print("Filtrando requirements.txt para evitar conflitos de versão...")
|
37 |
+
with open("requirements.txt", "r") as f_in, open("filtered_requirements.txt", "w") as f_out:
|
38 |
+
for line in f_in:
|
39 |
+
# Ignora as linhas que podem causar conflitos
|
40 |
+
if not line.strip().startswith(('torch', 'torchvision')):
|
41 |
+
f_out.write(line)
|
42 |
+
|
43 |
+
print("Instalando dependências filtradas...")
|
44 |
+
subprocess.run([python_executable, "-m", "pip", "install", "-r", "filtered_requirements.txt"], check=True)
|
45 |
|
46 |
print("Instalando flash-attn...")
|
47 |
subprocess.run([python_executable, "-m", "pip", "install", "flash-attn==2.5.9.post1", "--no-build-isolation"], check=True)
|
|
|
50 |
from urllib.parse import urlparse
|
51 |
from torch.hub import download_url_to_file, get_dir
|
52 |
|
|
|
53 |
def load_file_from_url(url, model_dir='.', progress=True, file_name=None):
|
54 |
os.makedirs(model_dir, exist_ok=True)
|
55 |
if not file_name:
|
|
|
61 |
download_url_to_file(url, cached_file, hash_prefix=None, progress=progress)
|
62 |
return cached_file
|
63 |
|
|
|
64 |
apex_url = 'https://huggingface.co/ByteDance-Seed/SeedVR2-3B/resolve/main/apex-0.1-cp310-cp310-linux_x86_64.whl'
|
65 |
apex_wheel_path = load_file_from_url(url=apex_url)
|
66 |
print("Instalando Apex a partir do wheel baixado...")
|
|
|
69 |
|
70 |
# --- ETAPA 4: Baixar os Modelos Pré-treinados ---
|
71 |
print("Baixando modelos pré-treinados...")
|
72 |
+
import torch
|
73 |
+
|
74 |
pretrain_model_url = {
|
75 |
'vae': 'https://huggingface.co/ByteDance-Seed/SeedVR2-3B/resolve/main/ema_vae.pth',
|
76 |
'dit': 'https://huggingface.co/ByteDance-Seed/SeedVR2-3B/resolve/main/seedvr2_ema_3b.pth',
|
|
|
83 |
model_dir = './ckpts' if key in ['vae', 'dit'] else '.'
|
84 |
load_file_from_url(url=url, model_dir=model_dir)
|
85 |
|
86 |
+
|
87 |
# --- ETAPA 5: Executar a Aplicação Principal ---
|
|
|
88 |
import mediapy
|
89 |
from einops import rearrange
|
90 |
from omegaconf import OmegaConf
|
|
|
110 |
from projects.video_diffusion_sr.infer import VideoDiffusionInfer
|
111 |
from common.distributed.ops import sync_data
|
112 |
|
113 |
+
torch.hub.download_url_to_file('https://huggingface.co/datasets/Iceclear/SeedVR_VideoDemos/resolve/main/seedvr_videos_crf23/aigc1k/23_1_lq.mp4', '01.mp4')
|
114 |
+
torch.hub.download_url_to_file('https://huggingface.co/datasets/Iceclear/SeedVR_VideoDemos/resolve/main/seedvr_videos_crf23/aigc1k/28_1_lq.mp4', '02.mp4')
|
115 |
+
torch.hub.download_url_to_file('https://huggingface.co/datasets/Iceclear/SeedVR_VideoDemos/resolve/main/seedvr_videos_crf23/aigc1k/2_1_lq.mp4', '03.mp4')
|
116 |
+
print("✅ Setup completo. Iniciando a aplicação...")
|
117 |
+
|
118 |
+
|
119 |
+
|
120 |
+
|
121 |
os.environ["MASTER_ADDR"] = "127.0.0.1"
|
122 |
os.environ["MASTER_PORT"] = "12355"
|
123 |
os.environ["RANK"] = str(0)
|
124 |
os.environ["WORLD_SIZE"] = str(1)
|
125 |
|
126 |
+
use_colorfix = os.path.exists("projects/video_diffusion_sr/color_fix.py")
|
|
|
|
|
|
|
|
|
127 |
|
128 |
def configure_runner():
|
129 |
config = load_config('configs_3b/main.yaml')
|
|
|
138 |
|
139 |
def generation_step(runner, text_embeds_dict, cond_latents):
|
140 |
def _move_to_cuda(x): return [i.to("cuda") for i in x]
|
141 |
+
noises, aug_noises = [torch.randn_like(l) for l in cond_latents], [torch.randn_like(l) for l in cond_latents]
|
|
|
142 |
noises, aug_noises, cond_latents = sync_data((noises, aug_noises, cond_latents), 0)
|
143 |
+
noises, aug_noises, cond_latents = map(_move_to_cuda, (noises, aug_noises, cond_latents))
|
144 |
def _add_noise(x, aug_noise):
|
145 |
t = torch.tensor([100.0], device="cuda")
|
146 |
shape = torch.tensor(x.shape[1:], device="cuda")[None]
|
|
|
159 |
runner.configure_diffusion()
|
160 |
set_seed(int(seed))
|
161 |
os.makedirs("output", exist_ok=True)
|
162 |
+
transform = Compose([NaResize(1024), DivisibleCrop(16), Normalize(0.5, 0.5), Rearrange("t c h w -> c t h w")])
|
163 |
media_type, _ = mimetypes.guess_type(video_path)
|
164 |
is_video = media_type and media_type.startswith("video")
|
165 |
+
|
166 |
if is_video:
|
167 |
video, _, _ = read_video(video_path, output_format="TCHW")
|
168 |
video = video[:121] / 255.0
|
|
|
170 |
else:
|
171 |
video = T.ToTensor()(Image.open(video_path).convert("RGB")).unsqueeze(0)
|
172 |
output_path = os.path.join("output", f"{uuid.uuid4()}.png")
|
173 |
+
|
174 |
+
cond_latents = [transform(video.to("cuda"))]
|
175 |
ori_length = cond_latents[0].size(2)
|
176 |
cond_latents = runner.vae_encode(cond_latents)
|
177 |
samples = generation_step(runner, text_embeds, cond_latents)
|
178 |
sample = samples[0][:ori_length].cpu()
|
179 |
sample = rearrange(sample, "t c h w -> t h w c").clip(-1, 1).add(1).mul(127.5).byte().numpy()
|
180 |
+
|
181 |
if is_video:
|
182 |
mediapy.write_video(output_path, sample, fps=fps_out)
|
183 |
return None, output_path, output_path
|
|
|
186 |
return output_path, None, output_path
|
187 |
|
188 |
with gr.Blocks(title="SeedVR") as demo:
|
189 |
+
gr.HTML(f"""
|
190 |
+
<div style='text-align:center; margin-bottom: 10px;'>
|
191 |
+
<img src='file/{os.path.abspath("assets/seedvr_logo.png")}' style='height:40px;' alt='SeedVR logo'/>
|
192 |
+
</div>
|
193 |
+
<p><b>Demonstração oficial do Gradio</b> para
|
194 |
+
<a href='https://github.com/ByteDance-Seed/SeedVR' target='_blank'>
|
195 |
+
<b>SeedVR2: One-Step Video Restoration via Diffusion Adversarial Post-Training</b></a>.<br>
|
196 |
+
🔥 <b>SeedVR2</b> é um algoritmo de restauração de imagem e vídeo em um passo para conteúdo do mundo real e AIGC.
|
197 |
+
</p>
|
198 |
+
""")
|
199 |
with gr.Row():
|
200 |
input_file = gr.File(label="Carregar Imagem ou Vídeo")
|
201 |
with gr.Column():
|
|
|
206 |
output_video = gr.Video(label="Vídeo de Saída")
|
207 |
download_link = gr.File(label="Baixar Resultado")
|
208 |
run_button.click(fn=generation_loop, inputs=[input_file, seed, fps], outputs=[output_image, output_video, download_link])
|
209 |
+
gr.Examples(examples=[["01.mp4", 42, 24], ["02.mp4", 42, 24], ["03.mp4", 42, 24]], inputs=[input_file, seed, fps])
|
210 |
+
gr.HTML("""
|
211 |
+
<hr>
|
212 |
+
<p>Se você achou o SeedVR útil, por favor ⭐ o
|
213 |
+
<a href='https://github.com/ByteDance-Seed/SeedVR' target='_blank'>repositório no GitHub</a>.</p>
|
214 |
+
""")
|
215 |
|
216 |
demo.queue().launch(share=True)
|