Fabrice-TIERCELIN commited on
Commit
5a3a052
·
verified ·
1 Parent(s): 73f31a8

Upload __init__.py

Browse files
Files changed (1) hide show
  1. hyvideo/__init__.py +62 -0
hyvideo/__init__.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+
3
+ import torch
4
+
5
+ from .autoencoder_kl_causal_3d import AutoencoderKLCausal3D
6
+ from ..constants import VAE_PATH, PRECISION_TO_TYPE
7
+
8
+ def load_vae(vae_type: str="884-16c-hy",
9
+ vae_precision: str=None,
10
+ sample_size: tuple=None,
11
+ vae_path: str=None,
12
+ logger=None,
13
+ device=None
14
+ ):
15
+ """the fucntion to load the 3D VAE model
16
+
17
+ Args:
18
+ vae_type (str): the type of the 3D VAE model. Defaults to "884-16c-hy".
19
+ vae_precision (str, optional): the precision to load vae. Defaults to None.
20
+ sample_size (tuple, optional): the tiling size. Defaults to None.
21
+ vae_path (str, optional): the path to vae. Defaults to None.
22
+ logger (_type_, optional): logger. Defaults to None.
23
+ device (_type_, optional): device to load vae. Defaults to None.
24
+ """
25
+ if vae_path is None:
26
+ vae_path = VAE_PATH[vae_type]
27
+
28
+ if logger is not None:
29
+ logger.info(f"Loading 3D VAE model ({vae_type}) from: {vae_path}")
30
+ config = AutoencoderKLCausal3D.load_config(vae_path)
31
+ if sample_size:
32
+ vae = AutoencoderKLCausal3D.from_config(config, sample_size=sample_size)
33
+ else:
34
+ vae = AutoencoderKLCausal3D.from_config(config)
35
+
36
+ vae_ckpt = Path(vae_path) / "pytorch_model.pt"
37
+ assert vae_ckpt.exists(), f"VAE checkpoint not found: {vae_ckpt}"
38
+
39
+ ckpt = torch.load(vae_ckpt, map_location=vae.device)
40
+ if "state_dict" in ckpt:
41
+ ckpt = ckpt["state_dict"]
42
+ if any(k.startswith("vae.") for k in ckpt.keys()):
43
+ ckpt = {k.replace("vae.", ""): v for k, v in ckpt.items() if k.startswith("vae.")}
44
+ vae.load_state_dict(ckpt)
45
+
46
+ spatial_compression_ratio = vae.config.spatial_compression_ratio
47
+ time_compression_ratio = vae.config.time_compression_ratio
48
+
49
+ if vae_precision is not None:
50
+ vae = vae.to(dtype=PRECISION_TO_TYPE[vae_precision])
51
+
52
+ vae.requires_grad_(False)
53
+
54
+ if logger is not None:
55
+ logger.info(f"VAE to dtype: {vae.dtype}")
56
+
57
+ if device is not None:
58
+ vae = vae.to(device)
59
+
60
+ vae.eval()
61
+
62
+ return vae, vae_path, spatial_compression_ratio, time_compression_ratio