recoilme commited on
Commit
1442092
·
1 Parent(s): 10fc91c
samples/sdxxxs_448x768_0.jpg CHANGED

Git LFS Details

  • SHA256: 820a4db8cdfd6217d613cd1f91d8d7a10d4a63dffa5c8742320f1526cdb55c1c
  • Pointer size: 130 Bytes
  • Size of remote file: 55 kB

Git LFS Details

  • SHA256: 0475c14f8cadd3387edfcf15526bc28b9c09e54c070eae5789e5f54417d35b0a
  • Pointer size: 130 Bytes
  • Size of remote file: 78.3 kB
samples/sdxxxs_512x768_0.jpg CHANGED

Git LFS Details

  • SHA256: f04ee9011311e4076821f93ad8c3e488f7188f2bbe407175e935dcf8f406571b
  • Pointer size: 130 Bytes
  • Size of remote file: 65 kB

Git LFS Details

  • SHA256: de74b9b003e7ae8c452782877f5d9e7bd23e8f8f2ef24683eff1cb4a0cfe086e
  • Pointer size: 130 Bytes
  • Size of remote file: 63.6 kB
samples/sdxxxs_576x768_0.jpg CHANGED

Git LFS Details

  • SHA256: 863a3d95e09f75a703bf6c5753da35075f244ebd60a2c942756a35901cf67ff9
  • Pointer size: 130 Bytes
  • Size of remote file: 64.8 kB

Git LFS Details

  • SHA256: 4aaa6d4b2f5c63d4201ed77190a4898786677845c9e71b0e5221a7910d52f5b7
  • Pointer size: 130 Bytes
  • Size of remote file: 62.6 kB
samples/sdxxxs_640x768_0.jpg CHANGED

Git LFS Details

  • SHA256: 98cad135fba01de43d7115a537c685fa53bbb653caac6386228987180066c951
  • Pointer size: 130 Bytes
  • Size of remote file: 63.2 kB

Git LFS Details

  • SHA256: 63fa635d9e802cf9394d0e82e56dd51c5db1dc25b92bc6cb5ca28706a5a9b1e6
  • Pointer size: 130 Bytes
  • Size of remote file: 61 kB
samples/sdxxxs_704x768_0.jpg CHANGED

Git LFS Details

  • SHA256: 48a6090bf545811657bda0717d1b10ac8603c2dd4f353f5b168774cf67e09fdc
  • Pointer size: 130 Bytes
  • Size of remote file: 52.6 kB

Git LFS Details

  • SHA256: 126cbbc4ab7562e0dc731ead503554c28a96c3e15a13b2389599e98bce8599cd
  • Pointer size: 131 Bytes
  • Size of remote file: 123 kB
samples/sdxxxs_768x448_0.jpg CHANGED

Git LFS Details

  • SHA256: 2077ed7e842dea644df53a631970d40a9881a9001170a470efa47f907cf913e5
  • Pointer size: 130 Bytes
  • Size of remote file: 44.1 kB

Git LFS Details

  • SHA256: 28e4c716f345066d1bd4b838c83d09b2b3a1ed4d5fb10908b5a43cb8f7aed0a5
  • Pointer size: 131 Bytes
  • Size of remote file: 132 kB
samples/sdxxxs_768x512_0.jpg CHANGED

Git LFS Details

  • SHA256: 6dfc8cd62089e54604867e7d412701d43a59ff5d8590eab43544ab220d8bdb68
  • Pointer size: 130 Bytes
  • Size of remote file: 72.9 kB

Git LFS Details

  • SHA256: 451e366f93d70ac091a84334a6ebcd7df617c9f23074f1fe53058c96d6606d41
  • Pointer size: 131 Bytes
  • Size of remote file: 153 kB
samples/sdxxxs_768x576_0.jpg CHANGED

Git LFS Details

  • SHA256: 7b94b1260ca5ac4ddb17e360e43ceec9ed97a8bc9055206e2145abb6d3dadde7
  • Pointer size: 130 Bytes
  • Size of remote file: 68.8 kB

Git LFS Details

  • SHA256: 7ffca6be049284b0098f94b103a3f1a43a4f28888a9b898dd32ab019b912390e
  • Pointer size: 130 Bytes
  • Size of remote file: 96 kB
samples/sdxxxs_768x640_0.jpg CHANGED

Git LFS Details

  • SHA256: 8528ffcb3d4e4eee7619456e5f85606f8e9a05e867aa53a6f253be6617311633
  • Pointer size: 130 Bytes
  • Size of remote file: 84.2 kB

Git LFS Details

  • SHA256: c67d01bc874c1e084021ece4e7db1841bb5f0d30847113438d7cfad7756d48f7
  • Pointer size: 131 Bytes
  • Size of remote file: 135 kB
samples/sdxxxs_768x704_0.jpg CHANGED

Git LFS Details

  • SHA256: 1e016376d41dbee394dc5941df6d59c872e170c840cc2594ef4d5969abc222c4
  • Pointer size: 130 Bytes
  • Size of remote file: 74.1 kB

Git LFS Details

  • SHA256: 7dd4a83c5ae09cd9f94f0284e01f3b3931e7e7154e8085967551d6707920848d
  • Pointer size: 131 Bytes
  • Size of remote file: 139 kB
