MMP_Diffusion / MMP_Diffusion_Lora_train.py
Emotion-Director's picture
Upload folder using huggingface_hub
017bf8e verified
import io
import logging
import math
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
import random
import shutil
import sys
sys.path.append('./')
from pathlib import Path
import accelerate
import datasets
import numpy as np
from PIL import Image
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
import transformers
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.state import AcceleratorState
from accelerate.utils import ProjectConfiguration, set_seed
from datasets import load_dataset
from huggingface_hub import create_repo, upload_folder
from packaging import version
from torchvision import transforms
from tqdm.auto import tqdm
from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
from transformers.utils import ContextManagers
import diffusers
from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, StableDiffusionXLPipeline, UNet2DConditionModel
from models.unet_2d_condition import UNet2DLoRAConditionModel
from models.lora import add_lora_to_model
from diffusers.optimization import get_scheduler
from diffusers.utils import check_min_version, deprecate, is_wandb_available, make_image_grid
from diffusers.utils.import_utils import is_xformers_available
from MMP_Diffusion_Lora_config import parse_args, import_model_class_from_model_name_or_path
from peft.utils import get_peft_model_state_dict
from diffusers.utils import convert_state_dict_to_diffusers
from models.visual_prompts import EmotionEmbedding, EmotionEmbedding2
import copy
if is_wandb_available():
import wandb
## SDXL
import functools
import gc
from torchvision.transforms.functional import crop
from transformers import AutoTokenizer
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.20.0")
logger = get_logger(__name__, log_level="INFO")
# Adapted from pipelines.StableDiffusionXLPipeline.encode_prompt
def encode_prompt_sdxl(batch, text_encoders, tokenizers, proportion_empty_prompts, caption_column, is_train=True):
prompt_embeds_list = []
prompt_batch = batch[caption_column]
captions = []
for caption in prompt_batch:
if random.random() < proportion_empty_prompts:
captions.append("")
elif isinstance(caption, str):
captions.append(caption)
elif isinstance(caption, (list, np.ndarray)):
# take a random caption if there are multiple
captions.append(random.choice(caption) if is_train else caption[0])
for tokenizer, text_encoder in zip(tokenizers, text_encoders):
text_input_ids = tokenizer(
captions,
padding="max_length",
max_length=tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
).input_ids
with torch.no_grad():
prompt_embeds = text_encoder(
text_input_ids.to('cuda'),
output_hidden_states=True,
)
# We are only ALWAYS interested in the pooled output of the final text encoder
# torch.Size([32, 1280]) this
# odict_keys(['text_embeds', 'last_hidden_state', 'hidden_states'])
if isinstance(text_encoder, CLIPTextModel):
pass
elif isinstance(text_encoder, CLIPTextModelWithProjection):
pooled_prompt_embeds = prompt_embeds[0]
# "2" because SDXL always indexes from the penultimate layer.
# torch.Size([32, 77, 768/1280])
prompt_embeds = prompt_embeds.hidden_states[-2]
bs_embed, seq_len, _ = prompt_embeds.shape
prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1)
prompt_embeds_list.append(prompt_embeds)
# torch.Size([32, 77, 768+1280=2048])
prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
# torch.Size([32, 1280])
pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1)
return {
"prompt_embeds": prompt_embeds,
"pooled_prompt_embeds": pooled_prompt_embeds,
}
def init_emotion_prompts(visual_prompts_dir, is_sdxl=True, prompt_len=16):
emotions = ["amusement", "anger", "awe", "contentment",
"disgust", "excitement", "fear", "sadness"]
if is_sdxl:
output_dim = 2048
else:
output_dim = 768
feature_names = ["clip", "vgg", "dinov2"]
visual_prompts = EmotionEmbedding(emotions, visual_prompts_dir,
feature_names, output_dim=output_dim, prompt_len=prompt_len)
return visual_prompts
def init_emotion_prompts2(is_sdxl=True):
emotions = ["amusement", "anger", "awe", "contentment",
"disgust", "excitement", "fear", "sadness"]
if is_sdxl:
output_dim = 2048
else:
output_dim = 768
input_dim = 2048
visual_prompts = EmotionEmbedding2(emotions, input_dim, output_dim=output_dim)
return visual_prompts
def random_sample_emotions(anchor_emotions):
emotions = ["amusement", "anger", "awe", "contentment", "disgust",
"excitement", "fear", "sadness"]
random_emotions = []
for anchor in anchor_emotions:
available_emotions = [emotion for emotion in emotions if emotion != anchor]
random_choice = random.choice(available_emotions)
random_emotions.append(random_choice)
return random_emotions
def main():
args = parse_args()
#### START ACCELERATOR BOILERPLATE ###
logging_dir = os.path.join(args.output_dir, args.logging_dir)
accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
accelerator = Accelerator(
gradient_accumulation_steps=args.gradient_accumulation_steps,
mixed_precision=args.mixed_precision,
log_with=args.report_to,
project_config=accelerator_project_config,
)
# Make one log on every process with the configuration for debugging.
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
)
logger.info(accelerator.state, main_process_only=False)
if accelerator.is_local_main_process:
datasets.utils.logging.set_verbosity_warning()
transformers.utils.logging.set_verbosity_warning()
diffusers.utils.logging.set_verbosity_info()
else:
datasets.utils.logging.set_verbosity_error()
transformers.utils.logging.set_verbosity_error()
diffusers.utils.logging.set_verbosity_error()
# If passed along, set the training seed now.
if args.seed is not None:
set_seed(args.seed + accelerator.process_index) # added in + term, untested
# Handle the repository creation
if accelerator.is_main_process:
if args.output_dir is not None:
os.makedirs(args.output_dir, exist_ok=True)
### END ACCELERATOR BOILERPLATE
### START DIFFUSION BOILERPLATE ###
# Load scheduler, tokenizer and models.
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path,
subfolder="scheduler")
# SDXL has two text encoders
if args.sdxl:
tokenizer_and_encoder_name = args.pretrained_model_name_or_path
tokenizer_one = AutoTokenizer.from_pretrained(tokenizer_and_encoder_name, subfolder="tokenizer", revision=args.revision, use_fast=False)
tokenizer_two = AutoTokenizer.from_pretrained(tokenizer_and_encoder_name, subfolder="tokenizer_2", revision=args.revision, use_fast=False)
else:
tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision)
# Not sure if we're hitting this at all
def deepspeed_zero_init_disabled_context_manager():
"""
returns either a context list that includes one that will disable zero.Init or an empty context list
"""
deepspeed_plugin = AcceleratorState().deepspeed_plugin if accelerate.state.is_initialized() else None
if deepspeed_plugin is None:
return []
return [deepspeed_plugin.zero3_init_context_manager(enable=False)]
with ContextManagers(deepspeed_zero_init_disabled_context_manager()):
# SDXL has two text encoders
if args.sdxl:
# import correct text encoder classes
text_encoder_cls_one = import_model_class_from_model_name_or_path(tokenizer_and_encoder_name, args.revision, subfolder="text_encoder")
text_encoder_cls_two = import_model_class_from_model_name_or_path(tokenizer_and_encoder_name, args.revision, subfolder="text_encoder_2")
text_encoder_one = text_encoder_cls_one.from_pretrained(tokenizer_and_encoder_name, revision=args.revision, subfolder="text_encoder")
text_encoder_two = text_encoder_cls_two.from_pretrained(tokenizer_and_encoder_name, revision=args.revision, subfolder="text_encoder_2")
text_encoders = [text_encoder_one, text_encoder_two]
tokenizers = [tokenizer_one, tokenizer_two]
else:
text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision)
# Can custom-select VAE (used in original SDXL tuning)
vae_path = (
args.pretrained_model_name_or_path
if args.pretrained_vae_model_name_or_path is None
else args.pretrained_vae_model_name_or_path
)
vae = AutoencoderKL.from_pretrained(
vae_path, subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None, revision=args.revision
)
# clone of model
ref_unet = UNet2DConditionModel.from_pretrained(
args.pretrained_model_name_or_path,
subfolder="unet", revision=args.revision
)
unet = UNet2DLoRAConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision)
print("======== init_emotion_prompts ================")
visual_prompts = init_emotion_prompts(args.visual_prompts_dir, is_sdxl=args.sdxl, prompt_len=args.prompt_len).to(accelerator.device)
# visual_prompts = init_emotion_prompts2(is_sdxl=args.sdxl).to(accelerator.device)
print("======== init_emotion_prompts done ================")
# Freeze vae, text_encoder(s), reference unet
vae.requires_grad_(False)
if args.sdxl:
text_encoder_one.requires_grad_(False)
text_encoder_two.requires_grad_(False)
else:
text_encoder.requires_grad_(False)
if args.train_method == 'dpo':
ref_unet.requires_grad_(False)
# if args.use_lora:
# unet.requires_grad_(False)
# args.lora_rank default 32
lora_p, negation = add_lora_to_model(unet, dropout=0.1, lora_rank=args.lora_rank, scale=1.0)
# xformers efficient attention
if is_xformers_available():
import xformers
xformers_version = version.parse(xformers.__version__)
if xformers_version == version.parse("0.0.16"):
logger.warning(
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
)
unet.enable_xformers_memory_efficient_attention()
else:
raise ValueError("xformers is not available. Make sure it is installed correctly")
# BRAM NOTE: We're using >=0.16.0. Below was a bit of a bug hive. I hacked around it, but ideally ref_unet wouldn't
# be getting passed here
#
# `accelerate` 0.16.0 will have better support for customized saving
if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
def save_model_hook(models, weights, output_dir):
print("save_model_hook")
for i in range(len(models)):
print(models[i].__class__.__name__)
if len(models) > 1:
assert args.train_method == 'dpo' # 2nd model is just ref_unet in DPO case
if args.sdxl:
# UNet2DLoRAConditionModel
models[0].save_pretrained(os.path.join(output_dir, 'unet_with_lora'))
weights.pop()
# EmotionEmbedding
torch.save(models[1].state_dict(), os.path.join(output_dir, "EmotionEmbedding.pth"))
weights.pop()
def load_model_hook(models, input_dir):
print("load_model_hook")
for i in range(len(models)):
print(models[i].__class__.__name__)
if len(models) > 1:
assert args.train_method == 'dpo' # 2nd model is just ref_unet in DPO case
if args.sdxl:
# UNet2DLoRAConditionModel
model = models.pop(0)
from safetensors.torch import load_file
# 加载两个safetensors文件
state_dict_1 = load_file(os.path.join(input_dir, 'unet_with_lora', 'diffusion_pytorch_model-00001-of-00002.safetensors'))
state_dict_2 = load_file(os.path.join(input_dir, 'unet_with_lora', 'diffusion_pytorch_model-00002-of-00002.safetensors'))
# 合并状态字典
state_dict = {**state_dict_1, **state_dict_2}
model.load_state_dict(state_dict)
# EmotionEmbedding
model = models.pop(0)
state_dict = torch.load(os.path.join(input_dir, "EmotionEmbedding.pth"), weights_only=True)
model.load_state_dict(state_dict)
accelerator.register_save_state_pre_hook(save_model_hook)
accelerator.register_load_state_pre_hook(load_model_hook)
if args.gradient_checkpointing or args.sdxl: # (args.sdxl and ('turbo' not in args.pretrained_model_name_or_path) ):
print("Enabling gradient checkpointing, either because you asked for this or because you're using SDXL")
unet.enable_gradient_checkpointing()
# Bram Note: haven't touched
# Enable TF32 for faster training on Ampere GPUs,
# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
if args.allow_tf32:
torch.backends.cuda.matmul.allow_tf32 = True
if args.scale_lr:
args.learning_rate = (
args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
)
unet_params = []
lora_params = []
for name, param in unet.named_parameters():
if 'lora' in name.lower():
lora_params.append(param)
else:
if param.requires_grad:
unet_params.append(param)
# if args.use_adafactor or args.sdxl:
print("Using Adafactor either because you asked for it or you're using SDXL")
param_groups = [
{
"params": unet_params,
"lr": args.learning_rate_unet,
},
{
"params": lora_params,
"lr": args.learning_rate_lora,
},
{
"params": visual_prompts.parameters(),
"lr": args.learning_rate_prompts,
}
]
optimizer = transformers.Adafactor(
param_groups,
weight_decay=args.adam_weight_decay,
clip_threshold=1.0,
scale_parameter=False,
relative_step=False
)
# else:
# optimizer = torch.optim.AdamW([
# {"params": unet_params, "lr": args.learning_rate,
# "beta": (args.adam_beta1, args.adam_beta2), "weight_decay": args.adam_weight_decay,
# "eps": args.adam_epsilon},
# {"params": lora_params, "lr": args.learning_rate*5,
# "beta": (args.adam_beta1, args.adam_beta2), "weight_decay": args.adam_weight_decay,
# "eps": args.adam_epsilon},
# {"params": visual_prompts.parameters(), "lr": args.learning_rate*5,
# "beta": (args.adam_beta1, args.adam_beta2), "weight_decay": args.adam_weight_decay,
# "eps": args.adam_epsilon}
# ])
# In distributed training, the load_dataset function guarantees that only one local process can concurrently
# download the dataset.
dataset = load_dataset(path='parquet', data_files=args.dataset_path)
caption_column = args.caption_column
def tokenize_captions(examples, is_train=True):
captions = []
for caption in examples[caption_column]:
if random.random() < args.proportion_empty_prompts:
captions.append("")
elif isinstance(caption, str):
captions.append(caption)
elif isinstance(caption, (list, np.ndarray)):
# take a random caption if there are multiple
captions.append(random.choice(caption) if is_train else caption[0])
else:
raise ValueError(
f"Caption column `{caption_column}` should contain either strings or lists of strings."
)
inputs = tokenizer(
captions, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt"
)
return inputs.input_ids
# Preprocessing the datasets.
train_transforms = transforms.Compose(
[
transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
transforms.RandomCrop(args.resolution) if args.random_crop else transforms.CenterCrop(args.resolution),
transforms.Lambda(lambda x: x) if args.no_hflip else transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
]
)
#### START PREPROCESSING/COLLATION ####
if args.train_method == 'dpo':
print("Ignoring image_column variable, reading from jpg_0 and jpg_1")
def preprocess_train(examples):
all_pixel_values = []
for col_name in ['jpg_0', 'jpg_1']:
images = [Image.open(io.BytesIO(im_bytes)).convert("RGB")
for im_bytes in examples[col_name]]
pixel_values = [train_transforms(image) for image in images]
all_pixel_values.append(pixel_values)
# DOUBLE win images for visual prompts optimization
# all_pixel_values
# [[jpg_0,...],[jpg_1,...]]
# => [[jpg_0,...],[jpg_1,...],[jpg_0,...]]
all_pixel_values.append(copy.deepcopy(all_pixel_values[0]))
# Triple on channel dim, jpg_y then jpg_w and jpg_y
# im_tup_iterator = [(jpg_0,jpg_1,jpg_0),...]
im_tup_iterator = zip(*all_pixel_values)
combined_pixel_values = []
# item = (jpg_0,jpg_1,jpg_0), label
for im_tup, label_0 in zip(im_tup_iterator, examples['label_0']):
# print(len(im_tup), im_tup[0].shape)
# 3 torch.Size([3, 512, 512])
if label_0==0 and (not args.choice_model): # don't want to flip things if using choice_model for AI feedback
im_tup = im_tup[::-1]
# [3+3+3, 512, 512]
combined_im = torch.cat(im_tup, dim=0) # no batch dim
combined_pixel_values.append(combined_im)
# [[9, 512, 512],...]
examples["pixel_values"] = combined_pixel_values
# SDXL takes raw prompts
if not args.sdxl:
examples["input_ids"] = tokenize_captions(examples)
return examples
def collate_fn(examples):
# [bs, 9, 512, 512]
pixel_values = torch.stack([example["pixel_values"] for example in examples])
pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
return_d = {"pixel_values": pixel_values}
return_d["emotions"] = [example["emotion"] for example in examples]
# SDXL takes raw prompts
if args.sdxl:
return_d["caption"] = [example["caption"] for example in examples]
else:
return_d["input_ids"] = torch.stack([example["input_ids"] for example in examples])
if args.choice_model:
# If using AIF then deliver image data for choice model to determine if should flip pixel values
for k in ['jpg_0', 'jpg_1']:
return_d[k] = [Image.open(io.BytesIO( example[k])).convert("RGB")
for example in examples]
return_d["caption"] = [example["caption"] for example in examples]
return return_d
### DATASET #####
with accelerator.main_process_first():
if args.max_train_samples is not None:
dataset[args.split] = dataset[args.split].shuffle(seed=args.seed).select(range(args.max_train_samples))
train_dataset = dataset[args.split].with_transform(preprocess_train)
# DataLoaders creation:
train_dataloader = torch.utils.data.DataLoader(
train_dataset,
shuffle=(args.split=='train'),
collate_fn=collate_fn,
batch_size=args.train_batch_size,
num_workers=args.dataloader_num_workers,
drop_last=True
)
##### END BIG OLD DATASET BLOCK #####
# Scheduler and math around the number of training steps.
overrode_max_train_steps = False
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
if args.max_train_steps is None:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
overrode_max_train_steps = True
lr_scheduler = get_scheduler(
args.lr_scheduler,
optimizer=optimizer,
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
num_training_steps=args.max_train_steps * accelerator.num_processes,
)
#### START ACCELERATOR PREP ####
unet, visual_prompts, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet, visual_prompts, optimizer, train_dataloader, lr_scheduler
)
weight_dtype = torch.float32
# Move text_encode and vae to gpu and cast to weight_dtype
vae.to(accelerator.device, dtype=weight_dtype)
if args.sdxl:
text_encoder_one.to(accelerator.device, dtype=weight_dtype)
text_encoder_two.to(accelerator.device, dtype=weight_dtype)
# print("offload vae (this actually stays as CPU)")
# vae = accelerate.cpu_offload(vae)
# print("Offloading text encoders to cpu")
text_encoder_one = accelerate.cpu_offload(text_encoder_one)
text_encoder_two = accelerate.cpu_offload(text_encoder_two)
if args.train_method == 'dpo':
ref_unet.to(accelerator.device, dtype=weight_dtype)
# print("offload ref_unet")
# ref_unet = accelerate.cpu_offload(ref_unet)
else:
text_encoder.to(accelerator.device, dtype=weight_dtype)
if args.train_method == 'dpo':
ref_unet.to(accelerator.device, dtype=weight_dtype)
### END ACCELERATOR PREP ###
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
if overrode_max_train_steps:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
# Afterwards we recalculate our number of training epochs
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
# We need to initialize the trackers we use, and also store our configuration.
# The trackers initializes automatically on the main process.
if accelerator.is_main_process:
tracker_config = dict(vars(args))
init_kwargs = {
"wandb": {
"name": args.tracker_run_name,
}
}
accelerator.init_trackers(args.tracker_project_name, tracker_config, init_kwargs)
# Training initialization
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
logger.info("***** Running training *****")
logger.info(f" Num examples = {len(train_dataset)}")
logger.info(f" Num Epochs = {args.num_train_epochs}")
logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
logger.info(f" Total optimization steps = {args.max_train_steps}")
global_step = 0
first_epoch = 0
# Potentially load in the weights and states from a previous save
if args.resume_from_checkpoint:
if args.resume_from_checkpoint != "latest":
path = os.path.basename(args.resume_from_checkpoint)
else:
# Get the most recent checkpoint
dirs = os.listdir(args.output_dir)
dirs = [d for d in dirs if d.startswith("checkpoint")]
dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
path = dirs[-1] if len(dirs) > 0 else None
if path is None:
accelerator.print(
f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
)
args.resume_from_checkpoint = None
else:
accelerator.print(f"Resuming from checkpoint {path}")
accelerator.load_state(os.path.join(args.output_dir, path))
global_step = int(path.split("-")[1])
resume_global_step = global_step * args.gradient_accumulation_steps
first_epoch = global_step // num_update_steps_per_epoch
resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps)
# Bram Note: This was pretty janky to wrangle to look proper but works to my liking now
progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process)
progress_bar.set_description("Steps")
#### START MAIN TRAINING LOOP #####
for epoch in range(first_epoch, args.num_train_epochs):
unet.train()
train_loss = 0.0
implicit_acc_accumulated_d, implicit_acc_accumulated_c = 0.0, 0.0
for step, batch in enumerate(train_dataloader):
# Skip steps until we reach the resumed step
if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step and (not args.hard_skip_resume):
if step % args.gradient_accumulation_steps == 0:
print(f"Dummy processing step {step}, will start training at {resume_step}")
continue
with accelerator.accumulate(unet):
# Convert images to latent space
if args.train_method == 'dpo':
# [bs, 6, 512, 512] =>
# [[bs, 3, 512, 512]*3] =>
# [bs*3, 3, 512, 512]
feed_pixel_values = torch.cat(batch["pixel_values"].chunk(3, dim=1))
elif args.train_method == 'sft':
feed_pixel_values = batch["pixel_values"]
#### Diffusion Stuff ####
# encode pixels --> latents
with torch.no_grad():
latents = vae.encode(feed_pixel_values.to(weight_dtype)).latent_dist.sample()
latents = latents * vae.config.scaling_factor
# Sample noise that we'll add to the latents
noise = torch.randn_like(latents)
bsz = latents.shape[0]
# Sample a random timestep for each image
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
timesteps = timesteps.long()
if args.train_method == 'dpo':
# make timesteps and noise same for pairs in DPO
# [bs] => [1/3bs, 1/3bs, 1/3bs] => [1/3bs] => [bs]
timesteps = timesteps.chunk(3)[0].repeat(3)
noise = noise.chunk(3)[0].repeat(3, 1, 1, 1)
# Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process)
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
### START PREP BATCH ###
if args.sdxl:
# Get the text embedding for conditioning
with torch.no_grad():
# Need to compute "time_ids" https://github.com/huggingface/diffusers/blob/v0.20.0-release/examples/text_to_image/train_text_to_image_sdxl.py#L969
# for SDXL-base these are torch.tensor([args.resolution, args.resolution, *crop_coords_top_left, *target_size))
add_time_ids = torch.tensor([args.resolution,
args.resolution,
0,
0,
args.resolution,
args.resolution],
dtype=weight_dtype,
device=accelerator.device)[None, :].repeat(timesteps.size(0), 1)
prompt_batch = encode_prompt_sdxl(batch,
text_encoders,
tokenizers,
args.proportion_empty_prompts,
caption_column,
is_train=True,
)
if args.train_method == 'dpo':
prompt_batch["prompt_embeds"] = prompt_batch["prompt_embeds"].repeat(3, 1, 1)
prompt_batch["pooled_prompt_embeds"] = prompt_batch["pooled_prompt_embeds"].repeat(3, 1)
unet_added_conditions = {"time_ids": add_time_ids,
"text_embeds": prompt_batch["pooled_prompt_embeds"]}
else: # sd1.5
# Get the text embedding for conditioning
encoder_hidden_states = text_encoder(batch["input_ids"])[0]
if args.train_method == 'dpo':
encoder_hidden_states = encoder_hidden_states.repeat(2, 1, 1)
emotion_visual_prompts = visual_prompts(batch['emotions'])
if args.train_method == 'dpo':
random_emotions = random_sample_emotions(batch['emotions'])
random_emotion_visual_prompts = visual_prompts(random_emotions)
emotion_visual_prompts = torch.cat([emotion_visual_prompts, emotion_visual_prompts, random_emotion_visual_prompts], dim=0)
#### END PREP BATCH ####
assert noise_scheduler.config.prediction_type == "epsilon"
target = noise
# Make the prediction from the model we're learning
model_batch_args = (
noisy_latents,
timesteps,
prompt_batch["prompt_embeds"] if args.sdxl else encoder_hidden_states
)
lora_model_batch_args = (
noisy_latents,
timesteps,
prompt_batch["prompt_embeds"] if args.sdxl else encoder_hidden_states,
emotion_visual_prompts
)
added_cond_kwargs = unet_added_conditions if args.sdxl else None
model_pred = unet(
*lora_model_batch_args,
added_cond_kwargs = added_cond_kwargs
).sample
#### START LOSS COMPUTATION ####
if args.train_method == 'sft': # SFT, casting for F.mse_loss
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
elif args.train_method == 'dpo':
# model_pred and ref_pred will be (2 * LBS) x 4 x latent_spatial_dim x latent_spatial_dim
# losses are both 2 * LBS
# 1st half of tensors is preferred (y_w), second half is unpreferred
model_losses = (model_pred - target).pow(2).mean(dim=[1,2,3])
model_losses_w, model_losses_l_d, model_losses_l_c = model_losses.chunk(3)
# below for logging purposes
raw_model_loss = (model_losses_w.mean() + model_losses_l_d.mean() + model_losses_l_c.mean()) / 3
model_diff_d = model_losses_w - model_losses_l_d # These are both LBS (as is t)
model_diff_c = model_losses_w - model_losses_l_c
with torch.no_grad(): # Get the reference policy (unet) prediction
ref_pred = ref_unet(
*model_batch_args,
added_cond_kwargs = added_cond_kwargs
).sample.detach()
ref_losses = (ref_pred - target).pow(2).mean(dim=[1,2,3])
ref_losses_w, ref_losses_l_d, ref_losses_l_c = ref_losses.chunk(3)
ref_diff = ref_losses_w - ref_losses_l_d
raw_ref_loss = ref_losses.mean()
scale_term = -0.5 * args.beta_dpo # beta_dpo = 5000
inside_term_d = scale_term * (model_diff_d - ref_diff)
implicit_acc_d = (inside_term_d > 0).sum().float() / inside_term_d.size(0)
# the scale_term may need to be adjust
# inside_term_c = -1 * model_diff_c
inside_term_c = scale_term * model_diff_c
implicit_acc_c = (inside_term_c > 0).sum().float() / inside_term_c.size(0)
loss = -1 * 0.5 * (F.logsigmoid(inside_term_d).mean() + F.logsigmoid(inside_term_c).mean())
#### END LOSS COMPUTATION ###
# Gather the losses across all processes for logging
avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()
train_loss += avg_loss.item() / args.gradient_accumulation_steps
# Also gather:
# - model MSE vs reference MSE (useful to observe divergent behavior)
# - Implicit accuracy
if args.train_method == 'dpo':
avg_model_mse = accelerator.gather(raw_model_loss.repeat(args.train_batch_size)).mean().item()
avg_ref_mse = accelerator.gather(raw_ref_loss.repeat(args.train_batch_size)).mean().item()
avg_acc_d = accelerator.gather(implicit_acc_d).mean().item()
avg_acc_c = accelerator.gather(implicit_acc_c).mean().item()
implicit_acc_accumulated_d += avg_acc_d / args.gradient_accumulation_steps
implicit_acc_accumulated_c += avg_acc_c / args.gradient_accumulation_steps
# Backpropagate
accelerator.backward(loss)
if accelerator.sync_gradients:
if not args.use_adafactor: # Adafactor does itself, maybe could do here to cut down on code
accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm)
# # 打印看看梯度
# for name, param in unet.named_parameters():
# # if "mid_block.attentions.0.transformer_blocks" in name and "lora" in name:
# if param.grad is not None:
# print(f"{name} has gradient ✅, grad mean: {param.grad.mean().item()}")
# else:
# print(f"{name} has NO gradient ❌")
# for name, param in visual_prompts.named_parameters():
# if param.grad is not None:
# print(f"{name} has gradient ✅, grad mean: {param.grad.mean().item()}")
# else:
# print(f"{name} has NO gradient ❌")
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
# Checks if the accelerator has just performed an optimization step, if so do "end of batch" logging
if accelerator.sync_gradients:
progress_bar.update(1)
global_step += 1
accelerator.log({"train_loss": train_loss}, step=global_step)
if args.train_method == 'dpo':
accelerator.log({"model_mse_unaccumulated": avg_model_mse}, step=global_step)
accelerator.log({"ref_mse_unaccumulated": avg_ref_mse}, step=global_step)
accelerator.log({"avg_acc_d": implicit_acc_accumulated_d}, step=global_step)
accelerator.log({"avg_acc_c": implicit_acc_accumulated_c}, step=global_step)
train_loss = 0.0
implicit_acc_accumulated_d, implicit_acc_accumulated_c = 0.0, 0.0
if global_step % args.checkpointing_steps == 0:
if accelerator.is_main_process:
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
accelerator.save_state(save_path)
logger.info(f"Saved state to {save_path}")
logger.info("Pretty sure saving/loading is fixed but proceed cautiously")
logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
if args.train_method == 'dpo':
logs["implicit_acc_d"] = avg_acc_d
logs["implicit_acc_c"] = avg_acc_c
progress_bar.set_postfix(**logs)
if global_step >= args.max_train_steps:
break
# Create the pipeline using the trained modules and save it.
# This will save to top level of output_dir instead of a checkpoint directory
accelerator.wait_for_everyone()
if accelerator.is_main_process:
unet = accelerator.unwrap_model(unet)
if args.sdxl:
# Serialize pipeline.
if args.use_lora:
unet_lora_state_dict = convert_state_dict_to_diffusers(
get_peft_model_state_dict(unet)
)
StableDiffusionXLPipeline.save_lora_weights(
save_directory=os.path.join(args.output_dir, 'lora_weights_64'),
unet_lora_layers=unet_lora_state_dict,
safe_serialization=True,
)
logger.info("Saved LoRA Model to {}".format(os.path.join(args.output_dir, 'lora_weights_64')))
else:
vae = AutoencoderKL.from_pretrained(
vae_path,
subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None,
revision=args.revision,
torch_dtype=weight_dtype,
)
pipeline = StableDiffusionXLPipeline.from_pretrained(
args.pretrained_model_name_or_path, unet=unet, vae=vae, revision=args.revision, torch_dtype=weight_dtype
)
pipeline.save_pretrained(args.output_dir)
logger.info("Saved Model to {}".format(args.output_dir))
else:
pipeline = StableDiffusionPipeline.from_pretrained(
args.pretrained_model_name_or_path,
text_encoder=text_encoder,
vae=vae,
unet=unet,
revision=args.revision,
)
if not args.use_lora: pipeline.save_pretrained(args.output_dir)
accelerator.end_training()
if __name__ == "__main__":
main()