import json import os import argparse import yaml args = None def set_global_args(local_args): global args args = local_args def parse_hp_string(hp_string): result = {} for pair in hp_string.split(','): if not pair: continue key, value = pair.split('=') try: # 自动转换为 int / float / str ori_value = value value = float(value) if '.' not in str(ori_value): value = int(value) except ValueError: pass if value in ['true', 'True']: value = True if value in ['false', 'False']: value = False if '.' in key: keys = key.split('.') keys = keys current = result for key in keys[:-1]: if key not in current or not isinstance(current[key], dict): current[key] = {} current = current[key] current[keys[-1]] = value else: result[key.strip()] = value return result def parse_args(): global args parser = argparse.ArgumentParser(description="Simple example of a training script.") parser.add_argument("--config", type=str, required=True, help="Path to YAML config file.") # 定义 argparse 参数 parser.add_argument("--exp_path", type=str, help="Path to save the model.") parser.add_argument("--input_file", type=str, help="Path to inference txt.") parser.add_argument("--debug", action='store_true', default=None) parser.add_argument("--infer", action='store_true') parser.add_argument("-hp", "--hparams", type=str, default="") args = parser.parse_args() # 读取 YAML 配置(如果提供了 --config 参数) if args.config: with open(args.config, "r") as f: yaml_config = yaml.safe_load(f) # 遍历 YAML 配置,将其添加到 args(如果 argparse 里没有定义) for key, value in yaml_config.items(): if not hasattr(args, key): # argparse 没有的参数 setattr(args, key, value) elif getattr(args, key) is None: # argparse 有但值为空 setattr(args, key, value) args.rank = int(os.getenv("RANK", "0")) args.world_size = int(os.getenv("WORLD_SIZE", "1")) args.local_rank = int(os.getenv("LOCAL_RANK", "0")) # torchrun args.device = 'cuda' debug = args.debug if not os.path.exists(args.exp_path): args.exp_path = f'checkpoints/{args.exp_path}' if hasattr(args, 'reload_cfg') and args.reload_cfg: # 重新加载配置文件 conf_path = os.path.join(args.exp_path, "config.json") if os.path.exists(conf_path): print('| Reloading config from:', conf_path) args = reload(args, conf_path) if len(args.hparams) > 0: hp_dict = parse_hp_string(args.hparams) for key, value in hp_dict.items(): if not hasattr(args, key): setattr(args, key, value) else: if isinstance(value, dict): ori_v = getattr(args, key) ori_v.update(value) setattr(args, key, ori_v) else: setattr(args, key, value) args.debug = debug dict_args = convert_namespace_to_dict(args) if args.local_rank == 0: print(dict_args) return args def reload(args, conf_path): """重新加载配置文件,不覆盖已有的参数""" with open(conf_path, "r") as f: yaml_config = yaml.safe_load(f) # 遍历 YAML 配置,将其添加到 args(如果 argparse 里没有定义) for key, value in yaml_config.items(): if not hasattr(args, key): # argparse 没有的参数 setattr(args, key, value) elif getattr(args, key) is None: # argparse 有但值为空 setattr(args, key, value) return args def convert_namespace_to_dict(namespace): """将 argparse.Namespace 转为字典,并处理不可序列化对象""" result = {} for key, value in vars(namespace).items(): try: json.dumps(value) # 检查是否可序列化 result[key] = value except (TypeError, OverflowError): result[key] = str(value) # 将不可序列化的对象转为字符串表示 return result