Julien Blanchon commited on
Commit
6fb693e
·
1 Parent(s): ef1acb1
tim/models/utils/text_encoders.py CHANGED
@@ -25,8 +25,8 @@ def load_text_encoder(text_encoder_dir, device, weight_dtype):
25
  else:
26
  raise NotImplementedError
27
  # Set requires_grad to False for all parameters to avoid functorch issues
28
- for param in text_encoder.parameters():
29
- param.requires_grad = False
30
 
31
  text_encoder = text_encoder.eval().to(device=device, dtype=weight_dtype)
32
 
 
25
  else:
26
  raise NotImplementedError
27
  # Set requires_grad to False for all parameters to avoid functorch issues
28
+ # for param in text_encoder.parameters():
29
+ # param.requires_grad = False
30
 
31
  text_encoder = text_encoder.eval().to(device=device, dtype=weight_dtype)
32
 
tim/models/vae/__init__.py CHANGED
@@ -8,8 +8,8 @@ def get_dc_ae(vae_dir, dtype, device):
8
  dc_ae = AutoencoderDC.from_pretrained(vae_dir).to(dtype=dtype, device=device)
9
  dc_ae.eval()
10
  # Set requires_grad to False for all parameters to avoid functorch issues
11
- for param in dc_ae.parameters():
12
- param.requires_grad = False
13
  return dc_ae
14
 
15
 
@@ -37,8 +37,8 @@ def get_sd_vae(vae_dir, dtype, device):
37
  sd_vae = AutoencoderKL.from_pretrained(vae_dir).to(dtype=dtype, device=device)
38
  sd_vae.eval()
39
  # Set requires_grad to False for all parameters to avoid functorch issues
40
- for param in sd_vae.parameters():
41
- param.requires_grad = False
42
  return sd_vae
43
 
44
 
 
8
  dc_ae = AutoencoderDC.from_pretrained(vae_dir).to(dtype=dtype, device=device)
9
  dc_ae.eval()
10
  # Set requires_grad to False for all parameters to avoid functorch issues
11
+ # for param in dc_ae.parameters():
12
+ # param.requires_grad = False
13
  return dc_ae
14
 
15
 
 
37
  sd_vae = AutoencoderKL.from_pretrained(vae_dir).to(dtype=dtype, device=device)
38
  sd_vae.eval()
39
  # Set requires_grad to False for all parameters to avoid functorch issues
40
+ # for param in sd_vae.parameters():
41
+ # param.requires_grad = False
42
  return sd_vae
43
 
44
 
tim/utils/misc_utils.py CHANGED
@@ -283,9 +283,9 @@ def load_model_from_config(config, ckpt, verbose=True, freeze=True):
283
  print("unexpected keys:")
284
  print(u)
285
 
286
- if freeze:
287
- for param in model.parameters():
288
- param.requires_grad = False
289
 
290
  model.eval()
291
  return model
 
283
  print("unexpected keys:")
284
  print(u)
285
 
286
+ # if freeze:
287
+ # for param in model.parameters():
288
+ # param.requires_grad = False
289
 
290
  model.eval()
291
  return model
tim/utils/train_utils.py CHANGED
@@ -1,6 +1,5 @@
1
  import torch
2
  from collections import OrderedDict
3
- from copy import deepcopy
4
  from diffusers.utils import logging
5
 
6
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -9,8 +8,13 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
9
  def get_fsdp_plugin(fsdp_cfg, mixed_precision):
10
  import functools
11
  from torch.distributed.fsdp.fully_sharded_data_parallel import (
12
- BackwardPrefetch, CPUOffload, ShardingStrategy, MixedPrecision,
13
- StateDictType, FullStateDictConfig, FullOptimStateDictConfig,
 
 
 
 
 
14
  )
15
  from accelerate.utils import FullyShardedDataParallelPlugin
16
  from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy
@@ -20,39 +24,41 @@ def get_fsdp_plugin(fsdp_cfg, mixed_precision):
20
  elif mixed_precision == "bf16":
21
  dtype = torch.bfloat16
22
  else:
