Tzktz's picture
Upload 7664 files
6fc683c verified
import os
import copy
import pytorch_lightning as pl
from vlmo.config import ex
from vlmo.modules import VLMo
from vlmo.datamodules.multitask_datamodule import MTDataModule
from pytorch_lightning.plugins import environments as pl_env
from pytorch_lightning.utilities.distributed import rank_zero_info
class OMPIClusterEnvironment(pl_env.ClusterEnvironment):
def __init__(self):
super().__init__()
# def creates_children(self) -> bool:
# # return True if the cluster is managed (you don't launch processes yourself)
# assert (
# "OMPI_COMM_WORLD_LOCAL_RANK" in os.environ
# ) # this cluster is managed
# return True
@property
def creates_processes_externally(self):
return True
def world_size(self) -> int:
return int(os.environ["OMPI_COMM_WORLD_SIZE"])
def set_world_size(self, size: int):
pass
def global_rank(self) -> int:
return int(os.environ["OMPI_COMM_WORLD_RANK"])
def set_global_rank(self, rank: int):
pass
def local_rank(self) -> int:
return int(os.environ["OMPI_COMM_WORLD_LOCAL_RANK"])
def node_rank(self) -> int:
if "NODE_RANK" in os.environ:
return int(os.environ["NODE_RANK"])
else:
return 0
def master_address(self) -> str:
return os.environ["MASTER_ADDR"]
def master_port(self) -> int:
return int(os.environ["MASTER_PORT"])
def get_cluster_plugin(num_gpus=1, num_nodes=1):
if num_nodes > 1 or (
num_nodes == 1 and "OMPI_COMM_WORLD_SIZE" in os.environ
):
rank_zero_info("ClusterPlugin: using OMPI Cluster Environment")
return OMPIClusterEnvironment()
if num_gpus >= 1:
rank_zero_info("ClusterPlugin: using Lightning Cluster Environment")
return pl_env.LightningEnvironment()
return None
@ex.automain
def main(_config):
_config = copy.deepcopy(_config)
pl.seed_everything(_config["seed"])
dm = MTDataModule(_config, dist=True)
model = VLMo(_config)
exp_name = f'{_config["exp_name"]}'
os.makedirs(_config["log_dir"], exist_ok=True)
checkpoint_callback = pl.callbacks.ModelCheckpoint(
save_top_k=-1,
verbose=True,
monitor="val/the_metric",
mode="max",
save_last=True,
)
logger = pl.loggers.TensorBoardLogger(
_config["log_dir"],
name=f'{exp_name}_seed{_config["seed"]}_from_{_config["load_path"].split("/")[-1][:-5]}',
)
lr_callback = pl.callbacks.LearningRateMonitor(logging_interval="step")
callbacks = [checkpoint_callback, lr_callback]
num_gpus = (
_config["num_gpus"]
if isinstance(_config["num_gpus"], int)
else len(_config["num_gpus"])
)
grad_steps = _config["batch_size"] // (
_config["per_gpu_batchsize"] * num_gpus * _config["num_nodes"]
)
rank_zero_info("grad_steps: {}".format(grad_steps))
max_steps = _config["max_steps"] if _config["max_steps"] is not None else None
resume_ckpt = None
if _config["resume_during_training"]:
for index in range(100):
ckpt_path = os.path.join(_config["log_dir"], f'{exp_name}_seed{_config["seed"]}_from_{_config["load_path"].split("/")[-1][:-5]}', "version_{}/checkpoints/last.ckpt".format(index))
if os.path.exists(ckpt_path):
resume_ckpt = ckpt_path
rank_zero_info("resume_ckpt: {}".format(resume_ckpt))
cluster_plugin = get_cluster_plugin(
_config["num_gpus"], _config["num_nodes"]
)
plugin_list = [cluster_plugin]
rank_zero_info("plugin_list: {}".format(plugin_list))
if _config["use_sharded_training"]:
rank_zero_info("Using ddp sharded")
distributed_strategy = "ddp_sharded"
else:
distributed_strategy = "ddp"
trainer = pl.Trainer(
gpus=_config["num_gpus"],
num_nodes=_config["num_nodes"],
precision=_config["precision"],
accelerator="gpu",
strategy=distributed_strategy,
benchmark=True,
deterministic=True,
max_epochs=_config["max_epoch"] if max_steps is None else 1000,
max_steps=max_steps,
callbacks=callbacks,
logger=logger,
# prepare_data_per_node=False,
replace_sampler_ddp=False,
accumulate_grad_batches=grad_steps,
log_every_n_steps=10,
flush_logs_every_n_steps=10,
resume_from_checkpoint=resume_ckpt,
weights_summary="top",
fast_dev_run=_config["fast_dev_run"],
val_check_interval=_config["val_check_interval"],
plugins=plugin_list,
)
if _config["loss_names"]["textmlm"] > 0:
for param in model.parameters():
param.requires_grad = False
for name, param in model.named_parameters():
for key in ["text_embeddings", "token_type_embeddings", "mlp_text", "norm2_text", "mlm_score", "relative_position_bias_table", "transformer.norm"]:
if key in name:
param.requires_grad = True
for name, param in model.named_parameters():
rank_zero_info("{}\t{}".format(name, param.requires_grad))
if not _config["test_only"]:
trainer.fit(model, datamodule=dm)
else:
trainer.test(model, datamodule=dm)