Spaces:
Runtime error
Runtime error
| def create_deepspeed_config(args): | |
| ds_config = { | |
| "steps_per_print": 1000, | |
| "train_batch_size": args.global_batch_size, | |
| "gradient_accumulation_steps": args.gradient_accumulation_steps, | |
| # "train_micro_batch_size_per_gpu": args.batch_size, # determined by (train_batch_size, gradient_accumulation_steps) | |
| "optimizer": { | |
| "type": "Adam", | |
| "adam_w_mode": True, | |
| "params": { | |
| "lr": args.lr, | |
| "weight_decay": args.weight_decay, | |
| "bias_correction": True, | |
| "betas": [ | |
| args.beta1, | |
| args.beta2 | |
| ], | |
| } | |
| }, | |
| "fp16": { | |
| "enabled": args.mixed_precision == 'fp16', | |
| "loss_scale": 0, | |
| "initial_scale_power": 16, | |
| "loss_scale_window": 1000, | |
| "hysteresis": 2, | |
| "min_loss_scale": 1 | |
| }, | |
| "bf16": { | |
| "enabled": args.mixed_precision == 'bf16', | |
| }, | |
| # "flops_profiler": { | |
| # "enabled": True, | |
| # "profile_step": -1, | |
| # "module_depth": -1, | |
| # "top_modules": 1, | |
| # "detailed": True, | |
| # }, | |
| "zero_allow_untested_optimizer": True | |
| } | |
| if args.clip_grad is not None: | |
| ds_config.update({'gradient_clipping': args.clip_grad}) | |
| if args.zero_stage == 0: | |
| ds_config.update({"zero_optimization": | |
| { | |
| "stage": args.zero_stage, | |
| "contiguous_gradients": True, | |
| "overlap_comm": True, | |
| } | |
| }) | |
| elif args.zero_stage == 1: | |
| ds_config.update({"zero_optimization": | |
| { | |
| "stage": args.zero_stage, | |
| "contiguous_gradients": True, | |
| "overlap_comm": True, | |
| "reduce_bucket_size": 5e8, | |
| } | |
| }) | |
| elif args.zero_stage == 2: | |
| ds_config.update({"zero_optimization": | |
| { | |
| "stage": args.zero_stage, | |
| "contiguous_gradients": True, | |
| "overlap_comm": True, | |
| "reduce_scatter": True, | |
| "reduce_bucket_size": 5e8, | |
| "allgather_bucket_size": 5e8, | |
| } | |
| }) | |
| elif args.zero_stage == 3: | |
| ds_config.update({"zero_optimization": | |
| { | |
| "stage": args.zero_stage, | |
| "contiguous_gradients": True, | |
| "overlap_comm": True, | |
| "reduce_bucket_size": 5e8, | |
| "stage3_prefetch_bucket_size": 5e8, | |
| "stage3_param_persistence_threshold": 1e6, | |
| "stage3_max_live_parameters": 1e9, | |
| "stage3_max_reuse_distance": 1e9, | |
| "stage3_gather_16bit_weights_on_model_save": True | |
| } | |
| }) | |
| return ds_config | |