Julien Blanchon
commited on
Commit
·
6fb693e
1
Parent(s):
ef1acb1
Update
Browse files- tim/models/utils/text_encoders.py +2 -2
- tim/models/vae/__init__.py +4 -4
- tim/utils/misc_utils.py +3 -3
- tim/utils/train_utils.py +46 -41
tim/models/utils/text_encoders.py
CHANGED
@@ -25,8 +25,8 @@ def load_text_encoder(text_encoder_dir, device, weight_dtype):
|
|
25 |
else:
|
26 |
raise NotImplementedError
|
27 |
# Set requires_grad to False for all parameters to avoid functorch issues
|
28 |
-
for param in text_encoder.parameters():
|
29 |
-
|
30 |
|
31 |
text_encoder = text_encoder.eval().to(device=device, dtype=weight_dtype)
|
32 |
|
|
|
25 |
else:
|
26 |
raise NotImplementedError
|
27 |
# Set requires_grad to False for all parameters to avoid functorch issues
|
28 |
+
# for param in text_encoder.parameters():
|
29 |
+
# param.requires_grad = False
|
30 |
|
31 |
text_encoder = text_encoder.eval().to(device=device, dtype=weight_dtype)
|
32 |
|
tim/models/vae/__init__.py
CHANGED
@@ -8,8 +8,8 @@ def get_dc_ae(vae_dir, dtype, device):
|
|
8 |
dc_ae = AutoencoderDC.from_pretrained(vae_dir).to(dtype=dtype, device=device)
|
9 |
dc_ae.eval()
|
10 |
# Set requires_grad to False for all parameters to avoid functorch issues
|
11 |
-
for param in dc_ae.parameters():
|
12 |
-
|
13 |
return dc_ae
|
14 |
|
15 |
|
@@ -37,8 +37,8 @@ def get_sd_vae(vae_dir, dtype, device):
|
|
37 |
sd_vae = AutoencoderKL.from_pretrained(vae_dir).to(dtype=dtype, device=device)
|
38 |
sd_vae.eval()
|
39 |
# Set requires_grad to False for all parameters to avoid functorch issues
|
40 |
-
for param in sd_vae.parameters():
|
41 |
-
|
42 |
return sd_vae
|
43 |
|
44 |
|
|
|
8 |
dc_ae = AutoencoderDC.from_pretrained(vae_dir).to(dtype=dtype, device=device)
|
9 |
dc_ae.eval()
|
10 |
# Set requires_grad to False for all parameters to avoid functorch issues
|
11 |
+
# for param in dc_ae.parameters():
|
12 |
+
# param.requires_grad = False
|
13 |
return dc_ae
|
14 |
|
15 |
|
|
|
37 |
sd_vae = AutoencoderKL.from_pretrained(vae_dir).to(dtype=dtype, device=device)
|
38 |
sd_vae.eval()
|
39 |
# Set requires_grad to False for all parameters to avoid functorch issues
|
40 |
+
# for param in sd_vae.parameters():
|
41 |
+
# param.requires_grad = False
|
42 |
return sd_vae
|
43 |
|
44 |
|
tim/utils/misc_utils.py
CHANGED
@@ -283,9 +283,9 @@ def load_model_from_config(config, ckpt, verbose=True, freeze=True):
|
|
283 |
print("unexpected keys:")
|
284 |
print(u)
|
285 |
|
286 |
-
if freeze:
|
287 |
-
|
288 |
-
|
289 |
|
290 |
model.eval()
|
291 |
return model
|
|
|
283 |
print("unexpected keys:")
|
284 |
print(u)
|
285 |
|
286 |
+
# if freeze:
|
287 |
+
# for param in model.parameters():
|
288 |
+
# param.requires_grad = False
|
289 |
|
290 |
model.eval()
|
291 |
return model
|
tim/utils/train_utils.py
CHANGED
@@ -1,6 +1,5 @@
|
|
1 |
import torch
|
2 |
from collections import OrderedDict
|
3 |
-
from copy import deepcopy
|
4 |
from diffusers.utils import logging
|
5 |
|
6 |
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
@@ -9,8 +8,13 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
|
9 |
def get_fsdp_plugin(fsdp_cfg, mixed_precision):
|
10 |
import functools
|
11 |
from torch.distributed.fsdp.fully_sharded_data_parallel import (
|
12 |
-
BackwardPrefetch,
|
13 |
-
|
|
|
|
|
|
|
|
|
|
|
14 |
)
|
15 |
from accelerate.utils import FullyShardedDataParallelPlugin
|
16 |
from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy
|
@@ -20,39 +24,41 @@ def get_fsdp_plugin(fsdp_cfg, mixed_precision):
|
|
20 |
elif mixed_precision == "bf16":
|
21 |
dtype = torch.bfloat16
|
22 |
else:
|
23 |
-
dtype = torch.float32
|
24 |
fsdp_plugin = FullyShardedDataParallelPlugin(
|
25 |
-
sharding_strategy
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
}[fsdp_cfg.sharding_strategy],
|
32 |
-
backward_prefetch
|
33 |
-
|
34 |
-
|
35 |
}[fsdp_cfg.backward_prefetch],
|
36 |
-
mixed_precision_policy
|
37 |
param_dtype=dtype,
|
38 |
reduce_dtype=dtype,
|
39 |
),
|
40 |
-
auto_wrap_policy
|
41 |
size_based_auto_wrap_policy, min_num_params=fsdp_cfg.min_num_params
|
42 |
),
|
43 |
-
cpu_offload
|
44 |
-
state_dict_type
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
}[fsdp_cfg.state_dict_type],
|
49 |
-
state_dict_config
|
50 |
-
optim_state_dict_config
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
|
|
|
|
56 |
)
|
57 |
return fsdp_plugin
|
58 |
|
@@ -60,21 +66,21 @@ def get_fsdp_plugin(fsdp_cfg, mixed_precision):
|
|
60 |
def freeze_model(model, trainable_modules={}, verbose=False):
|
61 |
logger.info("Start freeze")
|
62 |
for name, param in model.named_parameters():
|
63 |
-
param.requires_grad = False
|
64 |
if verbose:
|
65 |
-
logger.info("freeze moduel: "+str(name))
|
66 |
for trainable_module_name in trainable_modules:
|
67 |
if trainable_module_name in name:
|
68 |
-
param.requires_grad = True
|
69 |
if verbose:
|
70 |
-
logger.info("unfreeze moduel: "+str(name))
|
71 |
break
|
72 |
logger.info("End freeze")
|
73 |
-
params_unfreeze = [p.numel() if p.requires_grad == True else 0 for n, p in model.named_parameters()]
|
74 |
-
params_freeze = [p.numel() if p.requires_grad == False else 0 for n, p in model.named_parameters()]
|
75 |
-
logger.info(f"Unfreeze Module Parameters: {sum(params_unfreeze) / 1e6} M")
|
76 |
-
logger.info(f"Freeze Module Parameters: {sum(params_freeze) / 1e6} M")
|
77 |
-
return
|
78 |
|
79 |
|
80 |
@torch.no_grad()
|
@@ -82,18 +88,17 @@ def update_ema(ema_model, model, decay=0.9999):
|
|
82 |
"""
|
83 |
Step the EMA model towards the current model.
|
84 |
"""
|
85 |
-
if hasattr(model,
|
86 |
model = model.module
|
87 |
-
if hasattr(ema_model,
|
88 |
ema_model = ema_model.module
|
89 |
ema_params = OrderedDict(ema_model.named_parameters())
|
90 |
model_params = OrderedDict(model.named_parameters())
|
91 |
-
|
92 |
for name, param in model_params.items():
|
93 |
# TODO: Consider applying only to params that require_grad to avoid small numerical changes of pos_embed
|
94 |
ema_params[name].mul_(decay).add_(param.data, alpha=1 - decay)
|
95 |
|
96 |
|
97 |
-
|
98 |
def log_validation(model):
|
99 |
-
pass
|
|
|
1 |
import torch
|
2 |
from collections import OrderedDict
|
|
|
3 |
from diffusers.utils import logging
|
4 |
|
5 |
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
|
|
8 |
def get_fsdp_plugin(fsdp_cfg, mixed_precision):
|
9 |
import functools
|
10 |
from torch.distributed.fsdp.fully_sharded_data_parallel import (
|
11 |
+
BackwardPrefetch,
|
12 |
+
CPUOffload,
|
13 |
+
ShardingStrategy,
|
14 |
+
MixedPrecision,
|
15 |
+
StateDictType,
|
16 |
+
FullStateDictConfig,
|
17 |
+
FullOptimStateDictConfig,
|
18 |
)
|
19 |
from accelerate.utils import FullyShardedDataParallelPlugin
|
20 |
from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy
|
|
|
24 |
elif mixed_precision == "bf16":
|
25 |
dtype = torch.bfloat16
|
26 |
else:
|
27 |
+
dtype = torch.float32
|
28 |
fsdp_plugin = FullyShardedDataParallelPlugin(
|
29 |
+
sharding_strategy={
|
30 |
+
"FULL_SHARD": ShardingStrategy.FULL_SHARD,
|
31 |
+
"SHARD_GRAD_OP": ShardingStrategy.SHARD_GRAD_OP,
|
32 |
+
"NO_SHARD": ShardingStrategy.NO_SHARD,
|
33 |
+
"HYBRID_SHARD": ShardingStrategy.HYBRID_SHARD,
|
34 |
+
"HYBRID_SHARD_ZERO2": ShardingStrategy._HYBRID_SHARD_ZERO2,
|
35 |
}[fsdp_cfg.sharding_strategy],
|
36 |
+
backward_prefetch={
|
37 |
+
"BACKWARD_PRE": BackwardPrefetch.BACKWARD_PRE,
|
38 |
+
"BACKWARD_POST": BackwardPrefetch.BACKWARD_POST,
|
39 |
}[fsdp_cfg.backward_prefetch],
|
40 |
+
mixed_precision_policy=MixedPrecision(
|
41 |
param_dtype=dtype,
|
42 |
reduce_dtype=dtype,
|
43 |
),
|
44 |
+
auto_wrap_policy=functools.partial(
|
45 |
size_based_auto_wrap_policy, min_num_params=fsdp_cfg.min_num_params
|
46 |
),
|
47 |
+
cpu_offload=CPUOffload(offload_params=fsdp_cfg.cpu_offload),
|
48 |
+
state_dict_type={
|
49 |
+
"FULL_STATE_DICT": StateDictType.FULL_STATE_DICT,
|
50 |
+
"LOCAL_STATE_DICT": StateDictType.LOCAL_STATE_DICT,
|
51 |
+
"SHARDED_STATE_DICT": StateDictType.SHARDED_STATE_DICT,
|
52 |
}[fsdp_cfg.state_dict_type],
|
53 |
+
state_dict_config=FullStateDictConfig(offload_to_cpu=True, rank0_only=True),
|
54 |
+
optim_state_dict_config=FullOptimStateDictConfig(
|
55 |
+
offload_to_cpu=True, rank0_only=True
|
56 |
+
),
|
57 |
+
limit_all_gathers=fsdp_cfg.limit_all_gathers,
|
58 |
+
use_orig_params=fsdp_cfg.use_orig_params,
|
59 |
+
sync_module_states=fsdp_cfg.sync_module_states,
|
60 |
+
forward_prefetch=fsdp_cfg.forward_prefetch,
|
61 |
+
activation_checkpointing=fsdp_cfg.activation_checkpointing,
|
62 |
)
|
63 |
return fsdp_plugin
|
64 |
|
|
|
66 |
def freeze_model(model, trainable_modules={}, verbose=False):
|
67 |
logger.info("Start freeze")
|
68 |
for name, param in model.named_parameters():
|
69 |
+
# param.requires_grad = False
|
70 |
if verbose:
|
71 |
+
logger.info("freeze moduel: " + str(name))
|
72 |
for trainable_module_name in trainable_modules:
|
73 |
if trainable_module_name in name:
|
74 |
+
# param.requires_grad = True
|
75 |
if verbose:
|
76 |
+
logger.info("unfreeze moduel: " + str(name))
|
77 |
break
|
78 |
logger.info("End freeze")
|
79 |
+
# params_unfreeze = [p.numel() if p.requires_grad == True else 0 for n, p in model.named_parameters()]
|
80 |
+
# params_freeze = [p.numel() if p.requires_grad == False else 0 for n, p in model.named_parameters()]
|
81 |
+
# logger.info(f"Unfreeze Module Parameters: {sum(params_unfreeze) / 1e6} M")
|
82 |
+
# logger.info(f"Freeze Module Parameters: {sum(params_freeze) / 1e6} M")
|
83 |
+
return
|
84 |
|
85 |
|
86 |
@torch.no_grad()
|
|
|
88 |
"""
|
89 |
Step the EMA model towards the current model.
|
90 |
"""
|
91 |
+
if hasattr(model, "module"):
|
92 |
model = model.module
|
93 |
+
if hasattr(ema_model, "module"):
|
94 |
ema_model = ema_model.module
|
95 |
ema_params = OrderedDict(ema_model.named_parameters())
|
96 |
model_params = OrderedDict(model.named_parameters())
|
97 |
+
|
98 |
for name, param in model_params.items():
|
99 |
# TODO: Consider applying only to params that require_grad to avoid small numerical changes of pos_embed
|
100 |
ema_params[name].mul_(decay).add_(param.data, alpha=1 - decay)
|
101 |
|
102 |
|
|
|
103 |
def log_validation(model):
|
104 |
+
pass
|