SeedVR2-3B / app.py
aducsdr's picture
Update app.py
f4c2790 verified
raw
history blame
8.05 kB
import os
import sys
import subprocess
import importlib.util
# --- ETAPA 0: Instalação Final do flash-attn ---
# Verifica se o flash_attn já está instalado. Se não, instala.
package_name = 'flash_attn'
spec = importlib.util.find_spec(package_name)
if spec is None:
print(f"Instalando o pacote que faltava: {package_name}. Isso pode levar um minuto...")
# Usamos o python executável do ambiente atual para instalar o pacote
python_executable = sys.executable
subprocess.run(
[
python_executable, "-m", "pip", "install",
"flash_attn==2.5.9.post1",
"--no-build-isolation"
],
check=True
)
print(f"✅ {package_name} instalado com sucesso.")
else:
print(f"✅ Pacote {package_name} já está instalado.")
# A partir daqui, o ambiente está 100% pronto.
# ---------------------------------------------------------------------
import spaces
from pathlib import Path
from urllib.parse import urlparse
import torch
from torch.hub import download_url_to_file
import mediapy
from einops import rearrange
from omegaconf import OmegaConf
import datetime
import gc
from PIL import Image
import gradio as gr
import uuid
import mimetypes
import torchvision.transforms as T
from torchvision.transforms import Compose, Lambda, Normalize
from torchvision.io.video import read_video
# --- ETAPA 1: Clonar o Repositório e Mudar para o Diretório ---
repo_name = "SeedVR"
if not os.path.exists(repo_name):
print(f"Clonando o repositório {repo_name} do GitHub...")
subprocess.run(f"git clone https://github.com/ByteDance-Seed/{repo_name}.git", shell=True, check=True)
# Garante que estamos no diretório certo
if not os.getcwd().endswith(repo_name):
os.chdir(repo_name)
sys.path.insert(0, os.path.abspath('.'))
# Importações do projeto SeedVR (só podem ser feitas após o chdir)
from data.image.transforms.divisible_crop import DivisibleCrop
from data.image.transforms.na_resize import NaResize
from data.video.transforms.rearrange import Rearrange
from common.config import load_config
from common.distributed import init_torch
from common.seed import set_seed
from projects.video_diffusion_sr.infer import VideoDiffusionInfer
from common.distributed.ops import sync_data
print("Ambiente Conda carregado e verificado. Iniciando a aplicação...")
# --- ETAPA 2: Baixar os Modelos Pré-treinados ---
print("Baixando modelos pré-treinados...")
def load_file_from_url(url, model_dir='.', progress=True, file_name=None):
os.makedirs(model_dir, exist_ok=True)
if not file_name:
parts = urlparse(url)
file_name = os.path.basename(parts.path)
cached_file = os.path.join(model_dir, file_name)
if not os.path.exists(cached_file):
print(f'Baixando: "{url}" para {cached_file}\n')
download_url_to_file(url, cached_file, hash_prefix=None, progress=progress)
return cached_file
pretrain_model_url = {
'vae': 'https://huggingface.co/ByteDance-Seed/SeedVR2-3B/resolve/main/ema_vae.pth',
'dit': 'https://huggingface.co/ByteDance-Seed/SeedVR2-3B/resolve/main/seedvr2_ema_3b.pth',
'pos_emb': 'https://huggingface.co/ByteDance-Seed/SeedVR2-3B/resolve/main/pos_emb.pt',
'neg_emb': 'https://huggingface.co/ByteDance-Seed/SeedVR2-3B/resolve/main/neg_emb.pt',
}
Path('./ckpts').mkdir(exist_ok=True)
for key, url in pretrain_model_url.items():
model_dir = './ckpts' if key in ['vae', 'dit'] else '.'
load_file_from_url(url=url, model_dir=model_dir)
# --- ETAPA 3: Executar a Aplicação Principal ---
os.environ["MASTER_ADDR"] = "127.0.0.1"
os.environ["MASTER_PORT"] = "12355"
os.environ["RANK"] = str(0)
os.environ["WORLD_SIZE"] = str(1)
def configure_runner():
config = load_config('configs_3b/main.yaml')
runner = VideoDiffusionInfer(config)
OmegaConf.set_readonly(runner.config, False)
init_torch(cudnn_benchmark=False, timeout=datetime.timedelta(seconds=3600))
runner.configure_dit_model(device="cuda", checkpoint='ckpts/seedvr2_ema_3b.pth')
runner.configure_vae_model()
if hasattr(runner.vae, "set_memory_limit"):
runner.vae.set_memory_limit(**runner.config.vae.memory_limit)
return runner
def generation_step(runner, text_embeds_dict, cond_latents):
def _move_to_cuda(x): return [i.to("cuda") for i in x]
noises, aug_noises = [torch.randn_like(l) for l in cond_latents], [torch.randn_like(l) for l in cond_latents]
noises, aug_noises, cond_latents = sync_data((noises, aug_noises, cond_latents), 0)
noises, aug_noises, cond_latents = map(_move_to_cuda, (noises, aug_noises, cond_latents))
def _add_noise(x, aug_noise):
t = torch.tensor([100.0], device="cuda")
shape = torch.tensor(x.shape[1:], device="cuda")[None]
t = runner.timestep_transform(t, shape)
return runner.schedule.forward(x, aug_noise, t)
conditions = [runner.get_condition(n, task="sr", latent_blur=_add_noise(l, an)) for n, an, l in zip(noises, aug_noises, cond_latents)]
with torch.no_grad(), torch.autocast("cuda", torch.bfloat16, enabled=True):
video_tensors = runner.inference(noises=noises, conditions=conditions, **text_embeds_dict)
return [rearrange(v, "c t h w -> t c h w") for v in video_tensors]
@spaces.GPU
def generation_loop(video_path, seed=666, fps_out=24):
if video_path is None: return None, None, None
runner = configure_runner()
text_embeds = {
"texts_pos": [torch.load('pos_emb.pt', weights_only=True).to("cuda")],
"texts_neg": [torch.load('neg_emb.pt', weights_only=True).to("cuda")]
}
runner.configure_diffusion()
set_seed(int(seed))
os.makedirs("output", exist_ok=True)
res_h, res_w = 1280, 720
transform = Compose([
NaResize(resolution=(res_h * res_w)**0.5, mode="area", downsample_only=False),
Lambda(lambda x: torch.clamp(x, 0.0, 1.0)),
DivisibleCrop((16, 16)),
Normalize(0.5, 0.5),
Rearrange("t c h w -> c t h w")
])
media_type, _ = mimetypes.guess_type(video_path)
is_video = media_type and media_type.startswith("video")
if is_video:
video, _, _ = read_video(video_path, output_format="TCHW")
video = video[:121] / 255.0
output_path = os.path.join("output", f"{uuid.uuid4()}.mp4")
else:
video = T.ToTensor()(Image.open(video_path).convert("RGB")).unsqueeze(0)
output_path = os.path.join("output", f"{uuid.uuid4()}.png")
cond_latents = [transform(video.to("cuda"))]
ori_length = cond_latents[0].size(2)
cond_latents = runner.vae_encode(cond_latents)
samples = generation_step(runner, text_embeds, cond_latents)
sample = samples[0][:ori_length].cpu()
sample = rearrange(sample, "t c h w -> t h w c").clip(-1, 1).add(1).mul(127.5).byte().numpy()
if is_video:
mediapy.write_video(output_path, sample, fps=fps_out)
return None, output_path, output_path
else:
mediapy.write_image(output_path, sample[0])
return output_path, None, output_path
with gr.Blocks(title="SeedVR") as demo:
gr.HTML(f"""
<p><b>Demonstração oficial do Gradio</b> para
<a href='https://github.com/ByteDance-Seed/SeedVR' target='-blank'>
<b>SeedVR2: One-Step Video Restoration via Diffusion Adversarial Post-Training</b></a>.<br>
🔥 <b>SeedVR2</b> é um algoritmo de restauração de imagem e vídeo em um passo para conteúdo do mundo real e AIGC.
</p>
""")
with gr.Row():
input_file = gr.File(label="Carregar Imagem ou Vídeo")
with gr.Column():
seed = gr.Number(label="Seed", value=42)
fps = gr.Number(label="FPS de Saída", value=24)
run_button = gr.Button("Executar")
output_image = gr.Image(label="Imagem de Saída")
output_video = gr.Video(label="Vídeo de Saída")
download_link = gr.File(label="Baixar Resultado")
run_button.click(fn=generation_loop, inputs=[input_file, seed, fps], outputs=[output_image, output_video, download_link])
demo.queue().launch(share=True)