Spaces:
Running
on
Zero
Running
on
Zero
IceClear
commited on
Commit
·
1fd3071
1
Parent(s):
c9102eb
update
Browse files- app.py +5 -5
- common/distributed/basic.py +3 -3
- projects/video_diffusion_sr/infer.py +28 -28
app.py
CHANGED
|
@@ -139,9 +139,9 @@ torch.hub.download_url_to_file(
|
|
| 139 |
'https://huggingface.co/datasets/Iceclear/SeedVR_VideoDemos/resolve/main/seedvr_videos_crf23/aigc1k/2_1_lq.mp4',
|
| 140 |
'03.mp4')
|
| 141 |
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
|
| 146 |
@spaces.GPU(duration=120)
|
| 147 |
def configure_runner(sp_size):
|
|
@@ -150,8 +150,8 @@ def configure_runner(sp_size):
|
|
| 150 |
runner = VideoDiffusionInfer(config)
|
| 151 |
OmegaConf.set_readonly(runner.config, False)
|
| 152 |
|
| 153 |
-
|
| 154 |
-
|
| 155 |
runner.configure_dit_model(device="cuda", checkpoint='./ckpts/seedvr2_ema_3b.pth')
|
| 156 |
runner.configure_vae_model()
|
| 157 |
# Set memory limit.
|
|
|
|
| 139 |
'https://huggingface.co/datasets/Iceclear/SeedVR_VideoDemos/resolve/main/seedvr_videos_crf23/aigc1k/2_1_lq.mp4',
|
| 140 |
'03.mp4')
|
| 141 |
|
| 142 |
+
def configure_sequence_parallel(sp_size):
|
| 143 |
+
if sp_size > 1:
|
| 144 |
+
init_sequence_parallel(sp_size)
|
| 145 |
|
| 146 |
@spaces.GPU(duration=120)
|
| 147 |
def configure_runner(sp_size):
|
|
|
|
| 150 |
runner = VideoDiffusionInfer(config)
|
| 151 |
OmegaConf.set_readonly(runner.config, False)
|
| 152 |
|
| 153 |
+
init_torch(cudnn_benchmark=False, timeout=datetime.timedelta(seconds=3600))
|
| 154 |
+
configure_sequence_parallel(sp_size)
|
| 155 |
runner.configure_dit_model(device="cuda", checkpoint='./ckpts/seedvr2_ema_3b.pth')
|
| 156 |
runner.configure_vae_model()
|
| 157 |
# Set memory limit.
|
common/distributed/basic.py
CHANGED
|
@@ -66,11 +66,11 @@ def init_torch(cudnn_benchmark=True, timeout=timedelta(seconds=600)):
|
|
| 66 |
torch.backends.cuda.matmul.allow_tf32 = True
|
| 67 |
torch.backends.cudnn.allow_tf32 = True
|
| 68 |
torch.backends.cudnn.benchmark = cudnn_benchmark
|
| 69 |
-
torch.cuda.set_device(
|
| 70 |
dist.init_process_group(
|
| 71 |
backend="nccl",
|
| 72 |
-
rank=
|
| 73 |
-
world_size=
|
| 74 |
timeout=timeout,
|
| 75 |
)
|
| 76 |
|
|
|
|
| 66 |
torch.backends.cuda.matmul.allow_tf32 = True
|
| 67 |
torch.backends.cudnn.allow_tf32 = True
|
| 68 |
torch.backends.cudnn.benchmark = cudnn_benchmark
|
| 69 |
+
torch.cuda.set_device(0)
|
| 70 |
dist.init_process_group(
|
| 71 |
backend="nccl",
|
| 72 |
+
rank=0,
|
| 73 |
+
world_size=1,
|
| 74 |
timeout=timeout,
|
| 75 |
)
|
| 76 |
|
projects/video_diffusion_sr/infer.py
CHANGED
|
@@ -26,14 +26,14 @@ from common.diffusion import (
|
|
| 26 |
create_sampling_timesteps_from_config,
|
| 27 |
create_schedule_from_config,
|
| 28 |
)
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
# from common.fs import download
|
| 38 |
|
| 39 |
from models.dit_v2 import na
|
|
@@ -68,20 +68,20 @@ class VideoDiffusionInfer():
|
|
| 68 |
return cond
|
| 69 |
raise NotImplementedError
|
| 70 |
|
| 71 |
-
|
| 72 |
-
|
| 73 |
def configure_dit_model(self, device="cpu", checkpoint=None):
|
| 74 |
# Load dit checkpoint.
|
| 75 |
# For fast init & resume,
|
| 76 |
# when training from scratch, rank0 init DiT on cpu, then sync to other ranks with FSDP.
|
| 77 |
# otherwise, all ranks init DiT on meta device, then load_state_dict with assign=True.
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
|
| 83 |
# Create dit model.
|
| 84 |
-
with torch.device(
|
| 85 |
self.dit = create_object(self.config.dit.model)
|
| 86 |
self.dit.set_gradient_checkpointing(self.config.dit.gradient_checkpoint)
|
| 87 |
|
|
@@ -90,27 +90,27 @@ class VideoDiffusionInfer():
|
|
| 90 |
loading_info = self.dit.load_state_dict(state, strict=True, assign=True)
|
| 91 |
print(f"Loading pretrained ckpt from {checkpoint}")
|
| 92 |
print(f"Loading info: {loading_info}")
|
| 93 |
-
|
| 94 |
|
| 95 |
-
|
| 96 |
-
|
| 97 |
|
| 98 |
# Print model size.
|
| 99 |
num_params = sum(p.numel() for p in self.dit.parameters() if p.requires_grad)
|
| 100 |
print(f"DiT trainable parameters: {num_params:,}")
|
| 101 |
|
| 102 |
-
|
| 103 |
-
|
| 104 |
def configure_vae_model(self):
|
| 105 |
# Create vae model.
|
| 106 |
dtype = getattr(torch, self.config.vae.dtype)
|
| 107 |
self.vae = create_object(self.config.vae.model)
|
| 108 |
self.vae.requires_grad_(False).eval()
|
| 109 |
-
self.vae.to(device=
|
| 110 |
|
| 111 |
# Load vae checkpoint.
|
| 112 |
state = torch.load(
|
| 113 |
-
self.config.vae.checkpoint, map_location=
|
| 114 |
)
|
| 115 |
self.vae.load_state_dict(state)
|
| 116 |
|
|
@@ -123,12 +123,12 @@ class VideoDiffusionInfer():
|
|
| 123 |
def configure_diffusion(self):
|
| 124 |
self.schedule = create_schedule_from_config(
|
| 125 |
config=self.config.diffusion.schedule,
|
| 126 |
-
device=
|
| 127 |
)
|
| 128 |
self.sampling_timesteps = create_sampling_timesteps_from_config(
|
| 129 |
config=self.config.diffusion.timesteps.sampling,
|
| 130 |
schedule=self.schedule,
|
| 131 |
-
device=
|
| 132 |
)
|
| 133 |
self.sampler = create_sampler_from_config(
|
| 134 |
config=self.config.diffusion.sampler,
|
|
@@ -143,7 +143,7 @@ class VideoDiffusionInfer():
|
|
| 143 |
use_sample = self.config.vae.get("use_sample", True)
|
| 144 |
latents = []
|
| 145 |
if len(samples) > 0:
|
| 146 |
-
device =
|
| 147 |
dtype = getattr(torch, self.config.vae.dtype)
|
| 148 |
scale = self.config.vae.scaling_factor
|
| 149 |
shift = self.config.vae.get("shifting_factor", 0.0)
|
|
@@ -186,7 +186,7 @@ class VideoDiffusionInfer():
|
|
| 186 |
def vae_decode(self, latents: List[Tensor]) -> List[Tensor]:
|
| 187 |
samples = []
|
| 188 |
if len(latents) > 0:
|
| 189 |
-
device =
|
| 190 |
dtype = getattr(torch, self.config.vae.dtype)
|
| 191 |
scale = self.config.vae.scaling_factor
|
| 192 |
shift = self.config.vae.get("shifting_factor", 0.0)
|
|
@@ -340,9 +340,9 @@ class VideoDiffusionInfer():
|
|
| 340 |
self.dit.to("cpu")
|
| 341 |
|
| 342 |
# Vae decode.
|
| 343 |
-
self.vae.to(
|
| 344 |
samples = self.vae_decode(latents)
|
| 345 |
|
| 346 |
if dit_offload:
|
| 347 |
-
self.dit.to(
|
| 348 |
return samples
|
|
|
|
| 26 |
create_sampling_timesteps_from_config,
|
| 27 |
create_schedule_from_config,
|
| 28 |
)
|
| 29 |
+
from common.distributed import (
|
| 30 |
+
get_device,
|
| 31 |
+
get_global_rank,
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
from common.distributed.meta_init_utils import (
|
| 35 |
+
meta_non_persistent_buffer_init_fn,
|
| 36 |
+
)
|
| 37 |
# from common.fs import download
|
| 38 |
|
| 39 |
from models.dit_v2 import na
|
|
|
|
| 68 |
return cond
|
| 69 |
raise NotImplementedError
|
| 70 |
|
| 71 |
+
@log_on_entry
|
| 72 |
+
@log_runtime
|
| 73 |
def configure_dit_model(self, device="cpu", checkpoint=None):
|
| 74 |
# Load dit checkpoint.
|
| 75 |
# For fast init & resume,
|
| 76 |
# when training from scratch, rank0 init DiT on cpu, then sync to other ranks with FSDP.
|
| 77 |
# otherwise, all ranks init DiT on meta device, then load_state_dict with assign=True.
|
| 78 |
+
if self.config.dit.get("init_with_meta_device", False):
|
| 79 |
+
init_device = "cpu" if get_global_rank() == 0 and checkpoint is None else "meta"
|
| 80 |
+
else:
|
| 81 |
+
init_device = "cpu"
|
| 82 |
|
| 83 |
# Create dit model.
|
| 84 |
+
with torch.device(init_device):
|
| 85 |
self.dit = create_object(self.config.dit.model)
|
| 86 |
self.dit.set_gradient_checkpointing(self.config.dit.gradient_checkpoint)
|
| 87 |
|
|
|
|
| 90 |
loading_info = self.dit.load_state_dict(state, strict=True, assign=True)
|
| 91 |
print(f"Loading pretrained ckpt from {checkpoint}")
|
| 92 |
print(f"Loading info: {loading_info}")
|
| 93 |
+
self.dit = meta_non_persistent_buffer_init_fn(self.dit)
|
| 94 |
|
| 95 |
+
if device in [get_device(), "cuda"]:
|
| 96 |
+
self.dit.to(get_device())
|
| 97 |
|
| 98 |
# Print model size.
|
| 99 |
num_params = sum(p.numel() for p in self.dit.parameters() if p.requires_grad)
|
| 100 |
print(f"DiT trainable parameters: {num_params:,}")
|
| 101 |
|
| 102 |
+
@log_on_entry
|
| 103 |
+
@log_runtime
|
| 104 |
def configure_vae_model(self):
|
| 105 |
# Create vae model.
|
| 106 |
dtype = getattr(torch, self.config.vae.dtype)
|
| 107 |
self.vae = create_object(self.config.vae.model)
|
| 108 |
self.vae.requires_grad_(False).eval()
|
| 109 |
+
self.vae.to(device=get_device(), dtype=dtype)
|
| 110 |
|
| 111 |
# Load vae checkpoint.
|
| 112 |
state = torch.load(
|
| 113 |
+
self.config.vae.checkpoint, map_location=get_device(), mmap=True
|
| 114 |
)
|
| 115 |
self.vae.load_state_dict(state)
|
| 116 |
|
|
|
|
| 123 |
def configure_diffusion(self):
|
| 124 |
self.schedule = create_schedule_from_config(
|
| 125 |
config=self.config.diffusion.schedule,
|
| 126 |
+
device=get_device(),
|
| 127 |
)
|
| 128 |
self.sampling_timesteps = create_sampling_timesteps_from_config(
|
| 129 |
config=self.config.diffusion.timesteps.sampling,
|
| 130 |
schedule=self.schedule,
|
| 131 |
+
device=get_device(),
|
| 132 |
)
|
| 133 |
self.sampler = create_sampler_from_config(
|
| 134 |
config=self.config.diffusion.sampler,
|
|
|
|
| 143 |
use_sample = self.config.vae.get("use_sample", True)
|
| 144 |
latents = []
|
| 145 |
if len(samples) > 0:
|
| 146 |
+
device = get_device()
|
| 147 |
dtype = getattr(torch, self.config.vae.dtype)
|
| 148 |
scale = self.config.vae.scaling_factor
|
| 149 |
shift = self.config.vae.get("shifting_factor", 0.0)
|
|
|
|
| 186 |
def vae_decode(self, latents: List[Tensor]) -> List[Tensor]:
|
| 187 |
samples = []
|
| 188 |
if len(latents) > 0:
|
| 189 |
+
device = get_device()
|
| 190 |
dtype = getattr(torch, self.config.vae.dtype)
|
| 191 |
scale = self.config.vae.scaling_factor
|
| 192 |
shift = self.config.vae.get("shifting_factor", 0.0)
|
|
|
|
| 340 |
self.dit.to("cpu")
|
| 341 |
|
| 342 |
# Vae decode.
|
| 343 |
+
self.vae.to(get_device())
|
| 344 |
samples = self.vae_decode(latents)
|
| 345 |
|
| 346 |
if dit_offload:
|
| 347 |
+
self.dit.to(get_device())
|
| 348 |
return samples
|