Fabrice-TIERCELIN commited on
Commit
5ee9dd4
·
verified ·
1 Parent(s): 1b84443

Upload 4 files

Browse files
Files changed (1) hide show
  1. hyvideo/vae/__init__.py +62 -62
hyvideo/vae/__init__.py CHANGED
@@ -1,62 +1,62 @@
1
- from pathlib import Path
2
-
3
- import torch
4
-
5
- from .autoencoder_kl_causal_3d import AutoencoderKLCausal3D
6
- from ..constants import VAE_PATH, PRECISION_TO_TYPE
7
-
8
- def load_vae(vae_type: str="884-16c-hy",
9
- vae_precision: str=None,
10
- sample_size: tuple=None,
11
- vae_path: str=None,
12
- logger=None,
13
- device=None
14
- ):
15
- """the fucntion to load the 3D VAE model
16
-
17
- Args:
18
- vae_type (str): the type of the 3D VAE model. Defaults to "884-16c-hy".
19
- vae_precision (str, optional): the precision to load vae. Defaults to None.
20
- sample_size (tuple, optional): the tiling size. Defaults to None.
21
- vae_path (str, optional): the path to vae. Defaults to None.
22
- logger (_type_, optional): logger. Defaults to None.
23
- device (_type_, optional): device to load vae. Defaults to None.
24
- """
25
- if vae_path is None:
26
- vae_path = VAE_PATH[vae_type]
27
-
28
- if logger is not None:
29
- logger.info(f"Loading 3D VAE model ({vae_type}) from: {vae_path}")
30
- config = AutoencoderKLCausal3D.load_config(vae_path)
31
- if sample_size:
32
- vae = AutoencoderKLCausal3D.from_config(config, sample_size=sample_size)
33
- else:
34
- vae = AutoencoderKLCausal3D.from_config(config)
35
-
36
- vae_ckpt = Path(vae_path) / "pytorch_model.pt"
37
- assert vae_ckpt.exists(), f"VAE checkpoint not found: {vae_ckpt}"
38
-
39
- ckpt = torch.load(vae_ckpt, map_location=vae.device)
40
- if "state_dict" in ckpt:
41
- ckpt = ckpt["state_dict"]
42
- if any(k.startswith("vae.") for k in ckpt.keys()):
43
- ckpt = {k.replace("vae.", ""): v for k, v in ckpt.items() if k.startswith("vae.")}
44
- vae.load_state_dict(ckpt)
45
-
46
- spatial_compression_ratio = vae.config.spatial_compression_ratio
47
- time_compression_ratio = vae.config.time_compression_ratio
48
-
49
- if vae_precision is not None:
50
- vae = vae.to(dtype=PRECISION_TO_TYPE[vae_precision])
51
-
52
- vae.requires_grad_(False)
53
-
54
- if logger is not None:
55
- logger.info(f"VAE to dtype: {vae.dtype}")
56
-
57
- if device is not None:
58
- vae = vae.to(device)
59
-
60
- vae.eval()
61
-
62
- return vae, vae_path, spatial_compression_ratio, time_compression_ratio
 
1
+ from pathlib import Path
2
+
3
+ import torch
4
+
5
+ from .autoencoder_kl_causal_3d import AutoencoderKLCausal3D
6
+ from ..constants import VAE_PATH, PRECISION_TO_TYPE
7
+
8
+ def load_vae(vae_type: str="884-16c-hy",
9
+ vae_precision: str=None,
10
+ sample_size: tuple=None,
11
+ vae_path: str=None,
12
+ logger=None,
13
+ device=None
14
+ ):
15
+ """the fucntion to load the 3D VAE model
16
+
17
+ Args:
18
+ vae_type (str): the type of the 3D VAE model. Defaults to "884-16c-hy".
19
+ vae_precision (str, optional): the precision to load vae. Defaults to None.
20
+ sample_size (tuple, optional): the tiling size. Defaults to None.
21
+ vae_path (str, optional): the path to vae. Defaults to None.
22
+ logger (_type_, optional): logger. Defaults to None.
23
+ device (_type_, optional): device to load vae. Defaults to None.
24
+ """
25
+ if vae_path is None:
26
+ vae_path = VAE_PATH[vae_type]
27
+
28
+ if logger is not None:
29
+ logger.info(f"Loading 3D VAE model ({vae_type}) from: {vae_path}")
30
+ config = AutoencoderKLCausal3D.load_config(vae_path)
31
+ if sample_size:
32
+ vae = AutoencoderKLCausal3D.from_config(config, sample_size=sample_size)
33
+ else:
34
+ vae = AutoencoderKLCausal3D.from_config(config)
35
+
36
+ vae_ckpt = Path(vae_path) / "pytorch_model.pt"
37
+ assert vae_ckpt.exists(), f"VAE checkpoint not found: {vae_ckpt}"
38
+
39
+ ckpt = torch.load(vae_ckpt, map_location=vae.device)
40
+ if "state_dict" in ckpt:
41
+ ckpt = ckpt["state_dict"]
42
+ if any(k.startswith("vae.") for k in ckpt.keys()):
43
+ ckpt = {k.replace("vae.", ""): v for k, v in ckpt.items() if k.startswith("vae.")}
44
+ vae.load_state_dict(ckpt)
45
+
46
+ spatial_compression_ratio = vae.config.spatial_compression_ratio
47
+ time_compression_ratio = vae.config.time_compression_ratio
48
+
49
+ if vae_precision is not None:
50
+ vae = vae.to(dtype=PRECISION_TO_TYPE[vae_precision])
51
+
52
+ vae.requires_grad_(False)
53
+
54
+ if logger is not None:
55
+ logger.info(f"VAE to dtype: {vae.dtype}")
56
+
57
+ if device is not None:
58
+ vae = vae.to(device)
59
+
60
+ vae.eval()
61
+
62
+ return vae, vae_path, spatial_compression_ratio, time_compression_ratio