23
- dtype = torch.float32
24
  fsdp_plugin = FullyShardedDataParallelPlugin(
25
- sharding_strategy = {
26
- 'FULL_SHARD': ShardingStrategy.FULL_SHARD,
27
- 'SHARD_GRAD_OP': ShardingStrategy.SHARD_GRAD_OP,
28
- 'NO_SHARD': ShardingStrategy.NO_SHARD,
29
- 'HYBRID_SHARD': ShardingStrategy.HYBRID_SHARD,
30
- 'HYBRID_SHARD_ZERO2': ShardingStrategy._HYBRID_SHARD_ZERO2,
31
  }[fsdp_cfg.sharding_strategy],
32
- backward_prefetch = {
33
- 'BACKWARD_PRE': BackwardPrefetch.BACKWARD_PRE,
34
- 'BACKWARD_POST': BackwardPrefetch.BACKWARD_POST,
35
  }[fsdp_cfg.backward_prefetch],
36
- mixed_precision_policy = MixedPrecision(
37
  param_dtype=dtype,
38
  reduce_dtype=dtype,
39
  ),
40
- auto_wrap_policy = functools.partial(
41
  size_based_auto_wrap_policy, min_num_params=fsdp_cfg.min_num_params
42
  ),
43
- cpu_offload = CPUOffload(offload_params=fsdp_cfg.cpu_offload),
44
- state_dict_type = {
45
- 'FULL_STATE_DICT': StateDictType.FULL_STATE_DICT,
46
- 'LOCAL_STATE_DICT': StateDictType.LOCAL_STATE_DICT,
47
- 'SHARDED_STATE_DICT': StateDictType.SHARDED_STATE_DICT
48
  }[fsdp_cfg.state_dict_type],
49
- state_dict_config = FullStateDictConfig(offload_to_cpu=True, rank0_only=True),
50
- optim_state_dict_config = FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=True),
51
- limit_all_gathers = fsdp_cfg.limit_all_gathers,
52
- use_orig_params = fsdp_cfg.use_orig_params,
53
- sync_module_states = fsdp_cfg.sync_module_states,
54
- forward_prefetch = fsdp_cfg.forward_prefetch,
55
- activation_checkpointing = fsdp_cfg.activation_checkpointing,
 
 
56
  )
57
  return fsdp_plugin
58
 
@@ -60,21 +66,21 @@ def get_fsdp_plugin(fsdp_cfg, mixed_precision):
60
  def freeze_model(model, trainable_modules={}, verbose=False):
61
  logger.info("Start freeze")
62
  for name, param in model.named_parameters():
63
- param.requires_grad = False
64
  if verbose:
65
- logger.info("freeze moduel: "+str(name))
66
  for trainable_module_name in trainable_modules:
67
  if trainable_module_name in name:
68
- param.requires_grad = True
69
  if verbose:
70
- logger.info("unfreeze moduel: "+str(name))
71
  break
72
  logger.info("End freeze")
73
- params_unfreeze = [p.numel() if p.requires_grad == True else 0 for n, p in model.named_parameters()]
74
- params_freeze = [p.numel() if p.requires_grad == False else 0 for n, p in model.named_parameters()]
75
- logger.info(f"Unfreeze Module Parameters: {sum(params_unfreeze) / 1e6} M")
76
- logger.info(f"Freeze Module Parameters: {sum(params_freeze) / 1e6} M")
77
- return
78
 
79
 
80
  @torch.no_grad()
@@ -82,18 +88,17 @@ def update_ema(ema_model, model, decay=0.9999):
82
  """
83
  Step the EMA model towards the current model.
84
  """
85
- if hasattr(model, 'module'):
86
  model = model.module
87
- if hasattr(ema_model, 'module'):
88
  ema_model = ema_model.module
89
  ema_params = OrderedDict(ema_model.named_parameters())
90
  model_params = OrderedDict(model.named_parameters())
91
-
92
  for name, param in model_params.items():
93
  # TODO: Consider applying only to params that require_grad to avoid small numerical changes of pos_embed
94
  ema_params[name].mul_(decay).add_(param.data, alpha=1 - decay)
95
 
96
 
97
-
98
  def log_validation(model):
99
- pass
 
1
  import torch
2
  from collections import OrderedDict
 
3
  from diffusers.utils import logging
4
 
5
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
 
8
  def get_fsdp_plugin(fsdp_cfg, mixed_precision):
9
  import functools
10
  from torch.distributed.fsdp.fully_sharded_data_parallel import (
11
+ BackwardPrefetch,
12
+ CPUOffload,
13
+ ShardingStrategy,
14
+ MixedPrecision,
15
+ StateDictType,
16
+ FullStateDictConfig,
17
+ FullOptimStateDictConfig,
18
  )