samples/sdxxxs_768x768_0.jpg CHANGED

Git LFS Details

  • SHA256: 358de396775de87602706995e83ede7701fa639751cf635d9cfdf9f25edc84a4
  • Pointer size: 130 Bytes
  • Size of remote file: 78.4 kB

Git LFS Details

  • SHA256: 91fb6f18a34a35b307057ddd356fdeddaa0f8bfcb39d57b29cf0df7402f723dc
  • Pointer size: 130 Bytes
  • Size of remote file: 30.2 kB
sdxxxs/config.json CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:e24077b8bda68b40af6c7f9b1dc6ba3b014304ae89d3b37edba979a215f075f0
3
- size 1968
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9a97b0612232047e9d605d3c99a50331a9f99d32692659e688ee7fd6193e3cec
3
+ size 1953
sdxxxs/diffusion_pytorch_model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:78c7e27f075695a183dd1499b26ff1873e68272e6f0e94b432e69d52bb5bb031
3
- size 8273806320
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e867b3605ee4f67d844f948904175c4201be69d432f0f1be766a255d1c18abc5
3
+ size 8273810624
src/sdxs_create.ipynb CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:f618b69e4b5f1c912d01130da67d577f873928d160488531c0529eef17818c12
3
- size 9790
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:dc6987bbccaaedf529bc2d6b55bae0f28e5161fe96dcd48970aa92fb92a55664
3
+ size 9793
src/sdxs_sdxxs_transfer.ipynb CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:f140659eb65f1541979432e5d977aeae6a9ade394985971e63b2337765895916
3
- size 235007
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a7c11b012028e6ed060f3068d08adeb229bdad4b6b8cd2b4e0d50346cafcf02f
3
+ size 235161
train.py_ → train.old.py RENAMED
File without changes
train.py CHANGED
@@ -23,13 +23,13 @@ import bitsandbytes as bnb
23
 
24
  # --------------------------- Параметры ---------------------------
25
  save_path = "datasets/768" # "datasets/576" #"datasets/576p2" #"datasets/1152p2" #"datasets/576p2" #"datasets/dataset384_temp" #"datasets/dataset384" #"datasets/imagenet-1kk" #"datasets/siski576" #"datasets/siski384" #"datasets/siski64" #"datasets/mnist"
26
- batch_size = 30 #26 #45 #11 #45 #555 #35 #7
27
  base_learning_rate = 4e-6 #9.5e-7 #9e-7 #2e-6 #1e-6 #9e-7 #1e-6 #2e-6 #1e-6 #2e-6 #6e-6 #2e-6 #8e-7 #6e-6 #2e-5 #4e-5 #3e-5 #5e-5 #8e-5
28
  min_learning_rate = 2.5e-5 #2e-5
29
  num_epochs = 1 #2 #36 #18
30
- project = "sdxs"
31
- use_wandb = True
32
- save_model = True
33
  limit = 0 #200000 #0
34
  checkpoints_folder = ""
35
 
@@ -37,7 +37,7 @@ checkpoints_folder = ""
37
  n_diffusion_steps = 40
38
  samples_to_generate = 12
39
  guidance_scale = 5
40
- sample_interval_share = 25 # samples/save per epoch
41
 
42
  # Папки для сохранения результатов
43
  generated_folder = "samples"
@@ -214,7 +214,7 @@ scheduler = DDPMScheduler(
214
  prediction_type="v_prediction", # V-Prediction
215
  rescale_betas_zero_snr=True, # Включение Zero-SNR
216
  timestep_spacing="leading", # Добавляем улучшенное распределение шагов
217
- steps_offset=1 # Избегаем проблем с нулевым timestep
218
  )
219
 
220
  # Инициализация переменных для возобновления обучения
 
23
 
24
  # --------------------------- Параметры ---------------------------
25
  save_path = "datasets/768" # "datasets/576" #"datasets/576p2" #"datasets/1152p2" #"datasets/576p2" #"datasets/dataset384_temp" #"datasets/dataset384" #"datasets/imagenet-1kk" #"datasets/siski576" #"datasets/siski384" #"datasets/siski64" #"datasets/mnist"
26
+ batch_size = 50 #26 #45 #11 #45 #555 #35 #7
27
  base_learning_rate = 4e-6 #9.5e-7 #9e-7 #2e-6 #1e-6 #9e-7 #1e-6 #2e-6 #1e-6 #2e-6 #6e-6 #2e-6 #8e-7 #6e-6 #2e-5 #4e-5 #3e-5 #5e-5 #8e-5
28
  min_learning_rate = 2.5e-5 #2e-5
29
  num_epochs = 1 #2 #36 #18
30
+ project = "sdxxxs"
31
+ use_wandb = False
32
+ save_model = False
33
  limit = 0 #200000 #0
34
  checkpoints_folder = ""
35
 
 
37
  n_diffusion_steps = 40
38
  samples_to_generate = 12
39
  guidance_scale = 5
40
+ sample_interval_share = 50 # samples/save per epoch
41
 
42
  # Папки для сохранения результатов
