makeavid-sd-jax / makeavid_sd /trainer_xla.py
lopho's picture
this isn't very nice.
149cc2d
raw
history blame
4.42 kB
import os
os.environ['PJRT_DEVICE'] = 'TPU'
from tqdm.auto import tqdm
import torch
from torch.utils.data import DataLoader
from torch_xla.core import xla_model
from diffusers import UNetPseudo3DConditionModel
from dataset import load_dataset
class TempoTrainerXLA:
def __init__(self,
pretrained: str = 'lxj616/make-a-stable-diffusion-video-timelapse',
lr: float = 1e-4,
dtype: torch.dtype = torch.float32,
) -> None:
self.dtype = dtype
self.device: torch.device = xla_model.xla_device(0)
unet: UNetPseudo3DConditionModel = UNetPseudo3DConditionModel.from_pretrained(
pretrained,
subfolder = 'unet'
).to(dtype = dtype, memory_format = torch.contiguous_format)
unfreeze_all: bool = False
unet = unet.train()
if not unfreeze_all:
unet.requires_grad_(False)
for name, param in unet.named_parameters():
if 'temporal_conv' in name:
param.requires_grad_(True)
for block in [*unet.down_blocks, unet.mid_block, *unet.up_blocks]:
if hasattr(block, 'attentions') and block.attentions is not None:
for attn_block in block.attentions:
for transformer_block in attn_block.transformer_blocks:
transformer_block.requires_grad_(False)
transformer_block.attn_temporal.requires_grad_(True)
transformer_block.norm_temporal.requires_grad_(True)
else:
unet.requires_grad_(True)
self.model: UNetPseudo3DConditionModel = unet.to(device = self.device)
#self.model = torch.compile(self.model, backend = 'aot_torchxla_trace_once')
self.params = lambda: filter(lambda p: p.requires_grad, self.model.parameters())
self.optim: torch.optim.Optimizer = torch.optim.AdamW(self.params(), lr = lr)
def lr_warmup(warmup_steps: int = 0):
def lambda_lr(step: int) -> float:
if step < warmup_steps:
return step / warmup_steps
else:
return 1.0
return lambda_lr
self.scheduler = torch.optim.lr_scheduler.LambdaLR(self.optim, lr_lambda = lr_warmup(warmup_steps = 60), last_epoch = -1)
@torch.no_grad()
def train(self, dataloader: DataLoader, epochs: int = 1, log_every: int = 1, save_every: int = 1000) -> None:
# 'latent_model_input'
# 'encoder_hidden_states'
# 'timesteps'
# 'noise'
global_step: int = 0
for epoch in range(epochs):
pbar = tqdm(dataloader, dynamic_ncols = True, smoothing = 0.01)
for b in pbar:
latent_model_input: torch.Tensor = b['latent_model_input'].to(device = self.device)
encoder_hidden_states: torch.Tensor = b['encoder_hidden_states'].to(device = self.device)
timesteps: torch.Tensor = b['timesteps'].to(device = self.device)
noise: torch.Tensor = b['noise'].to(device = self.device)
with torch.enable_grad():
self.optim.zero_grad(set_to_none = True)
y = self.model(latent_model_input, timesteps, encoder_hidden_states).sample
loss = torch.nn.functional.mse_loss(noise, y)
loss.backward()
self.optim.step()
self.scheduler.step()
xla_model.mark_step()
if global_step % log_every == 0:
pbar.set_postfix({ 'loss': loss.detach().item(), 'epoch': epoch })
def main():
pretrained: str = 'lxj616/make-a-stable-diffusion-video-timelapse'
dataset_path: str = './storage/dataset/tempofunk'
dtype: torch.dtype = torch.bfloat16
trainer = TempoTrainerXLA(
pretrained = pretrained,
lr = 1e-5,
dtype = dtype
)
dataloader: DataLoader = load_dataset(
dataset_path = dataset_path,
pretrained = pretrained,
batch_size = 1,
num_frames = 10,
num_workers = 1,
dtype = dtype
)
trainer.train(
dataloader = dataloader,
epochs = 1000,
log_every = 1,
save_every = 1000
)
if __name__ == '__main__':
main()