576ep10
Browse files- samples/sdxs_320x576_0.jpg +2 -2
- samples/sdxs_384x576_0.jpg +2 -2
- samples/sdxs_448x576_0.jpg +2 -2
- samples/sdxs_512x576_0.jpg +2 -2
- samples/sdxs_576x320_0.jpg +2 -2
- samples/sdxs_576x384_0.jpg +2 -2
- samples/sdxs_576x448_0.jpg +2 -2
- samples/sdxs_576x512_0.jpg +2 -2
- samples/sdxs_576x576_0.jpg +2 -2
- scheduler/scheduler_config.json +3 -19
- sdxs/diffusion_pytorch_model.safetensors +1 -1
- text_encoder/config.json +3 -28
- text_projector/config.json +3 -1
- tokenizer/special_tokens_map.json +3 -51
- tokenizer/tokenizer_config.json +3 -55
- train.py_ +6 -6
- train_lora.py +489 -0
- unet/config.json +3 -78
- unet/diffusion_pytorch_model.fp16.safetensors +1 -1
- vae/config.json +3 -38
samples/sdxs_320x576_0.jpg
CHANGED
![]() |
Git LFS Details
|
![]() |
Git LFS Details
|
samples/sdxs_384x576_0.jpg
CHANGED
![]() |
Git LFS Details
|
![]() |
Git LFS Details
|
samples/sdxs_448x576_0.jpg
CHANGED
![]() |
Git LFS Details
|
![]() |
Git LFS Details
|
samples/sdxs_512x576_0.jpg
CHANGED
![]() |
Git LFS Details
|
![]() |
Git LFS Details
|
samples/sdxs_576x320_0.jpg
CHANGED
![]() |
Git LFS Details
|
![]() |
Git LFS Details
|
samples/sdxs_576x384_0.jpg
CHANGED
![]() |
Git LFS Details
|
![]() |
Git LFS Details
|
samples/sdxs_576x448_0.jpg
CHANGED
![]() |
Git LFS Details
|
![]() |
Git LFS Details
|
samples/sdxs_576x512_0.jpg
CHANGED
![]() |
Git LFS Details
|
![]() |
Git LFS Details
|
samples/sdxs_576x576_0.jpg
CHANGED
![]() |
Git LFS Details
|
![]() |
Git LFS Details
|
scheduler/scheduler_config.json
CHANGED
@@ -1,19 +1,3 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
"beta_end": 0.02,
|
5 |
-
"beta_schedule": "linear",
|
6 |
-
"beta_start": 0.0001,
|
7 |
-
"clip_sample": true,
|
8 |
-
"clip_sample_range": 1.0,
|
9 |
-
"dynamic_thresholding_ratio": 0.995,
|
10 |
-
"num_train_timesteps": 1000,
|
11 |
-
"prediction_type": "v_prediction",
|
12 |
-
"rescale_betas_zero_snr": true,
|
13 |
-
"sample_max_value": 1.0,
|
14 |
-
"steps_offset": 1,
|
15 |
-
"thresholding": false,
|
16 |
-
"timestep_spacing": "leading",
|
17 |
-
"trained_betas": null,
|
18 |
-
"variance_type": "fixed_small"
|
19 |
-
}
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:b855b6a6febe6a49dddb616e8f2445ad87530066b19520b0bd9abb75e3312a94
|
3 |
+
size 506
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sdxs/diffusion_pytorch_model.safetensors
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
size 4529095968
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:4bf4e3f25670722ebd75f6238db3c041d1933ed5611edc5a0795ea0f4674958e
|
3 |
size 4529095968
|
text_encoder/config.json
CHANGED
@@ -1,28 +1,3 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
"XLMRobertaModel"
|
5 |
-
],
|
6 |
-
"attention_probs_dropout_prob": 0.1,
|
7 |
-
"bos_token_id": 0,
|
8 |
-
"classifier_dropout": null,
|
9 |
-
"eos_token_id": 2,
|
10 |
-
"hidden_act": "gelu",
|
11 |
-
"hidden_dropout_prob": 0.1,
|
12 |
-
"hidden_size": 1024,
|
13 |
-
"initializer_range": 0.02,
|
14 |
-
"intermediate_size": 4096,
|
15 |
-
"layer_norm_eps": 1e-05,
|
16 |
-
"max_position_embeddings": 514,
|
17 |
-
"model_type": "xlm-roberta",
|
18 |
-
"num_attention_heads": 16,
|
19 |
-
"num_hidden_layers": 24,
|
20 |
-
"output_past": true,
|
21 |
-
"pad_token_id": 1,
|
22 |
-
"position_embedding_type": "absolute",
|
23 |
-
"torch_dtype": "float16",
|
24 |
-
"transformers_version": "4.48.3",
|
25 |
-
"type_vocab_size": 1,
|
26 |
-
"use_cache": true,
|
27 |
-
"vocab_size": 250002
|
28 |
-
}
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:87131a858ee394af6afae023f733cdebc36eda2ccbed27c36bc887cfae427392
|
3 |
+
size 721
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
text_projector/config.json
CHANGED
@@ -1 +1,3 @@
|
|
1 |
-
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ae2f211593cd2cc736bf8617bcb0a5e6abd4db0265170de82ae03b7a6664feda
|
3 |
+
size 83
|
tokenizer/special_tokens_map.json
CHANGED
@@ -1,51 +1,3 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
"lstrip": false,
|
5 |
-
"normalized": false,
|
6 |
-
"rstrip": false,
|
7 |
-
"single_word": false
|
8 |
-
},
|
9 |
-
"cls_token": {
|
10 |
-
"content": "<s>",
|
11 |
-
"lstrip": false,
|
12 |
-
"normalized": false,
|
13 |
-
"rstrip": false,
|
14 |
-
"single_word": false
|
15 |
-
},
|
16 |
-
"eos_token": {
|
17 |
-
"content": "</s>",
|
18 |
-
"lstrip": false,
|
19 |
-
"normalized": false,
|
20 |
-
"rstrip": false,
|
21 |
-
"single_word": false
|
22 |
-
},
|
23 |
-
"mask_token": {
|
24 |
-
"content": "<mask>",
|
25 |
-
"lstrip": true,
|
26 |
-
"normalized": false,
|
27 |
-
"rstrip": false,
|
28 |
-
"single_word": false
|
29 |
-
},
|
30 |
-
"pad_token": {
|
31 |
-
"content": "<pad>",
|
32 |
-
"lstrip": false,
|
33 |
-
"normalized": false,
|
34 |
-
"rstrip": false,
|
35 |
-
"single_word": false
|
36 |
-
},
|
37 |
-
"sep_token": {
|
38 |
-
"content": "</s>",
|
39 |
-
"lstrip": false,
|
40 |
-
"normalized": false,
|
41 |
-
"rstrip": false,
|
42 |
-
"single_word": false
|
43 |
-
},
|
44 |
-
"unk_token": {
|
45 |
-
"content": "<unk>",
|
46 |
-
"lstrip": false,
|
47 |
-
"normalized": false,
|
48 |
-
"rstrip": false,
|
49 |
-
"single_word": false
|
50 |
-
}
|
51 |
-
}
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:8c785abebea9ae3257b61681b4e6fd8365ceafde980c21970d001e834cf10835
|
3 |
+
size 964
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tokenizer/tokenizer_config.json
CHANGED
@@ -1,55 +1,3 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
"content": "<s>",
|
5 |
-
"lstrip": false,
|
6 |
-
"normalized": false,
|
7 |
-
"rstrip": false,
|
8 |
-
"single_word": false,
|
9 |
-
"special": true
|
10 |
-
},
|
11 |
-
"1": {
|
12 |
-
"content": "<pad>",
|
13 |
-
"lstrip": false,
|
14 |
-
"normalized": false,
|
15 |
-
"rstrip": false,
|
16 |
-
"single_word": false,
|
17 |
-
"special": true
|
18 |
-
},
|
19 |
-
"2": {
|
20 |
-
"content": "</s>",
|
21 |
-
"lstrip": false,
|
22 |
-
"normalized": false,
|
23 |
-
"rstrip": false,
|
24 |
-
"single_word": false,
|
25 |
-
"special": true
|
26 |
-
},
|
27 |
-
"3": {
|
28 |
-
"content": "<unk>",
|
29 |
-
"lstrip": false,
|
30 |
-
"normalized": false,
|
31 |
-
"rstrip": false,
|
32 |
-
"single_word": false,
|
33 |
-
"special": true
|
34 |
-
},
|
35 |
-
"250001": {
|
36 |
-
"content": "<mask>",
|
37 |
-
"lstrip": true,
|
38 |
-
"normalized": false,
|
39 |
-
"rstrip": false,
|
40 |
-
"single_word": false,
|
41 |
-
"special": true
|
42 |
-
}
|
43 |
-
},
|
44 |
-
"bos_token": "<s>",
|
45 |
-
"clean_up_tokenization_spaces": false,
|
46 |
-
"cls_token": "<s>",
|
47 |
-
"eos_token": "</s>",
|
48 |
-
"extra_special_tokens": {},
|
49 |
-
"mask_token": "<mask>",
|
50 |
-
"model_max_length": 512,
|
51 |
-
"pad_token": "<pad>",
|
52 |
-
"sep_token": "</s>",
|
53 |
-
"tokenizer_class": "XLMRobertaTokenizer",
|
54 |
-
"unk_token": "<unk>"
|
55 |
-
}
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ccf223ba3d5b3cc7fa6c3bf451f3bb40557a5c92b0aa33f63d17802ff1a96fd9
|
3 |
+
size 1178
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
train.py_
CHANGED
@@ -23,7 +23,7 @@ from datetime import datetime
|
|
23 |
# --------------------------- Параметры ---------------------------
|
24 |
save_path = "datasets/576" #"datasets/576p2" #"datasets/1152p2" #"datasets/576p2" #"datasets/dataset384_temp" #"datasets/dataset384" #"datasets/imagenet-1kk" #"datasets/siski576" #"datasets/siski384" #"datasets/siski64" #"datasets/mnist"
|
25 |
batch_size = 45 #11 #45 #555 #35 #7
|
26 |
-
base_learning_rate = 1e-6 #2e-6 #1e-6 #2e-6 #6e-6 #2e-6 #8e-7 #6e-6 #2e-5 #4e-5 #3e-5 #5e-5 #8e-5
|
27 |
min_learning_rate = 2.5e-5 #2e-5
|
28 |
num_epochs = 4 #2 #36 #18
|
29 |
project = "sdxs"
|
@@ -48,11 +48,11 @@ os.makedirs(generated_folder, exist_ok=True)
|
|
48 |
# Настройка seed для воспроизводимости
|
49 |
current_date = datetime.now()
|
50 |
seed = int(current_date.strftime("%Y%m%d"))
|
51 |
-
torch.manual_seed(seed)
|
52 |
-
np.random.seed(seed)
|
53 |
-
random.seed(seed)
|
54 |
-
if torch.cuda.is_available():
|
55 |
-
torch.cuda.manual_seed_all(seed)
|
56 |
|
57 |
print("init")
|
58 |
# Включение Flash Attention 2/SDPA
|
|
|
23 |
# --------------------------- Параметры ---------------------------
|
24 |
save_path = "datasets/576" #"datasets/576p2" #"datasets/1152p2" #"datasets/576p2" #"datasets/dataset384_temp" #"datasets/dataset384" #"datasets/imagenet-1kk" #"datasets/siski576" #"datasets/siski384" #"datasets/siski64" #"datasets/mnist"
|
25 |
batch_size = 45 #11 #45 #555 #35 #7
|
26 |
+
base_learning_rate = 1e-6 #9e-7 #1e-6 #2e-6 #1e-6 #2e-6 #6e-6 #2e-6 #8e-7 #6e-6 #2e-5 #4e-5 #3e-5 #5e-5 #8e-5
|
27 |
min_learning_rate = 2.5e-5 #2e-5
|
28 |
num_epochs = 4 #2 #36 #18
|
29 |
project = "sdxs"
|
|
|
48 |
# Настройка seed для воспроизводимости
|
49 |
current_date = datetime.now()
|
50 |
seed = int(current_date.strftime("%Y%m%d"))
|
51 |
+
#torch.manual_seed(seed)
|
52 |
+
#np.random.seed(seed)
|
53 |
+
#random.seed(seed)
|
54 |
+
#if torch.cuda.is_available():
|
55 |
+
# torch.cuda.manual_seed_all(seed)
|
56 |
|
57 |
print("init")
|
58 |
# Включение Flash Attention 2/SDPA
|
train_lora.py
ADDED
@@ -0,0 +1,489 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import math
|
3 |
+
import torch
|
4 |
+
import numpy as np
|
5 |
+
import matplotlib.pyplot as plt
|
6 |
+
from torch.utils.data import DataLoader, Sampler
|
7 |
+
from collections import defaultdict
|
8 |
+
from torch.optim.lr_scheduler import LambdaLR
|
9 |
+
from diffusers import UNet2DConditionModel, AutoencoderKL, DDPMScheduler
|
10 |
+
from accelerate import Accelerator
|
11 |
+
from datasets import load_from_disk
|
12 |
+
from tqdm import tqdm
|
13 |
+
from PIL import Image,ImageOps
|
14 |
+
import wandb
|
15 |
+
import random
|
16 |
+
import gc
|
17 |
+
from accelerate.state import DistributedType
|
18 |
+
from torch.distributed import broadcast_object_list
|
19 |
+
from torch.utils.checkpoint import checkpoint
|
20 |
+
from diffusers.models.attention_processor import AttnProcessor2_0
|
21 |
+
from datetime import datetime
|
22 |
+
|
23 |
+
# --------------------------- Параметры ---------------------------
|
24 |
+
save_path = "datasets/576" #"datasets/576p2" #"datasets/1152p2" #"datasets/576p2" #"datasets/dataset384_temp" #"datasets/dataset384" #"datasets/imagenet-1kk" #"datasets/siski576" #"datasets/siski384" #"datasets/siski64" #"datasets/mnist"
|
25 |
+
batch_size = 45 #11 #45 #555 #35 #7
|
26 |
+
base_learning_rate = 1e-6 #9e-7 #1e-6 #2e-6 #1e-6 #2e-6 #6e-6 #2e-6 #8e-7 #6e-6 #2e-5 #4e-5 #3e-5 #5e-5 #8e-5
|
27 |
+
min_learning_rate = 2.5e-5 #2e-5
|
28 |
+
num_epochs = 4 #2 #36 #18
|
29 |
+
project = "sdxs"
|
30 |
+
use_wandb = True
|
31 |
+
save_model = True
|
32 |
+
limit = 0 #200000 #0
|
33 |
+
checkpoints_folder = ""
|
34 |
+
use_lr_decay = False # отключить затухание
|
35 |
+
|
36 |
+
# Параметры для диффузии
|
37 |
+
n_diffusion_steps = 40
|
38 |
+
samples_to_generate = 12
|
39 |
+
guidance_scale = 5
|
40 |
+
sample_interval_share = 20
|
41 |
+
|
42 |
+
# Папки для сохранения результатов
|
43 |
+
generated_folder = "samples"
|
44 |
+
os.makedirs(generated_folder, exist_ok=True)
|
45 |
+
|
46 |
+
# Настройка seed для воспроизводимости
|
47 |
+
current_date = datetime.now()
|
48 |
+
seed = int(current_date.strftime("%Y%m%d"))
|
49 |
+
#torch.manual_seed(seed)
|
50 |
+
#np.random.seed(seed)
|
51 |
+
#random.seed(seed)
|
52 |
+
#if torch.cuda.is_available():
|
53 |
+
# torch.cuda.manual_seed_all(seed)
|
54 |
+
|
55 |
+
print("init")
|
56 |
+
# Включение Flash Attention 2/SDPA
|
57 |
+
torch.backends.cuda.enable_flash_sdp(True)
|
58 |
+
# --------------------------- Инициализация Accelerator --------------------
|
59 |
+
dtype = torch.bfloat16
|
60 |
+
accelerator = Accelerator(mixed_precision="bf16")
|
61 |
+
device = accelerator.device
|
62 |
+
gen = torch.Generator(device=device)
|
63 |
+
gen.manual_seed(seed)
|
64 |
+
|
65 |
+
# --------------------------- Инициализация WandB ---------------------------
|
66 |
+
if use_wandb and accelerator.is_main_process:
|
67 |
+
wandb.init(project=project, config={
|
68 |
+
"batch_size": batch_size,
|
69 |
+
"base_learning_rate": base_learning_rate,
|
70 |
+
"num_epochs": num_epochs,
|
71 |
+
"n_diffusion_steps": n_diffusion_steps,
|
72 |
+
"samples_to_generate": samples_to_generate,
|
73 |
+
"dtype": str(dtype)
|
74 |
+
})
|
75 |
+
|
76 |
+
# --------------------------- Загрузка датасета ---------------------------
|
77 |
+
class ResolutionBatchSampler(Sampler):
|
78 |
+
"""Сэмплер, который группирует примеры по одинаковым размерам"""
|
79 |
+
def __init__(self, dataset, batch_size, shuffle=True, drop_last=False):
|
80 |
+
self.dataset = dataset
|
81 |
+
self.batch_size = batch_size
|
82 |
+
self.shuffle = shuffle
|
83 |
+
self.drop_last = drop_last
|
84 |
+
|
85 |
+
# Группируем примеры по размерам
|
86 |
+
self.size_groups = defaultdict(list)
|
87 |
+
|
88 |
+
try:
|
89 |
+
widths = dataset["width"]
|
90 |
+
heights = dataset["height"]
|
91 |
+
except KeyError:
|
92 |
+
widths = [0] * len(dataset)
|
93 |
+
heights = [0] * len(dataset)
|
94 |
+
|
95 |
+
for i, (w, h) in enumerate(zip(widths, heights)):
|
96 |
+
size = (w, h)
|
97 |
+
self.size_groups[size].append(i)
|
98 |
+
|
99 |
+
# Печатаем статистику по размерам
|
100 |
+
print(f"Найдено {len(self.size_groups)} уникальных размеров:")
|
101 |
+
for size, indices in sorted(self.size_groups.items(), key=lambda x: len(x[1]), reverse=True):
|
102 |
+
width, height = size
|
103 |
+
print(f" {width}x{height}: {len(indices)} примеров")
|
104 |
+
|
105 |
+
# Формируем батчи
|
106 |
+
self.reset()
|
107 |
+
|
108 |
+
def reset(self):
|
109 |
+
"""Сбрасывает и перемешивает индексы"""
|
110 |
+
self.batches = []
|
111 |
+
|
112 |
+
for size, indices in self.size_groups.items():
|
113 |
+
if self.shuffle:
|
114 |
+
indices_copy = indices.copy()
|
115 |
+
random.shuffle(indices_copy)
|
116 |
+
else:
|
117 |
+
indices_copy = indices
|
118 |
+
|
119 |
+
# Разбиваем на батчи
|
120 |
+
for i in range(0, len(indices_copy), self.batch_size):
|
121 |
+
batch_indices = indices_copy[i:i + self.batch_size]
|
122 |
+
|
123 |
+
# Пропускаем неполные батчи если drop_last=True
|
124 |
+
if self.drop_last and len(batch_indices) < self.batch_size:
|
125 |
+
continue
|
126 |
+
|
127 |
+
self.batches.append(batch_indices)
|
128 |
+
|
129 |
+
# Пер��мешиваем батчи разных размеров между собой
|
130 |
+
if self.shuffle:
|
131 |
+
random.shuffle(self.batches)
|
132 |
+
|
133 |
+
def __iter__(self):
|
134 |
+
self.reset() # Сбрасываем и перемешиваем в начале каждой эпохи
|
135 |
+
return iter(self.batches)
|
136 |
+
|
137 |
+
def __len__(self):
|
138 |
+
return len(self.batches)
|
139 |
+
|
140 |
+
# Функция для выборки фиксированных семплов по размерам
|
141 |
+
def get_fixed_samples_by_resolution(dataset, samples_per_group=1):
|
142 |
+
"""Выбирает фиксированные семплы для каждого уникального разрешения"""
|
143 |
+
# Группируем по размерам
|
144 |
+
size_groups = defaultdict(list)
|
145 |
+
try:
|
146 |
+
widths = dataset["width"]
|
147 |
+
heights = dataset["height"]
|
148 |
+
except KeyError:
|
149 |
+
widths = [0] * len(dataset)
|
150 |
+
heights = [0] * len(dataset)
|
151 |
+
for i, (w, h) in enumerate(zip(widths, heights)):
|
152 |
+
size = (w, h)
|
153 |
+
size_groups[size].append(i)
|
154 |
+
|
155 |
+
# Выбираем фиксированные примеры из каждой группы
|
156 |
+
fixed_samples = {}
|
157 |
+
for size, indices in size_groups.items():
|
158 |
+
# Определяем сколько семплов брать из этой группы
|
159 |
+
n_samples = min(samples_per_group, len(indices))
|
160 |
+
if len(size_groups)==1:
|
161 |
+
n_samples = samples_to_generate
|
162 |
+
if n_samples == 0:
|
163 |
+
continue
|
164 |
+
|
165 |
+
# Выбираем случайные индексы
|
166 |
+
sample_indices = random.sample(indices, n_samples)
|
167 |
+
samples_data = [dataset[idx] for idx in sample_indices]
|
168 |
+
|
169 |
+
# Собираем данные
|
170 |
+
latents = torch.tensor(np.array([item["vae"] for item in samples_data]), dtype=dtype).to(device)
|
171 |
+
embeddings = torch.tensor(np.array([item["embeddings"] for item in samples_data]), dtype=dtype).to(device)
|
172 |
+
texts = [item["text"] for item in samples_data]
|
173 |
+
|
174 |
+
# Сохраняем для этого размера
|
175 |
+
fixed_samples[size] = (latents, embeddings, texts)
|
176 |
+
|
177 |
+
print(f"Создано {len(fixed_samples)} групп фиксированных семплов по разрешениям")
|
178 |
+
return fixed_samples
|
179 |
+
|
180 |
+
if limit > 0:
|
181 |
+
dataset = load_from_disk(save_path).select(range(limit))
|
182 |
+
else:
|
183 |
+
dataset = load_from_disk(save_path)
|
184 |
+
|
185 |
+
|
186 |
+
def collate_fn(batch):
|
187 |
+
# Преобразуем список в тензоры и перемещаем на девайс
|
188 |
+
latents = torch.tensor(np.array([item["vae"] for item in batch]), dtype=dtype).to(device)
|
189 |
+
embeddings = torch.tensor(np.array([item["embeddings"] for item in batch]), dtype=dtype).to(device)
|
190 |
+
return latents, embeddings
|
191 |
+
|
192 |
+
# Используем наш ResolutionBatchSampler
|
193 |
+
batch_sampler = ResolutionBatchSampler(dataset, batch_size=batch_size, shuffle=True)
|
194 |
+
dataloader = DataLoader(dataset, batch_sampler=batch_sampler, collate_fn=collate_fn)
|
195 |
+
|
196 |
+
print("Total samples",len(dataloader))
|
197 |
+
dataloader = accelerator.prepare(dataloader)
|
198 |
+
|
199 |
+
# --------------------------- Загрузка моделей ---------------------------
|
200 |
+
# VAE загружается на CPU для экономии GPU-памяти
|
201 |
+
vae = AutoencoderKL.from_pretrained("AuraDiffusion/16ch-vae").to("cpu", dtype=dtype)
|
202 |
+
|
203 |
+
# DDPMScheduler с V_Prediction и Zero-SNR
|
204 |
+
scheduler = DDPMScheduler(
|
205 |
+
num_train_timesteps=1000, # Полный график шагов для обучения
|
206 |
+
prediction_type="v_prediction", # V-Prediction
|
207 |
+
rescale_betas_zero_snr=True, # Включение Zero-SNR
|
208 |
+
timestep_spacing="leading", # Добавляем улучшенное распределение шагов
|
209 |
+
steps_offset=1 # Избегаем проблем с нулевым timestep
|
210 |
+
)
|
211 |
+
|
212 |
+
# Инициализация переменных для возобновления обучения
|
213 |
+
start_epoch = 0
|
214 |
+
global_step = 0
|
215 |
+
|
216 |
+
# Расчёт общего количества шагов
|
217 |
+
total_training_steps = (len(dataloader) * num_epochs)
|
218 |
+
# Get the world size
|
219 |
+
world_size = accelerator.state.num_processes
|
220 |
+
print(f"World Size: {world_size}")
|
221 |
+
|
222 |
+
# Опция загрузки модели из последнего чекпоинта (если существует)
|
223 |
+
latest_checkpoint = os.path.join(checkpoints_folder, project)
|
224 |
+
if os.path.isdir(latest_checkpoint):
|
225 |
+
print("Загружаем UNet из чекпоинта:", latest_checkpoint)
|
226 |
+
unet = UNet2DConditionModel.from_pretrained(latest_checkpoint).to(device, dtype=dtype)
|
227 |
+
unet.enable_gradient_checkpointing()
|
228 |
+
unet.set_use_memory_efficient_attention_xformers(False) # отключаем xformers
|
229 |
+
try:
|
230 |
+
unet.set_attn_processor(AttnProcessor2_0()) # Используем стандартный AttnProcessor
|
231 |
+
print("SDPA включен через set_attn_processor.")
|
232 |
+
except Exception as e:
|
233 |
+
print(f"Ошибка при включении SDPA: {e}")
|
234 |
+
print("Попытка использовать enable_xformers_memory_efficient_attention.")
|
235 |
+
unet.set_use_memory_efficient_attention_xformers(True)
|
236 |
+
|
237 |
+
# --------------------------- Оптимизатор и кастомный LR scheduler ---------------------------
|
238 |
+
# pip install bitsandbytes
|
239 |
+
import bitsandbytes as bnb
|
240 |
+
|
241 |
+
# [1] Создаем словарь оптимизаторов (fused backward)
|
242 |
+
optimizer_dict = {
|
243 |
+
p: bnb.optim.AdamW8bit(
|
244 |
+
[p], # Каждый параметр получает свой оптимизатор
|
245 |
+
lr=base_learning_rate,
|
246 |
+
betas=(0.9, 0.999),
|
247 |
+
weight_decay=1e-5,
|
248 |
+
eps=1e-8
|
249 |
+
) for p in unet.parameters()
|
250 |
+
}
|
251 |
+
|
252 |
+
# [2] Определяем hook для применения оптимизатора сразу после накопления градиента
|
253 |
+
def optimizer_hook(param):
|
254 |
+
optimizer_dict[param].step()
|
255 |
+
optimizer_dict[param].zero_grad(set_to_none=True)
|
256 |
+
|
257 |
+
# [3] Регистрируем hook для всех параметров модели
|
258 |
+
for param in unet.parameters():
|
259 |
+
param.register_post_accumulate_grad_hook(optimizer_hook)
|
260 |
+
|
261 |
+
# Подготовка через Accelerator
|
262 |
+
unet, optimizer = accelerator.prepare(unet, optimizer_dict)
|
263 |
+
|
264 |
+
# --------------------------- Фиксированные семплы для генерации ---------------------------
|
265 |
+
# Примеры фиксированных семплов по размерам
|
266 |
+
fixed_samples = get_fixed_samples_by_resolution(dataset)
|
267 |
+
|
268 |
+
@torch.no_grad()
|
269 |
+
def generate_and_save_samples(fixed_samples,step):
|
270 |
+
"""
|
271 |
+
Генерирует семплы для каждого из разрешений и сохраняет их.
|
272 |
+
|
273 |
+
Args:
|
274 |
+
step: Текущий шаг обучения
|
275 |
+
fixed_samples: Словарь, где ключи - размеры (width, height),
|
276 |
+
а значения - кортежи (latents, embeddings)
|
277 |
+
"""
|
278 |
+
try:
|
279 |
+
original_model = accelerator.unwrap_model(unet)
|
280 |
+
# Перемещаем VAE на device для семплирования
|
281 |
+
vae.to(accelerator.device, dtype=dtype)
|
282 |
+
|
283 |
+
# Устанавливаем количество diffusion шагов
|
284 |
+
scheduler.set_timesteps(n_diffusion_steps)
|
285 |
+
|
286 |
+
all_generated_images = []
|
287 |
+
size_info = [] # Для хранения информации о размере для каждого изображения
|
288 |
+
all_captions = []
|
289 |
+
|
290 |
+
# Проходим по всем группам размеров
|
291 |
+
for size, (sample_latents, sample_text_embeddings, sample_text) in fixed_samples.items():
|
292 |
+
width, height = size
|
293 |
+
size_info.append(f"{width}x{height}")
|
294 |
+
#print(f"Генерация {sample_latents.shape[0]} изображений размером {width}x{height}")
|
295 |
+
|
296 |
+
# Инициализируем латенты случайным шумом для этой группы
|
297 |
+
noise = torch.randn(
|
298 |
+
sample_latents.shape,
|
299 |
+
generator=gen,
|
300 |
+
device=sample_latents.device,
|
301 |
+
dtype=sample_latents.dtype
|
302 |
+
)
|
303 |
+
|
304 |
+
# Начинаем с шума
|
305 |
+
current_latents = noise.clone()
|
306 |
+
|
307 |
+
# Подготовка текстовых эмбеддингов для guidance
|
308 |
+
if guidance_scale > 0:
|
309 |
+
empty_embeddings = torch.zeros_like(sample_text_embeddings)
|
310 |
+
text_embeddings = torch.cat([empty_embeddings, sample_text_embeddings], dim=0)
|
311 |
+
else:
|
312 |
+
text_embeddings = sample_text_embeddings
|
313 |
+
|
314 |
+
# Генерация изображений
|
315 |
+
for t in scheduler.timesteps:
|
316 |
+
# Подготовка входных данных для UNet
|
317 |
+
if guidance_scale > 0:
|
318 |
+
latent_model_input = torch.cat([current_latents] * 2)
|
319 |
+
latent_model_input = scheduler.scale_model_input(latent_model_input, t)
|
320 |
+
else:
|
321 |
+
latent_model_input = scheduler.scale_model_input(current_latents, t)
|
322 |
+
|
323 |
+
# Предсказание шума
|
324 |
+
noise_pred = original_model(latent_model_input, t, text_embeddings).sample
|
325 |
+
|
326 |
+
# Применение guidance scale
|
327 |
+
if guidance_scale > 0:
|
328 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
329 |
+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
330 |
+
|
331 |
+
# Обновление латентов
|
332 |
+
current_latents = scheduler.step(noise_pred, t, current_latents).prev_sample
|
333 |
+
|
334 |
+
# Декодирование через VAE
|
335 |
+
latent = (current_latents.detach() / vae.config.scaling_factor) + vae.config.shift_factor
|
336 |
+
latent = latent.to(accelerator.device, dtype=dtype)
|
337 |
+
decoded = vae.decode(latent).sample
|
338 |
+
|
339 |
+
# Преобразуем тензоры в PIL-изображения и сохраняем
|
340 |
+
for img_idx, img_tensor in enumerate(decoded):
|
341 |
+
img = (img_tensor.to(torch.float32) / 2 + 0.5).clamp(0, 1).cpu().numpy().transpose(1, 2, 0)
|
342 |
+
pil_img = Image.fromarray((img * 255).astype("uint8"))
|
343 |
+
# Определяем максимальные ширину и высоту
|
344 |
+
max_width = max(size[0] for size in fixed_samples.keys())
|
345 |
+
max_height = max(size[1] for size in fixed_samples.keys())
|
346 |
+
max_width = max(255,max_width)
|
347 |
+
max_height = max(255,max_height)
|
348 |
+
|
349 |
+
# Добавляем padding, чтобы изображение стало размером max_width x max_height
|
350 |
+
padded_img = ImageOps.pad(pil_img, (max_width, max_height), color='white')
|
351 |
+
|
352 |
+
all_generated_images.append(padded_img)
|
353 |
+
|
354 |
+
caption_text = sample_text[img_idx][:200] if img_idx < len(sample_text) else ""
|
355 |
+
all_captions.append(caption_text)
|
356 |
+
|
357 |
+
# Сохраняем с информацией о размере в имени файла
|
358 |
+
save_path = f"{generated_folder}/{project}_{width}x{height}_{img_idx}.jpg"
|
359 |
+
pil_img.save(save_path, "JPEG", quality=96)
|
360 |
+
|
361 |
+
# Отправляем изображения на WandB с информацией о размере
|
362 |
+
if use_wandb and accelerator.is_main_process:
|
363 |
+
wandb_images = [
|
364 |
+
wandb.Image(img, caption=f"{all_captions[i]}")
|
365 |
+
for i, img in enumerate(all_generated_images)
|
366 |
+
]
|
367 |
+
wandb.log({"generated_images": wandb_images, "global_step": step})
|
368 |
+
|
369 |
+
finally:
|
370 |
+
# Гарантированное перемещение VAE обратно на CPU
|
371 |
+
vae.to("cpu")
|
372 |
+
if original_model is not None:
|
373 |
+
del original_model
|
374 |
+
# Очистка всех тензоров
|
375 |
+
for var in list(locals().keys()):
|
376 |
+
if isinstance(locals()[var], torch.Tensor):
|
377 |
+
del locals()[var]
|
378 |
+
torch.cuda.empty_cache()
|
379 |
+
gc.collect()
|
380 |
+
|
381 |
+
# --------------------------- Генерация сэмплов перед обучением ---------------------------
|
382 |
+
if accelerator.is_main_process:
|
383 |
+
if save_model:
|
384 |
+
print("Генерация сэмплов до старта обучения...")
|
385 |
+
generate_and_save_samples(fixed_samples,0)
|
386 |
+
|
387 |
+
# --------------------------- Тренировочный цикл ---------------------------
|
388 |
+
# Для логирования среднего лосса каждые % эпохи
|
389 |
+
if accelerator.is_main_process:
|
390 |
+
print(f"Total steps per GPU: {total_training_steps}")
|
391 |
+
print(f"[GPU {accelerator.process_index}] Total steps: {total_training_steps}")
|
392 |
+
|
393 |
+
epoch_loss_points = []
|
394 |
+
progress_bar = tqdm(total=total_training_steps, disable=not accelerator.is_local_main_process, desc="Training", unit="step")
|
395 |
+
|
396 |
+
# Определяем интервал для сэмплирования и логирования в пределах эпохи (10% эпохи)
|
397 |
+
steps_per_epoch = len(dataloader)
|
398 |
+
sample_interval = max(1, steps_per_epoch // sample_interval_share)
|
399 |
+
|
400 |
+
# Начинаем с указанной эпохи (полезно при возобновлении)
|
401 |
+
for epoch in range(start_epoch, start_epoch + num_epochs):
|
402 |
+
batch_losses = []
|
403 |
+
unet.train()
|
404 |
+
|
405 |
+
for step, (latents, embeddings) in enumerate(dataloader):
|
406 |
+
with accelerator.accumulate(unet):
|
407 |
+
if save_model == False and step == 3 :
|
408 |
+
used_gb = torch.cuda.max_memory_allocated() / 1024**3
|
409 |
+
print(f"Шаг {step}: {used_gb:.2f} GB")
|
410 |
+
# Forward pass
|
411 |
+
noise = torch.randn_like(latents)
|
412 |
+
|
413 |
+
timesteps = torch.randint(
|
414 |
+
1, # Начинаем с 1, не с 0
|
415 |
+
scheduler.config.num_train_timesteps,
|
416 |
+
(latents.shape[0],),
|
417 |
+
device=device
|
418 |
+
).long()
|
419 |
+
|
420 |
+
# Добавляем шум к латентам
|
421 |
+
noisy_latents = scheduler.add_noise(latents, noise, timesteps)
|
422 |
+
|
423 |
+
# Получаем предсказание шума - кастим в bf16
|
424 |
+
noise_pred = unet(noisy_latents, timesteps, embeddings).sample.to(dtype=torch.bfloat16)
|
425 |
+
|
426 |
+
# Используем целевое значение v_prediction
|
427 |
+
target = scheduler.get_velocity(latents, noise, timesteps)
|
428 |
+
|
429 |
+
# Считаем лосс
|
430 |
+
loss = torch.nn.functional.mse_loss(noise_pred, target)
|
431 |
+
|
432 |
+
# Делаем backward через Accelerator
|
433 |
+
accelerator.backward(loss)
|
434 |
+
|
435 |
+
# Увеличив��ем счетчик глобальных шагов
|
436 |
+
global_step += 1
|
437 |
+
|
438 |
+
# Обновляем прогресс-бар
|
439 |
+
progress_bar.update(1)
|
440 |
+
|
441 |
+
# Логируем метрики
|
442 |
+
if accelerator.is_main_process:
|
443 |
+
current_lr = base_learning_rate
|
444 |
+
batch_losses.append(loss.detach().item())
|
445 |
+
|
446 |
+
# Логируем в Wandb
|
447 |
+
if use_wandb:
|
448 |
+
wandb.log({
|
449 |
+
"loss": loss.detach().item(),
|
450 |
+
"learning_rate": current_lr,
|
451 |
+
"epoch": epoch,
|
452 |
+
#"grad_norm": grad_norm.item(),
|
453 |
+
"global_step": global_step
|
454 |
+
})
|
455 |
+
|
456 |
+
# Генерируем сэмплы с заданным интервалом
|
457 |
+
if global_step % sample_interval == 0:
|
458 |
+
if save_model:
|
459 |
+
accelerator.unwrap_model(unet).save_pretrained(os.path.join(checkpoints_folder, f"{project}"))
|
460 |
+
|
461 |
+
generate_and_save_samples(fixed_samples,global_step)
|
462 |
+
|
463 |
+
# Выводим текущий лосс
|
464 |
+
avg_loss = np.mean(batch_losses[-sample_interval:])
|
465 |
+
#print(f"Эпоха {epoch}, шаг {global_step}, средний лосс: {avg_loss:.6f}, LR: {current_lr:.8f}")
|
466 |
+
if use_wandb:
|
467 |
+
wandb.log({"intermediate_loss": avg_loss})
|
468 |
+
|
469 |
+
|
470 |
+
# По окончании эпохи
|
471 |
+
#accelerator.wait_for_everyone()
|
472 |
+
# Сохраняем чекпоинт в конце каждой эпохи
|
473 |
+
if accelerator.is_main_process:
|
474 |
+
|
475 |
+
# Сохраняем UNet отдельно для удобства использования
|
476 |
+
#if save_model:
|
477 |
+
# accelerator.unwrap_model(unet).save_pretrained(os.path.join(checkpoints_folder, f"{project}"))
|
478 |
+
avg_epoch_loss = np.mean(batch_losses)
|
479 |
+
print(f"\nЭпоха {epoch} завершена. Средний лосс: {avg_epoch_loss:.6f}")
|
480 |
+
if use_wandb:
|
481 |
+
wandb.log({"epoch_loss": avg_epoch_loss, "epoch": epoch+1})
|
482 |
+
|
483 |
+
# Завершение обучения - сохраняем финальную модель
|
484 |
+
if accelerator.is_main_process:
|
485 |
+
print("Обучение завершено! Сохраняем финальную модель...")
|
486 |
+
# Сохраняем основную модель
|
487 |
+
if save_model:
|
488 |
+
accelerator.unwrap_model(unet).save_pretrained(os.path.join(checkpoints_folder, f"{project}"))
|
489 |
+
print("Готово!")
|
unet/config.json
CHANGED
@@ -1,78 +1,3 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
"_name_or_path": "sdxs",
|
5 |
-
"act_fn": "silu",
|
6 |
-
"addition_embed_type": null,
|
7 |
-
"addition_embed_type_num_heads": 48,
|
8 |
-
"addition_time_embed_dim": null,
|
9 |
-
"attention_head_dim": 48,
|
10 |
-
"attention_type": "default",
|
11 |
-
"block_out_channels": [
|
12 |
-
384,
|
13 |
-
576,
|
14 |
-
768,
|
15 |
-
960
|
16 |
-
],
|
17 |
-
"center_input_sample": false,
|
18 |
-
"class_embed_type": null,
|
19 |
-
"class_embeddings_concat": false,
|
20 |
-
"conv_in_kernel": 3,
|
21 |
-
"conv_out_kernel": 3,
|
22 |
-
"cross_attention_dim": 1152,
|
23 |
-
"cross_attention_norm": null,
|
24 |
-
"down_block_types": [
|
25 |
-
"CrossAttnDownBlock2D",
|
26 |
-
"CrossAttnDownBlock2D",
|
27 |
-
"CrossAttnDownBlock2D",
|
28 |
-
"CrossAttnDownBlock2D"
|
29 |
-
],
|
30 |
-
"downsample_padding": 1,
|
31 |
-
"dropout": 0.1,
|
32 |
-
"dual_cross_attention": false,
|
33 |
-
"encoder_hid_dim": null,
|
34 |
-
"encoder_hid_dim_type": null,
|
35 |
-
"flip_sin_to_cos": true,
|
36 |
-
"freq_shift": 0,
|
37 |
-
"in_channels": 16,
|
38 |
-
"layers_per_block": [
|
39 |
-
2,
|
40 |
-
2,
|
41 |
-
2,
|
42 |
-
2
|
43 |
-
],
|
44 |
-
"mid_block_only_cross_attention": null,
|
45 |
-
"mid_block_scale_factor": 1.0,
|
46 |
-
"mid_block_type": "UNetMidBlock2DCrossAttn",
|
47 |
-
"norm_eps": 1e-05,
|
48 |
-
"norm_num_groups": 16,
|
49 |
-
"num_attention_heads": null,
|
50 |
-
"num_class_embeds": null,
|
51 |
-
"only_cross_attention": false,
|
52 |
-
"out_channels": 16,
|
53 |
-
"projection_class_embeddings_input_dim": null,
|
54 |
-
"resnet_out_scale_factor": 1.0,
|
55 |
-
"resnet_skip_time_act": false,
|
56 |
-
"resnet_time_scale_shift": "default",
|
57 |
-
"reverse_transformer_layers_per_block": null,
|
58 |
-
"sample_size": 64,
|
59 |
-
"time_cond_proj_dim": null,
|
60 |
-
"time_embedding_act_fn": null,
|
61 |
-
"time_embedding_dim": null,
|
62 |
-
"time_embedding_type": "positional",
|
63 |
-
"timestep_post_act": null,
|
64 |
-
"transformer_layers_per_block": [
|
65 |
-
4,
|
66 |
-
6,
|
67 |
-
8,
|
68 |
-
10
|
69 |
-
],
|
70 |
-
"up_block_types": [
|
71 |
-
"CrossAttnUpBlock2D",
|
72 |
-
"CrossAttnUpBlock2D",
|
73 |
-
"CrossAttnUpBlock2D",
|
74 |
-
"CrossAttnUpBlock2D"
|
75 |
-
],
|
76 |
-
"upcast_attention": false,
|
77 |
-
"use_linear_projection": true
|
78 |
-
}
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f2273cd4cc67fba9bb3ad740df06b6a87901105e88b40d9b18fa4cdc8809314e
|
3 |
+
size 1895
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
unet/diffusion_pytorch_model.fp16.safetensors
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
size 4529095968
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:4bf4e3f25670722ebd75f6238db3c041d1933ed5611edc5a0795ea0f4674958e
|
3 |
size 4529095968
|
vae/config.json
CHANGED
@@ -1,38 +1,3 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
"_name_or_path": "/home/recoilme/sdxs576/vae",
|
5 |
-
"act_fn": "silu",
|
6 |
-
"block_out_channels": [
|
7 |
-
128,
|
8 |
-
256,
|
9 |
-
512,
|
10 |
-
512
|
11 |
-
],
|
12 |
-
"down_block_types": [
|
13 |
-
"DownEncoderBlock2D",
|
14 |
-
"DownEncoderBlock2D",
|
15 |
-
"DownEncoderBlock2D",
|
16 |
-
"DownEncoderBlock2D"
|
17 |
-
],
|
18 |
-
"force_upcast": false,
|
19 |
-
"in_channels": 3,
|
20 |
-
"latent_channels": 16,
|
21 |
-
"latents_mean": null,
|
22 |
-
"latents_std": null,
|
23 |
-
"layers_per_block": 2,
|
24 |
-
"mid_block_add_attention": false,
|
25 |
-
"norm_num_groups": 32,
|
26 |
-
"out_channels": 3,
|
27 |
-
"sample_size": 1024,
|
28 |
-
"scaling_factor": 0.18215,
|
29 |
-
"shift_factor": 0,
|
30 |
-
"up_block_types": [
|
31 |
-
"UpDecoderBlock2D",
|
32 |
-
"UpDecoderBlock2D",
|
33 |
-
"UpDecoderBlock2D",
|
34 |
-
"UpDecoderBlock2D"
|
35 |
-
],
|
36 |
-
"use_post_quant_conv": true,
|
37 |
-
"use_quant_conv": true
|
38 |
-
}
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:992be219657d2876df23100f817d060c9f9aa497358ad30d7282103d96823d1f
|
3 |
+
size 819
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|