Spaces:
Running
on
Zero
Running
on
Zero
from pathlib import Path | |
import torch, os | |
from tqdm import tqdm | |
import pickle | |
import argparse | |
import logging, datetime | |
import torch.distributed as dist | |
from config import MyParser | |
from steps import trainer | |
from copy_codebase import copy_codebase | |
def world_info_from_env(): | |
local_rank = int(os.environ["LOCAL_RANK"]) | |
global_rank = int(os.environ["RANK"]) | |
world_size = int(os.environ["WORLD_SIZE"]) | |
return local_rank, global_rank, world_size | |
if __name__ == "__main__": | |
formatter = ( | |
"%(asctime)s [%(levelname)s] %(filename)s:%(lineno)d || %(message)s" | |
) | |
logging.basicConfig(format=formatter, level=logging.INFO) | |
torch.cuda.empty_cache() | |
args = MyParser().parse_args() | |
exp_dir = Path(args.exp_dir) | |
exp_dir.mkdir(exist_ok=True, parents=True) | |
logging.info(f"exp_dir: {str(exp_dir)}") | |
if args.resume and (os.path.exists("%s/bundle.pth" % args.exp_dir) or os.path.exists("%s/bundle_prev.pth" % args.exp_dir)): | |
if not os.path.exists("%s/bundle.pth" % args.exp_dir): | |
os.system(f"cp {args.exp_dir}/bundle_prev.pth {args.exp_dir}/bundle.pth") | |
resume = args.resume | |
assert(bool(args.exp_dir)) | |
with open("%s/args.pkl" % args.exp_dir, "rb") as f: | |
old_args = pickle.load(f) | |
new_args = vars(args) | |
old_args = vars(old_args) | |
for key in new_args: | |
if key not in old_args or old_args[key] != new_args[key]: | |
old_args[key] = new_args[key] | |
args = argparse.Namespace(**old_args) | |
args.resume = resume | |
else: | |
args.resume = False | |
with open("%s/args.pkl" % args.exp_dir, "wb") as f: | |
pickle.dump(args, f) | |
# make timeout longer (for generation) | |
timeout = datetime.timedelta(seconds=7200) # 60 minutes | |
if args.multinodes: | |
_local_rank, _, _ = world_info_from_env() | |
dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, timeout=timeout) | |
else: | |
dist.init_process_group(backend='nccl', init_method='env://', timeout=timeout) | |
if args.local_wandb: | |
os.environ["WANDB_MODE"] = "offline" | |
rank = dist.get_rank() | |
if rank == 0: | |
logging.info(args) | |
logging.info(f"exp_dir: {str(exp_dir)}") | |
world_size = dist.get_world_size() | |
local_rank = int(_local_rank) if args.multinodes else rank | |
num_devices= torch.cuda.device_count() | |
logging.info(f"{local_rank=}, {rank=}, {world_size=}, {type(local_rank)=}, {type(rank)=}, {type(world_size)=}") | |
for device_idx in range(num_devices): | |
device_name = torch.cuda.get_device_name(device_idx) | |
logging.info(f"Device {device_idx}: {device_name}") | |
torch.cuda.set_device(local_rank) | |
if rank == 0: | |
user_dir = os.path.expanduser("~") | |
codebase_name = "VoiceStar" | |
now = datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S') | |
copy_codebase(os.path.join(user_dir, codebase_name), os.path.join(exp_dir, f"{codebase_name}_{now}"), max_size_mb=5, gitignore_path=os.path.join(user_dir, codebase_name, ".gitignore")) | |
my_trainer = trainer.Trainer(args, world_size, rank, local_rank) | |
my_trainer.train() |