|  | import datetime | 
					
						
						|  | import argparse, importlib | 
					
						
						|  | from pytorch_lightning import seed_everything | 
					
						
						|  |  | 
					
						
						|  | import torch | 
					
						
						|  | import torch.distributed as dist | 
					
						
						|  |  | 
					
						
						|  | def setup_dist(local_rank): | 
					
						
						|  | if dist.is_initialized(): | 
					
						
						|  | return | 
					
						
						|  | torch.cuda.set_device(local_rank) | 
					
						
						|  | torch.distributed.init_process_group('nccl', init_method='env://') | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def get_dist_info(): | 
					
						
						|  | if dist.is_available(): | 
					
						
						|  | initialized = dist.is_initialized() | 
					
						
						|  | else: | 
					
						
						|  | initialized = False | 
					
						
						|  | if initialized: | 
					
						
						|  | rank = dist.get_rank() | 
					
						
						|  | world_size = dist.get_world_size() | 
					
						
						|  | else: | 
					
						
						|  | rank = 0 | 
					
						
						|  | world_size = 1 | 
					
						
						|  | return rank, world_size | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if __name__ == '__main__': | 
					
						
						|  | now = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S") | 
					
						
						|  | parser = argparse.ArgumentParser() | 
					
						
						|  | parser.add_argument("--module", type=str, help="module name", default="inference") | 
					
						
						|  | parser.add_argument("--local_rank", type=int, nargs="?", help="for ddp", default=0) | 
					
						
						|  | args, unknown = parser.parse_known_args() | 
					
						
						|  | inference_api = importlib.import_module(args.module, package=None) | 
					
						
						|  |  | 
					
						
						|  | inference_parser = inference_api.get_parser() | 
					
						
						|  | inference_args, unknown = inference_parser.parse_known_args() | 
					
						
						|  |  | 
					
						
						|  | seed_everything(inference_args.seed) | 
					
						
						|  | setup_dist(args.local_rank) | 
					
						
						|  | torch.backends.cudnn.benchmark = True | 
					
						
						|  | rank, gpu_num = get_dist_info() | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | print("@DynamiCrafter Inference [rank%d]: %s"%(rank, now)) | 
					
						
						|  | inference_api.run_inference(inference_args, gpu_num, rank) |