43
  generated_folder = "samples"
 
214
  prediction_type="v_prediction", # V-Prediction
215
  rescale_betas_zero_snr=True, # Включение Zero-SNR
216
  timestep_spacing="leading", # Добавляем улучшенное распределение шагов
217
+ #steps_offset=1 # Избегаем проблем с нулевым timestep
218
  )
219
 
220
  # Инициализация переменных для возобновления обучения
train_flow.py CHANGED
@@ -222,8 +222,13 @@ base_learning_rate = 4e-6 #9.5e-7 #9e-7 #2e-6 #1e-6 #9e-7 #1e-6 #2e-6 #1e-6 #2e-
222
  min_learning_rate = 2.5e-5 #2e-5
223
  num_epochs = 1 #2 #36 #18
224
  project = "sdxs"
 
225
  use_wandb = True
226
  save_model = True
 
 
 
 
227
  limit = 0 #200000 #0
228
  checkpoints_folder = ""
229
 
@@ -414,7 +419,11 @@ vae = AutoencoderKL.from_pretrained("AuraDiffusion/16ch-vae").to("cpu", dtype=dt
414
 
415
  # Flow Matching
416
  scheduler = FlowMatchingEulerScheduler(
 
417
  num_train_timesteps=1000,
 
 
 
418
  )
419
 
420
  # Инициализация переменных для возобновления обучения
 
222
  min_learning_rate = 2.5e-5 #2e-5
223
  num_epochs = 1 #2 #36 #18
224
  project = "sdxs"
225
+ <<<<<<< HEAD
226
  use_wandb = True
227
  save_model = True
228
+ =======
229
+ use_wandb = False
230
+ save_model = False
231
+ >>>>>>> d0c94e4 (sdxxxs)
232
  limit = 0 #200000 #0
233
  checkpoints_folder = ""
234
 
 
419
 
420
  # Flow Matching
421
  scheduler = FlowMatchingEulerScheduler(
422
+ <<<<<<< HEAD
423
  num_train_timesteps=1000,
424
+ =======
425
+ # num_train_timesteps=1000,
426
+ >>>>>>> d0c94e4 (sdxxxs)
427
  )
428
 
429
  # Инициализация переменных для возобновления обучения
train_lora.py DELETED
@@ -1,558 +0,0 @@
1
- import os
2
- import math
3
- import torch
4
- import numpy as np
5
- import matplotlib.pyplot as plt
6
- from torch.utils.data import DataLoader, Sampler
7
- from collections import defaultdict
8
- from torch.optim.lr_scheduler import LambdaLR
9
- from diffusers import UNet2DConditionModel, AutoencoderKL, DDPMScheduler
10
- from accelerate import Accelerator
11
- from datasets import load_from_disk
12
- from tqdm import tqdm
13
- from PIL import Image,ImageOps
14
- import wandb
15
- import random
16
- import gc
17
- from accelerate.state import DistributedType
18
- from torch.distributed import broadcast_object_list
19
- from torch.utils.checkpoint import checkpoint
20
- from diffusers.models.attention_processor import AttnProcessor2_0
21
- from datetime import datetime
22
- import bitsandbytes as bnb
23
-
24
- # --------------------------- Параметры ---------------------------
25
- save_path = "/home/recoilme/nusha2_768" #"datasets/768" # "datasets/576" #"datasets/576p2" #"datasets/1152p2" #"datasets/576p2" #"datasets/dataset384_temp" #"datasets/dataset384" #"datasets/imagenet-1kk" #"datasets/siski576" #"datasets/siski384" #"datasets/siski64" #"datasets/mnist"
26
- batch_size = 1 #45 #11 #45 #555 #35 #7
27
- base_learning_rate = 9e-7 #9e-7 #2e-6 #1e-6 #9e-7 #1e-6 #2e-6 #1e-6 #2e-6 #6e-6 #2e-6 #8e-7 #6e-6 #2e-5 #4e-5 #3e-5 #5e-5 #8e-5
28
- min_learning_rate = 2.5e-5 #2e-5
29
- num_epochs = 20 #2 #36 #18
30
- project = "sdxs"
31
- use_wandb = True
32
- save_model = True
33
- limit = 0 #200000 #0
34
- checkpoints_folder = ""
35
-
36
- # Параметры для диффузии
37
- n_diffusion_steps = 40
38
- samples_to_generate = 12
39
- guidance_scale = 5
40
- sample_interval_share = 1#20 # samples per epoch
41
-
42
- # Папки для сохранения результатов
43
- generated_folder = "samples"
44
- os.makedirs(generated_folder, exist_ok=True)
45
-
46
- # Настройка seed для воспроизводимости
47
- current_date = datetime.now()
48
- seed = int(current_date.strftime("%Y%m%d"))
49
- #torch.manual_seed(seed)
50
- #np.random.seed(seed)
51
- #random.seed(seed)
52
- #if torch.cuda.is_available():
53
- # torch.cuda.manual_seed_all(seed)
54
-
55
- # --------------------------- Параметры LoRA ---------------------------
56
- lora_name = "nusha" # Имя для сохранения/загрузки LoRA адаптеров
57
- lora_rank = 8 # Ранг LoRA (чем меньше, тем компактнее модель)
58
- lora_alpha = 8 # Альфа параметр LoRA, определяющий масштаб
59
-
60
- print("init")
61
- # Включение Flash Attention 2/SDPA
62
- torch.backends.cuda.enable_flash_sdp(True)
63
- # --------------------------- Инициализация Accelerator --------------------
64
- dtype = torch.bfloat16
65
- accelerator = Accelerator(mixed_precision="bf16")
66
- device = accelerator.device
67
- gen = torch.Generator(device=device)
68
- gen.manual_seed(seed)
69
-
70
- # --------------------------- Инициализация WandB ---------------------------
71
- if use_wandb and accelerator.is_main_process:
72
- wandb.init(project=project+lora_name, config={
73
- "batch_size": batch_size,
74
- "base_learning_rate": base_learning_rate,
75
- "num_epochs": num_epochs,
76
- "n_diffusion_steps": n_diffusion_steps,
77
- "samples_to_generate": samples_to_generate,
78
- "dtype": str(dtype)
79
- })
80
-
81
- # --------------------------- Загрузка датасета ---------------------------
82
- class ResolutionBatchSampler(Sampler):
83
- """Сэмплер, который группирует примеры по одинаковым размерам"""
84
- def __init__(self, dataset, batch_size, shuffle=True, drop_last=False):
85
- self.dataset = dataset
86
- self.batch_size = batch_size
87
- self.shuffle = shuffle
88
- self.drop_last = drop_last
89
-
90
- # Группируем примеры по размерам
91
- self.size_groups = defaultdict(list)
92
-
93
- try:
94
- widths = dataset["width"]
95
- heights = dataset["height"]
96
- except KeyError:
97
- widths = [0] * len(dataset)
98
- heights = [0] * len(dataset)
99
-
100
- for i, (w, h) in enumerate(zip(widths, heights)):
101
- size = (w, h)
102
- self.size_groups[size].append(i)
103
-
104
- # Печатаем статистику по размерам
105
- print(f"Найдено {len(self.size_groups)} уникальных размеров:")
106
- for size, indices in sorted(self.size_groups.items(), key=lambda x: len(x[1]), reverse=True):
107
- width, height = size
108
- print(f" {width}x{height}: {len(indices)} примеров")
109
-
110
- # Формируем батчи
111
- self.reset()
112
-
113
- def reset(self):
114
- """Сбрасывает и перемешивает индексы"""
115
- self.batches = []
116
-
117
- for size, indices in self.size_groups.items():
118
- if self.shuffle:
119
- indices_copy = indices.copy()
120
- random.shuffle(indices_copy)
121
- else:
122
- indices_copy = indices
123
-
124
- # Разбиваем на батчи
125
- for i in range(0, len(indices_copy), self.batch_size):
126
- batch_indices = indices_copy[i:i + self.batch_size]
127
-
128
- # Пропускаем неполные батчи если drop_last=True
129
- if self.drop_last and len(batch_indices) < self.batch_size:
130
- continue
131
-
132
- self.batches.append(batch_indices)
133
-
134
- # Перемешиваем батчи между собой
135
- if self.shuffle:
136
- random.shuffle(self.batches)
137
-
138
- def __iter__(self):
139
- self.reset() # Сбрасываем и перемешиваем в начале каждой эпохи
140
- return iter(self.batches)
141
-
142
- def __len__(self):
143
- return len(self.batches)
144
-
145
- # Функция для выборки фиксированных семплов по размерам
146
- def get_fixed_samples_by_resolution(dataset, samples_per_group=1):
147
- """Выбирает фиксированные семплы для каждого уникального разрешения"""
148
- # Группируем по размерам
149
- size_groups = defaultdict(list)
150
- try:
151
- widths = dataset["width"]
152
- heights = dataset["height"]
153
- except KeyError:
154
- widths = [0] * len(dataset)
155
- heights = [0] * len(dataset)
156
- for i, (w, h) in enumerate(zip(widths, heights)):
157
- size = (w, h)
158
- size_groups[size].append(i)
159
-
160
- # Выбираем фиксированные примеры из каждой группы
161
- fixed_samples = {}
162
- for size, indices in size_groups.items():
163
- # Определяем сколько семплов брать из этой группы
164
- n_samples = min(samples_per_group, len(indices))
165
- if len(size_groups)==1:
166
- n_samples = samples_to_generate
167
- if n_samples == 0:
168
- continue
169
-
170
- # Выбираем случайные индексы
171
- sample_indices = random.sample(indices, n_samples)
172
- samples_data = [dataset[idx] for idx in sample_indices]
173
-
174
- # Собираем данные
175
- latents = torch.tensor(np.array([item["vae"] for item in samples_data]), dtype=dtype).to(device)
176
- embeddings = torch.tensor(np.array([item["embeddings"] for item in samples_data]), dtype=dtype).to(device)
177
- texts = [item["text"] for item in samples_data]
178
-
179
- # Сохраняем для этого размера
180
- fixed_samples[size] = (latents, embeddings, texts)
181
-
182
- print(f"Создано {len(fixed_samples)} групп фиксированных семплов по разрешениям")
183
- return fixed_samples
184
-
185
- if limit > 0:
186
- dataset = load_from_disk(save_path).select(range(limit))
187
- else:
188
- dataset = load_from_disk(save_path)
189
-
190
-
191
- def collate_fn(batch):
192
- # Преобразуем список в тензоры и перемещаем на девайс
193
- latents = torch.tensor(np.array([item["vae"] for item in batch]), dtype=dtype).to(device)
194
- embeddings = torch.tensor(np.array([item["embeddings"] for item in batch]), dtype=dtype).to(device)
195
- return latents, embeddings
196
-
197
- # Используем наш ResolutionBatchSampler
198
- batch_sampler = ResolutionBatchSampler(dataset, batch_size=batch_size, shuffle=True)
199
- dataloader = DataLoader(dataset, batch_sampler=batch_sampler, collate_fn=collate_fn)
200
-
201
- print("Total samples",len(dataloader))
202
- dataloader = accelerator.prepare(dataloader)
203
-
204
- # --------------------------- Загрузка моделей ---------------------------
205
- # VAE загружается на CPU для экономии GPU-памяти
206
- vae = AutoencoderKL.from_pretrained("AuraDiffusion/16ch-vae").to("cpu", dtype=dtype)
207
-
208
- # DDPMScheduler с V_Prediction и Zero-SNR
209
- scheduler = DDPMScheduler(
210
- num_train_timesteps=1000, # Полный график шагов для обучения
211
- prediction_type="v_prediction", # V-Prediction
212
- rescale_betas_zero_snr=True, # Включение Zero-SNR
213
- timestep_spacing="leading", # Добавляем улучшенное распределение шагов
214
- steps_offset=1 # Избегаем проблем с нулевым timestep
215
- )
216
-
217
- # Инициализация переменных для возобновления обучения
218
- start_epoch = 0
219
- global_step = 0
220
-
221
- # Расчёт общего количества шагов
222
- total_training_steps = (len(dataloader) * num_epochs)
223
- # Get the world size
224
- world_size = accelerator.state.num_processes
225
- print(f"World Size: {world_size}")
226
-
227
- # Опция загрузки модели из последнего чекпоинта (если существует)
228
- latest_checkpoint = os.path.join(checkpoints_folder, project)
229
- if os.path.isdir(latest_checkpoint):
230
- print("Загружаем UNet из чекпоинта:", latest_checkpoint)
231
- unet = UNet2DConditionModel.from_pretrained(latest_checkpoint).to(device, dtype=dtype)
232
- unet.enable_gradient_checkpointing()
233
- unet.set_use_memory_efficient_attention_xformers(False) # отключаем xformers
234
- try:
235
- unet.set_attn_processor(AttnProcessor2_0()) # Используем стандартный AttnProcessor
236
- print("SDPA включен через set_attn_processor.")
237
- except Exception as e:
238
- print(f"Ошибка при включении SDPA: {e}")
239
- print("Попытка использовать enable_xformers_memory_efficient_attention.")
240
- unet.set_use_memory_efficient_attention_xformers(True)
241
-
242
- if lora_name:
243
- print(f"--- Настройка LoRA через PEFT (Rank={lora_rank}, Alpha={lora_alpha}) ---")
244
- from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
245
- from peft.tuners.lora import LoraModel
246
- import os
247
- # 1. Замораживаем все параметры UNet
248
- unet.requires_grad_(False)
249
- print("Параметры базового UNet заморожены.")
250
-
251
- # 2. Создаем конфигурацию LoRA
252
- lora_config = LoraConfig(
253
- r=lora_rank,
254
- lora_alpha=lora_alpha,
255
- target_modules=["to_q", "to_k", "to_v", "to_out.0"],
256
- lora_dropout=0.0,
257
- bias="none",
258
- task_type="FEATURE_EXTRACTION", # ✅ VALID
259
- )
260
-
261
- # 3. Оборачиваем UNet в PEFT-модель
262
- from peft import get_peft_model
263
-
264
- peft_unet = get_peft_model(unet, lora_config)
265
-
266
- # 4. Получаем параметры для оптимизации
267
- params_to_optimize = list(p for p in peft_unet.parameters() if p.requires_grad)
268
-
269
- # 5. Выводим информацию о количестве параметров
270
- if accelerator.is_main_process:
271
- lora_params_count = sum(p.numel() for p in params_to_optimize)
272
- total_params_count = sum(p.numel() for p in unet.parameters())
273
- print(f"Количество обучаемых параметров (LoRA): {lora_params_count:,}")
274
- print(f"Общее количество параметров UNet: {total_params_count:,}")
275
- if total_params_count > 0:
276
- print(f"Доля обучаемых параметров (LoRA): {lora_params_count / total_params_count:.4%}")
277
-
278
- # 6. Путь для сохранения
279
- lora_save_path = os.path.join("lora", lora_name)
280
- os.makedirs(os.path.dirname(lora_save_path), exist_ok=True)
281
-
282
- # 7. Функция для сохранения
283
- def save_lora_checkpoint(model):
284
- if accelerator.is_main_process:
285
- print(f"Сохраняем LoRA адаптеры в {lora_save_path}")
286
- from peft.utils.save_and_load import get_peft_model_state_dict
287
- # Получаем state_dict только LoRA
288
- lora_state_dict = get_peft_model_state_dict(model)
289
-
290
- # Сохраняем веса
291
- torch.save(lora_state_dict, os.path.join(lora_save_path, "adapter_model.bin"))
292
-
293
- # Сохраняем конфиг
294
- model.peft_config["default"].save_pretrained(lora_save_path)
295
-
296
- # --------------------------- Оптимизатор ---------------------------
297
- # Определяем параметры для оптимизации
298
- if lora_name:
299
- # Если используется LoRA, оптимизируем только параметры LoRA
300
- trainable_params = [p for p in unet.parameters() if p.requires_grad]
301
- print(f"Обучаем только {len(trainable_params)} LoRA параметров вместо полного UNet")
302
- else:
303
- # Иначе оптимизируем все параметры
304
- trainable_params = list(unet.parameters())
305
-
306
- # [1] Создаем словарь оптимизаторов (fused backward)
307
- optimizer_dict = {
308
- p: bnb.optim.AdamW8bit(
309
- [p], # Каждый параметр получает свой оптимизатор
310
- lr=base_learning_rate,
311
- betas=(0.9, 0.999),
312
- weight_decay=1e-5,
313
- eps=1e-8
314
- ) for p in trainable_params
315
- }
316
-
317
- # [2] Определяем hook для применения оптимизатора сразу после накопления градиента
318
- def optimizer_hook(param):
319
- optimizer_dict[param].step()
320
- optimizer_dict[param].zero_grad(set_to_none=True)
321
-
322
- # [3] Регистрируем hook для trainable параметров модели
323
- for param in trainable_params:
324
- param.register_post_accumulate_grad_hook(optimizer_hook)
325
-
326
- # Подготовка через Accelerator
327
- unet, optimizer = accelerator.prepare(unet, optimizer_dict)
328
-
329
- # --------------------------- Фиксированные семплы для генерации ---------------------------
330
- # Примеры фиксированных семплов по размерам
331
- fixed_samples = get_fixed_samples_by_resolution(dataset)
332
-
333
-
334
- @torch.no_grad()
335
- def generate_and_save_samples(fixed_samples,step):
336
- """
337
- Генерирует семплы для каждого из разрешений и сохраняет их.
338
-
339
- Args:
340
- step: Текущий шаг обучения
341
- fixed_samples: Словарь, где ключи - размеры (width, height),
342
- а значения - кортежи (latents, embeddings)
343
- """
344
- try:
345
- original_model = accelerator.unwrap_model(unet)
346
- # Перемещаем VAE на device для семплирования
347
- vae.to(accelerator.device, dtype=dtype)
348
-
349
- # Устанавливаем количество diffusion шагов
350
- scheduler.set_timesteps(n_diffusion_steps)
351
-
352
- all_generated_images = []
353
- size_info = [] # Для хранения информации о размере для каждого изображения
354
- all_captions = []
355
-
356
- # Проходим по всем группам размеров
357
- for size, (sample_latents, sample_text_embeddings, sample_text) in fixed_samples.items():
358
- width, height = size
359
- size_info.append(f"{width}x{height}")
360
- #print(f"Генерация {sample_latents.shape[0]} изображений размером {width}x{height}")
361
-
362
- # Инициализируем латенты случайным шумом для этой группы
363
- noise = torch.randn(
364
- sample_latents.shape,
365
- generator=gen,
366
- device=sample_latents.device,
367
- dtype=sample_latents.dtype
368
- )
369
-
370
- # Начинаем с шума
371
- current_latents = noise.clone()
372
-
373
- # Подготовка текстовых эмбеддингов для guidance
374
- if guidance_scale > 0:
375
- empty_embeddings = torch.zeros_like(sample_text_embeddings)
376
- text_embeddings = torch.cat([empty_embeddings, sample_text_embeddings], dim=0)
377
- else:
378
- text_embeddings = sample_text_embeddings
379
-
380
- # Генерация изображений
381
- for t in scheduler.timesteps:
382
- # Подготовка входных данных для UNet
383
- if guidance_scale > 0:
384
- latent_model_input = torch.cat([current_latents] * 2)
385
- latent_model_input = scheduler.scale_model_input(latent_model_input, t)
386
- else:
387
- latent_model_input = scheduler.scale_model_input(current_latents, t)
388
-
389
- # Предсказание шума
390
- noise_pred = original_model(latent_model_input, t, text_embeddings).sample
391
-
392
- # Применение guidance scale
393
- if guidance_scale > 0:
394
- noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
395
- noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
396
-
397
- # Обновление латентов
398
- current_latents = scheduler.step(noise_pred, t, current_latents).prev_sample
399
-
400
- # Декодирование через VAE
401
- latent = (current_latents.detach() / vae.config.scaling_factor) + vae.config.shift_factor
402
- latent = latent.to(accelerator.device, dtype=dtype)
403
- decoded = vae.decode(latent).sample
404
-
405
- # Преобразуем тензоры в PIL-изображения и сохраняем
406
- for img_idx, img_tensor in enumerate(decoded):
407
- img = (img_tensor.to(torch.float32) / 2 + 0.5).clamp(0, 1).cpu().numpy().transpose(1, 2, 0)
408
- pil_img = Image.fromarray((img * 255).astype("uint8"))
409
- # Определяем максимальные ширину и высоту
410
- max_width = max(size[0] for size in fixed_samples.keys())
411
- max_height = max(size[1] for size in fixed_samples.keys())
412
- max_width = max(255,max_width)
413
- max_height = max(255,max_height)
414
-
415
- # Добавляем padding, чтобы изображение стало размером max_width x max_height
416
- padded_img = ImageOps.pad(pil_img, (max_width, max_height), color='white')
417
-
418
- all_generated_images.append(padded_img)
419
-
420
- caption_text = sample_text[img_idx][:200] if img_idx < len(sample_text) else ""
421
- all_captions.append(caption_text)
422
-
423
- # Сохраняем с информацией о размере в имени файла
424
- save_path = f"{generated_folder}/{project}_{width}x{height}_{img_idx}.jpg"
425
- pil_img.save(save_path, "JPEG", quality=96)
426
-
427
- # Отправляем изображения на WandB с информацией о размере
428
- if use_wandb and accelerator.is_main_process:
429
- wandb_images = [
430
- wandb.Image(img, caption=f"{all_captions[i]}")
431
- for i, img in enumerate(all_generated_images)
432
- ]
433
- wandb.log({"generated_images": wandb_images, "global_step": step})
434
-
435
- finally:
436
- # Гарантированное перемещение VAE обратно на CPU
437
- vae.to("cpu")
438
- if original_model is not None:
439
- del original_model
440
- # Очистка всех тензоров
441
- for var in list(locals().keys()):
442
- if isinstance(locals()[var], torch.Tensor):
443
- del locals()[var]
444
- torch.cuda.empty_cache()
445
- gc.collect()
446
-
447
- # --------------------------- Генерация сэмплов перед обучением ---------------------------
448
- if accelerator.is_main_process:
449
- if save_model:
450
- print("Генерация сэмплов до старта обучения...")
451
- generate_and_save_samples(fixed_samples,0)
452
-
453
- # Модифицируем функцию сохранения модели для поддержки LoRA
454
- def save_checkpoint(unet):
455
- if accelerator.is_main_process:
456
- if lora_name:
457
- # Сохраняем только LoRA адаптеры
458
- save_lora_checkpoint(unet)
459
- else:
460
- # Сохраняем полную модель
461
- accelerator.unwrap_model(unet).save_pretrained(os.path.join(checkpoints_folder, f"{project}"))
462
-
463
- # --------------------------- Тренировочный цикл ---------------------------
464
- # Для логирования среднего лосса каждые % эпохи
465
- if accelerator.is_main_process:
466
- print(f"Total steps per GPU: {total_training_steps}")
467
- print(f"[GPU {accelerator.process_index}] Total steps: {total_training_steps}")
468
-
469
- epoch_loss_points = []
470
- progress_bar = tqdm(total=total_training_steps, disable=not accelerator.is_local_main_process, desc="Training", unit="step")
471
-
472
- # Определяем интервал для сэмплирования и логирования в пределах эпохи (10% эпохи)
473
- steps_per_epoch = len(dataloader)
474
- sample_interval = max(1, steps_per_epoch // sample_interval_share)
475
-
476
- # Начинаем с указанной эпохи (полезно при возобновлении)
477
- for epoch in range(start_epoch, start_epoch + num_epochs):
478
- batch_losses = []
479
- unet.train()
480
-
481
- for step, (latents, embeddings) in enumerate(dataloader):
482
- with accelerator.accumulate(unet):
483
- if save_model == False and step == 3 :
484
- used_gb = torch.cuda.max_memory_allocated() / 1024**3
485
- print(f"Шаг {step}: {used_gb:.2f} GB")
486
- # Forward pass
487
- noise = torch.randn_like(latents)
488
-
489
- timesteps = torch.randint(
490
- 1, # Начинаем с 1, не с 0
491
- scheduler.config.num_train_timesteps,
492
- (latents.shape[0],),
493
- device=device
494
- ).long()
495
-
496
- # Добавляем шум к латентам
497
- noisy_latents = scheduler.add_noise(latents, noise, timesteps)
498
-
499
- # Получаем предсказание шума - кастим в bf16
500
- noise_pred = unet(noisy_latents, timesteps, embeddings).sample.to(dtype=torch.bfloat16)
501
-
502
- # Используем целевое значение v_prediction
503
- target = scheduler.get_velocity(latents, noise, timesteps)
504
-
505
- # Считаем лосс
506
- loss = torch.nn.functional.mse_loss(noise_pred, target)
507
-
508
- # Делаем backward через Accelerator
509
- accelerator.backward(loss)
510
-
511
- # Увеличиваем счетчик глобальных шагов
512
- global_step += 1
513
-
514
- # Обновляем прогресс-бар
515
- progress_bar.update(1)
516
-
517
- # Логируем метрики
518
- if accelerator.is_main_process:
519
- current_lr = base_learning_rate
520
- batch_losses.append(loss.detach().item())
521
-
522
- # Логируем в Wandb
523
- if use_wandb:
524
- wandb.log({
525
- "loss": loss.detach().item(),
526
- "learning_rate": current_lr,
527
- "epoch": epoch,
528
- "global_step": global_step
529
- })
530
-
531
- # Генерируем сэмплы с заданным интервалом
532
- if global_step % sample_interval == 0:
533
- if save_model:
534
- save_checkpoint(unet)
535
-
536
- generate_and_save_samples(fixed_samples,global_step)
537
-
538
- # Выводим текущий лосс
539
- avg_loss = np.mean(batch_losses[-sample_interval:])
540
- #print(f"Эпоха {epoch}, шаг {global_step}, средний лосс: {avg_loss:.6f}, LR: {current_lr:.8f}")
541
- if use_wandb:
542
- wandb.log({"intermediate_loss": avg_loss})
543
-
544
-
545
- # По окончании эпохи
546
- if accelerator.is_main_process:
547
- avg_epoch_loss = np.mean(batch_losses)
548
- print(f"\nЭпоха {epoch} завершена. Средний лосс: {avg_epoch_loss:.6f}")
549
- if use_wandb:
550
- wandb.log({"epoch_loss": avg_epoch_loss, "epoch": epoch+1})
551
-
552
- # Завершение обучения - сохраняем финальную модель
553
- if accelerator.is_main_process:
554
- print("Обучение завершено! Сохраняем финальную модель...")
555
- # Сохраняем основную модель
556
- if save_model:
557
- save_checkpoint(unet)
558
- print("Готово!")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
train_sdxxxs.py CHANGED
@@ -23,10 +23,10 @@ import bitsandbytes as bnb
23
 
24
  # --------------------------- Параметры ---------------------------
25
  save_path = "datasets/768" # "datasets/576" #"datasets/576p2" #"datasets/1152p2" #"datasets/576p2" #"datasets/dataset384_temp" #"datasets/dataset384" #"datasets/imagenet-1kk" #"datasets/siski576" #"datasets/siski384" #"datasets/siski64" #"datasets/mnist"
26
- batch_size = 70 #30 #26 #45 #11 #45 #555 #35 #7
27
- base_learning_rate = 4e-6 #9.5e-7 #9e-7 #2e-6 #1e-6 #9e-7 #1e-6 #2e-6 #1e-6 #2e-6 #6e-6 #2e-6 #8e-7 #6e-6 #2e-5 #4e-5 #3e-5 #5e-5 #8e-5
28
  min_learning_rate = 2.5e-5 #2e-5
29
- num_epochs = 10 #2 #36 #18
30
  project = "sdxxxs"
31
  use_wandb = True
32
  save_model = True
@@ -37,7 +37,7 @@ checkpoints_folder = ""
37
  n_diffusion_steps = 40
38
  samples_to_generate = 12
39
  guidance_scale = 5
40
- sample_interval_share = 70 # samples/save per epoch
41
 
42
  # Папки для сохранения результатов
43
  generated_folder = "samples"
@@ -214,7 +214,7 @@ scheduler = DDPMScheduler(
214
  prediction_type="v_prediction", # V-Prediction
215
  rescale_betas_zero_snr=True, # Включение Zero-SNR
216
  timestep_spacing="leading", # Добавляем улучшенное распределение шагов
217
- steps_offset=1 # Избегаем проблем с нулевым timestep
218
  )
219
 
220
  # Инициализация переменных для возобновления обучения
 
23
 
24
  # --------------------------- Параметры ---------------------------
25
  save_path = "datasets/768" # "datasets/576" #"datasets/576p2" #"datasets/1152p2" #"datasets/576p2" #"datasets/dataset384_temp" #"datasets/dataset384" #"datasets/imagenet-1kk" #"datasets/siski576" #"datasets/siski384" #"datasets/siski64" #"datasets/mnist"
26
+ batch_size = 45 #30 #26 #45 #11 #45 #555 #35 #7
27
+ base_learning_rate = 2.5e-5 #4e-6 #2e-5 #4e-6 #9.5e-7 #9e-7 #2e-6 #1e-6 #9e-7 #1e-6 #2e-6 #1e-6 #2e-6 #6e-6 #2e-6 #8e-7 #6e-6 #2e-5 #4e-5 #3e-5 #5e-5 #8e-5
28
  min_learning_rate = 2.5e-5 #2e-5
29
+ num_epochs = 1 #2 #36 #18
30
  project = "sdxxxs"
31
  use_wandb = True
32
  save_model = True
 
37
  n_diffusion_steps = 40
38
  samples_to_generate = 12
39
  guidance_scale = 5
40
+ sample_interval_share = 20 # samples/save per epoch
41
 
42
  # Папки для сохранения результатов
43
  generated_folder = "samples"
 
214
  prediction_type="v_prediction", # V-Prediction
215
  rescale_betas_zero_snr=True, # Включение Zero-SNR
216
  timestep_spacing="leading", # Добавляем улучшенное распределение шагов
217
+ #steps_offset=1 # Избегаем проблем с нулевым timestep
218
  )
219
 
220
  # Инициализация переменных для возобновления обучения