SeedVR2-3B / app.py
Aduc-sdr's picture
Update app.py
1d6758a verified
raw
history blame
13 kB
# // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# //
# // Licensed under the Apache License, Version 2.0 (the "License");
# // you may not use this file except in compliance with the License.
# // You may obtain a copy of the License at
# //
# // http://www.apache.org/licenses/LICENSE-2.0
# //
# // Unless required by applicable law or agreed to in writing, software
# // distributed under the License is distributed on an "AS IS" BASIS,
# // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# // See the License for the specific language governing permissions and
# // limitations under the License.
import spaces
import subprocess
import os
import sys
# --- Setup: Clone repository and Change Working Directory ---
# This is the most robust way to ensure all relative paths work correctly.
# 1. Clone the repository with all its files
subprocess.run("git lfs install", shell=True, check=True)
repo_dir_name = "SeedVR2-3B"
if not os.path.exists(repo_dir_name):
print(f"Cloning {repo_dir_name} repository...")
subprocess.run(f"git clone https://huggingface.co/spaces/ByteDance-Seed/{repo_dir_name}", shell=True, check=True)
# 2. Change the current working directory to the repository's root
# CORREÇÃO PRINCIPAL: Isso resolve todos os problemas de caminho relativo.
os.chdir(repo_dir_name)
print(f"Changed working directory to: {os.getcwd()}")
# --- Main Application Code ---
# Now that we are inside the repo, all imports and file loads will work naturally.
import torch
import mediapy
from einops import rearrange
from omegaconf import OmegaConf
import datetime
from tqdm import tqdm
import gc
from PIL import Image
import gradio as gr
from pathlib import Path
import shlex
import uuid
import mimetypes
import torchvision.transforms as T
from torchvision.transforms import Compose, Lambda, Normalize
from torchvision.io.video import read_video
# Imports from the repository (will now work directly)
from data.image.transforms.divisible_crop import DivisibleCrop
from data.image.transforms.na_resize import NaResize
from data.video.transforms.rearrange import Rearrange
from common.config import load_config
from common.distributed import init_torch
from common.distributed.advanced import init_sequence_parallel
from common.seed import set_seed
from common.partition import partition_by_size
from projects.video_diffusion_sr.infer import VideoDiffusionInfer
from common.distributed.ops import sync_data
# Check for color_fix utility (using relative path)
if os.path.exists("projects/video_diffusion_sr/color_fix.py"):
from projects.video_diffusion_sr.color_fix import wavelet_reconstruction
use_colorfix = True
else:
use_colorfix = False
print('Note!!!!!! Color fix is not available!')
# --- Environment and Dependencies Setup ---
os.environ["MASTER_ADDR"] = "127.0.0.1"
os.environ["MASTER_PORT"] = "12355"
os.environ["RANK"] = str(0)
os.environ["WORLD_SIZE"] = str(1)
# Use sys.executable to ensure we use the correct pip
python_executable = sys.executable
subprocess.run(
[python_executable, "-m", "pip", "install", "flash-attn", "--no-build-isolation"],
env={**os.environ, "FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"},
check=True
)
apex_wheel_path = "apex-0.1-cp310-cp310-linux_x86_64.whl"
if os.path.exists(apex_wheel_path):
subprocess.run([python_executable, "-m", "pip", "install", apex_wheel_path], check=True)
print("✅ Apex setup completed.")
# --- Core Functions ---
def configure_sequence_parallel(sp_size):
if sp_size > 1:
init_sequence_parallel(sp_size)
def configure_runner(sp_size):
# Paths are now simple and relative to the repo root
config_path = 'configs_3b/main.yaml'
checkpoint_path = 'ckpts/seedvr2_ema_3b.pth'
config = load_config(config_path) # This will now work correctly
runner = VideoDiffusionInfer(config)
OmegaConf.set_readonly(runner.config, False)
init_torch(cudnn_benchmark=False, timeout=datetime.timedelta(seconds=3600))
configure_sequence_parallel(sp_size)
runner.configure_dit_model(device="cuda", checkpoint=checkpoint_path)
runner.configure_vae_model()
if hasattr(runner.vae, "set_memory_limit"):
runner.vae.set_memory_limit(**runner.config.vae.memory_limit)
return runner
def generation_step(runner, text_embeds_dict, cond_latents):
def _move_to_cuda(x):
return [i.to(torch.device("cuda")) for i in x]
noises = [torch.randn_like(latent) for latent in cond_latents]
aug_noises = [torch.randn_like(latent) for latent in cond_latents]
print(f"Generating with noise shape: {noises[0].size()}.")
noises, aug_noises, cond_latents = sync_data((noises, aug_noises, cond_latents), 0)
noises, aug_noises, cond_latents = list(map(_move_to_cuda, (noises, aug_noises, cond_latents)))
cond_noise_scale = 0.1
def _add_noise(x, aug_noise):
t = torch.tensor([1000.0], device=torch.device("cuda")) * cond_noise_scale
shape = torch.tensor(x.shape[1:], device=torch.device("cuda"))[None]
t = runner.timestep_transform(t, shape)
print(f"Timestep shifting from {1000.0 * cond_noise_scale} to {t}.")
x = runner.schedule.forward(x, aug_noise, t)
return x
conditions = [
runner.get_condition(noise, task="sr", latent_blur=_add_noise(latent_blur, aug_noise))
for noise, aug_noise, latent_blur in zip(noises, aug_noises, cond_latents)
]
with torch.no_grad(), torch.autocast("cuda", torch.bfloat16, enabled=True):
video_tensors = runner.inference(
noises=noises, conditions=conditions, dit_offload=False, **text_embeds_dict
)
samples = [rearrange(video, "c t h w -> t c h w") for video in video_tensors]
del video_tensors
return samples
@spaces.GPU
def generation_loop(video_path, seed=666, fps_out=24, batch_size=1, cfg_scale=1.0, cfg_rescale=0.0, sample_steps=1, res_h=1280, res_w=720, sp_size=1):
if video_path is None:
return None, None, None
runner = configure_runner(1)
def _extract_text_embeds():
positive_prompts_embeds = []
for _ in original_videos_local:
# Paths are now simple
text_pos_embeds = torch.load('pos_emb.pt')
text_neg_embeds = torch.load('neg_emb.pt')
positive_prompts_embeds.append({"texts_pos": [text_pos_embeds], "texts_neg": [text_neg_embeds]})
gc.collect()
torch.cuda.empty_cache()
return positive_prompts_embeds
def cut_videos(videos, sp_size):
if videos.size(1) > 121:
videos = videos[:, :121]
t = videos.size(1)
if t <= 4 * sp_size:
padding_needed = 4 * sp_size - t + 1
if padding_needed > 0:
padding = torch.cat([videos[:, -1].unsqueeze(1)] * padding_needed, dim=1)
videos = torch.cat([videos, padding], dim=1)
return videos
if (t - 1) % (4 * sp_size) == 0:
return videos
else:
padding_needed = 4 * sp_size - ((t - 1) % (4 * sp_size))
padding = torch.cat([videos[:, -1].unsqueeze(1)] * padding_needed, dim=1)
videos = torch.cat([videos, padding], dim=1)
assert (videos.size(1) - 1) % (4 * sp_size) == 0
return videos
runner.config.diffusion.cfg.scale = cfg_scale
runner.config.diffusion.cfg.rescale = cfg_rescale
runner.config.diffusion.timesteps.sampling.steps = sample_steps
runner.configure_diffusion()
seed = int(seed) % (2**32)
set_seed(seed, same_across_ranks=True)
output_base_dir = "output"
os.makedirs(output_base_dir, exist_ok=True)
original_videos = [os.path.basename(video_path)]
original_videos_local = partition_by_size(original_videos, batch_size)
positive_prompts_embeds = _extract_text_embeds()
video_transform = Compose([
NaResize(resolution=(res_h * res_w) ** 0.5, mode="area", downsample_only=False),
Lambda(lambda x: torch.clamp(x, 0.0, 1.0)),
DivisibleCrop((16, 16)),
Normalize(0.5, 0.5),
Rearrange("t c h w -> c t h w"),
])
for videos, text_embeds in tqdm(zip(original_videos_local, positive_prompts_embeds)):
cond_latents = []
for _ in videos:
media_type, _ = mimetypes.guess_type(video_path)
is_image = media_type and media_type.startswith("image")
is_video = media_type and media_type.startswith("video")
if is_video:
video, _, _ = read_video(video_path, output_format="TCHW")
video = video / 255.0
if video.size(0) > 121:
video = video[:121]
print(f"Read video size: {video.size()}")
output_dir = os.path.join(output_base_dir, f"{uuid.uuid4()}.mp4")
elif is_image:
img = Image.open(video_path).convert("RGB")
img_tensor = T.ToTensor()(img).unsqueeze(0)
video = img_tensor
print(f"Read Image size: {video.size()}")
output_dir = os.path.join(output_base_dir, f"{uuid.uuid4()}.png")
else:
raise ValueError("Unsupported file type")
cond_latents.append(video_transform(video.to(torch.device("cuda"))))
ori_lengths = [v.size(1) for v in cond_latents]
input_videos = cond_latents
if is_video:
cond_latents = [cut_videos(v, sp_size) for v in cond_latents]
print(f"Encoding videos: {[v.size() for v in cond_latents]}")
cond_latents = runner.vae_encode(cond_latents)
for i, emb in enumerate(text_embeds["texts_pos"]):
text_embeds["texts_pos"][i] = emb.to(torch.device("cuda"))
for i, emb in enumerate(text_embeds["texts_neg"]):
text_embeds["texts_neg"][i] = emb.to(torch.device("cuda"))
samples = generation_step(runner, text_embeds, cond_latents=cond_latents)
del cond_latents
for _, input_tensor, sample, ori_length in zip(videos, input_videos, samples, ori_lengths):
if ori_length < sample.shape[0]:
sample = sample[:ori_length]
input_tensor = rearrange(input_tensor, "c t h w -> t c h w")
if use_colorfix:
sample = wavelet_reconstruction(sample.to("cpu"), input_tensor[:sample.size(0)].to("cpu"))
else:
sample = sample.to("cpu")
sample = rearrange(sample, "t c h w -> t h w c")
sample = sample.clip(-1, 1).mul_(0.5).add_(0.5).mul_(255).round()
sample = sample.to(torch.uint8).numpy()
if is_image:
mediapy.write_image(output_dir, sample[0])
else:
mediapy.write_video(output_dir, sample, fps=fps_out)
gc.collect()
torch.cuda.empty_cache()
if is_image:
return output_dir, None, output_dir
else:
return None, output_dir, output_dir
# --- Gradio UI ---
with gr.Blocks(title="SeedVR2: One-Step Video Restoration") as demo:
# Use an absolute path for the Gradio file source to be safe
logo_path = os.path.abspath("assets/seedvr_logo.png")
gr.HTML(f"""
<a href='https://github.com/ByteDance-Seed/SeedVR' target='_blank'><b>SeedVR2: One-Step Video Restoration via Diffusion Adversarial Post-Training</b></a>.<br>
🔥 <b>SeedVR2</b> is a one-step image and video restoration algorithm for real-world and AIGC content.
""")
with gr.Row():
input_file = gr.File(label="Upload image or video", type="filepath")
with gr.Column():
seed = gr.Number(label="Seed", value=666)
fps = gr.Number(label="Output FPS (for video)", value=24)
run_button = gr.Button("Run")
with gr.Row():
output_image = gr.Image(label="Output Image")
output_video = gr.Video(label="Output Video")
download_link = gr.File(label="Download the output")
run_button.click(fn=generation_loop, inputs=[input_file, seed, fps], outputs=[output_image, output_video, download_link])
gr.HTML("""
<hr>
<p>If you find SeedVR helpful, please ⭐ the <a href='https://github.com/ByteDance-Seed/SeedVR' target='_blank'>GitHub repository</a>:
<a href="https://github.com/ByteDance-Seed/SeedVR" target="_blank"><img src="https://img.shields.io/github/stars/ByteDance-Seed/SeedVR?style=social" alt="GitHub Stars"></a></p>
<h4>Notice</h4>
<p>This demo supports up to <b>720p and 121 frames for videos or 2k images</b>. For other use cases, check the <a href='https://github.com/ByteDance-Seed/SeedVR' target='_blank'>GitHub repo</a>.</p>
<h4>Limitations</h4>
<p>May fail on heavy degradations or small-motion AIGC clips, causing oversharpening or poor restoration.</p>
""")
demo.queue().launch(share=True)