recoilme commited on
Commit
1e17a4b
·
1 Parent(s): 5bf4bad
samples/sdxxs_448x768_0.jpg ADDED

Git LFS Details

  • SHA256: 7ba99baebee405fde5e43afcfed7a84d52c64bd0fcd1eecbabd7fbabfca1498c
  • Pointer size: 130 Bytes
  • Size of remote file: 75.6 kB
samples/sdxxs_512x768_0.jpg ADDED

Git LFS Details

  • SHA256: 4e29a0bf6aa86aea03baab37210a4d79fb0de5c91720438b88078bd524506993
  • Pointer size: 130 Bytes
  • Size of remote file: 70.9 kB
samples/sdxxs_576x768_0.jpg ADDED

Git LFS Details

  • SHA256: b548963f2344d00fc7ee2e5f125bae542e73bbc9305d43b98d2c305c69b183e7
  • Pointer size: 130 Bytes
  • Size of remote file: 77.6 kB
samples/sdxxs_640x768_0.jpg ADDED

Git LFS Details

  • SHA256: 2df3474a5bacc5bb8520eebadd7fe084a0dfd84927a09ce1c3a1d02a7e547383
  • Pointer size: 131 Bytes
  • Size of remote file: 128 kB
samples/sdxxs_704x768_0.jpg ADDED

Git LFS Details

  • SHA256: 618961682ddf88d24b8edbd66e2752463b6038ab731e624e4859a644e878dde7
  • Pointer size: 130 Bytes
  • Size of remote file: 74.9 kB
samples/sdxxs_768x448_0.jpg ADDED

Git LFS Details

  • SHA256: 7e114da71a3e32d9685b96ddbc383e33707c2afee48f9f0e65fdd72763a1ef74
  • Pointer size: 130 Bytes
  • Size of remote file: 27.7 kB
samples/sdxxs_768x512_0.jpg ADDED

Git LFS Details

  • SHA256: ab3a97d0faa406fe0efc168d77074873cf2d347cdda7e02b5a3a9ec06a35cc4c
  • Pointer size: 130 Bytes
  • Size of remote file: 51 kB
samples/sdxxs_768x576_0.jpg ADDED

Git LFS Details

  • SHA256: 19ba56257a9c46122e4ab27e9962f4e0d589035e9363a574e5d08d16028c67ef
  • Pointer size: 130 Bytes
  • Size of remote file: 56.2 kB
samples/sdxxs_768x640_0.jpg ADDED

Git LFS Details

  • SHA256: 91fcc5a057d4ddd14c687c2edcb2496e9cfeac785dcdf5deefc53006490d01c6
  • Pointer size: 130 Bytes
  • Size of remote file: 87.4 kB
samples/sdxxs_768x704_0.jpg ADDED

Git LFS Details

  • SHA256: 16105b3b909c3d1ffe4023df7fddf1b88f50b505b77ae0e6c69aaf282712e51a
  • Pointer size: 131 Bytes
  • Size of remote file: 197 kB
samples/sdxxs_768x768_0.jpg ADDED

Git LFS Details

  • SHA256: c0bf340b59755a9c2c1ae0f53f37e58b776b55c3d070f85373f753d1e5141945
  • Pointer size: 130 Bytes
  • Size of remote file: 80.6 kB
