|
|
import argparse |
|
|
|
|
|
|
|
|
def parse_args(): |
|
|
parser = argparse.ArgumentParser() |
|
|
parser.add_argument( |
|
|
"--config_path", type=str, |
|
|
default='', |
|
|
help="Path to config file" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--optimal_transport_method", |
|
|
type=str, |
|
|
default="exact", |
|
|
help="Use optimal transport in CFM training", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--split_ratios", |
|
|
nargs=2, |
|
|
type=float, |
|
|
default=[0.9, 0.1], |
|
|
help="Split ratios for training/validation data in CFM training", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--accelerator", type=str, default="cpu", help="Training accelerator" |
|
|
) |
|
|
parser.add_argument("--date", type=str) |
|
|
parser.add_argument("--seed", default=2, type=int) |
|
|
parser.add_argument("--device", default="cuda:1", type=str) |
|
|
parser.add_argument("--molecule", default="aldp", type=str) |
|
|
parser.add_argument('--wandb', action='store_true', default=False) |
|
|
parser.add_argument('--unseen', action='store_true', default=False) |
|
|
parser.add_argument('--run_name', default=None, type=str) |
|
|
|
|
|
parser.add_argument("--save_dir", default="", type=str) |
|
|
parser.add_argument("--root_dir", default="", type=str) |
|
|
|
|
|
parser.add_argument("--bias", default="force", type=str) |
|
|
|
|
|
parser.add_argument("--start_state", default="c5", type=str) |
|
|
parser.add_argument("--end_state", default="c7ax", type=str) |
|
|
parser.add_argument("--num_steps", default=100, type=int) |
|
|
|
|
|
parser.add_argument("--sigma", default=0.1, type=float) |
|
|
parser.add_argument("--num_samples", default=16, type=int) |
|
|
parser.add_argument("--temperature", default=300, type=float) |
|
|
parser.add_argument("--friction", default=2.0, type=float) |
|
|
parser.add_argument("--rbf", action='store_true', default=False) |
|
|
parser.add_argument("--use_delta_to_target", action='store_true', default=False) |
|
|
parser.add_argument("--use_gnn", action='store_true', default=False) |
|
|
|
|
|
parser.add_argument("--start_temperature", default=600, type=float) |
|
|
parser.add_argument("--end_temperature", default=300, type=float) |
|
|
parser.add_argument("--num_rollouts", default=1000, type=int) |
|
|
parser.add_argument("--trains_per_rollout", default=1000, type=int) |
|
|
parser.add_argument("--log_z_lr", default=1e-3, type=float) |
|
|
parser.add_argument("--policy_lr", default=1e-4, type=float) |
|
|
parser.add_argument("--batch_size", default=64, type=int) |
|
|
parser.add_argument("--buffer_size", default=1000, type=int) |
|
|
parser.add_argument("--max_grad_norm", default=1, type=int) |
|
|
parser.add_argument("--control_variate", default="global", type=str) |
|
|
parser.add_argument("--self_normalize", action='store_true', default=False) |
|
|
|
|
|
parser.add_argument("--objective", default="ce", type=str) |
|
|
parser.add_argument("--vel_conditioned", action='store_true', default=False) |
|
|
parser.add_argument("--dir_only", action='store_true', default=False) |
|
|
|
|
|
|
|
|
parser.add_argument("--num_particles", default=16, type=int) |
|
|
|
|
|
parser.add_argument("--kT", type=float, default=0.0) |
|
|
|
|
|
parser = datasets_parser(parser) |
|
|
|
|
|
|
|
|
parser = metric_parser(parser) |
|
|
|
|
|
return parser.parse_args() |
|
|
|
|
|
|
|
|
def datasets_parser(parser): |
|
|
parser.add_argument("--dim", type=int, default=50, help="Dimension of data") |
|
|
|
|
|
parser.add_argument( |
|
|
"--data_type", |
|
|
type=str, |
|
|
default="tahoe", |
|
|
help="Type of data, now wither scrna or one of toys", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--data_name", |
|
|
type=str, |
|
|
default="tahoe", |
|
|
help="Path to the dataset", |
|
|
) |
|
|
return parser |
|
|
|
|
|
|
|
|
def metric_parser(parser): |
|
|
parser.add_argument( |
|
|
"--n_centers", |
|
|
type=int, |
|
|
default=300, |
|
|
help="Number of centers for RBF network", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--kappa", |
|
|
type=float, |
|
|
default=1.5, |
|
|
help="Kappa parameter for RBF network", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--rho", |
|
|
type=float, |
|
|
default=-2.75, |
|
|
help="Rho parameter in Riemanian Velocity Calculation", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--velocity_metric", |
|
|
type=str, |
|
|
default="rbf", |
|
|
help="Metric for velocity calculation", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--gamma", |
|
|
nargs="+", |
|
|
type=float, |
|
|
default=0.2, |
|
|
help="Gamma parameter in Riemanian Velocity Calculation", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--metric_epochs", |
|
|
type=int, |
|
|
default=200, |
|
|
help="Number of epochs for metric learning", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--metric_patience", |
|
|
type=int, |
|
|
default=25, |
|
|
help="Patience for metric learning", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--metric_lr", |
|
|
type=float, |
|
|
default=1e-2, |
|
|
help="Learning rate for metric learning", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--alpha_metric", |
|
|
type=float, |
|
|
default=1.0, |
|
|
help="Alpha parameter for metric learning", |
|
|
) |
|
|
return parser |
|
|
|
|
|
|