19
  from accelerate.utils import FullyShardedDataParallelPlugin
20
  from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy
 
24
  elif mixed_precision == "bf16":
25
  dtype = torch.bfloat16
26
  else:
27
+ dtype = torch.float32
28
  fsdp_plugin = FullyShardedDataParallelPlugin(
29
+ sharding_strategy={
30
+ "FULL_SHARD": ShardingStrategy.FULL_SHARD,
31
+ "SHARD_GRAD_OP": ShardingStrategy.SHARD_GRAD_OP,
32
+ "NO_SHARD": ShardingStrategy.NO_SHARD,
33
+ "HYBRID_SHARD": ShardingStrategy.HYBRID_SHARD,
34
+ "HYBRID_SHARD_ZERO2": ShardingStrategy._HYBRID_SHARD_ZERO2,
35
  }[fsdp_cfg.sharding_strategy],
36
+ backward_prefetch={
37
+ "BACKWARD_PRE": BackwardPrefetch.BACKWARD_PRE,
38
+ "BACKWARD_POST": BackwardPrefetch.BACKWARD_POST,
39
  }[fsdp_cfg.backward_prefetch],
40
+ mixed_precision_policy=MixedPrecision(
41
  param_dtype=dtype,
42
  reduce_dtype=dtype,
43
  ),
44
+ auto_wrap_policy=functools.partial(
45
  size_based_auto_wrap_policy, min_num_params=fsdp_cfg.min_num_params
46
  ),
47
+ cpu_offload=CPUOffload(offload_params=fsdp_cfg.cpu_offload),
48
+ state_dict_type={
49
+ "FULL_STATE_DICT": StateDictType.FULL_STATE_DICT,
50
+ "LOCAL_STATE_DICT": StateDictType.LOCAL_STATE_DICT,
51
+ "SHARDED_STATE_DICT": StateDictType.SHARDED_STATE_DICT,
52
  }[fsdp_cfg.state_dict_type],
53
+ state_dict_config=FullStateDictConfig(offload_to_cpu=True, rank0_only=True),
54
+ optim_state_dict_config=FullOptimStateDictConfig(
55
+ offload_to_cpu=True, rank0_only=True
56
+ ),
57
+ limit_all_gathers=fsdp_cfg.limit_all_gathers,
58
+ use_orig_params=fsdp_cfg.use_orig_params,
59
+ sync_module_states=fsdp_cfg.sync_module_states,
60
+ forward_prefetch=fsdp_cfg.forward_prefetch,
61
+ activation_checkpointing=fsdp_cfg.activation_checkpointing,
62
  )
63
  return fsdp_plugin
64
 
 
66
  def freeze_model(model, trainable_modules={}, verbose=False):
67
  logger.info("Start freeze")
68
  for name, param in model.named_parameters():
69
+ # param.requires_grad = False
70
  if verbose:
71
+ logger.info("freeze moduel: " + str(name))
72
  for trainable_module_name in trainable_modules:
73
  if trainable_module_name in name:
74
+ # param.requires_grad = True
75
  if verbose:
76
+ logger.info("unfreeze moduel: " + str(name))
77
  break
78
  logger.info("End freeze")
79
+ # params_unfreeze = [p.numel() if p.requires_grad == True else 0 for n, p in model.named_parameters()]
80
+ # params_freeze = [p.numel() if p.requires_grad == False else 0 for n, p in model.named_parameters()]
81
+ # logger.info(f"Unfreeze Module Parameters: {sum(params_unfreeze) / 1e6} M")
82
+ # logger.info(f"Freeze Module Parameters: {sum(params_freeze) / 1e6} M")
83
+ return
84
 
85
 
86
  @torch.no_grad()
 
88
  """
89
  Step the EMA model towards the current model.
90
  """
91
+ if hasattr(model, "module"):
92
  model = model.module
93
+ if hasattr(ema_model, "module"):
94
  ema_model = ema_model.module
95
  ema_params = OrderedDict(ema_model.named_parameters())
96
  model_params = OrderedDict(model.named_parameters())
97
+
98
  for name, param in model_params.items():
99
  # TODO: Consider applying only to params that require_grad to avoid small numerical changes of pos_embed
100
  ema_params[name].mul_(decay).add_(param.data, alpha=1 - decay)
101
 
102
 
 
103
  def log_validation(model):
104
+ pass