Mariam-Elz commited on
Commit
cb33ff6
·
verified ·
1 Parent(s): 733ec33

Upload imagedream/model_zoo.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. imagedream/model_zoo.py +64 -64
imagedream/model_zoo.py CHANGED
@@ -1,64 +1,64 @@
1
- """ Utiliy functions to load pre-trained models more easily """
2
- import os
3
- import pkg_resources
4
- from omegaconf import OmegaConf
5
-
6
- import torch
7
- from huggingface_hub import hf_hub_download
8
-
9
- from imagedream.ldm.util import instantiate_from_config
10
-
11
-
12
- PRETRAINED_MODELS = {
13
- "sd-v2.1-base-4view-ipmv": {
14
- "config": "sd_v2_base_ipmv.yaml",
15
- "repo_id": "Peng-Wang/ImageDream",
16
- "filename": "sd-v2.1-base-4view-ipmv.pt",
17
- },
18
- "sd-v2.1-base-4view-ipmv-local": {
19
- "config": "sd_v2_base_ipmv_local.yaml",
20
- "repo_id": "Peng-Wang/ImageDream",
21
- "filename": "sd-v2.1-base-4view-ipmv-local.pt",
22
- },
23
- }
24
-
25
-
26
- def get_config_file(config_path):
27
- cfg_file = pkg_resources.resource_filename(
28
- "imagedream", os.path.join("configs", config_path)
29
- )
30
- if not os.path.exists(cfg_file):
31
- raise RuntimeError(f"Config {config_path} not available!")
32
- return cfg_file
33
-
34
-
35
- def build_model(model_name, config_path=None, ckpt_path=None, cache_dir=None):
36
- if (config_path is not None) and (ckpt_path is not None):
37
- config = OmegaConf.load(config_path)
38
- model = instantiate_from_config(config.model)
39
- model.load_state_dict(torch.load(ckpt_path, map_location="cpu"), strict=False)
40
- return model
41
-
42
- if not model_name in PRETRAINED_MODELS:
43
- raise RuntimeError(
44
- f"Model name {model_name} is not a pre-trained model. Available models are:\n- "
45
- + "\n- ".join(PRETRAINED_MODELS.keys())
46
- )
47
- model_info = PRETRAINED_MODELS[model_name]
48
-
49
- # Instiantiate the model
50
- print(f"Loading model from config: {model_info['config']}")
51
- config_file = get_config_file(model_info["config"])
52
- config = OmegaConf.load(config_file)
53
- model = instantiate_from_config(config.model)
54
-
55
- # Load pre-trained checkpoint from huggingface
56
- if not ckpt_path:
57
- ckpt_path = hf_hub_download(
58
- repo_id=model_info["repo_id"],
59
- filename=model_info["filename"],
60
- cache_dir=cache_dir,
61
- )
62
- print(f"Loading model from cache file: {ckpt_path}")
63
- model.load_state_dict(torch.load(ckpt_path, map_location="cpu"), strict=False)
64
- return model
 
1
+ """ Utiliy functions to load pre-trained models more easily """
2
+ import os
3
+ import pkg_resources
4
+ from omegaconf import OmegaConf
5
+
6
+ import torch
7
+ from huggingface_hub import hf_hub_download
8
+
9
+ from imagedream.ldm.util import instantiate_from_config
10
+
11
+
12
+ PRETRAINED_MODELS = {
13
+ "sd-v2.1-base-4view-ipmv": {
14
+ "config": "sd_v2_base_ipmv.yaml",
15
+ "repo_id": "Peng-Wang/ImageDream",
16
+ "filename": "sd-v2.1-base-4view-ipmv.pt",
17
+ },
18
+ "sd-v2.1-base-4view-ipmv-local": {
19
+ "config": "sd_v2_base_ipmv_local.yaml",
20
+ "repo_id": "Peng-Wang/ImageDream",
21
+ "filename": "sd-v2.1-base-4view-ipmv-local.pt",
22
+ },
23
+ }
24
+
25
+
26
+ def get_config_file(config_path):
27
+ cfg_file = pkg_resources.resource_filename(
28
+ "imagedream", os.path.join("configs", config_path)
29
+ )
30
+ if not os.path.exists(cfg_file):
31
+ raise RuntimeError(f"Config {config_path} not available!")
32
+ return cfg_file
33
+
34
+
35
+ def build_model(model_name, config_path=None, ckpt_path=None, cache_dir=None):
36
+ if (config_path is not None) and (ckpt_path is not None):
37
+ config = OmegaConf.load(config_path)
38
+ model = instantiate_from_config(config.model)
39
+ model.load_state_dict(torch.load(ckpt_path, map_location="cpu"), strict=False)
40
+ return model
41
+
42
+ if not model_name in PRETRAINED_MODELS:
43
+ raise RuntimeError(
44
+ f"Model name {model_name} is not a pre-trained model. Available models are:\n- "
45
+ + "\n- ".join(PRETRAINED_MODELS.keys())
46
+ )
47
+ model_info = PRETRAINED_MODELS[model_name]
48
+
49
+ # Instiantiate the model
50
+ print(f"Loading model from config: {model_info['config']}")
51
+ config_file = get_config_file(model_info["config"])
52
+ config = OmegaConf.load(config_file)
53
+ model = instantiate_from_config(config.model)
54
+
55
+ # Load pre-trained checkpoint from huggingface
56
+ if not ckpt_path:
57
+ ckpt_path = hf_hub_download(
58
+ repo_id=model_info["repo_id"],
59
+ filename=model_info["filename"],
60
+ cache_dir=cache_dir,
61
+ )
62
+ print(f"Loading model from cache file: {ckpt_path}")
63
+ model.load_state_dict(torch.load(ckpt_path, map_location="cpu"), strict=False)
64
+ return model