Spaces:
Running
Running
| from tqdm import tqdm | |
| from dkm.utils.utils import to_cuda | |
| def train_step(train_batch, model, objective, optimizer, **kwargs): | |
| optimizer.zero_grad() | |
| out = model(train_batch) | |
| l = objective(out, train_batch) | |
| l.backward() | |
| optimizer.step() | |
| return {"train_out": out, "train_loss": l.item()} | |
| def train_k_steps( | |
| n_0, k, dataloader, model, objective, optimizer, lr_scheduler, progress_bar=True | |
| ): | |
| for n in tqdm(range(n_0, n_0 + k), disable=not progress_bar): | |
| batch = next(dataloader) | |
| model.train(True) | |
| batch = to_cuda(batch) | |
| train_step( | |
| train_batch=batch, | |
| model=model, | |
| objective=objective, | |
| optimizer=optimizer, | |
| lr_scheduler=lr_scheduler, | |
| n=n, | |
| ) | |
| lr_scheduler.step() | |
| def train_epoch( | |
| dataloader=None, | |
| model=None, | |
| objective=None, | |
| optimizer=None, | |
| lr_scheduler=None, | |
| epoch=None, | |
| ): | |
| model.train(True) | |
| print(f"At epoch {epoch}") | |
| for batch in tqdm(dataloader, mininterval=5.0): | |
| batch = to_cuda(batch) | |
| train_step( | |
| train_batch=batch, model=model, objective=objective, optimizer=optimizer | |
| ) | |
| lr_scheduler.step() | |
| return { | |
| "model": model, | |
| "optimizer": optimizer, | |
| "lr_scheduler": lr_scheduler, | |
| "epoch": epoch, | |
| } | |
| def train_k_epochs( | |
| start_epoch, end_epoch, dataloader, model, objective, optimizer, lr_scheduler | |
| ): | |
| for epoch in range(start_epoch, end_epoch + 1): | |
| train_epoch( | |
| dataloader=dataloader, | |
| model=model, | |
| objective=objective, | |
| optimizer=optimizer, | |
| lr_scheduler=lr_scheduler, | |
| epoch=epoch, | |
| ) | |