Spaces:
Sleeping
Sleeping
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import sys | |
| import tqdm | |
| from torch.utils.data import DataLoader | |
| from torch.utils.data.distributed import DistributedSampler | |
| from .utils import apply_model, average_metric, center_trim | |
| def train_model(epoch, | |
| dataset, | |
| model, | |
| criterion, | |
| optimizer, | |
| augment, | |
| quantizer=None, | |
| diffq=0, | |
| repeat=1, | |
| device="cpu", | |
| seed=None, | |
| workers=4, | |
| world_size=1, | |
| batch_size=16): | |
| if world_size > 1: | |
| sampler = DistributedSampler(dataset) | |
| sampler_epoch = epoch * repeat | |
| if seed is not None: | |
| sampler_epoch += seed * 1000 | |
| sampler.set_epoch(sampler_epoch) | |
| batch_size //= world_size | |
| loader = DataLoader(dataset, batch_size=batch_size, sampler=sampler, num_workers=workers) | |
| else: | |
| loader = DataLoader(dataset, batch_size=batch_size, num_workers=workers, shuffle=True) | |
| current_loss = 0 | |
| model_size = 0 | |
| for repetition in range(repeat): | |
| tq = tqdm.tqdm(loader, | |
| ncols=120, | |
| desc=f"[{epoch:03d}] train ({repetition + 1}/{repeat})", | |
| leave=False, | |
| file=sys.stdout, | |
| unit=" batch") | |
| total_loss = 0 | |
| for idx, sources in enumerate(tq): | |
| if len(sources) < batch_size: | |
| # skip uncomplete batch for augment.Remix to work properly | |
| continue | |
| sources = sources.to(device) | |
| sources = augment(sources) | |
| mix = sources.sum(dim=1) | |
| estimates = model(mix) | |
| sources = center_trim(sources, estimates) | |
| loss = criterion(estimates, sources) | |
| model_size = 0 | |
| if quantizer is not None: | |
| model_size = quantizer.model_size() | |
| train_loss = loss + diffq * model_size | |
| train_loss.backward() | |
| grad_norm = 0 | |
| for p in model.parameters(): | |
| if p.grad is not None: | |
| grad_norm += p.grad.data.norm()**2 | |
| grad_norm = grad_norm**0.5 | |
| optimizer.step() | |
| optimizer.zero_grad() | |
| if quantizer is not None: | |
| model_size = model_size.item() | |
| total_loss += loss.item() | |
| current_loss = total_loss / (1 + idx) | |
| tq.set_postfix(loss=f"{current_loss:.4f}", ms=f"{model_size:.2f}", | |
| grad=f"{grad_norm:.5f}") | |
| # free some space before next round | |
| del sources, mix, estimates, loss, train_loss | |
| if world_size > 1: | |
| sampler.epoch += 1 | |
| if world_size > 1: | |
| current_loss = average_metric(current_loss) | |
| return current_loss, model_size | |
| def validate_model(epoch, | |
| dataset, | |
| model, | |
| criterion, | |
| device="cpu", | |
| rank=0, | |
| world_size=1, | |
| shifts=0, | |
| overlap=0.25, | |
| split=False): | |
| indexes = range(rank, len(dataset), world_size) | |
| tq = tqdm.tqdm(indexes, | |
| ncols=120, | |
| desc=f"[{epoch:03d}] valid", | |
| leave=False, | |
| file=sys.stdout, | |
| unit=" track") | |
| current_loss = 0 | |
| for index in tq: | |
| streams = dataset[index] | |
| # first five minutes to avoid OOM on --upsample models | |
| streams = streams[..., :15_000_000] | |
| streams = streams.to(device) | |
| sources = streams[1:] | |
| mix = streams[0] | |
| estimates = apply_model(model, mix, shifts=shifts, split=split, overlap=overlap) | |
| loss = criterion(estimates, sources) | |
| current_loss += loss.item() / len(indexes) | |
| del estimates, streams, sources | |
| if world_size > 1: | |
| current_loss = average_metric(current_loss, len(indexes)) | |
| return current_loss | |