Spaces:
Configuration error
Configuration error
Upload files
Browse files- .gitattributes +2 -0
- .gitignore +6 -0
- README.md +139 -13
- contents/alpha_scale.gif +3 -0
- contents/alpha_scale.mp4 +3 -0
- contents/disney_lora.jpg +0 -0
- contents/pop_art.jpg +0 -0
- lora_diffusion/__init__.py +1 -0
- lora_diffusion/cli_lora_add.py +49 -0
- lora_diffusion/lora.py +166 -0
- lora_disney.pt +3 -0
- lora_illust.pt +3 -0
- lora_pop.pt +3 -0
- requirements.txt +3 -3
- run_lora_db.sh +17 -0
- scripts/make_alpha_gifs.ipynb +0 -0
- scripts/run_inference.ipynb +0 -0
- setup.py +25 -0
- train_lora_dreambooth.py +964 -0
.gitattributes
CHANGED
@@ -32,3 +32,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
32 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
33 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
34 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
32 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
33 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
34 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
35 |
+
contents/alpha_scale.gif filter=lfs diff=lfs merge=lfs -text
|
36 |
+
contents/alpha_scale.mp4 filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
data_*
|
2 |
+
output_*
|
3 |
+
__pycache__
|
4 |
+
*.pyc
|
5 |
+
__test*
|
6 |
+
merged_lora*
|
README.md
CHANGED
@@ -1,13 +1,139 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Low-rank Adaptation for Fast Text-to-Image Diffusion Fine-tuning
|
2 |
+
|
3 |
+
<!-- #region -->
|
4 |
+
<p align="center">
|
5 |
+
<img src="contents/alpha_scale.gif">
|
6 |
+
</p>
|
7 |
+
<!-- #endregion -->
|
8 |
+
|
9 |
+
> Using LORA to fine tune on illustration dataset : $W = W_0 + \alpha \Delta W$, where $\alpha$ is the merging ratio. Above gif is scaling alpha from 0 to 1. Setting alpha to 0 is same as using the original model, and setting alpha to 1 is same as using the fully fine-tuned model.
|
10 |
+
|
11 |
+
<!-- #region -->
|
12 |
+
<p align="center">
|
13 |
+
<img src="contents/disney_lora.jpg">
|
14 |
+
</p>
|
15 |
+
<!-- #endregion -->
|
16 |
+
|
17 |
+
> "style of sks, baby lion", with disney-style LORA model.
|
18 |
+
|
19 |
+
<!-- #region -->
|
20 |
+
<p align="center">
|
21 |
+
<img src="contents/pop_art.jpg">
|
22 |
+
</p>
|
23 |
+
<!-- #endregion -->
|
24 |
+
|
25 |
+
> "style of sks, superman", with pop-art style LORA model.
|
26 |
+
|
27 |
+
## Main Features
|
28 |
+
|
29 |
+
- Fine-tune Stable diffusion models twice as faster than dreambooth method, by Low-rank Adaptation
|
30 |
+
- Get insanely small end result, easy to share and download.
|
31 |
+
- Easy to use, compatible with diffusers
|
32 |
+
- Sometimes even better performance than full fine-tuning (but left as future work for extensive comparisons)
|
33 |
+
- Merge checkpoints by merging LORA
|
34 |
+
|
35 |
+
# Lengthy Introduction
|
36 |
+
|
37 |
+
Thanks to the generous work of Stability AI and Huggingface, so many people have enjoyed fine-tuning stable diffusion models to fit their needs and generate higher fidelity images. **However, the fine-tuning process is very slow, and it is not easy to find a good balance between the number of steps and the quality of the results.**
|
38 |
+
|
39 |
+
Also, the final results (fully fined-tuned model) is very large. Some people instead works with textual-inversion as an alternative for this. But clearly this is suboptimal: textual inversion only creates a small word-embedding, and the final image is not as good as a fully fine-tuned model.
|
40 |
+
|
41 |
+
Well, what's the alternative? In the domain of LLM, researchers have developed Efficient fine-tuning methods. LORA, especially, tackles the very problem the community currently has: end users with Open-sourced stable-diffusion model want to try various other fine-tuned model that is created by the community, but the model is too large to download and use. LORA instead attempts to fine-tune the "residual" of the model instead of the entire model: i.e., train the $\Delta W$ instead of $W$.
|
42 |
+
|
43 |
+
$$
|
44 |
+
W' = W + \Delta W
|
45 |
+
$$
|
46 |
+
|
47 |
+
Where we can further decompose $\Delta W$ into low-rank matrices : $\Delta W = A B^T $, where $A, \in \mathbb{R}^{n \times d}, B \in \mathbb{R}^{m \times d}, d << n$.
|
48 |
+
This is the key idea of LORA. We can then fine-tune $A$ and $B$ instead of $W$. In the end, you get an insanely small model as $A$ and $B$ are much smaller than $W$.
|
49 |
+
|
50 |
+
Also, not all of the parameters need tuning: they found that often, $Q, K, V, O$ (i.e., attention layer) of the transformer model is enough to tune. (This is also the reason why the end result is so small). This repo will follow the same idea.
|
51 |
+
|
52 |
+
Enough of the lengthy introduction, let's get to the code.
|
53 |
+
|
54 |
+
# Installation
|
55 |
+
|
56 |
+
```bash
|
57 |
+
pip install git+https://github.com/cloneofsimo/lora.git
|
58 |
+
```
|
59 |
+
|
60 |
+
# Getting Started
|
61 |
+
|
62 |
+
## Fine-tuning Stable diffusion with LORA.
|
63 |
+
|
64 |
+
Basic usage is as follows: prepare sets of $A, B$ matrices in an unet model, and fine-tune them.
|
65 |
+
|
66 |
+
```python
|
67 |
+
from lora_diffusion import inject_trainable_lora, extract_lora_up_downs
|
68 |
+
|
69 |
+
...
|
70 |
+
|
71 |
+
unet = UNet2DConditionModel.from_pretrained(
|
72 |
+
pretrained_model_name_or_path,
|
73 |
+
subfolder="unet",
|
74 |
+
)
|
75 |
+
unet.requires_grad_(False)
|
76 |
+
unet_lora_params, train_names = inject_trainable_lora(unet) # This will
|
77 |
+
# turn off all of the gradients of unet, except for the trainable LORA params.
|
78 |
+
optimizer = optim.Adam(
|
79 |
+
itertools.chain(*unet_lora_params, text_encoder.parameters()), lr=1e-4
|
80 |
+
)
|
81 |
+
```
|
82 |
+
|
83 |
+
An example of this can be found in `train_lora_dreambooth.py`. Run this example with
|
84 |
+
|
85 |
+
```bash
|
86 |
+
run_lora_db.sh
|
87 |
+
```
|
88 |
+
|
89 |
+
## Loading, merging, and interpolating trained LORAs.
|
90 |
+
|
91 |
+
We've seen that people have been merging different checkpoints with different ratios, and this seems to be very useful to the community. LORA is extremely easy to merge.
|
92 |
+
|
93 |
+
By the nature of LORA, one can interpolate between different fine-tuned models by adding different $A, B$ matrices.
|
94 |
+
|
95 |
+
Currently, LORA cli has two options : merge unet with LORA, or merge LORA with LORA.
|
96 |
+
|
97 |
+
### Merging unet with LORA
|
98 |
+
|
99 |
+
```bash
|
100 |
+
$ lora_add --path_1 PATH_TO_DIFFUSER_FORMAT_MODEL --path_2 PATH_TO_LORA.PT --mode upl --alpha 1.0 --output_path OUTPUT_PATH
|
101 |
+
```
|
102 |
+
|
103 |
+
`path_1` can be both local path or huggingface model name. When adding LORA to unet, alpha is the constant as below:
|
104 |
+
|
105 |
+
$$
|
106 |
+
W' = W + \alpha \Delta W
|
107 |
+
$$
|
108 |
+
|
109 |
+
So, set alpha to 1.0 to fully add LORA. If the LORA seems to have too much effect (i.e., overfitted), set alpha to lower value. If the LORA seems to have too little effect, set alpha to higher than 1.0. You can tune these values to your needs.
|
110 |
+
|
111 |
+
**Example**
|
112 |
+
|
113 |
+
```bash
|
114 |
+
$ lora_add --path_1 stabilityai/stable-diffusion-2-base --path_2 lora_illust.pt --mode upl --alpha 1.0 --output_path merged_model
|
115 |
+
```
|
116 |
+
|
117 |
+
### Merging LORA with LORA
|
118 |
+
|
119 |
+
```bash
|
120 |
+
$ lora_add --path_1 PATH_TO_LORA.PT --path_2 PATH_TO_LORA.PT --mode lpl --alpha 0.5 --output_path OUTPUT_PATH.PT
|
121 |
+
```
|
122 |
+
|
123 |
+
alpha is the ratio of the first model to the second model. i.e.,
|
124 |
+
|
125 |
+
$$
|
126 |
+
\Delta W = (\alpha A_1 + (1 - \alpha) A_2) (B_1 + (1 - \alpha) B_2)^T
|
127 |
+
$$
|
128 |
+
|
129 |
+
Set alpha to 0.5 to get the average of the two models. Set alpha close to 1.0 to get more effect of the first model, and set alpha close to 0.0 to get more effect of the second model.
|
130 |
+
|
131 |
+
**Example**
|
132 |
+
|
133 |
+
```bash
|
134 |
+
$ lora_add --path_1 lora_illust.pt --path_2 lora_pop.pt --alpha 0.3 --output_path lora_merged.pt
|
135 |
+
```
|
136 |
+
|
137 |
+
### Making Inference with trained LORA
|
138 |
+
|
139 |
+
Checkout `scripts/run_inference.ipynb` for an example of how to make inference with LORA.
|
contents/alpha_scale.gif
ADDED
![]() |
Git LFS Details
|
contents/alpha_scale.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:1ad74f5f69d99bfcbeee1d4d2b3900ac1ca7ff83fba5ddf8269ffed8a56c9c6e
|
3 |
+
size 5247140
|
contents/disney_lora.jpg
ADDED
![]() |
contents/pop_art.jpg
ADDED
![]() |
lora_diffusion/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .lora import *
|
lora_diffusion/cli_lora_add.py
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Literal, Union, Dict
|
2 |
+
|
3 |
+
import fire
|
4 |
+
from diffusers import StableDiffusionPipeline
|
5 |
+
|
6 |
+
import torch
|
7 |
+
from .lora import tune_lora_scale, weight_apply_lora
|
8 |
+
|
9 |
+
|
10 |
+
def add(
|
11 |
+
path_1: str,
|
12 |
+
path_2: str,
|
13 |
+
output_path: str = "./merged_lora.pt",
|
14 |
+
alpha: float = 0.5,
|
15 |
+
mode: Literal["lpl", "upl"] = "lpl",
|
16 |
+
):
|
17 |
+
if mode == "lpl":
|
18 |
+
out_list = []
|
19 |
+
l1 = torch.load(path_1)
|
20 |
+
l2 = torch.load(path_2)
|
21 |
+
|
22 |
+
l1pairs = zip(l1[::2], l1[1::2])
|
23 |
+
l2pairs = zip(l2[::2], l2[1::2])
|
24 |
+
|
25 |
+
for (x1, y1), (x2, y2) in zip(l1pairs, l2pairs):
|
26 |
+
x1.data = alpha * x1.data + (1 - alpha) * x2.data
|
27 |
+
y1.data = alpha * y1.data + (1 - alpha) * y2.data
|
28 |
+
|
29 |
+
out_list.append(x1)
|
30 |
+
out_list.append(y1)
|
31 |
+
|
32 |
+
torch.save(out_list, output_path)
|
33 |
+
|
34 |
+
elif mode == "upl":
|
35 |
+
|
36 |
+
loaded_pipeline = StableDiffusionPipeline.from_pretrained(
|
37 |
+
path_1,
|
38 |
+
).to("cpu")
|
39 |
+
|
40 |
+
weight_apply_lora(loaded_pipeline.unet, torch.load(path_2), alpha=alpha)
|
41 |
+
|
42 |
+
if output_path.endswith(".pt"):
|
43 |
+
output_path = output_path[:-3]
|
44 |
+
|
45 |
+
loaded_pipeline.save_pretrained(output_path)
|
46 |
+
|
47 |
+
|
48 |
+
def main():
|
49 |
+
fire.Fire(add)
|
lora_diffusion/lora.py
ADDED
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from typing import Callable, Dict, List, Optional, Tuple
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import PIL
|
6 |
+
import torch
|
7 |
+
import torch.nn.functional as F
|
8 |
+
|
9 |
+
import torch.nn as nn
|
10 |
+
|
11 |
+
|
12 |
+
class LoraInjectedLinear(nn.Module):
|
13 |
+
def __init__(self, in_features, out_features, bias=False):
|
14 |
+
super().__init__()
|
15 |
+
self.linear = nn.Linear(in_features, out_features, bias)
|
16 |
+
self.lora_down = nn.Linear(in_features, 4, bias=False)
|
17 |
+
self.lora_up = nn.Linear(4, out_features, bias=False)
|
18 |
+
self.scale = 1.0
|
19 |
+
|
20 |
+
nn.init.normal_(self.lora_down.weight, std=1 / 16)
|
21 |
+
nn.init.zeros_(self.lora_up.weight)
|
22 |
+
|
23 |
+
def forward(self, input):
|
24 |
+
return self.linear(input) + self.lora_up(self.lora_down(input)) * self.scale
|
25 |
+
|
26 |
+
|
27 |
+
def inject_trainable_lora(
|
28 |
+
model: nn.Module, target_replace_module: List[str] = ["CrossAttention", "Attention"]
|
29 |
+
):
|
30 |
+
"""
|
31 |
+
inject lora into model, and returns lora parameter groups.
|
32 |
+
"""
|
33 |
+
|
34 |
+
require_grad_params = []
|
35 |
+
names = []
|
36 |
+
|
37 |
+
for _module in model.modules():
|
38 |
+
if _module.__class__.__name__ in target_replace_module:
|
39 |
+
|
40 |
+
for name, _child_module in _module.named_modules():
|
41 |
+
if _child_module.__class__.__name__ == "Linear":
|
42 |
+
|
43 |
+
weight = _child_module.weight
|
44 |
+
bias = _child_module.bias
|
45 |
+
_tmp = LoraInjectedLinear(
|
46 |
+
_child_module.in_features,
|
47 |
+
_child_module.out_features,
|
48 |
+
_child_module.bias is not None,
|
49 |
+
)
|
50 |
+
_tmp.linear.weight = weight
|
51 |
+
if bias is not None:
|
52 |
+
_tmp.linear.bias = bias
|
53 |
+
|
54 |
+
# switch the module
|
55 |
+
_module._modules[name] = _tmp
|
56 |
+
|
57 |
+
require_grad_params.append(
|
58 |
+
_module._modules[name].lora_up.parameters()
|
59 |
+
)
|
60 |
+
require_grad_params.append(
|
61 |
+
_module._modules[name].lora_down.parameters()
|
62 |
+
)
|
63 |
+
|
64 |
+
_module._modules[name].lora_up.weight.requires_grad = True
|
65 |
+
_module._modules[name].lora_down.weight.requires_grad = True
|
66 |
+
names.append(name)
|
67 |
+
|
68 |
+
return require_grad_params, names
|
69 |
+
|
70 |
+
|
71 |
+
def extract_lora_ups_down(model, target_replace_module=["CrossAttention", "Attention"]):
|
72 |
+
|
73 |
+
loras = []
|
74 |
+
|
75 |
+
for _module in model.modules():
|
76 |
+
if _module.__class__.__name__ in target_replace_module:
|
77 |
+
for _child_module in _module.modules():
|
78 |
+
if _child_module.__class__.__name__ == "LoraInjectedLinear":
|
79 |
+
loras.append((_child_module.lora_up, _child_module.lora_down))
|
80 |
+
if len(loras) == 0:
|
81 |
+
raise ValueError("No lora injected.")
|
82 |
+
return loras
|
83 |
+
|
84 |
+
|
85 |
+
def save_lora_weight(model, path="./lora.pt"):
|
86 |
+
weights = []
|
87 |
+
for _up, _down in extract_lora_ups_down(model):
|
88 |
+
weights.append(_up.weight)
|
89 |
+
weights.append(_down.weight)
|
90 |
+
|
91 |
+
torch.save(weights, path)
|
92 |
+
|
93 |
+
|
94 |
+
def save_lora_as_json(model, path="./lora.json"):
|
95 |
+
weights = []
|
96 |
+
for _up, _down in extract_lora_ups_down(model):
|
97 |
+
weights.append(_up.weight.detach().cpu().numpy().tolist())
|
98 |
+
weights.append(_down.weight.detach().cpu().numpy().tolist())
|
99 |
+
|
100 |
+
import json
|
101 |
+
|
102 |
+
with open(path, "w") as f:
|
103 |
+
json.dump(weights, f)
|
104 |
+
|
105 |
+
|
106 |
+
def weight_apply_lora(
|
107 |
+
model, loras, target_replace_module=["CrossAttention", "Attention"], alpha=1.0
|
108 |
+
):
|
109 |
+
|
110 |
+
for _module in model.modules():
|
111 |
+
if _module.__class__.__name__ in target_replace_module:
|
112 |
+
for _child_module in _module.modules():
|
113 |
+
if _child_module.__class__.__name__ == "Linear":
|
114 |
+
|
115 |
+
weight = _child_module.weight
|
116 |
+
|
117 |
+
up_weight = loras.pop(0).detach().to(weight.device)
|
118 |
+
down_weight = loras.pop(0).detach().to(weight.device)
|
119 |
+
|
120 |
+
# W <- W + U * D
|
121 |
+
weight = weight + alpha * (up_weight @ down_weight).type(
|
122 |
+
weight.dtype
|
123 |
+
)
|
124 |
+
_child_module.weight = nn.Parameter(weight)
|
125 |
+
|
126 |
+
|
127 |
+
def monkeypatch_lora(
|
128 |
+
model, loras, target_replace_module=["CrossAttention", "Attention"]
|
129 |
+
):
|
130 |
+
for _module in model.modules():
|
131 |
+
if _module.__class__.__name__ in target_replace_module:
|
132 |
+
for name, _child_module in _module.named_modules():
|
133 |
+
if _child_module.__class__.__name__ == "Linear":
|
134 |
+
|
135 |
+
weight = _child_module.weight
|
136 |
+
bias = _child_module.bias
|
137 |
+
_tmp = LoraInjectedLinear(
|
138 |
+
_child_module.in_features,
|
139 |
+
_child_module.out_features,
|
140 |
+
_child_module.bias is not None,
|
141 |
+
)
|
142 |
+
_tmp.linear.weight = weight
|
143 |
+
|
144 |
+
if bias is not None:
|
145 |
+
_tmp.linear.bias = bias
|
146 |
+
|
147 |
+
# switch the module
|
148 |
+
_module._modules[name] = _tmp
|
149 |
+
|
150 |
+
up_weight = loras.pop(0)
|
151 |
+
down_weight = loras.pop(0)
|
152 |
+
|
153 |
+
_module._modules[name].lora_up.weight = nn.Parameter(
|
154 |
+
up_weight.type(weight.dtype)
|
155 |
+
)
|
156 |
+
_module._modules[name].lora_down.weight = nn.Parameter(
|
157 |
+
down_weight.type(weight.dtype)
|
158 |
+
)
|
159 |
+
|
160 |
+
_module._modules[name].to(weight.device)
|
161 |
+
|
162 |
+
|
163 |
+
def tune_lora_scale(model, alpha: float = 1.0):
|
164 |
+
for _module in model.modules():
|
165 |
+
if _module.__class__.__name__ == "LoraInjectedLinear":
|
166 |
+
_module.scale = alpha
|
lora_disney.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:72f687f810b86bb8cc64d2ece59886e2e96d29e3f57f97340ee147d168b8a5fe
|
3 |
+
size 3397249
|
lora_illust.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:7f6acb0bc0cd5f96299be7839f89f58727e2666e58861e55866ea02125c97aba
|
3 |
+
size 3397249
|
lora_pop.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:18a1565852a08cfcff63e90670286c9427e3958f57de9b84e3f8b2c9a3a14b6c
|
3 |
+
size 3397249
|
requirements.txt
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
-
diffusers
|
2 |
transformers
|
3 |
-
|
4 |
-
|
|
|
1 |
+
diffusers>=0.9.0
|
2 |
transformers
|
3 |
+
scipy
|
4 |
+
ftfy
|
run_lora_db.sh
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#https://github.com/huggingface/diffusers/tree/main/examples/dreambooth
|
2 |
+
export MODEL_NAME="stabilityai/stable-diffusion-2-1-base"
|
3 |
+
export INSTANCE_DIR="./data_example"
|
4 |
+
export OUTPUT_DIR="./output_example"
|
5 |
+
|
6 |
+
accelerate launch train_lora_dreambooth.py \
|
7 |
+
--pretrained_model_name_or_path=$MODEL_NAME \
|
8 |
+
--instance_data_dir=$INSTANCE_DIR \
|
9 |
+
--output_dir=$OUTPUT_DIR \
|
10 |
+
--instance_prompt="style of sks" \
|
11 |
+
--resolution=512 \
|
12 |
+
--train_batch_size=1 \
|
13 |
+
--gradient_accumulation_steps=1 \
|
14 |
+
--learning_rate=1e-4 \
|
15 |
+
--lr_scheduler="constant" \
|
16 |
+
--lr_warmup_steps=0 \
|
17 |
+
--max_train_steps=30000
|
scripts/make_alpha_gifs.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
scripts/run_inference.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
setup.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
import pkg_resources
|
4 |
+
from setuptools import find_packages, setup
|
5 |
+
|
6 |
+
setup(
|
7 |
+
name="lora_diffusion",
|
8 |
+
py_modules=["lora_diffusion"],
|
9 |
+
version="0.0.1",
|
10 |
+
description="Low Rank Adaptation for Diffusion Models. Works with Stable Diffusion out-of-the-box.",
|
11 |
+
author="Simo Ryu",
|
12 |
+
packages=find_packages(),
|
13 |
+
entry_points={
|
14 |
+
"console_scripts": [
|
15 |
+
"lora_add = lora_diffusion.cli_lora_add:main",
|
16 |
+
],
|
17 |
+
},
|
18 |
+
install_requires=[
|
19 |
+
str(r)
|
20 |
+
for r in pkg_resources.parse_requirements(
|
21 |
+
open(os.path.join(os.path.dirname(__file__), "requirements.txt"))
|
22 |
+
)
|
23 |
+
],
|
24 |
+
include_package_data=True,
|
25 |
+
)
|
train_lora_dreambooth.py
ADDED
@@ -0,0 +1,964 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Bootstrapped from:
|
2 |
+
# https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/train_dreambooth.py
|
3 |
+
|
4 |
+
import argparse
|
5 |
+
import hashlib
|
6 |
+
import itertools
|
7 |
+
import math
|
8 |
+
import os
|
9 |
+
from pathlib import Path
|
10 |
+
from typing import Optional
|
11 |
+
|
12 |
+
import torch
|
13 |
+
import torch.nn.functional as F
|
14 |
+
import torch.utils.checkpoint
|
15 |
+
|
16 |
+
|
17 |
+
from accelerate import Accelerator
|
18 |
+
from accelerate.logging import get_logger
|
19 |
+
from accelerate.utils import set_seed
|
20 |
+
from diffusers import (
|
21 |
+
AutoencoderKL,
|
22 |
+
DDPMScheduler,
|
23 |
+
StableDiffusionPipeline,
|
24 |
+
UNet2DConditionModel,
|
25 |
+
)
|
26 |
+
from diffusers.optimization import get_scheduler
|
27 |
+
from huggingface_hub import HfFolder, Repository, whoami
|
28 |
+
|
29 |
+
from tqdm.auto import tqdm
|
30 |
+
from transformers import CLIPTextModel, CLIPTokenizer
|
31 |
+
|
32 |
+
from lora_diffusion import (
|
33 |
+
inject_trainable_lora,
|
34 |
+
save_lora_weight,
|
35 |
+
extract_lora_ups_down,
|
36 |
+
)
|
37 |
+
|
38 |
+
from torch.utils.data import Dataset
|
39 |
+
from PIL import Image
|
40 |
+
from torchvision import transforms
|
41 |
+
|
42 |
+
from pathlib import Path
|
43 |
+
|
44 |
+
import random
|
45 |
+
import re
|
46 |
+
|
47 |
+
|
48 |
+
class DreamBoothDataset(Dataset):
|
49 |
+
"""
|
50 |
+
A dataset to prepare the instance and class images with the prompts for fine-tuning the model.
|
51 |
+
It pre-processes the images and the tokenizes prompts.
|
52 |
+
"""
|
53 |
+
|
54 |
+
def __init__(
|
55 |
+
self,
|
56 |
+
instance_data_root,
|
57 |
+
instance_prompt,
|
58 |
+
tokenizer,
|
59 |
+
class_data_root=None,
|
60 |
+
class_prompt=None,
|
61 |
+
size=512,
|
62 |
+
center_crop=False,
|
63 |
+
):
|
64 |
+
self.size = size
|
65 |
+
self.center_crop = center_crop
|
66 |
+
self.tokenizer = tokenizer
|
67 |
+
|
68 |
+
self.instance_data_root = Path(instance_data_root)
|
69 |
+
if not self.instance_data_root.exists():
|
70 |
+
raise ValueError("Instance images root doesn't exists.")
|
71 |
+
|
72 |
+
self.instance_images_path = list(Path(instance_data_root).iterdir())
|
73 |
+
self.num_instance_images = len(self.instance_images_path)
|
74 |
+
self.instance_prompt = instance_prompt
|
75 |
+
self._length = self.num_instance_images
|
76 |
+
|
77 |
+
if class_data_root is not None:
|
78 |
+
self.class_data_root = Path(class_data_root)
|
79 |
+
self.class_data_root.mkdir(parents=True, exist_ok=True)
|
80 |
+
self.class_images_path = list(self.class_data_root.iterdir())
|
81 |
+
self.num_class_images = len(self.class_images_path)
|
82 |
+
self._length = max(self.num_class_images, self.num_instance_images)
|
83 |
+
self.class_prompt = class_prompt
|
84 |
+
else:
|
85 |
+
self.class_data_root = None
|
86 |
+
|
87 |
+
self.image_transforms = transforms.Compose(
|
88 |
+
[
|
89 |
+
transforms.Resize(
|
90 |
+
size, interpolation=transforms.InterpolationMode.BILINEAR
|
91 |
+
),
|
92 |
+
transforms.CenterCrop(size)
|
93 |
+
if center_crop
|
94 |
+
else transforms.RandomCrop(size),
|
95 |
+
transforms.ToTensor(),
|
96 |
+
transforms.Normalize([0.5], [0.5]),
|
97 |
+
]
|
98 |
+
)
|
99 |
+
|
100 |
+
def __len__(self):
|
101 |
+
return self._length
|
102 |
+
|
103 |
+
def __getitem__(self, index):
|
104 |
+
example = {}
|
105 |
+
instance_image = Image.open(
|
106 |
+
self.instance_images_path[index % self.num_instance_images]
|
107 |
+
)
|
108 |
+
if not instance_image.mode == "RGB":
|
109 |
+
instance_image = instance_image.convert("RGB")
|
110 |
+
example["instance_images"] = self.image_transforms(instance_image)
|
111 |
+
example["instance_prompt_ids"] = self.tokenizer(
|
112 |
+
self.instance_prompt,
|
113 |
+
padding="do_not_pad",
|
114 |
+
truncation=True,
|
115 |
+
max_length=self.tokenizer.model_max_length,
|
116 |
+
).input_ids
|
117 |
+
|
118 |
+
if self.class_data_root:
|
119 |
+
class_image = Image.open(
|
120 |
+
self.class_images_path[index % self.num_class_images]
|
121 |
+
)
|
122 |
+
if not class_image.mode == "RGB":
|
123 |
+
class_image = class_image.convert("RGB")
|
124 |
+
example["class_images"] = self.image_transforms(class_image)
|
125 |
+
example["class_prompt_ids"] = self.tokenizer(
|
126 |
+
self.class_prompt,
|
127 |
+
padding="do_not_pad",
|
128 |
+
truncation=True,
|
129 |
+
max_length=self.tokenizer.model_max_length,
|
130 |
+
).input_ids
|
131 |
+
|
132 |
+
return example
|
133 |
+
|
134 |
+
|
135 |
+
class DreamBoothLabled(Dataset):
|
136 |
+
"""
|
137 |
+
A dataset to prepare the instance and class images with the prompts for fine-tuning the model.
|
138 |
+
It pre-processes the images and the tokenizes prompts.
|
139 |
+
"""
|
140 |
+
|
141 |
+
def __init__(
|
142 |
+
self,
|
143 |
+
instance_data_root,
|
144 |
+
instance_prompt,
|
145 |
+
tokenizer,
|
146 |
+
class_data_root=None,
|
147 |
+
class_prompt=None,
|
148 |
+
size=512,
|
149 |
+
center_crop=False,
|
150 |
+
):
|
151 |
+
self.size = size
|
152 |
+
self.center_crop = center_crop
|
153 |
+
self.tokenizer = tokenizer
|
154 |
+
|
155 |
+
self.instance_data_root = Path(instance_data_root)
|
156 |
+
if not self.instance_data_root.exists():
|
157 |
+
raise ValueError("Instance images root doesn't exists.")
|
158 |
+
|
159 |
+
self.instance_images_path = list(Path(instance_data_root).iterdir())
|
160 |
+
self.num_instance_images = len(self.instance_images_path)
|
161 |
+
self.instance_prompt = instance_prompt
|
162 |
+
self._length = self.num_instance_images
|
163 |
+
|
164 |
+
if class_data_root is not None:
|
165 |
+
self.class_data_root = Path(class_data_root)
|
166 |
+
self.class_data_root.mkdir(parents=True, exist_ok=True)
|
167 |
+
self.class_images_path = list(self.class_data_root.iterdir())
|
168 |
+
self.num_class_images = len(self.class_images_path)
|
169 |
+
self._length = max(self.num_class_images, self.num_instance_images)
|
170 |
+
self.class_prompt = class_prompt
|
171 |
+
else:
|
172 |
+
self.class_data_root = None
|
173 |
+
|
174 |
+
self.image_transforms = transforms.Compose(
|
175 |
+
[
|
176 |
+
transforms.Resize(
|
177 |
+
size, interpolation=transforms.InterpolationMode.BILINEAR
|
178 |
+
),
|
179 |
+
transforms.CenterCrop(size)
|
180 |
+
if center_crop
|
181 |
+
else transforms.RandomCrop(size),
|
182 |
+
transforms.ToTensor(),
|
183 |
+
transforms.Normalize([0.5], [0.5]),
|
184 |
+
]
|
185 |
+
)
|
186 |
+
|
187 |
+
def __len__(self):
|
188 |
+
return self._length
|
189 |
+
|
190 |
+
def __getitem__(self, index):
|
191 |
+
example = {}
|
192 |
+
instance_image = Image.open(
|
193 |
+
self.instance_images_path[index % self.num_instance_images]
|
194 |
+
)
|
195 |
+
|
196 |
+
instance_prompt = (
|
197 |
+
str(self.instance_images_path[index % self.num_instance_images])
|
198 |
+
.split("/")[-1]
|
199 |
+
.split(".")[0]
|
200 |
+
.replace("-", " ")
|
201 |
+
)
|
202 |
+
# remove numbers in prompt
|
203 |
+
instance_prompt = re.sub(r"\d+", "", instance_prompt)
|
204 |
+
# print(instance_prompt)
|
205 |
+
|
206 |
+
_svg = random.choice(["svg", "flat color", "vector illustration", "sks"])
|
207 |
+
instance_prompt = f"{instance_prompt}, style of {_svg}"
|
208 |
+
|
209 |
+
if not instance_image.mode == "RGB":
|
210 |
+
instance_image = instance_image.convert("RGB")
|
211 |
+
example["instance_images"] = self.image_transforms(instance_image)
|
212 |
+
example["instance_prompt_ids"] = self.tokenizer(
|
213 |
+
instance_prompt,
|
214 |
+
padding="do_not_pad",
|
215 |
+
truncation=True,
|
216 |
+
max_length=self.tokenizer.model_max_length,
|
217 |
+
).input_ids
|
218 |
+
|
219 |
+
if self.class_data_root:
|
220 |
+
class_image = Image.open(
|
221 |
+
self.class_images_path[index % self.num_class_images]
|
222 |
+
)
|
223 |
+
if not class_image.mode == "RGB":
|
224 |
+
class_image = class_image.convert("RGB")
|
225 |
+
example["class_images"] = self.image_transforms(class_image)
|
226 |
+
example["class_prompt_ids"] = self.tokenizer(
|
227 |
+
self.class_prompt,
|
228 |
+
padding="do_not_pad",
|
229 |
+
truncation=True,
|
230 |
+
max_length=self.tokenizer.model_max_length,
|
231 |
+
).input_ids
|
232 |
+
|
233 |
+
return example
|
234 |
+
|
235 |
+
|
236 |
+
class PromptDataset(Dataset):
|
237 |
+
"A simple dataset to prepare the prompts to generate class images on multiple GPUs."
|
238 |
+
|
239 |
+
def __init__(self, prompt, num_samples):
|
240 |
+
self.prompt = prompt
|
241 |
+
self.num_samples = num_samples
|
242 |
+
|
243 |
+
def __len__(self):
|
244 |
+
return self.num_samples
|
245 |
+
|
246 |
+
def __getitem__(self, index):
|
247 |
+
example = {}
|
248 |
+
example["prompt"] = self.prompt
|
249 |
+
example["index"] = index
|
250 |
+
return example
|
251 |
+
|
252 |
+
|
253 |
+
logger = get_logger(__name__)
|
254 |
+
|
255 |
+
|
256 |
+
def parse_args(input_args=None):
|
257 |
+
parser = argparse.ArgumentParser(description="Simple example of a training script.")
|
258 |
+
parser.add_argument(
|
259 |
+
"--pretrained_model_name_or_path",
|
260 |
+
type=str,
|
261 |
+
default=None,
|
262 |
+
required=True,
|
263 |
+
help="Path to pretrained model or model identifier from huggingface.co/models.",
|
264 |
+
)
|
265 |
+
parser.add_argument(
|
266 |
+
"--revision",
|
267 |
+
type=str,
|
268 |
+
default=None,
|
269 |
+
required=False,
|
270 |
+
help="Revision of pretrained model identifier from huggingface.co/models.",
|
271 |
+
)
|
272 |
+
parser.add_argument(
|
273 |
+
"--tokenizer_name",
|
274 |
+
type=str,
|
275 |
+
default=None,
|
276 |
+
help="Pretrained tokenizer name or path if not the same as model_name",
|
277 |
+
)
|
278 |
+
parser.add_argument(
|
279 |
+
"--instance_data_dir",
|
280 |
+
type=str,
|
281 |
+
default=None,
|
282 |
+
required=True,
|
283 |
+
help="A folder containing the training data of instance images.",
|
284 |
+
)
|
285 |
+
parser.add_argument(
|
286 |
+
"--class_data_dir",
|
287 |
+
type=str,
|
288 |
+
default=None,
|
289 |
+
required=False,
|
290 |
+
help="A folder containing the training data of class images.",
|
291 |
+
)
|
292 |
+
parser.add_argument(
|
293 |
+
"--instance_prompt",
|
294 |
+
type=str,
|
295 |
+
default=None,
|
296 |
+
required=True,
|
297 |
+
help="The prompt with identifier specifying the instance",
|
298 |
+
)
|
299 |
+
parser.add_argument(
|
300 |
+
"--class_prompt",
|
301 |
+
type=str,
|
302 |
+
default=None,
|
303 |
+
help="The prompt to specify images in the same class as provided instance images.",
|
304 |
+
)
|
305 |
+
parser.add_argument(
|
306 |
+
"--with_prior_preservation",
|
307 |
+
default=False,
|
308 |
+
action="store_true",
|
309 |
+
help="Flag to add prior preservation loss.",
|
310 |
+
)
|
311 |
+
parser.add_argument(
|
312 |
+
"--prior_loss_weight",
|
313 |
+
type=float,
|
314 |
+
default=1.0,
|
315 |
+
help="The weight of prior preservation loss.",
|
316 |
+
)
|
317 |
+
parser.add_argument(
|
318 |
+
"--num_class_images",
|
319 |
+
type=int,
|
320 |
+
default=100,
|
321 |
+
help=(
|
322 |
+
"Minimal class images for prior preservation loss. If not have enough images, additional images will be"
|
323 |
+
" sampled with class_prompt."
|
324 |
+
),
|
325 |
+
)
|
326 |
+
parser.add_argument(
|
327 |
+
"--output_dir",
|
328 |
+
type=str,
|
329 |
+
default="text-inversion-model",
|
330 |
+
help="The output directory where the model predictions and checkpoints will be written.",
|
331 |
+
)
|
332 |
+
parser.add_argument(
|
333 |
+
"--seed", type=int, default=None, help="A seed for reproducible training."
|
334 |
+
)
|
335 |
+
parser.add_argument(
|
336 |
+
"--resolution",
|
337 |
+
type=int,
|
338 |
+
default=512,
|
339 |
+
help=(
|
340 |
+
"The resolution for input images, all the images in the train/validation dataset will be resized to this"
|
341 |
+
" resolution"
|
342 |
+
),
|
343 |
+
)
|
344 |
+
parser.add_argument(
|
345 |
+
"--center_crop",
|
346 |
+
action="store_true",
|
347 |
+
help="Whether to center crop images before resizing to resolution",
|
348 |
+
)
|
349 |
+
parser.add_argument(
|
350 |
+
"--train_text_encoder",
|
351 |
+
action="store_true",
|
352 |
+
help="Whether to train the text encoder",
|
353 |
+
)
|
354 |
+
parser.add_argument(
|
355 |
+
"--train_batch_size",
|
356 |
+
type=int,
|
357 |
+
default=4,
|
358 |
+
help="Batch size (per device) for the training dataloader.",
|
359 |
+
)
|
360 |
+
parser.add_argument(
|
361 |
+
"--sample_batch_size",
|
362 |
+
type=int,
|
363 |
+
default=4,
|
364 |
+
help="Batch size (per device) for sampling images.",
|
365 |
+
)
|
366 |
+
parser.add_argument("--num_train_epochs", type=int, default=1)
|
367 |
+
parser.add_argument(
|
368 |
+
"--max_train_steps",
|
369 |
+
type=int,
|
370 |
+
default=None,
|
371 |
+
help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
|
372 |
+
)
|
373 |
+
parser.add_argument(
|
374 |
+
"--save_steps",
|
375 |
+
type=int,
|
376 |
+
default=500,
|
377 |
+
help="Save checkpoint every X updates steps.",
|
378 |
+
)
|
379 |
+
parser.add_argument(
|
380 |
+
"--gradient_accumulation_steps",
|
381 |
+
type=int,
|
382 |
+
default=1,
|
383 |
+
help="Number of updates steps to accumulate before performing a backward/update pass.",
|
384 |
+
)
|
385 |
+
parser.add_argument(
|
386 |
+
"--gradient_checkpointing",
|
387 |
+
action="store_true",
|
388 |
+
help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
|
389 |
+
)
|
390 |
+
parser.add_argument(
|
391 |
+
"--learning_rate",
|
392 |
+
type=float,
|
393 |
+
default=5e-6,
|
394 |
+
help="Initial learning rate (after the potential warmup period) to use.",
|
395 |
+
)
|
396 |
+
parser.add_argument(
|
397 |
+
"--scale_lr",
|
398 |
+
action="store_true",
|
399 |
+
default=False,
|
400 |
+
help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
|
401 |
+
)
|
402 |
+
parser.add_argument(
|
403 |
+
"--lr_scheduler",
|
404 |
+
type=str,
|
405 |
+
default="constant",
|
406 |
+
help=(
|
407 |
+
'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
|
408 |
+
' "constant", "constant_with_warmup"]'
|
409 |
+
),
|
410 |
+
)
|
411 |
+
parser.add_argument(
|
412 |
+
"--lr_warmup_steps",
|
413 |
+
type=int,
|
414 |
+
default=500,
|
415 |
+
help="Number of steps for the warmup in the lr scheduler.",
|
416 |
+
)
|
417 |
+
parser.add_argument(
|
418 |
+
"--use_8bit_adam",
|
419 |
+
action="store_true",
|
420 |
+
help="Whether or not to use 8-bit Adam from bitsandbytes.",
|
421 |
+
)
|
422 |
+
parser.add_argument(
|
423 |
+
"--adam_beta1",
|
424 |
+
type=float,
|
425 |
+
default=0.9,
|
426 |
+
help="The beta1 parameter for the Adam optimizer.",
|
427 |
+
)
|
428 |
+
parser.add_argument(
|
429 |
+
"--adam_beta2",
|
430 |
+
type=float,
|
431 |
+
default=0.999,
|
432 |
+
help="The beta2 parameter for the Adam optimizer.",
|
433 |
+
)
|
434 |
+
parser.add_argument(
|
435 |
+
"--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use."
|
436 |
+
)
|
437 |
+
parser.add_argument(
|
438 |
+
"--adam_epsilon",
|
439 |
+
type=float,
|
440 |
+
default=1e-08,
|
441 |
+
help="Epsilon value for the Adam optimizer",
|
442 |
+
)
|
443 |
+
parser.add_argument(
|
444 |
+
"--max_grad_norm", default=1.0, type=float, help="Max gradient norm."
|
445 |
+
)
|
446 |
+
parser.add_argument(
|
447 |
+
"--push_to_hub",
|
448 |
+
action="store_true",
|
449 |
+
help="Whether or not to push the model to the Hub.",
|
450 |
+
)
|
451 |
+
parser.add_argument(
|
452 |
+
"--hub_token",
|
453 |
+
type=str,
|
454 |
+
default=None,
|
455 |
+
help="The token to use to push to the Model Hub.",
|
456 |
+
)
|
457 |
+
parser.add_argument(
|
458 |
+
"--hub_model_id",
|
459 |
+
type=str,
|
460 |
+
default=None,
|
461 |
+
help="The name of the repository to keep in sync with the local `output_dir`.",
|
462 |
+
)
|
463 |
+
parser.add_argument(
|
464 |
+
"--logging_dir",
|
465 |
+
type=str,
|
466 |
+
default="logs",
|
467 |
+
help=(
|
468 |
+
"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
|
469 |
+
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
|
470 |
+
),
|
471 |
+
)
|
472 |
+
parser.add_argument(
|
473 |
+
"--mixed_precision",
|
474 |
+
type=str,
|
475 |
+
default=None,
|
476 |
+
choices=["no", "fp16", "bf16"],
|
477 |
+
help=(
|
478 |
+
"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
|
479 |
+
" 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
|
480 |
+
" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
|
481 |
+
),
|
482 |
+
)
|
483 |
+
parser.add_argument(
|
484 |
+
"--local_rank",
|
485 |
+
type=int,
|
486 |
+
default=-1,
|
487 |
+
help="For distributed training: local_rank",
|
488 |
+
)
|
489 |
+
|
490 |
+
if input_args is not None:
|
491 |
+
args = parser.parse_args(input_args)
|
492 |
+
else:
|
493 |
+
args = parser.parse_args()
|
494 |
+
|
495 |
+
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
|
496 |
+
if env_local_rank != -1 and env_local_rank != args.local_rank:
|
497 |
+
args.local_rank = env_local_rank
|
498 |
+
|
499 |
+
if args.with_prior_preservation:
|
500 |
+
if args.class_data_dir is None:
|
501 |
+
raise ValueError("You must specify a data directory for class images.")
|
502 |
+
if args.class_prompt is None:
|
503 |
+
raise ValueError("You must specify prompt for class images.")
|
504 |
+
else:
|
505 |
+
if args.class_data_dir is not None:
|
506 |
+
logger.warning(
|
507 |
+
"You need not use --class_data_dir without --with_prior_preservation."
|
508 |
+
)
|
509 |
+
if args.class_prompt is not None:
|
510 |
+
logger.warning(
|
511 |
+
"You need not use --class_prompt without --with_prior_preservation."
|
512 |
+
)
|
513 |
+
|
514 |
+
return args
|
515 |
+
|
516 |
+
|
517 |
+
def get_full_repo_name(
|
518 |
+
model_id: str, organization: Optional[str] = None, token: Optional[str] = None
|
519 |
+
):
|
520 |
+
if token is None:
|
521 |
+
token = HfFolder.get_token()
|
522 |
+
if organization is None:
|
523 |
+
username = whoami(token)["name"]
|
524 |
+
return f"{username}/{model_id}"
|
525 |
+
else:
|
526 |
+
return f"{organization}/{model_id}"
|
527 |
+
|
528 |
+
|
529 |
+
def main(args):
|
530 |
+
logging_dir = Path(args.output_dir, args.logging_dir)
|
531 |
+
|
532 |
+
accelerator = Accelerator(
|
533 |
+
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
534 |
+
mixed_precision=args.mixed_precision,
|
535 |
+
log_with="tensorboard",
|
536 |
+
logging_dir=logging_dir,
|
537 |
+
)
|
538 |
+
|
539 |
+
# Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate
|
540 |
+
# This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models.
|
541 |
+
# TODO (patil-suraj): Remove this check when gradient accumulation with two models is enabled in accelerate.
|
542 |
+
if (
|
543 |
+
args.train_text_encoder
|
544 |
+
and args.gradient_accumulation_steps > 1
|
545 |
+
and accelerator.num_processes > 1
|
546 |
+
):
|
547 |
+
raise ValueError(
|
548 |
+
"Gradient accumulation is not supported when training the text encoder in distributed training. "
|
549 |
+
"Please set gradient_accumulation_steps to 1. This feature will be supported in the future."
|
550 |
+
)
|
551 |
+
|
552 |
+
if args.seed is not None:
|
553 |
+
set_seed(args.seed)
|
554 |
+
|
555 |
+
if args.with_prior_preservation:
|
556 |
+
class_images_dir = Path(args.class_data_dir)
|
557 |
+
if not class_images_dir.exists():
|
558 |
+
class_images_dir.mkdir(parents=True)
|
559 |
+
cur_class_images = len(list(class_images_dir.iterdir()))
|
560 |
+
|
561 |
+
if cur_class_images < args.num_class_images:
|
562 |
+
torch_dtype = (
|
563 |
+
torch.float16 if accelerator.device.type == "cuda" else torch.float32
|
564 |
+
)
|
565 |
+
pipeline = StableDiffusionPipeline.from_pretrained(
|
566 |
+
args.pretrained_model_name_or_path,
|
567 |
+
torch_dtype=torch_dtype,
|
568 |
+
safety_checker=None,
|
569 |
+
revision=args.revision,
|
570 |
+
)
|
571 |
+
pipeline.set_progress_bar_config(disable=True)
|
572 |
+
|
573 |
+
num_new_images = args.num_class_images - cur_class_images
|
574 |
+
logger.info(f"Number of class images to sample: {num_new_images}.")
|
575 |
+
|
576 |
+
sample_dataset = PromptDataset(args.class_prompt, num_new_images)
|
577 |
+
sample_dataloader = torch.utils.data.DataLoader(
|
578 |
+
sample_dataset, batch_size=args.sample_batch_size
|
579 |
+
)
|
580 |
+
|
581 |
+
sample_dataloader = accelerator.prepare(sample_dataloader)
|
582 |
+
pipeline.to(accelerator.device)
|
583 |
+
|
584 |
+
for example in tqdm(
|
585 |
+
sample_dataloader,
|
586 |
+
desc="Generating class images",
|
587 |
+
disable=not accelerator.is_local_main_process,
|
588 |
+
):
|
589 |
+
images = pipeline(example["prompt"]).images
|
590 |
+
|
591 |
+
for i, image in enumerate(images):
|
592 |
+
hash_image = hashlib.sha1(image.tobytes()).hexdigest()
|
593 |
+
image_filename = (
|
594 |
+
class_images_dir
|
595 |
+
/ f"{example['index'][i] + cur_class_images}-{hash_image}.jpg"
|
596 |
+
)
|
597 |
+
image.save(image_filename)
|
598 |
+
|
599 |
+
del pipeline
|
600 |
+
if torch.cuda.is_available():
|
601 |
+
torch.cuda.empty_cache()
|
602 |
+
|
603 |
+
# Handle the repository creation
|
604 |
+
if accelerator.is_main_process:
|
605 |
+
if args.push_to_hub:
|
606 |
+
if args.hub_model_id is None:
|
607 |
+
repo_name = get_full_repo_name(
|
608 |
+
Path(args.output_dir).name, token=args.hub_token
|
609 |
+
)
|
610 |
+
else:
|
611 |
+
repo_name = args.hub_model_id
|
612 |
+
repo = Repository(args.output_dir, clone_from=repo_name)
|
613 |
+
|
614 |
+
with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
|
615 |
+
if "step_*" not in gitignore:
|
616 |
+
gitignore.write("step_*\n")
|
617 |
+
if "epoch_*" not in gitignore:
|
618 |
+
gitignore.write("epoch_*\n")
|
619 |
+
elif args.output_dir is not None:
|
620 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
621 |
+
|
622 |
+
# Load the tokenizer
|
623 |
+
if args.tokenizer_name:
|
624 |
+
tokenizer = CLIPTokenizer.from_pretrained(
|
625 |
+
args.tokenizer_name,
|
626 |
+
revision=args.revision,
|
627 |
+
)
|
628 |
+
elif args.pretrained_model_name_or_path:
|
629 |
+
tokenizer = CLIPTokenizer.from_pretrained(
|
630 |
+
args.pretrained_model_name_or_path,
|
631 |
+
subfolder="tokenizer",
|
632 |
+
revision=args.revision,
|
633 |
+
)
|
634 |
+
|
635 |
+
# Load models and create wrapper for stable diffusion
|
636 |
+
text_encoder = CLIPTextModel.from_pretrained(
|
637 |
+
args.pretrained_model_name_or_path,
|
638 |
+
subfolder="text_encoder",
|
639 |
+
revision=args.revision,
|
640 |
+
)
|
641 |
+
vae = AutoencoderKL.from_pretrained(
|
642 |
+
args.pretrained_model_name_or_path,
|
643 |
+
subfolder="vae",
|
644 |
+
revision=args.revision,
|
645 |
+
)
|
646 |
+
unet = UNet2DConditionModel.from_pretrained(
|
647 |
+
args.pretrained_model_name_or_path,
|
648 |
+
subfolder="unet",
|
649 |
+
revision=args.revision,
|
650 |
+
)
|
651 |
+
unet.requires_grad_(False)
|
652 |
+
unet_lora_params, train_names = inject_trainable_lora(unet)
|
653 |
+
|
654 |
+
for _up, _down in extract_lora_ups_down(unet):
|
655 |
+
print(_up.weight)
|
656 |
+
print(_down.weight)
|
657 |
+
break
|
658 |
+
|
659 |
+
vae.requires_grad_(False)
|
660 |
+
if not args.train_text_encoder:
|
661 |
+
text_encoder.requires_grad_(False)
|
662 |
+
|
663 |
+
if args.gradient_checkpointing:
|
664 |
+
unet.enable_gradient_checkpointing()
|
665 |
+
if args.train_text_encoder:
|
666 |
+
text_encoder.gradient_checkpointing_enable()
|
667 |
+
|
668 |
+
if args.scale_lr:
|
669 |
+
args.learning_rate = (
|
670 |
+
args.learning_rate
|
671 |
+
* args.gradient_accumulation_steps
|
672 |
+
* args.train_batch_size
|
673 |
+
* accelerator.num_processes
|
674 |
+
)
|
675 |
+
|
676 |
+
# Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
|
677 |
+
if args.use_8bit_adam:
|
678 |
+
try:
|
679 |
+
import bitsandbytes as bnb
|
680 |
+
except ImportError:
|
681 |
+
raise ImportError(
|
682 |
+
"To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
|
683 |
+
)
|
684 |
+
|
685 |
+
optimizer_class = bnb.optim.AdamW8bit
|
686 |
+
else:
|
687 |
+
optimizer_class = torch.optim.AdamW
|
688 |
+
|
689 |
+
params_to_optimize = (
|
690 |
+
itertools.chain(*unet_lora_params, text_encoder.parameters())
|
691 |
+
if args.train_text_encoder
|
692 |
+
else itertools.chain(*unet_lora_params)
|
693 |
+
)
|
694 |
+
optimizer = optimizer_class(
|
695 |
+
params_to_optimize,
|
696 |
+
lr=args.learning_rate,
|
697 |
+
betas=(args.adam_beta1, args.adam_beta2),
|
698 |
+
weight_decay=args.adam_weight_decay,
|
699 |
+
eps=args.adam_epsilon,
|
700 |
+
)
|
701 |
+
|
702 |
+
noise_scheduler = DDPMScheduler.from_config(
|
703 |
+
args.pretrained_model_name_or_path, subfolder="scheduler"
|
704 |
+
)
|
705 |
+
|
706 |
+
train_dataset = DreamBoothDataset(
|
707 |
+
instance_data_root=args.instance_data_dir,
|
708 |
+
instance_prompt=args.instance_prompt,
|
709 |
+
class_data_root=args.class_data_dir if args.with_prior_preservation else None,
|
710 |
+
class_prompt=args.class_prompt,
|
711 |
+
tokenizer=tokenizer,
|
712 |
+
size=args.resolution,
|
713 |
+
center_crop=args.center_crop,
|
714 |
+
)
|
715 |
+
|
716 |
+
def collate_fn(examples):
|
717 |
+
input_ids = [example["instance_prompt_ids"] for example in examples]
|
718 |
+
pixel_values = [example["instance_images"] for example in examples]
|
719 |
+
|
720 |
+
# Concat class and instance examples for prior preservation.
|
721 |
+
# We do this to avoid doing two forward passes.
|
722 |
+
if args.with_prior_preservation:
|
723 |
+
input_ids += [example["class_prompt_ids"] for example in examples]
|
724 |
+
pixel_values += [example["class_images"] for example in examples]
|
725 |
+
|
726 |
+
pixel_values = torch.stack(pixel_values)
|
727 |
+
pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
|
728 |
+
|
729 |
+
input_ids = tokenizer.pad(
|
730 |
+
{"input_ids": input_ids},
|
731 |
+
padding="max_length",
|
732 |
+
max_length=tokenizer.model_max_length,
|
733 |
+
return_tensors="pt",
|
734 |
+
).input_ids
|
735 |
+
|
736 |
+
batch = {
|
737 |
+
"input_ids": input_ids,
|
738 |
+
"pixel_values": pixel_values,
|
739 |
+
}
|
740 |
+
return batch
|
741 |
+
|
742 |
+
train_dataloader = torch.utils.data.DataLoader(
|
743 |
+
train_dataset,
|
744 |
+
batch_size=args.train_batch_size,
|
745 |
+
shuffle=True,
|
746 |
+
collate_fn=collate_fn,
|
747 |
+
num_workers=1,
|
748 |
+
)
|
749 |
+
|
750 |
+
# Scheduler and math around the number of training steps.
|
751 |
+
overrode_max_train_steps = False
|
752 |
+
num_update_steps_per_epoch = math.ceil(
|
753 |
+
len(train_dataloader) / args.gradient_accumulation_steps
|
754 |
+
)
|
755 |
+
if args.max_train_steps is None:
|
756 |
+
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
757 |
+
overrode_max_train_steps = True
|
758 |
+
|
759 |
+
lr_scheduler = get_scheduler(
|
760 |
+
args.lr_scheduler,
|
761 |
+
optimizer=optimizer,
|
762 |
+
num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
|
763 |
+
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
|
764 |
+
)
|
765 |
+
|
766 |
+
if args.train_text_encoder:
|
767 |
+
(
|
768 |
+
unet,
|
769 |
+
text_encoder,
|
770 |
+
optimizer,
|
771 |
+
train_dataloader,
|
772 |
+
lr_scheduler,
|
773 |
+
) = accelerator.prepare(
|
774 |
+
unet, text_encoder, optimizer, train_dataloader, lr_scheduler
|
775 |
+
)
|
776 |
+
else:
|
777 |
+
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
778 |
+
unet, optimizer, train_dataloader, lr_scheduler
|
779 |
+
)
|
780 |
+
|
781 |
+
weight_dtype = torch.float32
|
782 |
+
if accelerator.mixed_precision == "fp16":
|
783 |
+
weight_dtype = torch.float16
|
784 |
+
elif accelerator.mixed_precision == "bf16":
|
785 |
+
weight_dtype = torch.bfloat16
|
786 |
+
|
787 |
+
# Move text_encode and vae to gpu.
|
788 |
+
# For mixed precision training we cast the text_encoder and vae weights to half-precision
|
789 |
+
# as these models are only used for inference, keeping weights in full precision is not required.
|
790 |
+
vae.to(accelerator.device, dtype=weight_dtype)
|
791 |
+
if not args.train_text_encoder:
|
792 |
+
text_encoder.to(accelerator.device, dtype=weight_dtype)
|
793 |
+
|
794 |
+
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
|
795 |
+
num_update_steps_per_epoch = math.ceil(
|
796 |
+
len(train_dataloader) / args.gradient_accumulation_steps
|
797 |
+
)
|
798 |
+
if overrode_max_train_steps:
|
799 |
+
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
800 |
+
# Afterwards we recalculate our number of training epochs
|
801 |
+
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
802 |
+
|
803 |
+
# We need to initialize the trackers we use, and also store our configuration.
|
804 |
+
# The trackers initializes automatically on the main process.
|
805 |
+
if accelerator.is_main_process:
|
806 |
+
accelerator.init_trackers("dreambooth", config=vars(args))
|
807 |
+
|
808 |
+
# Train!
|
809 |
+
total_batch_size = (
|
810 |
+
args.train_batch_size
|
811 |
+
* accelerator.num_processes
|
812 |
+
* args.gradient_accumulation_steps
|
813 |
+
)
|
814 |
+
|
815 |
+
print("***** Running training *****")
|
816 |
+
print(f" Num examples = {len(train_dataset)}")
|
817 |
+
print(f" Num batches each epoch = {len(train_dataloader)}")
|
818 |
+
print(f" Num Epochs = {args.num_train_epochs}")
|
819 |
+
print(f" Instantaneous batch size per device = {args.train_batch_size}")
|
820 |
+
print(
|
821 |
+
f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}"
|
822 |
+
)
|
823 |
+
print(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
|
824 |
+
print(f" Total optimization steps = {args.max_train_steps}")
|
825 |
+
# Only show the progress bar once on each machine.
|
826 |
+
progress_bar = tqdm(
|
827 |
+
range(args.max_train_steps), disable=not accelerator.is_local_main_process
|
828 |
+
)
|
829 |
+
progress_bar.set_description("Steps")
|
830 |
+
global_step = 0
|
831 |
+
|
832 |
+
for epoch in range(args.num_train_epochs):
|
833 |
+
unet.train()
|
834 |
+
if args.train_text_encoder:
|
835 |
+
text_encoder.train()
|
836 |
+
for step, batch in enumerate(train_dataloader):
|
837 |
+
|
838 |
+
# Convert images to latent space
|
839 |
+
latents = vae.encode(
|
840 |
+
batch["pixel_values"].to(dtype=weight_dtype)
|
841 |
+
).latent_dist.sample()
|
842 |
+
latents = latents * 0.18215
|
843 |
+
|
844 |
+
# Sample noise that we'll add to the latents
|
845 |
+
noise = torch.randn_like(latents)
|
846 |
+
bsz = latents.shape[0]
|
847 |
+
# Sample a random timestep for each image
|
848 |
+
timesteps = torch.randint(
|
849 |
+
0,
|
850 |
+
noise_scheduler.config.num_train_timesteps,
|
851 |
+
(bsz,),
|
852 |
+
device=latents.device,
|
853 |
+
)
|
854 |
+
timesteps = timesteps.long()
|
855 |
+
|
856 |
+
# Add noise to the latents according to the noise magnitude at each timestep
|
857 |
+
# (this is the forward diffusion process)
|
858 |
+
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
859 |
+
|
860 |
+
# Get the text embedding for conditioning
|
861 |
+
encoder_hidden_states = text_encoder(batch["input_ids"])[0]
|
862 |
+
|
863 |
+
# Predict the noise residual
|
864 |
+
model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
|
865 |
+
|
866 |
+
# Get the target for loss depending on the prediction type
|
867 |
+
if noise_scheduler.config.prediction_type == "epsilon":
|
868 |
+
target = noise
|
869 |
+
elif noise_scheduler.config.prediction_type == "v_prediction":
|
870 |
+
target = noise_scheduler.get_velocity(latents, noise, timesteps)
|
871 |
+
else:
|
872 |
+
raise ValueError(
|
873 |
+
f"Unknown prediction type {noise_scheduler.config.prediction_type}"
|
874 |
+
)
|
875 |
+
|
876 |
+
if args.with_prior_preservation:
|
877 |
+
# Chunk the noise and model_pred into two parts and compute the loss on each part separately.
|
878 |
+
model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
|
879 |
+
target, target_prior = torch.chunk(target, 2, dim=0)
|
880 |
+
|
881 |
+
# Compute instance loss
|
882 |
+
loss = (
|
883 |
+
F.mse_loss(model_pred.float(), target.float(), reduction="none")
|
884 |
+
.mean([1, 2, 3])
|
885 |
+
.mean()
|
886 |
+
)
|
887 |
+
|
888 |
+
# Compute prior loss
|
889 |
+
prior_loss = F.mse_loss(
|
890 |
+
model_pred_prior.float(), target_prior.float(), reduction="mean"
|
891 |
+
)
|
892 |
+
|
893 |
+
# Add the prior loss to the instance loss.
|
894 |
+
loss = loss + args.prior_loss_weight * prior_loss
|
895 |
+
else:
|
896 |
+
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
|
897 |
+
|
898 |
+
accelerator.backward(loss)
|
899 |
+
if accelerator.sync_gradients:
|
900 |
+
params_to_clip = (
|
901 |
+
itertools.chain(unet.parameters(), text_encoder.parameters())
|
902 |
+
if args.train_text_encoder
|
903 |
+
else unet.parameters()
|
904 |
+
)
|
905 |
+
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
|
906 |
+
optimizer.step()
|
907 |
+
lr_scheduler.step()
|
908 |
+
progress_bar.update(1)
|
909 |
+
optimizer.zero_grad()
|
910 |
+
|
911 |
+
# Checks if the accelerator has performed an optimization step behind the scenes
|
912 |
+
if accelerator.sync_gradients:
|
913 |
+
|
914 |
+
global_step += 1
|
915 |
+
|
916 |
+
if global_step % args.save_steps == 0:
|
917 |
+
if accelerator.is_main_process:
|
918 |
+
pipeline = StableDiffusionPipeline.from_pretrained(
|
919 |
+
args.pretrained_model_name_or_path,
|
920 |
+
unet=accelerator.unwrap_model(unet),
|
921 |
+
text_encoder=accelerator.unwrap_model(text_encoder),
|
922 |
+
revision=args.revision,
|
923 |
+
)
|
924 |
+
|
925 |
+
save_lora_weight(pipeline.unet, args.output_dir + "/lora_weight.pt")
|
926 |
+
|
927 |
+
logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
|
928 |
+
progress_bar.set_postfix(**logs)
|
929 |
+
accelerator.log(logs, step=global_step)
|
930 |
+
|
931 |
+
if global_step >= args.max_train_steps:
|
932 |
+
break
|
933 |
+
|
934 |
+
accelerator.wait_for_everyone()
|
935 |
+
|
936 |
+
# Create the pipeline using using the trained modules and save it.
|
937 |
+
if accelerator.is_main_process:
|
938 |
+
pipeline = StableDiffusionPipeline.from_pretrained(
|
939 |
+
args.pretrained_model_name_or_path,
|
940 |
+
unet=accelerator.unwrap_model(unet),
|
941 |
+
text_encoder=accelerator.unwrap_model(text_encoder),
|
942 |
+
revision=args.revision,
|
943 |
+
)
|
944 |
+
|
945 |
+
print("\n\nLora TRAINING DONE!\n\n")
|
946 |
+
|
947 |
+
save_lora_weight(pipeline.unet, args.output_dir + "/lora_weight.pt")
|
948 |
+
|
949 |
+
for _up, _down in extract_lora_ups_down(pipeline.unet):
|
950 |
+
print("First Layer's Up Weight is now : ", _up.weight)
|
951 |
+
print("First Layer's Down Weight is now : ", _down.weight)
|
952 |
+
break
|
953 |
+
|
954 |
+
if args.push_to_hub:
|
955 |
+
repo.push_to_hub(
|
956 |
+
commit_message="End of training", blocking=False, auto_lfs_prune=True
|
957 |
+
)
|
958 |
+
|
959 |
+
accelerator.end_training()
|
960 |
+
|
961 |
+
|
962 |
+
if __name__ == "__main__":
|
963 |
+
args = parse_args()
|
964 |
+
main(args)
|