sdxxs/config.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f0b371d4886bb1aef3cfc1186b247bc5c6ad0a0d741bd41a3d68e78a88814e82
3
+ size 1965
sdxxs/diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:da4b7b0e9675b368ab31ce4c30bb590101df73262b466d0781b464ed68bc83ba
3
+ size 7010929912
src/sdxs_create.ipynb CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:dc6987bbccaaedf529bc2d6b55bae0f28e5161fe96dcd48970aa92fb92a55664
3
- size 9793
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:107ff7729ffd04b3e2fd6c4305f74f910d7e7e23141ba71836d055a497fffebd
3
+ size 9787
src/sdxs_sdxxs_transfer.ipynb CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:a7c11b012028e6ed060f3068d08adeb229bdad4b6b8cd2b4e0d50346cafcf02f
3
- size 235161
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:57db74be42dbf73bc551cf86b4302dd2717555280e26325e599ca89f51b4916e
3
+ size 168192
train_sdxxs.py ADDED
@@ -0,0 +1,586 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import math
3
+ import torch
4
+ import numpy as np
5
+ import matplotlib.pyplot as plt
6
+ from torch.utils.data import DataLoader, Sampler
7
+ from collections import defaultdict
8
+ from torch.optim.lr_scheduler import LambdaLR
9
+ from diffusers import UNet2DConditionModel, AutoencoderKL, DDPMScheduler
10
+ from accelerate import Accelerator
11
+ from datasets import load_from_disk
12
+ from tqdm import tqdm
13
+ from PIL import Image,ImageOps
14
+ import wandb
15
+ import random
16
+ import gc
17
+ from accelerate.state import DistributedType
18
+ from torch.distributed import broadcast_object_list
19
+ from torch.utils.checkpoint import checkpoint
20
+ from diffusers.models.attention_processor import AttnProcessor2_0
21
+ from datetime import datetime
22
+ import bitsandbytes as bnb
23
+
24
+ # --------------------------- Параметры ---------------------------
25
+ save_path = "datasets/768" # "datasets/576" #"datasets/576p2" #"datasets/1152p2" #"datasets/576p2" #"datasets/dataset384_temp" #"datasets/dataset384" #"datasets/imagenet-1kk" #"datasets/siski576" #"datasets/siski384" #"datasets/siski64" #"datasets/mnist"
26
+ batch_size = 60 #30 #26 #45 #11 #45 #555 #35 #7
27
+ base_learning_rate = 2.5e-5 #4e-6 #2e-5 #4e-6 #9.5e-7 #9e-7 #2e-6 #1e-6 #9e-7 #1e-6 #2e-6 #1e-6 #2e-6 #6e-6 #2e-6 #8e-7 #6e-6 #2e-5 #4e-5 #3e-5 #5e-5 #8e-5
28
+ min_learning_rate = 2.5e-5 #2e-5
29
+ num_epochs = 1 #2 #36 #18
30
+ project = "sdxxs"
31
+ use_wandb = True
32
+ save_model = True
33
+ limit = 0 #200000 #0
34
+ checkpoints_folder = ""
35
+
36
+ # Параметры для диффузии
37
+ n_diffusion_steps = 40
38
+ samples_to_generate = 12
39
+ guidance_scale = 5
40
+ sample_interval_share = 20 # samples/save per epoch
41
+
42
+ # Папки для сохранения результатов
43
+ generated_folder = "samples"
44
+ os.makedirs(generated_folder, exist_ok=True)
45
+
46
+ # Настройка seed для воспроизводимости
47
+ current_date = datetime.now()
48
+ seed = int(current_date.strftime("%Y%m%d"))
49
+ fixed_seed = True
50
+ if fixed_seed:
51
+ torch.manual_seed(seed)
52
+ np.random.seed(seed)
53
+ random.seed(seed)
54
+ if torch.cuda.is_available():
55
+ torch.cuda.manual_seed_all(seed)
56
+
57
+ # --------------------------- Параметры LoRA ---------------------------
58
+ # pip install peft
59
+ lora_name = "" #"nusha" # Имя для сохранения/загрузки LoRA адаптеров
60
+ lora_rank = 32 # Ранг LoRA (чем меньше, тем компактнее модель)
61
+ lora_alpha = 64 # Альфа параметр LoRA, определяющий масштаб
62
+
63
+ print("init")
64
+ # Включение Flash Attention 2/SDPA
65
+ torch.backends.cuda.enable_flash_sdp(True)
66
+ # --------------------------- Инициализация Accelerator --------------------
67
+ dtype = torch.bfloat16
68
+ accelerator = Accelerator(mixed_precision="bf16")
69
+ device = accelerator.device
70
+ gen = torch.Generator(device=device)
71
+ gen.manual_seed(seed)
72
+
73
+ # --------------------------- Инициализация WandB ---------------------------
74
+ if use_wandb and accelerator.is_main_process:
75
+ wandb.init(project=project+lora_name, config={
76
+ "batch_size": batch_size,
77
+ "base_learning_rate": base_learning_rate,
78
+ "num_epochs": num_epochs,
79
+ "n_diffusion_steps": n_diffusion_steps,
80
+ "samples_to_generate": samples_to_generate,
81
+ "dtype": str(dtype)
82
+ })
83
+
84
+ # --------------------------- Загрузка датасета ---------------------------
85
+ class ResolutionBatchSampler(Sampler):
86
+ """Сэмплер, который группирует примеры по одинаковым размерам"""
87
+ def __init__(self, dataset, batch_size, shuffle=True, drop_last=False):
88
+ self.dataset = dataset
89
+ self.batch_size = batch_size
90
+ self.shuffle = shuffle
91
+ self.drop_last = drop_last
92
+
93
+ # Группируем примеры по размерам
94
+ self.size_groups = defaultdict(list)
95
+
96
+ try:
97
+ widths = dataset["width"]
98
+ heights = dataset["height"]
99
+ except KeyError:
100
+ widths = [0] * len(dataset)
101
+ heights = [0] * len(dataset)
102
+
103
+ for i, (w, h) in enumerate(zip(widths, heights)):
104
+ size = (w, h)
105
+ self.size_groups[size].append(i)
106
+
107
+ # Печатаем статистику по размерам
108
+ print(f"Найдено {len(self.size_groups)} уникальных размеров:")
109
+ for size, indices in sorted(self.size_groups.items(), key=lambda x: len(x[1]), reverse=True):
110
+ width, height = size
111
+ print(f" {width}x{height}: {len(indices)} примеров")
112
+
113
+ # Формируем батчи
114
+ self.reset()
115
+
116
+ def reset(self):
117
+ """Сбрасывает и перемешивает индексы"""
118
+ self.batches = []
119
+
120
+ for size, indices in self.size_groups.items():
121
+ if self.shuffle:
122
+ indices_copy = indices.copy()
123
+ random.shuffle(indices_copy)
124
+ else:
125
+ indices_copy = indices
126
+
127
+ # Разбиваем на батчи
128
+ for i in range(0, len(indices_copy), self.batch_size):
129
+ batch_indices = indices_copy[i:i + self.batch_size]
130
+
131
+ # Пропускаем неполные батчи если drop_last=True
132
+ if self.drop_last and len(batch_indices) < self.batch_size:
133
+ continue
134
+
135
+ self.batches.append(batch_indices)
136
+
137
+ # Перемешиваем батчи между собой
138
+ if self.shuffle:
139
+ random.shuffle(self.batches)
140
+
141
+ def __iter__(self):
142
+ self.reset() # Сбрасываем и перемешиваем в начале каждой эпохи
143
+ return iter(self.batches)
144
+
145
+ def __len__(self):
146
+ return len(self.batches)
147
+
148
+ # Функция для выборки фиксированных семплов по размерам
149
+ def get_fixed_samples_by_resolution(dataset, samples_per_group=1):
150
+ """Выбирает фиксированные семплы для каждого уникального разрешения"""
151
+ # Группируем по размерам
152
+ size_groups = defaultdict(list)
153
+ try:
154
+ widths = dataset["width"]
155
+ heights = dataset["height"]
156
+ except KeyError:
157
+ widths = [0] * len(dataset)
158
+ heights = [0] * len(dataset)
159
+ for i, (w, h) in enumerate(zip(widths, heights)):
160
+ size = (w, h)
161
+ size_groups[size].append(i)
162
+
163
+ # Выбираем фиксированные примеры из каждой группы
164
+ fixed_samples = {}
165
+ for size, indices in size_groups.items():
166
+ # Определяем сколько семплов брать из этой группы
167
+ n_samples = min(samples_per_group, len(indices))
168
+ if len(size_groups)==1:
169
+ n_samples = samples_to_generate
170
+ if n_samples == 0:
171
+ continue
172
+
173
+ # Выбираем случайные индексы
174
+ sample_indices = random.sample(indices, n_samples)
175
+ samples_data = [dataset[idx] for idx in sample_indices]
176
+
177
+ # Собираем данные
178
+ latents = torch.tensor(np.array([item["vae"] for item in samples_data]), dtype=dtype).to(device)
179
+ embeddings = torch.tensor(np.array([item["embeddings"] for item in samples_data]), dtype=dtype).to(device)
180
+ texts = [item["text"] for item in samples_data]
181
+
182
+ # Сохраняем для этого размера
183
+ fixed_samples[size] = (latents, embeddings, texts)
184
+
185
+ print(f"Создано {len(fixed_samples)} групп фиксированных семплов по разрешениям")
186
+ return fixed_samples
187
+
188
+ if limit > 0:
189
+ dataset = load_from_disk(save_path).select(range(limit))
190
+ else:
191
+ dataset = load_from_disk(save_path)
192
+
193
+
194
+ def collate_fn(batch):
195
+ # Преобразуем список в тензоры и перемещаем на девайс
196
+ latents = torch.tensor(np.array([item["vae"] for item in batch]), dtype=dtype).to(device)
197
+ embeddings = torch.tensor(np.array([item["embeddings"] for item in batch]), dtype=dtype).to(device)
198
+ return latents, embeddings
199
+
200
+ # Используем наш ResolutionBatchSampler
201
+ batch_sampler = ResolutionBatchSampler(dataset, batch_size=batch_size, shuffle=True)
202
+ dataloader = DataLoader(dataset, batch_sampler=batch_sampler, collate_fn=collate_fn)
203
+
204
+ print("Total samples",len(dataloader))
205
+ dataloader = accelerator.prepare(dataloader)
206
+
207
+ # --------------------------- Загрузка моделей ---------------------------
208
+ # VAE загружается на CPU для экономии GPU-памяти
209
+ vae = AutoencoderKL.from_pretrained("AuraDiffusion/16ch-vae").to("cpu", dtype=dtype)
210
+
211
+ # DDPMScheduler с V_Prediction и Zero-SNR
212
+ scheduler = DDPMScheduler(
213
+ num_train_timesteps=1000, # Полный график шагов для обучения
214
+ prediction_type="v_prediction", # V-Prediction
215
+ rescale_betas_zero_snr=True, # Включение Zero-SNR
216
+ timestep_spacing="leading", # Добавляем улучшенное распределение шагов
217
+ #steps_offset=1 # Избегаем проблем с нулевым timestep
218
+ )
219
+
220
+ # Инициализация переменных для возобновления обучения
221
+ start_epoch = 0
222
+ global_step = 0
223
+
224
+ # Расчёт общего количества шагов
225
+ total_training_steps = (len(dataloader) * num_epochs)
226
+ # Get the world size
227
+ world_size = accelerator.state.num_processes
228
+ print(f"World Size: {world_size}")
229
+
230
+ # Опция загрузки модели из последнего чекпоинта (если существует)
231
+ latest_checkpoint = os.path.join(checkpoints_folder, project)
232
+ if os.path.isdir(latest_checkpoint):
233
+ print("Загружаем UNet из чекпоинта:", latest_checkpoint)
234
+ unet = UNet2DConditionModel.from_pretrained(latest_checkpoint).to(device, dtype=dtype)
235
+ unet.enable_gradient_checkpointing()
236
+ unet.set_use_memory_efficient_attention_xformers(False) # отключаем xformers
237
+ try:
238
+ unet.set_attn_processor(AttnProcessor2_0()) # Используем стандартный AttnProcessor
239
+ print("SDPA включен через set_attn_processor.")
240
+ except Exception as e:
241
+ print(f"Ошибка при включении SDPA: {e}")
242
+ print("Попытка использовать enable_xformers_memory_efficient_attention.")
243
+ unet.set_use_memory_efficient_attention_xformers(True)
244
+ if project == "sdxxs":
245
+ # Замораживаем все параметры модели
246
+ for param in unet.parameters():
247
+ param.requires_grad = False
248
+
249
+ # Список параметров, которые вы хотите тренировать
250
+ target_params = [
251
+ "down_blocks.3.downsamplers.0.conv.bias",
252
+ "down_blocks.3.downsamplers.0.conv.weight",
253
+ "down_blocks.4.",
254
+ "mid_block.attentions.0.",
255
+ "up_blocks.0"
256
+ ]
257
+
258
+ # Размораживаем только целевые параметры
259
+ for name, param in unet.named_parameters():
260
+ for target in target_params:
261
+ if name.startswith(target):
262
+ param.requires_grad = True
263
+ break
264
+
265
+ # Определяем параметры для оптимизации
266
+ trainable_params = [p for p in unet.parameters() if p.requires_grad]
267
+ lora_params_count = sum(p.numel() for p in trainable_params)
268
+ print(f"Количество обучаемых параметров (как бля LoRA): {lora_params_count:,}")
269
+
270
+
271
+ if lora_name:
272
+ print(f"--- Настройка LoRA через PEFT (Rank={lora_rank}, Alpha={lora_alpha}) ---")
273
+ from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
274
+ from peft.tuners.lora import LoraModel
275
+ import os
276
+ # 1. Замораживаем все параметры UNet
277
+ unet.requires_grad_(False)
278
+ print("Параметры базового UNet заморожены.")
279
+
280
+ # 2. Создаем конфигурацию LoRA
281
+ lora_config = LoraConfig(
282
+ r=lora_rank,
283
+ lora_alpha=lora_alpha,
284
+ target_modules=["to_q", "to_k", "to_v", "to_out.0"],
285
+ )
286
+ unet.add_adapter(lora_config)
287
+
288
+ # 3. Оборачиваем UNet в PEFT-модель
289
+ from peft import get_peft_model
290
+
291
+ peft_unet = get_peft_model(unet, lora_config)
292
+
293
+ # 4. Получаем параметры для оптимизации
294
+ params_to_optimize = list(p for p in peft_unet.parameters() if p.requires_grad)
295
+
296
+
297
+ # 5. Выводим информацию о количестве параметров
298
+ if accelerator.is_main_process:
299
+ lora_params_count = sum(p.numel() for p in params_to_optimize)
300
+ total_params_count = sum(p.numel() for p in unet.parameters())
301
+ print(f"Количество обучаемых параметров (LoRA): {lora_params_count:,}")
302
+ print(f"Общее количество параметров UNet: {total_params_count:,}")
303
+
304
+ # 6. Путь для сохранения
305
+ lora_save_path = os.path.join("lora", lora_name)
306
+ os.makedirs(lora_save_path, exist_ok=True)
307
+
308
+ # 7. Функция для сохранения
309
+ def save_lora_checkpoint(model):
310
+ if accelerator.is_main_process:
311
+ print(f"Сохраняем LoRA адаптеры в {lora_save_path}")
312
+ from peft.utils.save_and_load import get_peft_model_state_dict
313
+ # Получаем state_dict только LoRA
314
+ lora_state_dict = get_peft_model_state_dict(model)
315
+
316
+ # Сохраняем веса
317
+ torch.save(lora_state_dict, os.path.join(lora_save_path, "adapter_model.bin"))
318
+
319
+ # Сохраняем конфиг
320
+ model.peft_config["default"].save_pretrained(lora_save_path)
321
+ # SDXL must be compatible
322
+ from diffusers import StableDiffusionXLPipeline
323
+ StableDiffusionXLPipeline.save_lora_weights(lora_save_path, lora_state_dict)
324
+
325
+ # --------------------------- Оптимизатор ---------------------------
326
+ # Определяем параметры для оптимизации
327
+ if lora_name or project=="sdxxs":
328
+ # Если используется LoRA, оптимизируем только параметры LoRA
329
+ trainable_params = [p for p in unet.parameters() if p.requires_grad]
330
+ else:
331
+ # Иначе оптимизируем все параметры
332
+ trainable_params = list(unet.parameters())
333
+
334
+ # [1] Создаем словарь оптимизаторов (fused backward)
335
+ optimizer_dict = {
336
+ p: bnb.optim.AdamW8bit(
337
+ [p], # Каждый параметр получает свой оптимизатор
338
+ lr=base_learning_rate,
339
+ betas=(0.9, 0.999),
340
+ weight_decay=1e-5,
341
+ eps=1e-8
342
+ ) for p in trainable_params
343
+ }
344
+
345
+ # [2] Определяем hook для применения оптимизатора сразу после накопления градиента
346
+ def optimizer_hook(param):
347
+ optimizer_dict[param].step()
348
+ optimizer_dict[param].zero_grad(set_to_none=True)
349
+
350
+ # [3] Регистрируем hook для trainable параметров модели
351
+ for param in trainable_params:
352
+ param.register_post_accumulate_grad_hook(optimizer_hook)
353
+
354
+ # Подготовка через Accelerator
355
+ unet, optimizer = accelerator.prepare(unet, optimizer_dict)
356
+
357
+ # --------------------------- Фиксированные семплы для генерации ---------------------------
358
+ # Примеры фиксированных семплов по размерам
359
+ fixed_samples = get_fixed_samples_by_resolution(dataset)
360
+
361
+
362
+ @torch.no_grad()
363
+ def generate_and_save_samples(fixed_samples,step):
364
+ """
365
+ Генерирует семплы для каждого из разрешений и сохраняет их.
366
+
367
+ Args:
368
+ step: Текущий шаг обучения
369
+ fixed_samples: Словарь, где ключи - размеры (width, height),
370
+ а значения - кортежи (latents, embeddings)
371
+ """
372
+ try:
373
+ original_model = accelerator.unwrap_model(unet)
374
+ # Перемещаем VAE на device для семплирования
375
+ vae.to(accelerator.device, dtype=dtype)
376
+
377
+ # Устанавливаем количество diffusion шагов
378
+ scheduler.set_timesteps(n_diffusion_steps)
379
+
380
+ all_generated_images = []
381
+ size_info = [] # Для хранения информации о размере для каждого изображения
382
+ all_captions = []
383
+
384
+ # Проходим по всем группам размеров
385
+ for size, (sample_latents, sample_text_embeddings, sample_text) in fixed_samples.items():
386
+ width, height = size
387
+ size_info.append(f"{width}x{height}")
388
+ #print(f"Генерация {sample_latents.shape[0]} изображений размером {width}x{height}")
389
+
390
+ # Инициализируем латенты случайным шумом для этой группы
391
+ noise = torch.randn(
392
+ sample_latents.shape,
393
+ generator=gen,
394
+ device=sample_latents.device,
395
+ dtype=sample_latents.dtype
396
+ )
397
+
398
+ # Начинаем с шума
399
+ current_latents = noise.clone()
400
+
401
+ # Подготовка текстовых эмбеддингов для guidance
402
+ if guidance_scale > 0:
403
+ empty_embeddings = torch.zeros_like(sample_text_embeddings)
404
+ text_embeddings = torch.cat([empty_embeddings, sample_text_embeddings], dim=0)
405
+ else:
406
+ text_embeddings = sample_text_embeddings
407
+
408
+ # Генерация изображений
409
+ for t in scheduler.timesteps:
410
+ # Подготовка входных данных для UNet
411
+ if guidance_scale > 0:
412
+ latent_model_input = torch.cat([current_latents] * 2)
413
+ latent_model_input = scheduler.scale_model_input(latent_model_input, t)
414
+ else:
415
+ latent_model_input = scheduler.scale_model_input(current_latents, t)
416
+
417
+ # Предсказание шума
418
+ noise_pred = original_model(latent_model_input, t, text_embeddings).sample
419
+
420
+ # Применение guidance scale
421
+ if guidance_scale > 0:
422
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
423
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
424
+
425
+ # Обновление латентов
426
+ current_latents = scheduler.step(noise_pred, t, current_latents).prev_sample
427
+
428
+ # Декодирование через VAE
429
+ latent = (current_latents.detach() / vae.config.scaling_factor) + vae.config.shift_factor
430
+ latent = latent.to(accelerator.device, dtype=dtype)
431
+ decoded = vae.decode(latent).sample
432
+
433
+ # Преобразуем тензоры в PIL-изображения и сохраняем
434
+ for img_idx, img_tensor in enumerate(decoded):
435
+ img = (img_tensor.to(torch.float32) / 2 + 0.5).clamp(0, 1).cpu().numpy().transpose(1, 2, 0)
436
+ pil_img = Image.fromarray((img * 255).astype("uint8"))
437
+ # Определяем максимальные ширину и высоту
438
+ max_width = max(size[0] for size in fixed_samples.keys())
439
+ max_height = max(size[1] for size in fixed_samples.keys())
440
+ max_width = max(255,max_width)
441
+ max_height = max(255,max_height)
442
+
443
+ # Добавляем padding, чтобы изображение стало размером max_width x max_height
444
+ padded_img = ImageOps.pad(pil_img, (max_width, max_height), color='white')
445
+
446
+ all_generated_images.append(padded_img)
447
+
448
+ caption_text = sample_text[img_idx][:200] if img_idx < len(sample_text) else ""
449
+ all_captions.append(caption_text)
450
+
451
+ # Сохраняем с информацией о размере в имени файла
452
+ save_path = f"{generated_folder}/{project}_{width}x{height}_{img_idx}.jpg"
453
+ pil_img.save(save_path, "JPEG", quality=96)
454
+
455
+ # Отправляем изображения на WandB с информацией о размере
456
+ if use_wandb and accelerator.is_main_process:
457
+ wandb_images = [
458
+ wandb.Image(img, caption=f"{all_captions[i]}")
459
+ for i, img in enumerate(all_generated_images)
460
+ ]
461
+ wandb.log({"generated_images": wandb_images, "global_step": step})
462
+
463
+ finally:
464
+ # Гарантированное перемещение VAE обратно на CPU
465
+ vae.to("cpu")
466
+ if original_model is not None:
467
+ del original_model
468
+ # Очистка всех тензоров
469
+ for var in list(locals().keys()):
470
+ if isinstance(locals()[var], torch.Tensor):
471
+ del locals()[var]
472
+ torch.cuda.empty_cache()
473
+ gc.collect()
474
+
475
+ # --------------------------- Генерация сэмплов перед обучением ---------------------------
476
+ if accelerator.is_main_process:
477
+ if save_model:
478
+ print("Генерация сэмплов до старта обучения...")
479
+ generate_and_save_samples(fixed_samples,0)
480
+
481
+ # Модифицируем функцию сохранения модели для поддержки LoRA
482
+ def save_checkpoint(unet):
483
+ if accelerator.is_main_process:
484
+ if lora_name:
485
+ # Сохраняем только LoRA адаптеры
486
+ save_lora_checkpoint(unet)
487
+ else:
488
+ # Сохраняем полную модель
489
+ accelerator.unwrap_model(unet).save_pretrained(os.path.join(checkpoints_folder, f"{project}"))
490
+
491
+ # --------------------------- Тренировочный цикл ---------------------------
492
+ # Для логирования среднего лосса каждые % эпохи
493
+ if accelerator.is_main_process:
494
+ print(f"Total steps per GPU: {total_training_steps}")
495
+ print(f"[GPU {accelerator.process_index}] Total steps: {total_training_steps}")
496
+
497
+ epoch_loss_points = []
498
+ progress_bar = tqdm(total=total_training_steps, disable=not accelerator.is_local_main_process, desc="Training", unit="step")
499
+
500
+ # Определяем интервал для сэмплирования и логирования в пределах эпохи (10% эпохи)
501
+ steps_per_epoch = len(dataloader)
502
+ sample_interval = max(1, steps_per_epoch // sample_interval_share)
503
+
504
+ # Начинаем с указанной эпохи (полезно при возобновлении)
505
+ for epoch in range(start_epoch, start_epoch + num_epochs):
506
+ batch_losses = []
507
+ unet.train()
508
+
509
+ for step, (latents, embeddings) in enumerate(dataloader):
510
+ with accelerator.accumulate(unet):
511
+ if save_model == False and step == 3 :
512
+ used_gb = torch.cuda.max_memory_allocated() / 1024**3
513
+ print(f"Шаг {step}: {used_gb:.2f} GB")
514
+ # Forward pass
515
+ noise = torch.randn_like(latents)
516
+
517
+ timesteps = torch.randint(
518
+ 1, # Начинаем с 1, не с 0
519
+ scheduler.config.num_train_timesteps,
520
+ (latents.shape[0],),
521
+ device=device
522
+ ).long()
523
+
524
+ # Добавляем шум к латентам
525
+ noisy_latents = scheduler.add_noise(latents, noise, timesteps)
526
+
527
+ # Получаем предсказание шума - кастим в bf16
528
+ noise_pred = unet(noisy_latents, timesteps, embeddings).sample.to(dtype=torch.bfloat16)
529
+
530
+ # Используем целевое значение v_prediction
531
+ target = scheduler.get_velocity(latents, noise, timesteps)
532
+
533
+ # Считаем лосс
534
+ loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float())
535
+
536
+ # Делаем backward через Accelerator
537
+ accelerator.backward(loss)
538
+
539
+ # Увеличиваем счетчик глобальных шагов
540
+ global_step += 1
541
+
542
+ # Обновляем прогресс-бар
543
+ progress_bar.update(1)
544
+
545
+ # Логируем метрики
546
+ if accelerator.is_main_process:
547
+ current_lr = base_learning_rate
548
+ batch_losses.append(loss.detach().item())
549
+
550
+ # Логируем в Wandb
551
+ if use_wandb:
552
+ wandb.log({
553
+ "loss": loss.detach().item(),
554
+ "learning_rate": current_lr,
555
+ "epoch": epoch,
556
+ "global_step": global_step
557
+ })
558
+
559
+ # Генерируем сэмплы с заданным интервалом
560
+ if global_step % sample_interval == 0:
561
+ if save_model:
562
+ save_checkpoint(unet)
563
+
564
+ generate_and_save_samples(fixed_samples,global_step)
565
+
566
+ # Выводим текущий лосс
567
+ avg_loss = np.mean(batch_losses[-sample_interval:])
568
+ #print(f"Эпоха {epoch}, шаг {global_step}, средний лосс: {avg_loss:.6f}, LR: {current_lr:.8f}")
569
+ if use_wandb:
570
+ wandb.log({"intermediate_loss": avg_loss})
571
+
572
+
573
+ # По окончании эпохи
574
+ if accelerator.is_main_process:
575
+ avg_epoch_loss = np.mean(batch_losses)
576
+ print(f"\nЭпоха {epoch} завершена. Средний лосс: {avg_epoch_loss:.6f}")
577
+ if use_wandb:
578
+ wandb.log({"epoch_loss": avg_epoch_loss, "epoch": epoch+1})
579
+
580
+ # Завершение обучения - сохраняем финальную модель
581
+ if accelerator.is_main_process:
582
+ print("Обучение завершено! Сохраняем финальную модель...")
583
+ # Сохраняем основную модель
584
+ #if save_model:
585
+ save_checkpoint(unet)
586
+ print("Готово!")