Spaces:
Running
on
Zero
Running
on
Zero
IceClear
commited on
Commit
·
bfeb62a
1
Parent(s):
303cd3c
update
Browse files
projects/video_diffusion_sr/infer.py
CHANGED
|
@@ -71,7 +71,7 @@ class VideoDiffusionInfer():
|
|
| 71 |
|
| 72 |
@log_on_entry
|
| 73 |
@log_runtime
|
| 74 |
-
def configure_dit_model(self, device="
|
| 75 |
# Load dit checkpoint.
|
| 76 |
# For fast init & resume,
|
| 77 |
# when training from scratch, rank0 init DiT on cpu, then sync to other ranks with FSDP.
|
|
@@ -83,7 +83,7 @@ class VideoDiffusionInfer():
|
|
| 83 |
self.dit.set_gradient_checkpointing(self.config.dit.gradient_checkpoint)
|
| 84 |
|
| 85 |
if checkpoint:
|
| 86 |
-
state = torch.load(checkpoint, map_location=
|
| 87 |
loading_info = self.dit.load_state_dict(state, strict=True, assign=True)
|
| 88 |
print(f"Loading pretrained ckpt from {checkpoint}")
|
| 89 |
print(f"Loading info: {loading_info}")
|
|
|
|
| 71 |
|
| 72 |
@log_on_entry
|
| 73 |
@log_runtime
|
| 74 |
+
def configure_dit_model(self, device="cuda", checkpoint=None):
|
| 75 |
# Load dit checkpoint.
|
| 76 |
# For fast init & resume,
|
| 77 |
# when training from scratch, rank0 init DiT on cpu, then sync to other ranks with FSDP.
|
|
|
|
| 83 |
self.dit.set_gradient_checkpointing(self.config.dit.gradient_checkpoint)
|
| 84 |
|
| 85 |
if checkpoint:
|
| 86 |
+
state = torch.load(checkpoint, map_location=self.device, mmap=True)
|
| 87 |
loading_info = self.dit.load_state_dict(state, strict=True, assign=True)
|
| 88 |
print(f"Loading pretrained ckpt from {checkpoint}")
|
| 89 |
print(f"Loading info: {loading_info}")
|