import os import sys import subprocess import importlib.util # --- ETAPA 0: Instalação Final do flash-attn --- # Verifica se o flash_attn já está instalado. Se não, instala. package_name = 'flash_attn' spec = importlib.util.find_spec(package_name) if spec is None: print(f"Instalando o pacote que faltava: {package_name}. Isso pode levar um minuto...") # Usamos o python executável do ambiente atual para instalar o pacote python_executable = sys.executable subprocess.run( [ python_executable, "-m", "pip", "install", "flash_attn==2.5.9.post1", "--no-build-isolation" ], check=True ) print(f"✅ {package_name} instalado com sucesso.") else: print(f"✅ Pacote {package_name} já está instalado.") # A partir daqui, o ambiente está 100% pronto. # --------------------------------------------------------------------- import spaces from pathlib import Path from urllib.parse import urlparse import torch from torch.hub import download_url_to_file import mediapy from einops import rearrange from omegaconf import OmegaConf import datetime import gc from PIL import Image import gradio as gr import uuid import mimetypes import torchvision.transforms as T from torchvision.transforms import Compose, Lambda, Normalize from torchvision.io.video import read_video # --- ETAPA 1: Clonar o Repositório e Mudar para o Diretório --- repo_name = "SeedVR" if not os.path.exists(repo_name): print(f"Clonando o repositório {repo_name} do GitHub...") subprocess.run(f"git clone https://github.com/ByteDance-Seed/{repo_name}.git", shell=True, check=True) # Garante que estamos no diretório certo if not os.getcwd().endswith(repo_name): os.chdir(repo_name) sys.path.insert(0, os.path.abspath('.')) # Importações do projeto SeedVR (só podem ser feitas após o chdir) from data.image.transforms.divisible_crop import DivisibleCrop from data.image.transforms.na_resize import NaResize from data.video.transforms.rearrange import Rearrange from common.config import load_config from common.distributed import init_torch from common.seed import set_seed from projects.video_diffusion_sr.infer import VideoDiffusionInfer from common.distributed.ops import sync_data print("Ambiente Conda carregado e verificado. Iniciando a aplicação...") # --- ETAPA 2: Baixar os Modelos Pré-treinados --- print("Baixando modelos pré-treinados...") def load_file_from_url(url, model_dir='.', progress=True, file_name=None): os.makedirs(model_dir, exist_ok=True) if not file_name: parts = urlparse(url) file_name = os.path.basename(parts.path) cached_file = os.path.join(model_dir, file_name) if not os.path.exists(cached_file): print(f'Baixando: "{url}" para {cached_file}\n') download_url_to_file(url, cached_file, hash_prefix=None, progress=progress) return cached_file pretrain_model_url = { 'vae': 'https://huggingface.co/ByteDance-Seed/SeedVR2-3B/resolve/main/ema_vae.pth', 'dit': 'https://huggingface.co/ByteDance-Seed/SeedVR2-3B/resolve/main/seedvr2_ema_3b.pth', 'pos_emb': 'https://huggingface.co/ByteDance-Seed/SeedVR2-3B/resolve/main/pos_emb.pt', 'neg_emb': 'https://huggingface.co/ByteDance-Seed/SeedVR2-3B/resolve/main/neg_emb.pt', } Path('./ckpts').mkdir(exist_ok=True) for key, url in pretrain_model_url.items(): model_dir = './ckpts' if key in ['vae', 'dit'] else '.' load_file_from_url(url=url, model_dir=model_dir) # --- ETAPA 3: Executar a Aplicação Principal --- os.environ["MASTER_ADDR"] = "127.0.0.1" os.environ["MASTER_PORT"] = "12355" os.environ["RANK"] = str(0) os.environ["WORLD_SIZE"] = str(1) def configure_runner(): config = load_config('configs_3b/main.yaml') runner = VideoDiffusionInfer(config) OmegaConf.set_readonly(runner.config, False) init_torch(cudnn_benchmark=False, timeout=datetime.timedelta(seconds=3600)) runner.configure_dit_model(device="cuda", checkpoint='ckpts/seedvr2_ema_3b.pth') runner.configure_vae_model() if hasattr(runner.vae, "set_memory_limit"): runner.vae.set_memory_limit(**runner.config.vae.memory_limit) return runner def generation_step(runner, text_embeds_dict, cond_latents): def _move_to_cuda(x): return [i.to("cuda") for i in x] noises, aug_noises = [torch.randn_like(l) for l in cond_latents], [torch.randn_like(l) for l in cond_latents] noises, aug_noises, cond_latents = sync_data((noises, aug_noises, cond_latents), 0) noises, aug_noises, cond_latents = map(_move_to_cuda, (noises, aug_noises, cond_latents)) def _add_noise(x, aug_noise): t = torch.tensor([100.0], device="cuda") shape = torch.tensor(x.shape[1:], device="cuda")[None] t = runner.timestep_transform(t, shape) return runner.schedule.forward(x, aug_noise, t) conditions = [runner.get_condition(n, task="sr", latent_blur=_add_noise(l, an)) for n, an, l in zip(noises, aug_noises, cond_latents)] with torch.no_grad(), torch.autocast("cuda", torch.bfloat16, enabled=True): video_tensors = runner.inference(noises=noises, conditions=conditions, **text_embeds_dict) return [rearrange(v, "c t h w -> t c h w") for v in video_tensors] @spaces.GPU def generation_loop(video_path, seed=666, fps_out=24): if video_path is None: return None, None, None runner = configure_runner() text_embeds = { "texts_pos": [torch.load('pos_emb.pt', weights_only=True).to("cuda")], "texts_neg": [torch.load('neg_emb.pt', weights_only=True).to("cuda")] } runner.configure_diffusion() set_seed(int(seed)) os.makedirs("output", exist_ok=True) res_h, res_w = 1280, 720 transform = Compose([ NaResize(resolution=(res_h * res_w)**0.5, mode="area", downsample_only=False), Lambda(lambda x: torch.clamp(x, 0.0, 1.0)), DivisibleCrop((16, 16)), Normalize(0.5, 0.5), Rearrange("t c h w -> c t h w") ]) media_type, _ = mimetypes.guess_type(video_path) is_video = media_type and media_type.startswith("video") if is_video: video, _, _ = read_video(video_path, output_format="TCHW") video = video[:121] / 255.0 output_path = os.path.join("output", f"{uuid.uuid4()}.mp4") else: video = T.ToTensor()(Image.open(video_path).convert("RGB")).unsqueeze(0) output_path = os.path.join("output", f"{uuid.uuid4()}.png") cond_latents = [transform(video.to("cuda"))] ori_length = cond_latents[0].size(2) cond_latents = runner.vae_encode(cond_latents) samples = generation_step(runner, text_embeds, cond_latents) sample = samples[0][:ori_length].cpu() sample = rearrange(sample, "t c h w -> t h w c").clip(-1, 1).add(1).mul(127.5).byte().numpy() if is_video: mediapy.write_video(output_path, sample, fps=fps_out) return None, output_path, output_path else: mediapy.write_image(output_path, sample[0]) return output_path, None, output_path with gr.Blocks(title="SeedVR") as demo: gr.HTML(f"""
Demonstração oficial do Gradio para
SeedVR2: One-Step Video Restoration via Diffusion Adversarial Post-Training.
🔥 SeedVR2 é um algoritmo de restauração de imagem e vídeo em um passo para conteúdo do mundo real e AIGC.