Spaces:
Sleeping
Sleeping
import os | |
import copy | |
import pytorch_lightning as pl | |
from vlmo.config import ex | |
from vlmo.modules import VLMo | |
from vlmo.datamodules.multitask_datamodule import MTDataModule | |
from pytorch_lightning.plugins import environments as pl_env | |
from pytorch_lightning.utilities.distributed import rank_zero_info | |
class OMPIClusterEnvironment(pl_env.ClusterEnvironment): | |
def __init__(self): | |
super().__init__() | |
# def creates_children(self) -> bool: | |
# # return True if the cluster is managed (you don't launch processes yourself) | |
# assert ( | |
# "OMPI_COMM_WORLD_LOCAL_RANK" in os.environ | |
# ) # this cluster is managed | |
# return True | |
def creates_processes_externally(self): | |
return True | |
def world_size(self) -> int: | |
return int(os.environ["OMPI_COMM_WORLD_SIZE"]) | |
def set_world_size(self, size: int): | |
pass | |
def global_rank(self) -> int: | |
return int(os.environ["OMPI_COMM_WORLD_RANK"]) | |
def set_global_rank(self, rank: int): | |
pass | |
def local_rank(self) -> int: | |
return int(os.environ["OMPI_COMM_WORLD_LOCAL_RANK"]) | |
def node_rank(self) -> int: | |
if "NODE_RANK" in os.environ: | |
return int(os.environ["NODE_RANK"]) | |
else: | |
return 0 | |
def master_address(self) -> str: | |
return os.environ["MASTER_ADDR"] | |
def master_port(self) -> int: | |
return int(os.environ["MASTER_PORT"]) | |
def get_cluster_plugin(num_gpus=1, num_nodes=1): | |
if num_nodes > 1 or ( | |
num_nodes == 1 and "OMPI_COMM_WORLD_SIZE" in os.environ | |
): | |
rank_zero_info("ClusterPlugin: using OMPI Cluster Environment") | |
return OMPIClusterEnvironment() | |
if num_gpus >= 1: | |
rank_zero_info("ClusterPlugin: using Lightning Cluster Environment") | |
return pl_env.LightningEnvironment() | |
return None | |
def main(_config): | |
_config = copy.deepcopy(_config) | |
pl.seed_everything(_config["seed"]) | |
dm = MTDataModule(_config, dist=True) | |
model = VLMo(_config) | |
exp_name = f'{_config["exp_name"]}' | |
os.makedirs(_config["log_dir"], exist_ok=True) | |
checkpoint_callback = pl.callbacks.ModelCheckpoint( | |
save_top_k=-1, | |
verbose=True, | |
monitor="val/the_metric", | |
mode="max", | |
save_last=True, | |
) | |
logger = pl.loggers.TensorBoardLogger( | |
_config["log_dir"], | |
name=f'{exp_name}_seed{_config["seed"]}_from_{_config["load_path"].split("/")[-1][:-5]}', | |
) | |
lr_callback = pl.callbacks.LearningRateMonitor(logging_interval="step") | |
callbacks = [checkpoint_callback, lr_callback] | |
num_gpus = ( | |
_config["num_gpus"] | |
if isinstance(_config["num_gpus"], int) | |
else len(_config["num_gpus"]) | |
) | |
grad_steps = _config["batch_size"] // ( | |
_config["per_gpu_batchsize"] * num_gpus * _config["num_nodes"] | |
) | |
rank_zero_info("grad_steps: {}".format(grad_steps)) | |
max_steps = _config["max_steps"] if _config["max_steps"] is not None else None | |
resume_ckpt = None | |
if _config["resume_during_training"]: | |
for index in range(100): | |
ckpt_path = os.path.join(_config["log_dir"], f'{exp_name}_seed{_config["seed"]}_from_{_config["load_path"].split("/")[-1][:-5]}', "version_{}/checkpoints/last.ckpt".format(index)) | |
if os.path.exists(ckpt_path): | |
resume_ckpt = ckpt_path | |
rank_zero_info("resume_ckpt: {}".format(resume_ckpt)) | |
cluster_plugin = get_cluster_plugin( | |
_config["num_gpus"], _config["num_nodes"] | |
) | |
plugin_list = [cluster_plugin] | |
rank_zero_info("plugin_list: {}".format(plugin_list)) | |
if _config["use_sharded_training"]: | |
rank_zero_info("Using ddp sharded") | |
distributed_strategy = "ddp_sharded" | |
else: | |
distributed_strategy = "ddp" | |
trainer = pl.Trainer( | |
gpus=_config["num_gpus"], | |
num_nodes=_config["num_nodes"], | |
precision=_config["precision"], | |
accelerator="gpu", | |
strategy=distributed_strategy, | |
benchmark=True, | |
deterministic=True, | |
max_epochs=_config["max_epoch"] if max_steps is None else 1000, | |
max_steps=max_steps, | |
callbacks=callbacks, | |
logger=logger, | |
# prepare_data_per_node=False, | |
replace_sampler_ddp=False, | |
accumulate_grad_batches=grad_steps, | |
log_every_n_steps=10, | |
flush_logs_every_n_steps=10, | |
resume_from_checkpoint=resume_ckpt, | |
weights_summary="top", | |
fast_dev_run=_config["fast_dev_run"], | |
val_check_interval=_config["val_check_interval"], | |
plugins=plugin_list, | |
) | |
if _config["loss_names"]["textmlm"] > 0: | |
for param in model.parameters(): | |
param.requires_grad = False | |
for name, param in model.named_parameters(): | |
for key in ["text_embeddings", "token_type_embeddings", "mlp_text", "norm2_text", "mlm_score", "relative_position_bias_table", "transformer.norm"]: | |
if key in name: | |
param.requires_grad = True | |
for name, param in model.named_parameters(): | |
rank_zero_info("{}\t{}".format(name, param.requires_grad)) | |
if not _config["test_only"]: | |
trainer.fit(model, datamodule=dm) | |
else: | |
trainer.test(model, datamodule=dm) | |