import os import math import torch import numpy as np import matplotlib.pyplot as plt from torch.utils.data import DataLoader, Sampler from collections import defaultdict from torch.optim.lr_scheduler import LambdaLR from diffusers import UNet2DConditionModel, AutoencoderKL, DDPMScheduler from accelerate import Accelerator from datasets import load_from_disk from tqdm import tqdm from PIL import Image,ImageOps import wandb import random import gc from accelerate.state import DistributedType from torch.distributed import broadcast_object_list from torch.utils.checkpoint import checkpoint from diffusers.models.attention_processor import AttnProcessor2_0 from datetime import datetime import bitsandbytes as bnb # --------------------------- Параметры --------------------------- save_path = "datasets/768" # "datasets/576" #"datasets/576p2" #"datasets/1152p2" #"datasets/576p2" #"datasets/dataset384_temp" #"datasets/dataset384" #"datasets/imagenet-1kk" #"datasets/siski576" #"datasets/siski384" #"datasets/siski64" #"datasets/mnist" batch_size = 50 #30 #26 #45 #11 #45 #555 #35 #7 base_learning_rate = 2.5e-5 #4e-6 #2e-5 #4e-6 #9.5e-7 #9e-7 #2e-6 #1e-6 #9e-7 #1e-6 #2e-6 #1e-6 #2e-6 #6e-6 #2e-6 #8e-7 #6e-6 #2e-5 #4e-5 #3e-5 #5e-5 #8e-5 min_learning_rate = 2.5e-5 #2e-5 num_epochs = 1 #2 #36 #18 project = "sdxxs" use_wandb = True save_model = True limit = 0 #200000 #0 checkpoints_folder = "" # Параметры для диффузии n_diffusion_steps = 40 samples_to_generate = 12 guidance_scale = 5 sample_interval_share = 20 # samples/save per epoch # Папки для сохранения результатов generated_folder = "samples" os.makedirs(generated_folder, exist_ok=True) # Настройка seed для воспроизводимости current_date = datetime.now() seed = int(current_date.strftime("%Y%m%d")) fixed_seed = True if fixed_seed: torch.manual_seed(seed) np.random.seed(seed) random.seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) # --------------------------- Параметры LoRA --------------------------- # pip install peft lora_name = "" #"nusha" # Имя для сохранения/загрузки LoRA адаптеров lora_rank = 32 # Ранг LoRA (чем меньше, тем компактнее модель) lora_alpha = 64 # Альфа параметр LoRA, определяющий масштаб print("init") # Включение Flash Attention 2/SDPA torch.backends.cuda.enable_flash_sdp(True) # --------------------------- Инициализация Accelerator -------------------- dtype = torch.bfloat16 accelerator = Accelerator(mixed_precision="bf16") device = accelerator.device gen = torch.Generator(device=device) gen.manual_seed(seed) # --------------------------- Инициализация WandB --------------------------- if use_wandb and accelerator.is_main_process: wandb.init(project=project+lora_name, config={ "batch_size": batch_size, "base_learning_rate": base_learning_rate, "num_epochs": num_epochs, "n_diffusion_steps": n_diffusion_steps, "samples_to_generate": samples_to_generate, "dtype": str(dtype) }) # --------------------------- Загрузка датасета --------------------------- class ResolutionBatchSampler(Sampler): """Сэмплер, который группирует примеры по одинаковым размерам""" def __init__(self, dataset, batch_size, shuffle=True, drop_last=False): self.dataset = dataset self.batch_size = batch_size self.shuffle = shuffle self.drop_last = drop_last # Группируем примеры по размерам self.size_groups = defaultdict(list) try: widths = dataset["width"] heights = dataset["height"] except KeyError: widths = [0] * len(dataset) heights = [0] * len(dataset) for i, (w, h) in enumerate(zip(widths, heights)): size = (w, h) self.size_groups[size].append(i) # Печатаем статистику по размерам print(f"Найдено {len(self.size_groups)} уникальных размеров:") for size, indices in sorted(self.size_groups.items(), key=lambda x: len(x[1]), reverse=True): width, height = size print(f" {width}x{height}: {len(indices)} примеров") # Формируем батчи self.reset() def reset(self): """Сбрасывает и перемешивает индексы""" self.batches = [] for size, indices in self.size_groups.items(): if self.shuffle: indices_copy = indices.copy() random.shuffle(indices_copy) else: indices_copy = indices # Разбиваем на батчи for i in range(0, len(indices_copy), self.batch_size): batch_indices = indices_copy[i:i + self.batch_size] # Пропускаем неполные батчи если drop_last=True if self.drop_last and len(batch_indices) < self.batch_size: continue self.batches.append(batch_indices) # Перемешиваем батчи между собой if self.shuffle: random.shuffle(self.batches) def __iter__(self): self.reset() # Сбрасываем и перемешиваем в начале каждой эпохи return iter(self.batches) def __len__(self): return len(self.batches) # Функция для выборки фиксированных семплов по размерам def get_fixed_samples_by_resolution(dataset, samples_per_group=1): """Выбирает фиксированные семплы для каждого уникального разрешения""" # Группируем по размерам size_groups = defaultdict(list) try: widths = dataset["width"] heights = dataset["height"] except KeyError: widths = [0] * len(dataset) heights = [0] * len(dataset) for i, (w, h) in enumerate(zip(widths, heights)): size = (w, h) size_groups[size].append(i) # Выбираем фиксированные примеры из каждой группы fixed_samples = {} for size, indices in size_groups.items(): # Определяем сколько семплов брать из этой группы n_samples = min(samples_per_group, len(indices)) if len(size_groups)==1: n_samples = samples_to_generate if n_samples == 0: continue # Выбираем случайные индексы sample_indices = random.sample(indices, n_samples) samples_data = [dataset[idx] for idx in sample_indices] # Собираем данные latents = torch.tensor(np.array([item["vae"] for item in samples_data]), dtype=dtype).to(device) embeddings = torch.tensor(np.array([item["embeddings"] for item in samples_data]), dtype=dtype).to(device) texts = [item["text"] for item in samples_data] # Сохраняем для этого размера fixed_samples[size] = (latents, embeddings, texts) print(f"Создано {len(fixed_samples)} групп фиксированных семплов по разрешениям") return fixed_samples if limit > 0: dataset = load_from_disk(save_path).select(range(limit)) else: dataset = load_from_disk(save_path) def collate_fn(batch): # Преобразуем список в тензоры и перемещаем на девайс latents = torch.tensor(np.array([item["vae"] for item in batch]), dtype=dtype).to(device) embeddings = torch.tensor(np.array([item["embeddings"] for item in batch]), dtype=dtype).to(device) return latents, embeddings # Используем наш ResolutionBatchSampler batch_sampler = ResolutionBatchSampler(dataset, batch_size=batch_size, shuffle=True) dataloader = DataLoader(dataset, batch_sampler=batch_sampler, collate_fn=collate_fn) print("Total samples",len(dataloader)) dataloader = accelerator.prepare(dataloader) # --------------------------- Загрузка моделей --------------------------- # VAE загружается на CPU для экономии GPU-памяти vae = AutoencoderKL.from_pretrained("AuraDiffusion/16ch-vae").to("cpu", dtype=dtype) # DDPMScheduler с V_Prediction и Zero-SNR scheduler = DDPMScheduler( num_train_timesteps=1000, # Полный график шагов для обучения prediction_type="v_prediction", # V-Prediction rescale_betas_zero_snr=True, # Включение Zero-SNR timestep_spacing="leading", # Добавляем улучшенное распределение шагов #steps_offset=1 # Избегаем проблем с нулевым timestep ) # Инициализация переменных для возобновления обучения start_epoch = 0 global_step = 0 # Расчёт общего количества шагов total_training_steps = (len(dataloader) * num_epochs) # Get the world size world_size = accelerator.state.num_processes print(f"World Size: {world_size}") # Опция загрузки модели из последнего чекпоинта (если существует) latest_checkpoint = os.path.join(checkpoints_folder, project) if os.path.isdir(latest_checkpoint): print("Загружаем UNet из чекпоинта:", latest_checkpoint) unet = UNet2DConditionModel.from_pretrained(latest_checkpoint).to(device, dtype=dtype) unet.enable_gradient_checkpointing() unet.set_use_memory_efficient_attention_xformers(False) # отключаем xformers try: unet.set_attn_processor(AttnProcessor2_0()) # Используем стандартный AttnProcessor print("SDPA включен через set_attn_processor.") except Exception as e: print(f"Ошибка при включении SDPA: {e}") print("Попытка использовать enable_xformers_memory_efficient_attention.") unet.set_use_memory_efficient_attention_xformers(True) if project == "sdxxs": # Замораживаем все параметры модели for param in unet.parameters(): param.requires_grad = False # Список параметров, которые вы хотите тренировать target_params = [ "down_blocks.3.downsamplers.0.conv.bias", "down_blocks.3.downsamplers.0.conv.weight", "down_blocks.4.", "mid_block.attentions.0.", "up_blocks.0" ] # Размораживаем только целевые параметры for name, param in unet.named_parameters(): for target in target_params: if name.startswith(target): param.requires_grad = True break # Определяем параметры для оптимизации trainable_params = [p for p in unet.parameters() if p.requires_grad] lora_params_count = sum(p.numel() for p in trainable_params) print(f"Количество обучаемых параметров (как бля LoRA): {lora_params_count:,}") if lora_name: print(f"--- Настройка LoRA через PEFT (Rank={lora_rank}, Alpha={lora_alpha}) ---") from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training from peft.tuners.lora import LoraModel import os # 1. Замораживаем все параметры UNet unet.requires_grad_(False) print("Параметры базового UNet заморожены.") # 2. Создаем конфигурацию LoRA lora_config = LoraConfig( r=lora_rank, lora_alpha=lora_alpha, target_modules=["to_q", "to_k", "to_v", "to_out.0"], ) unet.add_adapter(lora_config) # 3. Оборачиваем UNet в PEFT-модель from peft import get_peft_model peft_unet = get_peft_model(unet, lora_config) # 4. Получаем параметры для оптимизации params_to_optimize = list(p for p in peft_unet.parameters() if p.requires_grad) # 5. Выводим информацию о количестве параметров if accelerator.is_main_process: lora_params_count = sum(p.numel() for p in params_to_optimize) total_params_count = sum(p.numel() for p in unet.parameters()) print(f"Количество обучаемых параметров (LoRA): {lora_params_count:,}") print(f"Общее количество параметров UNet: {total_params_count:,}") # 6. Путь для сохранения lora_save_path = os.path.join("lora", lora_name) os.makedirs(lora_save_path, exist_ok=True) # 7. Функция для сохранения def save_lora_checkpoint(model): if accelerator.is_main_process: print(f"Сохраняем LoRA адаптеры в {lora_save_path}") from peft.utils.save_and_load import get_peft_model_state_dict # Получаем state_dict только LoRA lora_state_dict = get_peft_model_state_dict(model) # Сохраняем веса torch.save(lora_state_dict, os.path.join(lora_save_path, "adapter_model.bin")) # Сохраняем конфиг model.peft_config["default"].save_pretrained(lora_save_path) # SDXL must be compatible from diffusers import StableDiffusionXLPipeline StableDiffusionXLPipeline.save_lora_weights(lora_save_path, lora_state_dict) # --------------------------- Оптимизатор --------------------------- # Определяем параметры для оптимизации if lora_name or project=="sdxxs": # Если используется LoRA, оптимизируем только параметры LoRA trainable_params = [p for p in unet.parameters() if p.requires_grad] else: # Иначе оптимизируем все параметры trainable_params = list(unet.parameters()) # [1] Создаем словарь оптимизаторов (fused backward) optimizer_dict = { p: bnb.optim.AdamW8bit( [p], # Каждый параметр получает свой оптимизатор lr=base_learning_rate, betas=(0.9, 0.999), weight_decay=1e-5, eps=1e-8 ) for p in trainable_params } # [2] Определяем hook для применения оптимизатора сразу после накопления градиента def optimizer_hook(param): optimizer_dict[param].step() optimizer_dict[param].zero_grad(set_to_none=True) # [3] Регистрируем hook для trainable параметров модели for param in trainable_params: param.register_post_accumulate_grad_hook(optimizer_hook) # Подготовка через Accelerator unet, optimizer = accelerator.prepare(unet, optimizer_dict) # --------------------------- Фиксированные семплы для генерации --------------------------- # Примеры фиксированных семплов по размерам fixed_samples = get_fixed_samples_by_resolution(dataset) @torch.no_grad() def generate_and_save_samples(fixed_samples,step): """ Генерирует семплы для каждого из разрешений и сохраняет их. Args: step: Текущий шаг обучения fixed_samples: Словарь, где ключи - размеры (width, height), а значения - кортежи (latents, embeddings) """ try: original_model = accelerator.unwrap_model(unet) # Перемещаем VAE на device для семплирования vae.to(accelerator.device, dtype=dtype) # Устанавливаем количество diffusion шагов scheduler.set_timesteps(n_diffusion_steps) all_generated_images = [] size_info = [] # Для хранения информации о размере для каждого изображения all_captions = [] # Проходим по всем группам размеров for size, (sample_latents, sample_text_embeddings, sample_text) in fixed_samples.items(): width, height = size size_info.append(f"{width}x{height}") #print(f"Генерация {sample_latents.shape[0]} изображений размером {width}x{height}") # Инициализируем латенты случайным шумом для этой группы noise = torch.randn( sample_latents.shape, generator=gen, device=sample_latents.device, dtype=sample_latents.dtype ) # Начинаем с шума current_latents = noise.clone() # Подготовка текстовых эмбеддингов для guidance if guidance_scale > 0: empty_embeddings = torch.zeros_like(sample_text_embeddings) text_embeddings = torch.cat([empty_embeddings, sample_text_embeddings], dim=0) else: text_embeddings = sample_text_embeddings # Генерация изображений for t in scheduler.timesteps: # Подготовка входных данных для UNet if guidance_scale > 0: latent_model_input = torch.cat([current_latents] * 2) latent_model_input = scheduler.scale_model_input(latent_model_input, t) else: latent_model_input = scheduler.scale_model_input(current_latents, t) # Предсказание шума noise_pred = original_model(latent_model_input, t, text_embeddings).sample # Применение guidance scale if guidance_scale > 0: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # Обновление латентов current_latents = scheduler.step(noise_pred, t, current_latents).prev_sample # Декодирование через VAE latent = (current_latents.detach() / vae.config.scaling_factor) + vae.config.shift_factor latent = latent.to(accelerator.device, dtype=dtype) decoded = vae.decode(latent).sample # Преобразуем тензоры в PIL-изображения и сохраняем for img_idx, img_tensor in enumerate(decoded): img = (img_tensor.to(torch.float32) / 2 + 0.5).clamp(0, 1).cpu().numpy().transpose(1, 2, 0) pil_img = Image.fromarray((img * 255).astype("uint8")) # Определяем максимальные ширину и высоту max_width = max(size[0] for size in fixed_samples.keys()) max_height = max(size[1] for size in fixed_samples.keys()) max_width = max(255,max_width) max_height = max(255,max_height) # Добавляем padding, чтобы изображение стало размером max_width x max_height padded_img = ImageOps.pad(pil_img, (max_width, max_height), color='white') all_generated_images.append(padded_img) caption_text = sample_text[img_idx][:200] if img_idx < len(sample_text) else "" all_captions.append(caption_text) # Сохраняем с информацией о размере в имени файла save_path = f"{generated_folder}/{project}_{width}x{height}_{img_idx}.jpg" pil_img.save(save_path, "JPEG", quality=96) # Отправляем изображения на WandB с информацией о размере if use_wandb and accelerator.is_main_process: wandb_images = [ wandb.Image(img, caption=f"{all_captions[i]}") for i, img in enumerate(all_generated_images) ] wandb.log({"generated_images": wandb_images, "global_step": step}) finally: # Гарантированное перемещение VAE обратно на CPU vae.to("cpu") if original_model is not None: del original_model # Очистка всех тензоров for var in list(locals().keys()): if isinstance(locals()[var], torch.Tensor): del locals()[var] torch.cuda.empty_cache() gc.collect() # --------------------------- Генерация сэмплов перед обучением --------------------------- if accelerator.is_main_process: if save_model: print("Генерация сэмплов до старта обучения...") generate_and_save_samples(fixed_samples,0) # Модифицируем функцию сохранения модели для поддержки LoRA def save_checkpoint(unet): if accelerator.is_main_process: if lora_name: # Сохраняем только LoRA адаптеры save_lora_checkpoint(unet) else: # Сохраняем полную модель accelerator.unwrap_model(unet).save_pretrained(os.path.join(checkpoints_folder, f"{project}")) # --------------------------- Тренировочный цикл --------------------------- # Для логирования среднего лосса каждые % эпохи if accelerator.is_main_process: print(f"Total steps per GPU: {total_training_steps}") print(f"[GPU {accelerator.process_index}] Total steps: {total_training_steps}") epoch_loss_points = [] progress_bar = tqdm(total=total_training_steps, disable=not accelerator.is_local_main_process, desc="Training", unit="step") # Определяем интервал для сэмплирования и логирования в пределах эпохи (10% эпохи) steps_per_epoch = len(dataloader) sample_interval = max(1, steps_per_epoch // sample_interval_share) # Начинаем с указанной эпохи (полезно при возобновлении) for epoch in range(start_epoch, start_epoch + num_epochs): batch_losses = [] unet.train() for step, (latents, embeddings) in enumerate(dataloader): with accelerator.accumulate(unet): if save_model == False and step == 3 : used_gb = torch.cuda.max_memory_allocated() / 1024**3 print(f"Шаг {step}: {used_gb:.2f} GB") # Forward pass noise = torch.randn_like(latents) timesteps = torch.randint( 1, # Начинаем с 1, не с 0 scheduler.config.num_train_timesteps, (latents.shape[0],), device=device ).long() # Добавляем шум к латентам noisy_latents = scheduler.add_noise(latents, noise, timesteps) # Получаем предсказание шума - кастим в bf16 noise_pred = unet(noisy_latents, timesteps, embeddings).sample.to(dtype=torch.bfloat16) # Используем целевое значение v_prediction target = scheduler.get_velocity(latents, noise, timesteps) # Считаем лосс loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float()) # Делаем backward через Accelerator accelerator.backward(loss) # Увеличиваем счетчик глобальных шагов global_step += 1 # Обновляем прогресс-бар progress_bar.update(1) # Логируем метрики if accelerator.is_main_process: current_lr = base_learning_rate batch_losses.append(loss.detach().item()) # Логируем в Wandb if use_wandb: wandb.log({ "loss": loss.detach().item(), "learning_rate": current_lr, "epoch": epoch, "global_step": global_step }) # Генерируем сэмплы с заданным интервалом if global_step % sample_interval == 0: if save_model: save_checkpoint(unet) generate_and_save_samples(fixed_samples,global_step) # Выводим текущий лосс avg_loss = np.mean(batch_losses[-sample_interval:]) #print(f"Эпоха {epoch}, шаг {global_step}, средний лосс: {avg_loss:.6f}, LR: {current_lr:.8f}") if use_wandb: wandb.log({"intermediate_loss": avg_loss}) # По окончании эпохи if accelerator.is_main_process: avg_epoch_loss = np.mean(batch_losses) print(f"\nЭпоха {epoch} завершена. Средний лосс: {avg_epoch_loss:.6f}") if use_wandb: wandb.log({"epoch_loss": avg_epoch_loss, "epoch": epoch+1}) # Завершение обучения - сохраняем финальную модель if accelerator.is_main_process: print("Обучение завершено! Сохраняем финальную модель...") # Сохраняем основную модель #if save_model: save_checkpoint(unet) print("Готово!")