Spaces:
Paused
Paused
| import datetime | |
| import functools | |
| import math | |
| import os | |
| import random | |
| import subprocess | |
| import sys | |
| import threading | |
| import time | |
| from collections import defaultdict, deque | |
| from typing import Iterator, List, Tuple | |
| import numpy as np | |
| import pytz | |
| import torch | |
| import torch.distributed as tdist | |
| import torch.nn.functional as F | |
| import utils.dist as dist | |
| os_system = functools.partial(subprocess.call, shell=True) | |
| def echo(info): | |
| os_system(f'echo "[$(date "+%m-%d-%H:%M:%S")] ({os.path.basename(sys._getframe().f_back.f_code.co_filename)}, line{sys._getframe().f_back.f_lineno})=> {info}"') | |
| def os_system_get_stdout(cmd): | |
| return subprocess.run(cmd, shell=True, stdout=subprocess.PIPE).stdout.decode('utf-8') | |
| def os_system_get_stdout_stderr(cmd): | |
| cnt = 0 | |
| while True: | |
| try: | |
| sp = subprocess.run(cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, timeout=30) | |
| except subprocess.TimeoutExpired: | |
| cnt += 1 | |
| print(f'[fetch free_port file] timeout cnt={cnt}') | |
| else: | |
| return sp.stdout.decode('utf-8'), sp.stderr.decode('utf-8') | |
| def is_pow2n(x): | |
| return x > 0 and (x & (x - 1) == 0) | |
| def time_str(fmt='[%m-%d %H:%M:%S]'): | |
| return datetime.datetime.now(tz=pytz.timezone('Asia/Shanghai')).strftime(fmt) | |
| class DistLogger(object): | |
| def __init__(self, lg): | |
| self._lg = lg | |
| def do_nothing(*args, **kwargs): | |
| pass | |
| def __getattr__(self, attr: str): | |
| return getattr(self._lg, attr) if self._lg is not None else DistLogger.do_nothing | |
| class TensorboardLogger(object): | |
| def __init__(self, log_dir, filename_suffix): | |
| try: import tensorflow_io as tfio | |
| except: pass | |
| from torch.utils.tensorboard import SummaryWriter | |
| self.writer = SummaryWriter(log_dir=log_dir, filename_suffix=filename_suffix) | |
| self.step = 0 | |
| def set_step(self, step=None): | |
| if step is not None: | |
| self.step = step | |
| else: | |
| self.step += 1 | |
| def loggable(self): | |
| return self.step == 0 or (self.step + 1) % 500 == 0 | |
| def update(self, head='scalar', step=None, **kwargs): | |
| if step is None: | |
| step = self.step | |
| if not self.loggable(): return | |
| for k, v in kwargs.items(): | |
| if v is None: continue | |
| if hasattr(v, 'item'): v = v.item() | |
| self.writer.add_scalar(f'{head}/{k}', v, step) | |
| def log_tensor_as_distri(self, tag, tensor1d, step=None): | |
| if step is None: | |
| step = self.step | |
| if not self.loggable(): return | |
| try: | |
| self.writer.add_histogram(tag=tag, values=tensor1d, global_step=step) | |
| except Exception as e: | |
| print(f'[log_tensor_as_distri writer.add_histogram failed]: {e}') | |
| def log_image(self, tag, img_chw, step=None): | |
| if step is None: | |
| step = self.step | |
| if not self.loggable(): return | |
| self.writer.add_image(tag, img_chw, step, dataformats='CHW') | |
| def flush(self): | |
| self.writer.flush() | |
| def close(self): | |
| self.writer.close() | |
| class Low_GPU_usage(object): | |
| def __init__(self, files, sleep_secs, verbose): | |
| pass | |
| def early_stop(self): | |
| pass | |
| def __enter__(self): | |
| return self | |
| def __exit__(self, exc_type, exc_val, exc_tb): | |
| pass | |
| class TouchingDaemonDontForgetToStartMe(threading.Thread): | |
| def __init__(self, files: List[str], sleep_secs: int, verbose=False): | |
| super().__init__(daemon=True) | |
| self.files = tuple(files) | |
| self.sleep_secs = sleep_secs | |
| self.is_finished = False | |
| self.verbose = verbose | |
| f_back = sys._getframe().f_back | |
| file_desc = f'{f_back.f_code.co_filename:24s}'[-24:] | |
| self.print_prefix = f' ({file_desc}, line{f_back.f_lineno:-4d}) @daemon@ ' | |
| def finishing(self): | |
| self.is_finished = True | |
| def run(self) -> None: | |
| kw = {} | |
| if tdist.is_initialized(): kw['clean'] = True | |
| stt = time.time() | |
| if self.verbose: print(f'{time_str()}{self.print_prefix}[TouchingDaemon tid={threading.get_native_id()}] start touching {self.files} per {self.sleep_secs}s ...', **kw) | |
| while not self.is_finished: | |
| for f in self.files: | |
| if os.path.exists(f): | |
| try: | |
| os.utime(f) | |
| fp = open(f, 'a') | |
| fp.close() | |
| except: pass | |
| time.sleep(self.sleep_secs) | |
| if self.verbose: print(f'{time_str()}{self.print_prefix}[TouchingDaemon tid={threading.get_native_id()}] finish touching after {time.time()-stt:.1f} secs {self.files} per {self.sleep_secs}s. ', **kw) | |
| class SmoothedValue(object): | |
| """Track a series of values and provide access to smoothed values over a | |
| window or the global series average. | |
| """ | |
| def __init__(self, window_size=30, fmt=None): | |
| if fmt is None: | |
| fmt = "{median:.4f} ({global_avg:.4f})" | |
| self.deque = deque(maxlen=window_size) | |
| self.total = 0.0 | |
| self.count = 0 | |
| self.fmt = fmt | |
| def update(self, value, n=1): | |
| self.deque.append(value) | |
| self.count += n | |
| self.total += value * n | |
| def synchronize_between_processes(self): | |
| """ | |
| Warning: does not synchronize the deque! | |
| """ | |
| t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') | |
| tdist.barrier() | |
| tdist.all_reduce(t) | |
| t = t.tolist() | |
| self.count = int(t[0]) | |
| self.total = t[1] | |
| def median(self): | |
| return np.median(self.deque) if len(self.deque) else 0 | |
| def avg(self): | |
| return sum(self.deque) / (len(self.deque) or 1) | |
| def global_avg(self): | |
| return self.total / (self.count or 1) | |
| def max(self): | |
| return max(self.deque) if len(self.deque) else 0 | |
| def value(self): | |
| return self.deque[-1] if len(self.deque) else 0 | |
| def time_preds(self, counts) -> Tuple[float, str, str]: | |
| remain_secs = counts * self.median | |
| return remain_secs, str(datetime.timedelta(seconds=round(remain_secs))), time.strftime("%Y-%m-%d %H:%M", time.localtime(time.time() + remain_secs)) | |
| def __str__(self): | |
| return self.fmt.format(median=self.median, avg=self.avg, global_avg=self.global_avg, max=self.max, value=self.value) | |
| class MetricLogger(object): | |
| def __init__(self): | |
| self.meters = defaultdict(SmoothedValue) | |
| self.iter_end_t = time.time() | |
| self.log_iters = set() | |
| self.log_every_iter = False | |
| def update(self, **kwargs): | |
| # if it != 0 and it not in self.log_iters: return | |
| for k, v in kwargs.items(): | |
| if v is None: continue | |
| if hasattr(v, 'item'): v = v.item() | |
| # assert isinstance(v, (float, int)), type(v) | |
| self.meters[k].update(v) | |
| def __getattr__(self, attr): | |
| if attr in self.meters: | |
| return self.meters[attr] | |
| if attr in self.__dict__: | |
| return self.__dict__[attr] | |
| raise AttributeError("'{}' object has no attribute '{}'".format( | |
| type(self).__name__, attr)) | |
| def __str__(self): | |
| loss_str = [] | |
| for name, meter in self.meters.items(): | |
| if len(meter.deque): | |
| loss_str.append( | |
| "{}: {}".format(name, str(meter)) | |
| ) | |
| return ' '.join(loss_str) | |
| def synchronize_between_processes(self): | |
| for meter in self.meters.values(): | |
| meter.synchronize_between_processes() | |
| def add_meter(self, name, meter): | |
| self.meters[name] = meter | |
| def log_every(self, start_it, max_iters, itrt, log_freq, log_every_iter=False, header=''): # also solve logging & skipping iterations before start_it | |
| start_it = start_it % max_iters | |
| self.log_iters = set(range(start_it, max_iters, log_freq)) | |
| self.log_iters.add(start_it) | |
| self.log_iters.add(max_iters-1) | |
| self.log_iters.add(max_iters) | |
| self.log_every_iter = log_every_iter | |
| self.iter_end_t = time.time() | |
| self.iter_time = SmoothedValue(fmt='{value:.4f}') | |
| self.data_time = SmoothedValue(fmt='{value:.3f}') | |
| header_fmt = header + ': [{0:' + str(len(str(max_iters))) + 'd}/{1}]' | |
| start_time = time.time() | |
| if isinstance(itrt, Iterator) and not hasattr(itrt, 'preload') and not hasattr(itrt, 'set_epoch'): | |
| for it in range(start_it, max_iters): | |
| obj = next(itrt) | |
| if it < start_it: continue | |
| self.data_time.update(time.time() - self.iter_end_t) | |
| yield it, obj | |
| self.iter_time.update(time.time() - self.iter_end_t) | |
| if self.log_every_iter or it in self.log_iters: | |
| eta_seconds = self.iter_time.avg * (max_iters - it) | |
| print(f'{header_fmt.format(it, max_iters)} eta: {str(datetime.timedelta(seconds=int(eta_seconds)))} {str(self)} T: {self.iter_time.value:.3f}s dataT: {self.data_time.value*1e3:.1f}ms', flush=True) | |
| self.iter_end_t = time.time() | |
| else: | |
| if isinstance(itrt, int): itrt = range(itrt) | |
| for it, obj in enumerate(itrt): | |
| if it < start_it: | |
| self.iter_end_t = time.time() | |
| continue | |
| self.data_time.update(time.time() - self.iter_end_t) | |
| yield it, obj | |
| self.iter_time.update(time.time() - self.iter_end_t) | |
| if self.log_every_iter or it in self.log_iters: | |
| eta_seconds = self.iter_time.avg * (max_iters - it) | |
| print(f'{header_fmt.format(it, max_iters)} eta: {str(datetime.timedelta(seconds=int(eta_seconds)))} {str(self)} T: {self.iter_time.value:.3f}s dataT: {self.data_time.value*1e3:.1f}ms', flush=True) | |
| self.iter_end_t = time.time() | |
| cost = time.time() - start_time | |
| cost_str = str(datetime.timedelta(seconds=int(cost))) | |
| print(f'{header} Cost of this ep: {cost_str} ({cost / (max_iters-start_it):.3f} s / it)', flush=True) | |
| class NullDDP(torch.nn.Module): | |
| def __init__(self, module, *args, **kwargs): | |
| super(NullDDP, self).__init__() | |
| self.module = module | |
| self.require_backward_grad_sync = False | |
| def forward(self, *args, **kwargs): | |
| return self.module(*args, **kwargs) | |
| def build_2d_sincos_position_embedding(h, w, embed_dim, temperature=10000., sc=0, verbose=True): # (1, hw**2, embed_dim) | |
| # DiT: sc=0 | |
| # DETR: sc=2? | |
| grid_w = torch.arange(w, dtype=torch.float32) | |
| grid_h = torch.arange(h, dtype=torch.float32) | |
| grid_w, grid_h = torch.meshgrid([grid_w, grid_h], indexing='ij') | |
| if sc == 0: | |
| scale = 1 | |
| elif sc == 1: | |
| scale = math.pi * 2 / w | |
| else: | |
| scale = 1 / w | |
| grid_w = scale * grid_w.reshape(h*w, 1) # scale * [0, 0, 0, 1, 1, 1, 2, 2, 2] | |
| grid_h = scale * grid_h.reshape(h*w, 1) # scale * [0, 1, 2, 0, 1, 2, 0, 1, 2] | |
| assert embed_dim % 4 == 0, f'Embed dimension ({embed_dim}) must be divisible by 4 for 2D sin-cos position embedding!' | |
| pos_dim = embed_dim // 4 | |
| omega = torch.arange(pos_dim, dtype=torch.float32) / pos_dim | |
| omega = (-math.log(temperature) * omega).exp() | |
| # omega == (1/T) ** (arange(pos_dim) / pos_dim), a vector only dependent on C | |
| out_w = grid_w * omega.view(1, pos_dim) # out_w: scale * [0*ome, 0*ome, 0*ome, 1*ome, 1*ome, 1*ome, 2*ome, 2*ome, 2*ome] | |
| out_h = grid_h * omega.view(1, pos_dim) # out_h: scale * [0*ome, 1*ome, 2*ome, 0*ome, 1*ome, 2*ome, 0*ome, 1*ome, 2*ome] | |
| pos_emb = torch.cat([torch.sin(out_w), torch.cos(out_w), torch.sin(out_h), torch.cos(out_h)], dim=1)[None, :, :] | |
| if verbose: print(f'[build_2d_sincos_position_embedding @ {hw} x {hw}] scale_type={sc}, temperature={temperature:g}, shape={pos_emb.shape}') | |
| return pos_emb # (1, hw**2, embed_dim) | |
| if __name__ == '__main__': | |
| import seaborn as sns | |
| import matplotlib.pyplot as plt | |
| cmap_div = sns.color_palette('icefire', as_cmap=True) | |
| scs = [0, 1, 2] | |
| temps = [20, 50, 100, 1000] | |
| reso = 3.0 | |
| RR, CC = len(scs), len(temps) | |
| plt.figure(figsize=(CC * reso, RR * reso)) # figsize=(16, 16) | |
| for row, sc in enumerate(scs): | |
| for col, temp in enumerate(temps): | |
| name = f'sc={sc}, T={temp}' | |
| hw, C = 16, 512 | |
| N = hw*hw | |
| pe = build_2d_sincos_position_embedding(hw, C, temperature=temp, sc=sc, verbose=False)[0] # N, C = 64, 16 | |
| hw2 = 16 | |
| N2 = hw2*hw2 | |
| pe2 = build_2d_sincos_position_embedding(hw2, C, temperature=temp, sc=sc, verbose=False)[0] # N, C = 64, 16 | |
| # pe2 = pe2.flip(dims=(0,)) | |
| bchw, bchw2 = F.normalize(pe.view(hw, hw, C).permute(2, 0, 1).unsqueeze(0), dim=1), F.normalize(pe2.view(hw2, hw2, C).permute(2, 0, 1).unsqueeze(0), dim=1) | |
| dis = [ | |
| f'{F.mse_loss(bchw, F.interpolate(bchw2, size=bchw.shape[-2], mode=inter)).item():.3f}' | |
| for inter in ('bilinear', 'bicubic', 'nearest') | |
| ] | |
| dis += [ | |
| f'{F.mse_loss(F.interpolate(bchw, size=bchw2.shape[-2], mode=inter), bchw2).item():.3f}' | |
| for inter in ('area', 'nearest') | |
| ] | |
| print(f'[{name:^20s}] dis: {dis}') | |
| """ | |
| [ sc=0, T=20 ] dis: ['0.010', '0.011', '0.011', '0.009', '0.010'] | |
| [ sc=0, T=100 ] dis: ['0.007', '0.007', '0.007', '0.006', '0.007'] | |
| [ sc=0, T=1000 ] dis: ['0.005', '0.005', '0.005', '0.004', '0.005'] | |
| [ sc=0, T=10000 ] dis: ['0.004', '0.004', '0.004', '0.003', '0.004'] | |
| [ sc=1, T=20 ] dis: ['0.007', '0.008', '0.008', '0.007', '0.008'] | |
| [ sc=1, T=100 ] dis: ['0.005', '0.005', '0.005', '0.005', '0.005'] | |
| [ sc=1, T=1000 ] dis: ['0.003', '0.003', '0.003', '0.003', '0.003'] | |
| [ sc=1, T=10000 ] dis: ['0.003', '0.003', '0.003', '0.003', '0.003'] | |
| [ sc=2, T=20 ] dis: ['0.000', '0.000', '0.000', '0.000', '0.000'] | |
| [ sc=2, T=100 ] dis: ['0.000', '0.000', '0.000', '0.000', '0.000'] | |
| [ sc=2, T=1000 ] dis: ['0.000', '0.000', '0.000', '0.000', '0.000'] | |
| [ sc=2, T=10000 ] dis: ['0.000', '0.000', '0.000', '0.000', '0.000'] | |
| Process finished with exit code 0 | |
| """ | |
| pe = torch.from_numpy(cmap_div(pe.T.numpy())[:, :, :3]) # C, N, 3 | |
| tar_h, tar_w = 1024, 1024 | |
| pe = pe.repeat_interleave(tar_w//pe.shape[0], dim=0).repeat_interleave(tar_h//pe.shape[1], dim=1) | |
| plt.subplot(RR, CC, 1+row*CC+col) | |
| plt.title(name) | |
| plt.xlabel('hxw'), plt.ylabel('C') | |
| plt.xticks([]), plt.yticks([]) | |
| plt.imshow(pe.mul(255).round().clamp(0, 255).byte().numpy()) | |
| plt.tight_layout(h_pad=0.02) | |
| plt.show() | |
| def check_randomness(args): | |
| U = 16384 | |
| t = torch.zeros(dist.get_world_size(), 4, dtype=torch.float32, device=args.device) | |
| t0 = torch.zeros(1, dtype=torch.float32, device=args.device).random_(U) | |
| t[dist.get_rank(), 0] = float(random.randrange(U)) | |
| t[dist.get_rank(), 1] = float(np.random.randint(U)) | |
| t[dist.get_rank(), 2] = float(torch.randint(0, U, (1,))[0]) | |
| t[dist.get_rank(), 3] = float(t0[0]) | |
| dist.allreduce(t) | |
| for rk in range(1, dist.get_world_size()): | |
| assert torch.allclose(t[rk - 1], t[rk]), f't={t}' | |
| del t0, t, U | |