sdxxs
Browse files- samples/sdxxs_448x768_0.jpg +3 -0
- samples/sdxxs_512x768_0.jpg +3 -0
- samples/sdxxs_576x768_0.jpg +3 -0
- samples/sdxxs_640x768_0.jpg +3 -0
- samples/sdxxs_704x768_0.jpg +3 -0
- samples/sdxxs_768x448_0.jpg +3 -0
- samples/sdxxs_768x512_0.jpg +3 -0
- samples/sdxxs_768x576_0.jpg +3 -0
- samples/sdxxs_768x640_0.jpg +3 -0
- samples/sdxxs_768x704_0.jpg +3 -0
- samples/sdxxs_768x768_0.jpg +3 -0
- sdxxs/config.json +3 -0
- sdxxs/diffusion_pytorch_model.safetensors +3 -0
- src/sdxs_create.ipynb +2 -2
- src/sdxs_sdxxs_transfer.ipynb +2 -2
- train_sdxxs.py +586 -0
samples/sdxxs_448x768_0.jpg
ADDED
![]() |
Git LFS Details
|
samples/sdxxs_512x768_0.jpg
ADDED
![]() |
Git LFS Details
|
samples/sdxxs_576x768_0.jpg
ADDED
![]() |
Git LFS Details
|
samples/sdxxs_640x768_0.jpg
ADDED
![]() |
Git LFS Details
|
samples/sdxxs_704x768_0.jpg
ADDED
![]() |
Git LFS Details
|
samples/sdxxs_768x448_0.jpg
ADDED
![]() |
Git LFS Details
|
samples/sdxxs_768x512_0.jpg
ADDED
![]() |
Git LFS Details
|
samples/sdxxs_768x576_0.jpg
ADDED
![]() |
Git LFS Details
|
samples/sdxxs_768x640_0.jpg
ADDED
![]() |
Git LFS Details
|
samples/sdxxs_768x704_0.jpg
ADDED
![]() |
Git LFS Details
|
samples/sdxxs_768x768_0.jpg
ADDED
![]() |
Git LFS Details
|
sdxxs/config.json
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f0b371d4886bb1aef3cfc1186b247bc5c6ad0a0d741bd41a3d68e78a88814e82
|
3 |
+
size 1965
|
sdxxs/diffusion_pytorch_model.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:da4b7b0e9675b368ab31ce4c30bb590101df73262b466d0781b464ed68bc83ba
|
3 |
+
size 7010929912
|
src/sdxs_create.ipynb
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:107ff7729ffd04b3e2fd6c4305f74f910d7e7e23141ba71836d055a497fffebd
|
3 |
+
size 9787
|
src/sdxs_sdxxs_transfer.ipynb
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:57db74be42dbf73bc551cf86b4302dd2717555280e26325e599ca89f51b4916e
|
3 |
+
size 168192
|
train_sdxxs.py
ADDED
@@ -0,0 +1,586 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import math
|
3 |
+
import torch
|
4 |
+
import numpy as np
|
5 |
+
import matplotlib.pyplot as plt
|
6 |
+
from torch.utils.data import DataLoader, Sampler
|
7 |
+
from collections import defaultdict
|
8 |
+
from torch.optim.lr_scheduler import LambdaLR
|
9 |
+
from diffusers import UNet2DConditionModel, AutoencoderKL, DDPMScheduler
|
10 |
+
from accelerate import Accelerator
|
11 |
+
from datasets import load_from_disk
|
12 |
+
from tqdm import tqdm
|
13 |
+
from PIL import Image,ImageOps
|
14 |
+
import wandb
|
15 |
+
import random
|
16 |
+
import gc
|
17 |
+
from accelerate.state import DistributedType
|
18 |
+
from torch.distributed import broadcast_object_list
|
19 |
+
from torch.utils.checkpoint import checkpoint
|
20 |
+
from diffusers.models.attention_processor import AttnProcessor2_0
|
21 |
+
from datetime import datetime
|
22 |
+
import bitsandbytes as bnb
|
23 |
+
|
24 |
+
# --------------------------- Параметры ---------------------------
|
25 |
+
save_path = "datasets/768" # "datasets/576" #"datasets/576p2" #"datasets/1152p2" #"datasets/576p2" #"datasets/dataset384_temp" #"datasets/dataset384" #"datasets/imagenet-1kk" #"datasets/siski576" #"datasets/siski384" #"datasets/siski64" #"datasets/mnist"
|
26 |
+
batch_size = 60 #30 #26 #45 #11 #45 #555 #35 #7
|
27 |
+
base_learning_rate = 2.5e-5 #4e-6 #2e-5 #4e-6 #9.5e-7 #9e-7 #2e-6 #1e-6 #9e-7 #1e-6 #2e-6 #1e-6 #2e-6 #6e-6 #2e-6 #8e-7 #6e-6 #2e-5 #4e-5 #3e-5 #5e-5 #8e-5
|
28 |
+
min_learning_rate = 2.5e-5 #2e-5
|
29 |
+
num_epochs = 1 #2 #36 #18
|
30 |
+
project = "sdxxs"
|
31 |
+
use_wandb = True
|
32 |
+
save_model = True
|
33 |
+
limit = 0 #200000 #0
|
34 |
+
checkpoints_folder = ""
|
35 |
+
|
36 |
+
# Параметры для диффузии
|
37 |
+
n_diffusion_steps = 40
|
38 |
+
samples_to_generate = 12
|
39 |
+
guidance_scale = 5
|
40 |
+
sample_interval_share = 20 # samples/save per epoch
|
41 |
+
|
42 |
+
# Папки для сохранения результатов
|
43 |
+
generated_folder = "samples"
|
44 |
+
os.makedirs(generated_folder, exist_ok=True)
|
45 |
+
|
46 |
+
# Настройка seed для воспроизводимости
|
47 |
+
current_date = datetime.now()
|
48 |
+
seed = int(current_date.strftime("%Y%m%d"))
|
49 |
+
fixed_seed = True
|
50 |
+
if fixed_seed:
|
51 |
+
torch.manual_seed(seed)
|
52 |
+
np.random.seed(seed)
|
53 |
+
random.seed(seed)
|
54 |
+
if torch.cuda.is_available():
|
55 |
+
torch.cuda.manual_seed_all(seed)
|
56 |
+
|
57 |
+
# --------------------------- Параметры LoRA ---------------------------
|
58 |
+
# pip install peft
|
59 |
+
lora_name = "" #"nusha" # Имя для сохранения/загрузки LoRA адаптеров
|
60 |
+
lora_rank = 32 # Ранг LoRA (чем меньше, тем компактнее модель)
|
61 |
+
lora_alpha = 64 # Альфа параметр LoRA, определяющий масштаб
|
62 |
+
|
63 |
+
print("init")
|
64 |
+
# Включение Flash Attention 2/SDPA
|
65 |
+
torch.backends.cuda.enable_flash_sdp(True)
|
66 |
+
# --------------------------- Инициализация Accelerator --------------------
|
67 |
+
dtype = torch.bfloat16
|
68 |
+
accelerator = Accelerator(mixed_precision="bf16")
|
69 |
+
device = accelerator.device
|
70 |
+
gen = torch.Generator(device=device)
|
71 |
+
gen.manual_seed(seed)
|
72 |
+
|
73 |
+
# --------------------------- Инициализация WandB ---------------------------
|
74 |
+
if use_wandb and accelerator.is_main_process:
|
75 |
+
wandb.init(project=project+lora_name, config={
|
76 |
+
"batch_size": batch_size,
|
77 |
+
"base_learning_rate": base_learning_rate,
|
78 |
+
"num_epochs": num_epochs,
|
79 |
+
"n_diffusion_steps": n_diffusion_steps,
|
80 |
+
"samples_to_generate": samples_to_generate,
|
81 |
+
"dtype": str(dtype)
|
82 |
+
})
|
83 |
+
|
84 |
+
# --------------------------- Загрузка датасета ---------------------------
|
85 |
+
class ResolutionBatchSampler(Sampler):
|
86 |
+
"""Сэмплер, который группирует примеры по одинаковым размерам"""
|
87 |
+
def __init__(self, dataset, batch_size, shuffle=True, drop_last=False):
|
88 |
+
self.dataset = dataset
|
89 |
+
self.batch_size = batch_size
|
90 |
+
self.shuffle = shuffle
|
91 |
+
self.drop_last = drop_last
|
92 |
+
|
93 |
+
# Группируем примеры по размерам
|
94 |
+
self.size_groups = defaultdict(list)
|
95 |
+
|
96 |
+
try:
|
97 |
+
widths = dataset["width"]
|
98 |
+
heights = dataset["height"]
|
99 |
+
except KeyError:
|
100 |
+
widths = [0] * len(dataset)
|
101 |
+
heights = [0] * len(dataset)
|
102 |
+
|
103 |
+
for i, (w, h) in enumerate(zip(widths, heights)):
|
104 |
+
size = (w, h)
|
105 |
+
self.size_groups[size].append(i)
|
106 |
+
|
107 |
+
# Печатаем статистику по размерам
|
108 |
+
print(f"Найдено {len(self.size_groups)} уникальных размеров:")
|
109 |
+
for size, indices in sorted(self.size_groups.items(), key=lambda x: len(x[1]), reverse=True):
|
110 |
+
width, height = size
|
111 |
+
print(f" {width}x{height}: {len(indices)} примеров")
|
112 |
+
|
113 |
+
# Формируем батчи
|
114 |
+
self.reset()
|
115 |
+
|
116 |
+
def reset(self):
|
117 |
+
"""Сбрасывает и перемешивает индексы"""
|
118 |
+
self.batches = []
|
119 |
+
|
120 |
+
for size, indices in self.size_groups.items():
|
121 |
+
if self.shuffle:
|
122 |
+
indices_copy = indices.copy()
|
123 |
+
random.shuffle(indices_copy)
|
124 |
+
else:
|
125 |
+
indices_copy = indices
|
126 |
+
|
127 |
+
# Разбиваем на батчи
|
128 |
+
for i in range(0, len(indices_copy), self.batch_size):
|
129 |
+
batch_indices = indices_copy[i:i + self.batch_size]
|
130 |
+
|
131 |
+
# Пропускаем неполные батчи если drop_last=True
|
132 |
+
if self.drop_last and len(batch_indices) < self.batch_size:
|
133 |
+
continue
|
134 |
+
|
135 |
+
self.batches.append(batch_indices)
|
136 |
+
|
137 |
+
# Перемешиваем батчи между собой
|
138 |
+
if self.shuffle:
|
139 |
+
random.shuffle(self.batches)
|
140 |
+
|
141 |
+
def __iter__(self):
|
142 |
+
self.reset() # Сбрасываем и перемешиваем в начале каждой эпохи
|
143 |
+
return iter(self.batches)
|
144 |
+
|
145 |
+
def __len__(self):
|
146 |
+
return len(self.batches)
|
147 |
+
|
148 |
+
# Функция для выборки фиксированных семплов по размерам
|
149 |
+
def get_fixed_samples_by_resolution(dataset, samples_per_group=1):
|
150 |
+
"""Выбирает фиксированные семплы для каждого уникального разрешения"""
|
151 |
+
# Группируем по размерам
|
152 |
+
size_groups = defaultdict(list)
|
153 |
+
try:
|
154 |
+
widths = dataset["width"]
|
155 |
+
heights = dataset["height"]
|
156 |
+
except KeyError:
|
157 |
+
widths = [0] * len(dataset)
|
158 |
+
heights = [0] * len(dataset)
|
159 |
+
for i, (w, h) in enumerate(zip(widths, heights)):
|
160 |
+
size = (w, h)
|
161 |
+
size_groups[size].append(i)
|
162 |
+
|
163 |
+
# Выбираем фиксированные примеры из каждой группы
|
164 |
+
fixed_samples = {}
|
165 |
+
for size, indices in size_groups.items():
|
166 |
+
# Определяем сколько семплов брать из этой группы
|
167 |
+
n_samples = min(samples_per_group, len(indices))
|
168 |
+
if len(size_groups)==1:
|
169 |
+
n_samples = samples_to_generate
|
170 |
+
if n_samples == 0:
|
171 |
+
continue
|
172 |
+
|
173 |
+
# Выбираем случайные индексы
|
174 |
+
sample_indices = random.sample(indices, n_samples)
|
175 |
+
samples_data = [dataset[idx] for idx in sample_indices]
|
176 |
+
|
177 |
+
# Собираем данные
|
178 |
+
latents = torch.tensor(np.array([item["vae"] for item in samples_data]), dtype=dtype).to(device)
|
179 |
+
embeddings = torch.tensor(np.array([item["embeddings"] for item in samples_data]), dtype=dtype).to(device)
|
180 |
+
texts = [item["text"] for item in samples_data]
|
181 |
+
|
182 |
+
# Сохраняем для этого размера
|
183 |
+
fixed_samples[size] = (latents, embeddings, texts)
|
184 |
+
|
185 |
+
print(f"Создано {len(fixed_samples)} групп фиксированных семплов по разрешениям")
|
186 |
+
return fixed_samples
|
187 |
+
|
188 |
+
if limit > 0:
|
189 |
+
dataset = load_from_disk(save_path).select(range(limit))
|
190 |
+
else:
|
191 |
+
dataset = load_from_disk(save_path)
|
192 |
+
|
193 |
+
|
194 |
+
def collate_fn(batch):
|
195 |
+
# Преобразуем список в тензоры и перемещаем на девайс
|
196 |
+
latents = torch.tensor(np.array([item["vae"] for item in batch]), dtype=dtype).to(device)
|
197 |
+
embeddings = torch.tensor(np.array([item["embeddings"] for item in batch]), dtype=dtype).to(device)
|
198 |
+
return latents, embeddings
|
199 |
+
|
200 |
+
# Используем наш ResolutionBatchSampler
|
201 |
+
batch_sampler = ResolutionBatchSampler(dataset, batch_size=batch_size, shuffle=True)
|
202 |
+
dataloader = DataLoader(dataset, batch_sampler=batch_sampler, collate_fn=collate_fn)
|
203 |
+
|
204 |
+
print("Total samples",len(dataloader))
|
205 |
+
dataloader = accelerator.prepare(dataloader)
|
206 |
+
|
207 |
+
# --------------------------- Загрузка моделей ---------------------------
|
208 |
+
# VAE загружается на CPU для экономии GPU-памяти
|
209 |
+
vae = AutoencoderKL.from_pretrained("AuraDiffusion/16ch-vae").to("cpu", dtype=dtype)
|
210 |
+
|
211 |
+
# DDPMScheduler с V_Prediction и Zero-SNR
|
212 |
+
scheduler = DDPMScheduler(
|
213 |
+
num_train_timesteps=1000, # Полный график шагов для обучения
|
214 |
+
prediction_type="v_prediction", # V-Prediction
|
215 |
+
rescale_betas_zero_snr=True, # Включение Zero-SNR
|
216 |
+
timestep_spacing="leading", # Добавляем улучшенное распределение шагов
|
217 |
+
#steps_offset=1 # Избегаем проблем с нулевым timestep
|
218 |
+
)
|
219 |
+
|
220 |
+
# Инициализация переменных для возобновления обучения
|
221 |
+
start_epoch = 0
|
222 |
+
global_step = 0
|
223 |
+
|
224 |
+
# Расчёт общего количества шагов
|
225 |
+
total_training_steps = (len(dataloader) * num_epochs)
|
226 |
+
# Get the world size
|
227 |
+
world_size = accelerator.state.num_processes
|
228 |
+
print(f"World Size: {world_size}")
|
229 |
+
|
230 |
+
# Опция загрузки модели из последнего чекпоинта (если существует)
|
231 |
+
latest_checkpoint = os.path.join(checkpoints_folder, project)
|
232 |
+
if os.path.isdir(latest_checkpoint):
|
233 |
+
print("Загружаем UNet из чекпоинта:", latest_checkpoint)
|
234 |
+
unet = UNet2DConditionModel.from_pretrained(latest_checkpoint).to(device, dtype=dtype)
|
235 |
+
unet.enable_gradient_checkpointing()
|
236 |
+
unet.set_use_memory_efficient_attention_xformers(False) # отключаем xformers
|
237 |
+
try:
|
238 |
+
unet.set_attn_processor(AttnProcessor2_0()) # Используем стандартный AttnProcessor
|
239 |
+
print("SDPA включен через set_attn_processor.")
|
240 |
+
except Exception as e:
|
241 |
+
print(f"Ошибка при включении SDPA: {e}")
|
242 |
+
print("Попытка использовать enable_xformers_memory_efficient_attention.")
|
243 |
+
unet.set_use_memory_efficient_attention_xformers(True)
|
244 |
+
if project == "sdxxs":
|
245 |
+
# Замораживаем все параметры модели
|
246 |
+
for param in unet.parameters():
|
247 |
+
param.requires_grad = False
|
248 |
+
|
249 |
+
# Список параметров, которые вы хотите тренировать
|
250 |
+
target_params = [
|
251 |
+
"down_blocks.3.downsamplers.0.conv.bias",
|
252 |
+
"down_blocks.3.downsamplers.0.conv.weight",
|
253 |
+
"down_blocks.4.",
|
254 |
+
"mid_block.attentions.0.",
|
255 |
+
"up_blocks.0"
|
256 |
+
]
|
257 |
+
|
258 |
+
# Размораживаем только целевые параметры
|
259 |
+
for name, param in unet.named_parameters():
|
260 |
+
for target in target_params:
|
261 |
+
if name.startswith(target):
|
262 |
+
param.requires_grad = True
|
263 |
+
break
|
264 |
+
|
265 |
+
# Определяем параметры для оптимизации
|
266 |
+
trainable_params = [p for p in unet.parameters() if p.requires_grad]
|
267 |
+
lora_params_count = sum(p.numel() for p in trainable_params)
|
268 |
+
print(f"Количество обучаемых параметров (как бля LoRA): {lora_params_count:,}")
|
269 |
+
|
270 |
+
|
271 |
+
if lora_name:
|
272 |
+
print(f"--- Настройка LoRA через PEFT (Rank={lora_rank}, Alpha={lora_alpha}) ---")
|
273 |
+
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
|
274 |
+
from peft.tuners.lora import LoraModel
|
275 |
+
import os
|
276 |
+
# 1. Замораживаем все параметры UNet
|
277 |
+
unet.requires_grad_(False)
|
278 |
+
print("Параметры базового UNet заморожены.")
|
279 |
+
|
280 |
+
# 2. Создаем конфигурацию LoRA
|
281 |
+
lora_config = LoraConfig(
|
282 |
+
r=lora_rank,
|
283 |
+
lora_alpha=lora_alpha,
|
284 |
+
target_modules=["to_q", "to_k", "to_v", "to_out.0"],
|
285 |
+
)
|
286 |
+
unet.add_adapter(lora_config)
|
287 |
+
|
288 |
+
# 3. Оборачиваем UNet в PEFT-модель
|
289 |
+
from peft import get_peft_model
|
290 |
+
|
291 |
+
peft_unet = get_peft_model(unet, lora_config)
|
292 |
+
|
293 |
+
# 4. Получаем параметры для оптимизации
|
294 |
+
params_to_optimize = list(p for p in peft_unet.parameters() if p.requires_grad)
|
295 |
+
|
296 |
+
|
297 |
+
# 5. Выводим информацию о количестве параметров
|
298 |
+
if accelerator.is_main_process:
|
299 |
+
lora_params_count = sum(p.numel() for p in params_to_optimize)
|
300 |
+
total_params_count = sum(p.numel() for p in unet.parameters())
|
301 |
+
print(f"Количество обучаемых параметров (LoRA): {lora_params_count:,}")
|
302 |
+
print(f"Общее количество параметров UNet: {total_params_count:,}")
|
303 |
+
|
304 |
+
# 6. Путь для сохранения
|
305 |
+
lora_save_path = os.path.join("lora", lora_name)
|
306 |
+
os.makedirs(lora_save_path, exist_ok=True)
|
307 |
+
|
308 |
+
# 7. Функция для сохранения
|
309 |
+
def save_lora_checkpoint(model):
|
310 |
+
if accelerator.is_main_process:
|
311 |
+
print(f"Сохраняем LoRA адаптеры в {lora_save_path}")
|
312 |
+
from peft.utils.save_and_load import get_peft_model_state_dict
|
313 |
+
# Получаем state_dict только LoRA
|
314 |
+
lora_state_dict = get_peft_model_state_dict(model)
|
315 |
+
|
316 |
+
# Сохраняем веса
|
317 |
+
torch.save(lora_state_dict, os.path.join(lora_save_path, "adapter_model.bin"))
|
318 |
+
|
319 |
+
# Сохраняем конфиг
|
320 |
+
model.peft_config["default"].save_pretrained(lora_save_path)
|
321 |
+
# SDXL must be compatible
|
322 |
+
from diffusers import StableDiffusionXLPipeline
|
323 |
+
StableDiffusionXLPipeline.save_lora_weights(lora_save_path, lora_state_dict)
|
324 |
+
|
325 |
+
# --------------------------- Оптимизатор ---------------------------
|
326 |
+
# Определяем параметры для оптимизации
|
327 |
+
if lora_name or project=="sdxxs":
|
328 |
+
# Если используется LoRA, оптимизируем только параметры LoRA
|
329 |
+
trainable_params = [p for p in unet.parameters() if p.requires_grad]
|
330 |
+
else:
|
331 |
+
# Иначе оптимизируем все параметры
|
332 |
+
trainable_params = list(unet.parameters())
|
333 |
+
|
334 |
+
# [1] Создаем словарь оптимизаторов (fused backward)
|
335 |
+
optimizer_dict = {
|
336 |
+
p: bnb.optim.AdamW8bit(
|
337 |
+
[p], # Каждый параметр получает свой оптимизатор
|
338 |
+
lr=base_learning_rate,
|
339 |
+
betas=(0.9, 0.999),
|
340 |
+
weight_decay=1e-5,
|
341 |
+
eps=1e-8
|
342 |
+
) for p in trainable_params
|
343 |
+
}
|
344 |
+
|
345 |
+
# [2] Определяем hook для применения оптимизатора сразу после накопления градиента
|
346 |
+
def optimizer_hook(param):
|
347 |
+
optimizer_dict[param].step()
|
348 |
+
optimizer_dict[param].zero_grad(set_to_none=True)
|
349 |
+
|
350 |
+
# [3] Регистрируем hook для trainable параметров модели
|
351 |
+
for param in trainable_params:
|
352 |
+
param.register_post_accumulate_grad_hook(optimizer_hook)
|
353 |
+
|
354 |
+
# Подготовка через Accelerator
|
355 |
+
unet, optimizer = accelerator.prepare(unet, optimizer_dict)
|
356 |
+
|
357 |
+
# --------------------------- Фиксированные семплы для генерации ---------------------------
|
358 |
+
# Примеры фиксированных семплов по размерам
|
359 |
+
fixed_samples = get_fixed_samples_by_resolution(dataset)
|
360 |
+
|
361 |
+
|
362 |
+
@torch.no_grad()
|
363 |
+
def generate_and_save_samples(fixed_samples,step):
|
364 |
+
"""
|
365 |
+
Генерирует семплы для каждого из разрешений и сохраняет их.
|
366 |
+
|
367 |
+
Args:
|
368 |
+
step: Текущий шаг обучения
|
369 |
+
fixed_samples: Словарь, где ключи - размеры (width, height),
|
370 |
+
а значения - кортежи (latents, embeddings)
|
371 |
+
"""
|
372 |
+
try:
|
373 |
+
original_model = accelerator.unwrap_model(unet)
|
374 |
+
# Перемещаем VAE на device для семплирования
|
375 |
+
vae.to(accelerator.device, dtype=dtype)
|
376 |
+
|
377 |
+
# Устанавливаем количество diffusion шагов
|
378 |
+
scheduler.set_timesteps(n_diffusion_steps)
|
379 |
+
|
380 |
+
all_generated_images = []
|
381 |
+
size_info = [] # Для хранения информации о размере для каждого изображения
|
382 |
+
all_captions = []
|
383 |
+
|
384 |
+
# Проходим по всем группам размеров
|
385 |
+
for size, (sample_latents, sample_text_embeddings, sample_text) in fixed_samples.items():
|
386 |
+
width, height = size
|
387 |
+
size_info.append(f"{width}x{height}")
|
388 |
+
#print(f"Генерация {sample_latents.shape[0]} изображений размером {width}x{height}")
|
389 |
+
|
390 |
+
# Инициализируем латенты случайным шумом для этой группы
|
391 |
+
noise = torch.randn(
|
392 |
+
sample_latents.shape,
|
393 |
+
generator=gen,
|
394 |
+
device=sample_latents.device,
|
395 |
+
dtype=sample_latents.dtype
|
396 |
+
)
|
397 |
+
|
398 |
+
# Начинаем с шума
|
399 |
+
current_latents = noise.clone()
|
400 |
+
|
401 |
+
# Подготовка текстовых эмбеддингов для guidance
|
402 |
+
if guidance_scale > 0:
|
403 |
+
empty_embeddings = torch.zeros_like(sample_text_embeddings)
|
404 |
+
text_embeddings = torch.cat([empty_embeddings, sample_text_embeddings], dim=0)
|
405 |
+
else:
|
406 |
+
text_embeddings = sample_text_embeddings
|
407 |
+
|
408 |
+
# Генерация изображений
|
409 |
+
for t in scheduler.timesteps:
|
410 |
+
# Подготовка входных данных для UNet
|
411 |
+
if guidance_scale > 0:
|
412 |
+
latent_model_input = torch.cat([current_latents] * 2)
|
413 |
+
latent_model_input = scheduler.scale_model_input(latent_model_input, t)
|
414 |
+
else:
|
415 |
+
latent_model_input = scheduler.scale_model_input(current_latents, t)
|
416 |
+
|
417 |
+
# Предсказание шума
|
418 |
+
noise_pred = original_model(latent_model_input, t, text_embeddings).sample
|
419 |
+
|
420 |
+
# Применение guidance scale
|
421 |
+
if guidance_scale > 0:
|
422 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
423 |
+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
424 |
+
|
425 |
+
# Обновление латентов
|
426 |
+
current_latents = scheduler.step(noise_pred, t, current_latents).prev_sample
|
427 |
+
|
428 |
+
# Декодирование через VAE
|
429 |
+
latent = (current_latents.detach() / vae.config.scaling_factor) + vae.config.shift_factor
|
430 |
+
latent = latent.to(accelerator.device, dtype=dtype)
|
431 |
+
decoded = vae.decode(latent).sample
|
432 |
+
|
433 |
+
# Преобразуем тензоры в PIL-изображения и сохраняем
|
434 |
+
for img_idx, img_tensor in enumerate(decoded):
|
435 |
+
img = (img_tensor.to(torch.float32) / 2 + 0.5).clamp(0, 1).cpu().numpy().transpose(1, 2, 0)
|
436 |
+
pil_img = Image.fromarray((img * 255).astype("uint8"))
|
437 |
+
# Определяем максимальные ширину и высоту
|
438 |
+
max_width = max(size[0] for size in fixed_samples.keys())
|
439 |
+
max_height = max(size[1] for size in fixed_samples.keys())
|
440 |
+
max_width = max(255,max_width)
|
441 |
+
max_height = max(255,max_height)
|
442 |
+
|
443 |
+
# Добавляем padding, чтобы изображение стало размером max_width x max_height
|
444 |
+
padded_img = ImageOps.pad(pil_img, (max_width, max_height), color='white')
|
445 |
+
|
446 |
+
all_generated_images.append(padded_img)
|
447 |
+
|
448 |
+
caption_text = sample_text[img_idx][:200] if img_idx < len(sample_text) else ""
|
449 |
+
all_captions.append(caption_text)
|
450 |
+
|
451 |
+
# Сохраняем с информацией о размере в имени файла
|
452 |
+
save_path = f"{generated_folder}/{project}_{width}x{height}_{img_idx}.jpg"
|
453 |
+
pil_img.save(save_path, "JPEG", quality=96)
|
454 |
+
|
455 |
+
# Отправляем изображения на WandB с информацией о размере
|
456 |
+
if use_wandb and accelerator.is_main_process:
|
457 |
+
wandb_images = [
|
458 |
+
wandb.Image(img, caption=f"{all_captions[i]}")
|
459 |
+
for i, img in enumerate(all_generated_images)
|
460 |
+
]
|
461 |
+
wandb.log({"generated_images": wandb_images, "global_step": step})
|
462 |
+
|
463 |
+
finally:
|
464 |
+
# Гарантированное перемещение VAE обратно на CPU
|
465 |
+
vae.to("cpu")
|
466 |
+
if original_model is not None:
|
467 |
+
del original_model
|
468 |
+
# Очистка всех тензоров
|
469 |
+
for var in list(locals().keys()):
|
470 |
+
if isinstance(locals()[var], torch.Tensor):
|
471 |
+
del locals()[var]
|
472 |
+
torch.cuda.empty_cache()
|
473 |
+
gc.collect()
|
474 |
+
|
475 |
+
# --------------------------- Генерация сэмплов перед обучением ---------------------------
|
476 |
+
if accelerator.is_main_process:
|
477 |
+
if save_model:
|
478 |
+
print("Генерация сэмплов до старта обучения...")
|
479 |
+
generate_and_save_samples(fixed_samples,0)
|
480 |
+
|
481 |
+
# Модифицируем функцию сохранения модели для поддержки LoRA
|
482 |
+
def save_checkpoint(unet):
|
483 |
+
if accelerator.is_main_process:
|
484 |
+
if lora_name:
|
485 |
+
# Сохраняем только LoRA адаптеры
|
486 |
+
save_lora_checkpoint(unet)
|
487 |
+
else:
|
488 |
+
# Сохраняем полную модель
|
489 |
+
accelerator.unwrap_model(unet).save_pretrained(os.path.join(checkpoints_folder, f"{project}"))
|
490 |
+
|
491 |
+
# --------------------------- Тренировочный цикл ---------------------------
|
492 |
+
# Для логирования среднего лосса каждые % эпохи
|
493 |
+
if accelerator.is_main_process:
|
494 |
+
print(f"Total steps per GPU: {total_training_steps}")
|
495 |
+
print(f"[GPU {accelerator.process_index}] Total steps: {total_training_steps}")
|
496 |
+
|
497 |
+
epoch_loss_points = []
|
498 |
+
progress_bar = tqdm(total=total_training_steps, disable=not accelerator.is_local_main_process, desc="Training", unit="step")
|
499 |
+
|
500 |
+
# Определяем интервал для сэмплирования и логирования в пределах эпохи (10% эпохи)
|
501 |
+
steps_per_epoch = len(dataloader)
|
502 |
+
sample_interval = max(1, steps_per_epoch // sample_interval_share)
|
503 |
+
|
504 |
+
# Начинаем с указанной эпохи (полезно при возобновлении)
|
505 |
+
for epoch in range(start_epoch, start_epoch + num_epochs):
|
506 |
+
batch_losses = []
|
507 |
+
unet.train()
|
508 |
+
|
509 |
+
for step, (latents, embeddings) in enumerate(dataloader):
|
510 |
+
with accelerator.accumulate(unet):
|
511 |
+
if save_model == False and step == 3 :
|
512 |
+
used_gb = torch.cuda.max_memory_allocated() / 1024**3
|
513 |
+
print(f"Шаг {step}: {used_gb:.2f} GB")
|
514 |
+
# Forward pass
|
515 |
+
noise = torch.randn_like(latents)
|
516 |
+
|
517 |
+
timesteps = torch.randint(
|
518 |
+
1, # Начинаем с 1, не с 0
|
519 |
+
scheduler.config.num_train_timesteps,
|
520 |
+
(latents.shape[0],),
|
521 |
+
device=device
|
522 |
+
).long()
|
523 |
+
|
524 |
+
# Добавляем шум к латентам
|
525 |
+
noisy_latents = scheduler.add_noise(latents, noise, timesteps)
|
526 |
+
|
527 |
+
# Получаем предсказание шума - кастим в bf16
|
528 |
+
noise_pred = unet(noisy_latents, timesteps, embeddings).sample.to(dtype=torch.bfloat16)
|
529 |
+
|
530 |
+
# Используем целевое значение v_prediction
|
531 |
+
target = scheduler.get_velocity(latents, noise, timesteps)
|
532 |
+
|
533 |
+
# Считаем лосс
|
534 |
+
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float())
|
535 |
+
|
536 |
+
# Делаем backward через Accelerator
|
537 |
+
accelerator.backward(loss)
|
538 |
+
|
539 |
+
# Увеличиваем счетчик глобальных шагов
|
540 |
+
global_step += 1
|
541 |
+
|
542 |
+
# Обновляем прогресс-бар
|
543 |
+
progress_bar.update(1)
|
544 |
+
|
545 |
+
# Логируем метрики
|
546 |
+
if accelerator.is_main_process:
|
547 |
+
current_lr = base_learning_rate
|
548 |
+
batch_losses.append(loss.detach().item())
|
549 |
+
|
550 |
+
# Логируем в Wandb
|
551 |
+
if use_wandb:
|
552 |
+
wandb.log({
|
553 |
+
"loss": loss.detach().item(),
|
554 |
+
"learning_rate": current_lr,
|
555 |
+
"epoch": epoch,
|
556 |
+
"global_step": global_step
|
557 |
+
})
|
558 |
+
|
559 |
+
# Генерируем сэмплы с заданным интервалом
|
560 |
+
if global_step % sample_interval == 0:
|
561 |
+
if save_model:
|
562 |
+
save_checkpoint(unet)
|
563 |
+
|
564 |
+
generate_and_save_samples(fixed_samples,global_step)
|
565 |
+
|
566 |
+
# Выводим текущий лосс
|
567 |
+
avg_loss = np.mean(batch_losses[-sample_interval:])
|
568 |
+
#print(f"Эпоха {epoch}, шаг {global_step}, средний лосс: {avg_loss:.6f}, LR: {current_lr:.8f}")
|
569 |
+
if use_wandb:
|
570 |
+
wandb.log({"intermediate_loss": avg_loss})
|
571 |
+
|
572 |
+
|
573 |
+
# По окончании эпохи
|
574 |
+
if accelerator.is_main_process:
|
575 |
+
avg_epoch_loss = np.mean(batch_losses)
|
576 |
+
print(f"\nЭпоха {epoch} завершена. Средний лосс: {avg_epoch_loss:.6f}")
|
577 |
+
if use_wandb:
|
578 |
+
wandb.log({"epoch_loss": avg_epoch_loss, "epoch": epoch+1})
|
579 |
+
|
580 |
+
# Завершение обучения - сохраняем финальную модель
|
581 |
+
if accelerator.is_main_process:
|
582 |
+
print("Обучение завершено! Сохраняем финальную модель...")
|
583 |
+
# Сохраняем основную модель
|
584 |
+
#if save_model:
|
585 |
+
save_checkpoint(unet)
|
586 |
+
print("Готово!")
|