Spaces:
Paused
Paused
Commit
·
c41b22c
1
Parent(s):
dacfebb
init
Browse files- .gitignore +2 -0
- Dockerfile +46 -0
- README.md +1 -2
- app.py +461 -0
- get_flash_attn.py +59 -0
- modules/__init__.py +0 -0
- modules/__pycache__/__init__.cpython-310.pyc +0 -0
- modules/__pycache__/attention.cpython-310.pyc +0 -0
- modules/__pycache__/autoencoder.cpython-310.pyc +0 -0
- modules/__pycache__/conditioner.cpython-310.pyc +0 -0
- modules/__pycache__/connector_edit.cpython-310.pyc +0 -0
- modules/__pycache__/layers.cpython-310.pyc +0 -0
- modules/__pycache__/model_edit.cpython-310.pyc +0 -0
- modules/attention.py +133 -0
- modules/autoencoder.py +326 -0
- modules/conditioner.py +216 -0
- modules/connector_edit.py +486 -0
- modules/layers.py +639 -0
- modules/model_edit.py +143 -0
- requirements.txt +8 -0
- sampling.py +47 -0
.gitignore
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
__pycache/
|
2 |
+
*/__pycache/
|
Dockerfile
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM nvidia/cuda:12.1.1-cudnn8-devel-ubuntu22.04
|
2 |
+
|
3 |
+
ARG DEBIAN_FRONTEND=noninteractive
|
4 |
+
|
5 |
+
ENV PYTHONUNBUFFERED=1
|
6 |
+
|
7 |
+
RUN apt-get update && apt-get install --no-install-recommends -y \
|
8 |
+
build-essential \
|
9 |
+
python3.10 \
|
10 |
+
python3-pip \
|
11 |
+
git \
|
12 |
+
ffmpeg \
|
13 |
+
&& apt-get clean && rm -rf /var/lib/apt/lists/*
|
14 |
+
|
15 |
+
WORKDIR /code
|
16 |
+
|
17 |
+
COPY ./requirements.txt /code/requirements.txt
|
18 |
+
COPY ./get_flash_attn.py /code/get_flash_attn.py
|
19 |
+
|
20 |
+
# Set up a new user named "user" with user ID 1000
|
21 |
+
RUN useradd -m -u 1000 user
|
22 |
+
# Switch to the "user" user
|
23 |
+
USER user
|
24 |
+
# Set home to the user's home directory
|
25 |
+
ENV HOME=/home/user \
|
26 |
+
PATH=/home/user/.local/bin:$PATH \
|
27 |
+
PYTHONPATH=$HOME/app \
|
28 |
+
PYTHONUNBUFFERED=1 \
|
29 |
+
GRADIO_ALLOW_FLAGGING=never \
|
30 |
+
GRADIO_NUM_PORTS=1 \
|
31 |
+
GRADIO_SERVER_NAME=0.0.0.0 \
|
32 |
+
GRADIO_THEME=huggingface \
|
33 |
+
SYSTEM=spaces
|
34 |
+
|
35 |
+
RUN pip3 install torch==2.3.1 torchvision==0.18.1 torchaudio==2.3.1 --index-url https://download.pytorch.org/whl/cu121
|
36 |
+
RUN pip3 install --no-cache-dir --upgrade -r /code/requirements.txt
|
37 |
+
ARG DYNAMIC_PARAMS=$(python3 /code/get_flash_attn.py)
|
38 |
+
RUN pip3 install https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.2.post1/${DYNAMIC_PARAMS}
|
39 |
+
|
40 |
+
# Set the working directory to the user's home directory
|
41 |
+
WORKDIR $HOME/app
|
42 |
+
|
43 |
+
# Copy the current directory contents into the container at $HOME/app setting the owner to the user
|
44 |
+
COPY --chown=user . $HOME/app
|
45 |
+
|
46 |
+
CMD ["python3", "app.py"]
|
README.md
CHANGED
@@ -3,8 +3,7 @@ title: Test
|
|
3 |
emoji: 🚀
|
4 |
colorFrom: indigo
|
5 |
colorTo: pink
|
6 |
-
sdk:
|
7 |
-
sdk_version: 5.26.0
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
license: mit
|
|
|
3 |
emoji: 🚀
|
4 |
colorFrom: indigo
|
5 |
colorTo: pink
|
6 |
+
sdk: docker
|
|
|
7 |
app_file: app.py
|
8 |
pinned: false
|
9 |
license: mit
|
app.py
ADDED
@@ -0,0 +1,461 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import datetime
|
3 |
+
import json
|
4 |
+
import itertools
|
5 |
+
import math
|
6 |
+
import os
|
7 |
+
import time
|
8 |
+
from pathlib import Path
|
9 |
+
|
10 |
+
|
11 |
+
import gradio as gr
|
12 |
+
import numpy as np
|
13 |
+
import torch
|
14 |
+
from einops import rearrange, repeat
|
15 |
+
from huggingface_hub import snapshot_download
|
16 |
+
from PIL import Image, ImageOps
|
17 |
+
from safetensors.torch import load_file
|
18 |
+
from torchvision.transforms import functional as F
|
19 |
+
from tqdm import tqdm
|
20 |
+
|
21 |
+
import sampling
|
22 |
+
from modules.autoencoder import AutoEncoder
|
23 |
+
from modules.conditioner import Qwen25VL_7b_Embedder as Qwen2VLEmbedder
|
24 |
+
from modules.model_edit import Step1XParams, Step1XEdit
|
25 |
+
|
26 |
+
print("TORCH_CUDA", torch.cuda.is_available())
|
27 |
+
|
28 |
+
def load_state_dict(model, ckpt_path, device="cuda", strict=False, assign=True):
|
29 |
+
if Path(ckpt_path).suffix == ".safetensors":
|
30 |
+
state_dict = load_file(ckpt_path, device)
|
31 |
+
else:
|
32 |
+
state_dict = torch.load(ckpt_path, map_location="cpu")
|
33 |
+
|
34 |
+
missing, unexpected = model.load_state_dict(
|
35 |
+
state_dict, strict=strict, assign=assign
|
36 |
+
)
|
37 |
+
if len(missing) > 0 and len(unexpected) > 0:
|
38 |
+
print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing))
|
39 |
+
print("\n" + "-" * 79 + "\n")
|
40 |
+
print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected))
|
41 |
+
elif len(missing) > 0:
|
42 |
+
print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing))
|
43 |
+
elif len(unexpected) > 0:
|
44 |
+
print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected))
|
45 |
+
return model
|
46 |
+
|
47 |
+
|
48 |
+
def load_models(
|
49 |
+
dit_path=None,
|
50 |
+
ae_path=None,
|
51 |
+
qwen2vl_model_path=None,
|
52 |
+
device="cuda",
|
53 |
+
max_length=256,
|
54 |
+
dtype=torch.bfloat16,
|
55 |
+
):
|
56 |
+
qwen2vl_encoder = Qwen2VLEmbedder(
|
57 |
+
qwen2vl_model_path,
|
58 |
+
device=device,
|
59 |
+
max_length=max_length,
|
60 |
+
dtype=dtype,
|
61 |
+
)
|
62 |
+
|
63 |
+
with torch.device("meta"):
|
64 |
+
ae = AutoEncoder(
|
65 |
+
resolution=256,
|
66 |
+
in_channels=3,
|
67 |
+
ch=128,
|
68 |
+
out_ch=3,
|
69 |
+
ch_mult=[1, 2, 4, 4],
|
70 |
+
num_res_blocks=2,
|
71 |
+
z_channels=16,
|
72 |
+
scale_factor=0.3611,
|
73 |
+
shift_factor=0.1159,
|
74 |
+
)
|
75 |
+
|
76 |
+
step1x_params = Step1XParams(
|
77 |
+
in_channels=64,
|
78 |
+
out_channels=64,
|
79 |
+
vec_in_dim=768,
|
80 |
+
context_in_dim=4096,
|
81 |
+
hidden_size=3072,
|
82 |
+
mlp_ratio=4.0,
|
83 |
+
num_heads=24,
|
84 |
+
depth=19,
|
85 |
+
depth_single_blocks=38,
|
86 |
+
axes_dim=[16, 56, 56],
|
87 |
+
theta=10_000,
|
88 |
+
qkv_bias=True,
|
89 |
+
)
|
90 |
+
dit = Step1XEdit(step1x_params)
|
91 |
+
|
92 |
+
ae = load_state_dict(ae, ae_path)
|
93 |
+
dit = load_state_dict(
|
94 |
+
dit, dit_path
|
95 |
+
)
|
96 |
+
|
97 |
+
dit = dit.to(device=device, dtype=dtype)
|
98 |
+
ae = ae.to(device=device, dtype=torch.float32)
|
99 |
+
|
100 |
+
return ae, dit, qwen2vl_encoder
|
101 |
+
|
102 |
+
|
103 |
+
class ImageGenerator:
|
104 |
+
def __init__(
|
105 |
+
self,
|
106 |
+
dit_path=None,
|
107 |
+
ae_path=None,
|
108 |
+
qwen2vl_model_path=None,
|
109 |
+
device="cuda",
|
110 |
+
max_length=640,
|
111 |
+
dtype=torch.bfloat16,
|
112 |
+
) -> None:
|
113 |
+
self.device = torch.device(device)
|
114 |
+
self.ae, self.dit, self.llm_encoder = load_models(
|
115 |
+
dit_path=dit_path,
|
116 |
+
ae_path=ae_path,
|
117 |
+
qwen2vl_model_path=qwen2vl_model_path,
|
118 |
+
max_length=max_length,
|
119 |
+
dtype=dtype,
|
120 |
+
)
|
121 |
+
|
122 |
+
def prepare(self, prompt, img, ref_image, ref_image_raw):
|
123 |
+
bs, _, h, w = img.shape
|
124 |
+
bs, _, ref_h, ref_w = ref_image.shape
|
125 |
+
|
126 |
+
assert h == ref_h and w == ref_w
|
127 |
+
|
128 |
+
if bs == 1 and not isinstance(prompt, str):
|
129 |
+
bs = len(prompt)
|
130 |
+
elif bs >= 1 and isinstance(prompt, str):
|
131 |
+
prompt = [prompt] * bs
|
132 |
+
|
133 |
+
img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
|
134 |
+
ref_img = rearrange(ref_image, "b c (ref_h ph) (ref_w pw) -> b (ref_h ref_w) (c ph pw)", ph=2, pw=2)
|
135 |
+
if img.shape[0] == 1 and bs > 1:
|
136 |
+
img = repeat(img, "1 ... -> bs ...", bs=bs)
|
137 |
+
ref_img = repeat(ref_img, "1 ... -> bs ...", bs=bs)
|
138 |
+
|
139 |
+
img_ids = torch.zeros(h // 2, w // 2, 3)
|
140 |
+
|
141 |
+
img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None]
|
142 |
+
img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :]
|
143 |
+
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
|
144 |
+
|
145 |
+
ref_img_ids = torch.zeros(ref_h // 2, ref_w // 2, 3)
|
146 |
+
|
147 |
+
ref_img_ids[..., 1] = ref_img_ids[..., 1] + torch.arange(ref_h // 2)[:, None]
|
148 |
+
ref_img_ids[..., 2] = ref_img_ids[..., 2] + torch.arange(ref_w // 2)[None, :]
|
149 |
+
ref_img_ids = repeat(ref_img_ids, "ref_h ref_w c -> b (ref_h ref_w) c", b=bs)
|
150 |
+
|
151 |
+
if isinstance(prompt, str):
|
152 |
+
prompt = [prompt]
|
153 |
+
|
154 |
+
txt, mask = self.llm_encoder(prompt, ref_image_raw)
|
155 |
+
|
156 |
+
txt_ids = torch.zeros(bs, txt.shape[1], 3)
|
157 |
+
|
158 |
+
img = torch.cat([img, ref_img.to(device=img.device, dtype=img.dtype)], dim=-2)
|
159 |
+
img_ids = torch.cat([img_ids, ref_img_ids], dim=-2)
|
160 |
+
|
161 |
+
|
162 |
+
return {
|
163 |
+
"img": img,
|
164 |
+
"mask": mask,
|
165 |
+
"img_ids": img_ids.to(img.device),
|
166 |
+
"llm_embedding": txt.to(img.device),
|
167 |
+
"txt_ids": txt_ids.to(img.device),
|
168 |
+
}
|
169 |
+
|
170 |
+
@staticmethod
|
171 |
+
def process_diff_norm(diff_norm, k):
|
172 |
+
pow_result = torch.pow(diff_norm, k)
|
173 |
+
|
174 |
+
result = torch.where(
|
175 |
+
diff_norm > 1.0,
|
176 |
+
pow_result,
|
177 |
+
torch.where(diff_norm < 1.0, torch.ones_like(diff_norm), diff_norm),
|
178 |
+
)
|
179 |
+
return result
|
180 |
+
|
181 |
+
def denoise(
|
182 |
+
self,
|
183 |
+
img: torch.Tensor,
|
184 |
+
img_ids: torch.Tensor,
|
185 |
+
llm_embedding: torch.Tensor,
|
186 |
+
txt_ids: torch.Tensor,
|
187 |
+
timesteps: list[float],
|
188 |
+
cfg_guidance: float = 4.5,
|
189 |
+
mask=None,
|
190 |
+
show_progress=False,
|
191 |
+
timesteps_truncate=1.0,
|
192 |
+
):
|
193 |
+
if show_progress:
|
194 |
+
pbar = tqdm(itertools.pairwise(timesteps), desc='denoising...')
|
195 |
+
else:
|
196 |
+
pbar = itertools.pairwise(timesteps)
|
197 |
+
for t_curr, t_prev in pbar:
|
198 |
+
if img.shape[0] == 1 and cfg_guidance != -1:
|
199 |
+
img = torch.cat([img, img], dim=0)
|
200 |
+
t_vec = torch.full(
|
201 |
+
(img.shape[0],), t_curr, dtype=img.dtype, device=img.device
|
202 |
+
)
|
203 |
+
|
204 |
+
txt, vec = self.dit.connector(llm_embedding, t_vec, mask)
|
205 |
+
|
206 |
+
|
207 |
+
pred = self.dit(
|
208 |
+
img=img,
|
209 |
+
img_ids=img_ids,
|
210 |
+
txt=txt,
|
211 |
+
txt_ids=txt_ids,
|
212 |
+
y=vec,
|
213 |
+
timesteps=t_vec,
|
214 |
+
)
|
215 |
+
|
216 |
+
if cfg_guidance != -1:
|
217 |
+
cond, uncond = (
|
218 |
+
pred[0 : pred.shape[0] // 2, :],
|
219 |
+
pred[pred.shape[0] // 2 :, :],
|
220 |
+
)
|
221 |
+
if t_curr > timesteps_truncate:
|
222 |
+
diff = cond - uncond
|
223 |
+
diff_norm = torch.norm(diff, dim=(2), keepdim=True)
|
224 |
+
pred = uncond + cfg_guidance * (
|
225 |
+
cond - uncond
|
226 |
+
) / self.process_diff_norm(diff_norm, k=0.4)
|
227 |
+
else:
|
228 |
+
pred = uncond + cfg_guidance * (cond - uncond)
|
229 |
+
tem_img = img[0 : img.shape[0] // 2, :] + (t_prev - t_curr) * pred
|
230 |
+
img_input_length = img.shape[1] // 2
|
231 |
+
img = torch.cat(
|
232 |
+
[
|
233 |
+
tem_img[:, :img_input_length],
|
234 |
+
img[ : img.shape[0] // 2, img_input_length:],
|
235 |
+
], dim=1
|
236 |
+
)
|
237 |
+
|
238 |
+
return img[:, :img.shape[1] // 2]
|
239 |
+
|
240 |
+
@staticmethod
|
241 |
+
def unpack(x: torch.Tensor, height: int, width: int) -> torch.Tensor:
|
242 |
+
return rearrange(
|
243 |
+
x,
|
244 |
+
"b (h w) (c ph pw) -> b c (h ph) (w pw)",
|
245 |
+
h=math.ceil(height / 16),
|
246 |
+
w=math.ceil(width / 16),
|
247 |
+
ph=2,
|
248 |
+
pw=2,
|
249 |
+
)
|
250 |
+
|
251 |
+
@staticmethod
|
252 |
+
def load_image(image):
|
253 |
+
from PIL import Image
|
254 |
+
|
255 |
+
if isinstance(image, np.ndarray):
|
256 |
+
image = torch.from_numpy(image).permute(2, 0, 1).float() / 255.0
|
257 |
+
image = image.unsqueeze(0)
|
258 |
+
return image
|
259 |
+
elif isinstance(image, Image.Image):
|
260 |
+
image = F.to_tensor(image.convert("RGB"))
|
261 |
+
image = image.unsqueeze(0)
|
262 |
+
return image
|
263 |
+
elif isinstance(image, torch.Tensor):
|
264 |
+
return image
|
265 |
+
elif isinstance(image, str):
|
266 |
+
image = F.to_tensor(Image.open(image).convert("RGB"))
|
267 |
+
image = image.unsqueeze(0)
|
268 |
+
return image
|
269 |
+
else:
|
270 |
+
raise ValueError(f"Unsupported image type: {type(image)}")
|
271 |
+
|
272 |
+
def output_process_image(self, resize_img, image_size):
|
273 |
+
res_image = resize_img.resize(image_size)
|
274 |
+
return res_image
|
275 |
+
|
276 |
+
def input_process_image(self, img, img_size=512):
|
277 |
+
# 1. 打开图片
|
278 |
+
w, h = img.size
|
279 |
+
r = w / h
|
280 |
+
|
281 |
+
if w > h:
|
282 |
+
w_new = math.ceil(math.sqrt(img_size * img_size * r))
|
283 |
+
h_new = math.ceil(w_new / r)
|
284 |
+
else:
|
285 |
+
h_new = math.ceil(math.sqrt(img_size * img_size / r))
|
286 |
+
w_new = math.ceil(h_new * r)
|
287 |
+
h_new = math.ceil(h_new) // 16 * 16
|
288 |
+
w_new = math.ceil(w_new) // 16 * 16
|
289 |
+
|
290 |
+
img_resized = img.resize((w_new, h_new))
|
291 |
+
return img_resized, img.size
|
292 |
+
|
293 |
+
@torch.inference_mode()
|
294 |
+
def generate_image(
|
295 |
+
self,
|
296 |
+
prompt,
|
297 |
+
negative_prompt,
|
298 |
+
ref_images,
|
299 |
+
num_steps,
|
300 |
+
cfg_guidance,
|
301 |
+
seed,
|
302 |
+
num_samples=1,
|
303 |
+
init_image=None,
|
304 |
+
image2image_strength=0.0,
|
305 |
+
show_progress=False,
|
306 |
+
size_level=512,
|
307 |
+
):
|
308 |
+
assert num_samples == 1, "num_samples > 1 is not supported yet."
|
309 |
+
ref_images_raw, img_info = self.input_process_image(ref_images, img_size=size_level)
|
310 |
+
|
311 |
+
width, height = ref_images_raw.width, ref_images_raw.height
|
312 |
+
|
313 |
+
|
314 |
+
ref_images_raw = self.load_image(ref_images_raw)
|
315 |
+
ref_images_raw = ref_images_raw.to(self.device)
|
316 |
+
ref_images = self.ae.encode(ref_images_raw.to(self.device) * 2 - 1)
|
317 |
+
|
318 |
+
seed = int(seed)
|
319 |
+
seed = torch.Generator(device="cpu").seed() if seed < 0 else seed
|
320 |
+
|
321 |
+
t0 = time.perf_counter()
|
322 |
+
|
323 |
+
if init_image is not None:
|
324 |
+
init_image = self.load_image(init_image)
|
325 |
+
init_image = init_image.to(self.device)
|
326 |
+
init_image = torch.nn.functional.interpolate(init_image, (height, width))
|
327 |
+
init_image = self.ae.encode(init_image.to() * 2 - 1)
|
328 |
+
|
329 |
+
x = torch.randn(
|
330 |
+
num_samples,
|
331 |
+
16,
|
332 |
+
height // 8,
|
333 |
+
width // 8,
|
334 |
+
device=self.device,
|
335 |
+
dtype=torch.bfloat16,
|
336 |
+
generator=torch.Generator(device=self.device).manual_seed(seed),
|
337 |
+
)
|
338 |
+
|
339 |
+
timesteps = sampling.get_schedule(
|
340 |
+
num_steps, x.shape[-1] * x.shape[-2] // 4, shift=True
|
341 |
+
)
|
342 |
+
|
343 |
+
if init_image is not None:
|
344 |
+
t_idx = int((1 - image2image_strength) * num_steps)
|
345 |
+
t = timesteps[t_idx]
|
346 |
+
timesteps = timesteps[t_idx:]
|
347 |
+
x = t * x + (1.0 - t) * init_image.to(x.dtype)
|
348 |
+
|
349 |
+
x = torch.cat([x, x], dim=0)
|
350 |
+
ref_images = torch.cat([ref_images, ref_images], dim=0)
|
351 |
+
ref_images_raw = torch.cat([ref_images_raw, ref_images_raw], dim=0)
|
352 |
+
inputs = self.prepare([prompt, negative_prompt], x, ref_image=ref_images, ref_image_raw=ref_images_raw)
|
353 |
+
|
354 |
+
x = self.denoise(
|
355 |
+
**inputs,
|
356 |
+
cfg_guidance=cfg_guidance,
|
357 |
+
timesteps=timesteps,
|
358 |
+
show_progress=show_progress,
|
359 |
+
timesteps_truncate=1.0,
|
360 |
+
)
|
361 |
+
x = self.unpack(x.float(), height, width)
|
362 |
+
with torch.autocast(device_type=self.device.type, dtype=torch.bfloat16):
|
363 |
+
x = self.ae.decode(x)
|
364 |
+
x = x.clamp(-1, 1)
|
365 |
+
x = x.mul(0.5).add(0.5)
|
366 |
+
|
367 |
+
t1 = time.perf_counter()
|
368 |
+
print(f"Done in {t1 - t0:.1f}s.")
|
369 |
+
images_list = []
|
370 |
+
for img in x.float():
|
371 |
+
images_list.append(self.output_process_image(F.to_pil_image(img), img_info))
|
372 |
+
return images_list
|
373 |
+
|
374 |
+
|
375 |
+
def prepare_infer_func():
|
376 |
+
# 模型仓库ID(如:"bert-base-uncased")
|
377 |
+
model_repo = "stepfun-ai/Step1X-Edit"
|
378 |
+
# 本地保存路径
|
379 |
+
model_path = "./model_weights"
|
380 |
+
os.makedirs(model_path, exist_ok=True)
|
381 |
+
|
382 |
+
|
383 |
+
# 下载模型(包括所有文件)
|
384 |
+
snapshot_download(
|
385 |
+
repo_id=model_repo,
|
386 |
+
local_dir=model_path,
|
387 |
+
local_dir_use_symlinks=False # 避免使用符号链接
|
388 |
+
)
|
389 |
+
|
390 |
+
|
391 |
+
image_edit = ImageGenerator(
|
392 |
+
ae_path=os.path.join(model_path, 'vae.safetensors'),
|
393 |
+
dit_path=os.path.join(model_path, "step1x-edit-i1258.safetensors"),
|
394 |
+
qwen2vl_model_path='Qwen/Qwen2.5-VL-7B-Instruct',
|
395 |
+
max_length=640,
|
396 |
+
)
|
397 |
+
|
398 |
+
return image_edit.generate_image
|
399 |
+
|
400 |
+
def inference(infer_func, prompt, ref_images, seed, size_level):
|
401 |
+
start_time = time.time()
|
402 |
+
|
403 |
+
image = infer_func(
|
404 |
+
prompt,
|
405 |
+
negative_prompt="",
|
406 |
+
ref_images=ref_images,
|
407 |
+
num_samples=1,
|
408 |
+
num_steps=28,
|
409 |
+
cfg_guidance=6.0,
|
410 |
+
seed=seed,
|
411 |
+
show_progress=True,
|
412 |
+
size_level=size_level,
|
413 |
+
)[0]
|
414 |
+
|
415 |
+
print(f"Time taken: {time.time() - start_time:.2f} seconds")
|
416 |
+
return image
|
417 |
+
|
418 |
+
|
419 |
+
def create_demo():
|
420 |
+
inference_func = prepare_infer_func()
|
421 |
+
with gr.Blocks() as demo:
|
422 |
+
gr.Markdown(
|
423 |
+
"""
|
424 |
+
# Step1X-Edit
|
425 |
+
"""
|
426 |
+
)
|
427 |
+
with gr.Row():
|
428 |
+
with gr.Column():
|
429 |
+
prompt = gr.Textbox(
|
430 |
+
label="编辑指令",
|
431 |
+
value='Remove the person from the image.',
|
432 |
+
)
|
433 |
+
init_image = gr.Image(label="Input Image", type='pil')
|
434 |
+
|
435 |
+
random_seed = gr.Number(label="Random Seed", value=-1, minimum=-1)
|
436 |
+
|
437 |
+
size_level = gr.Number(label="size level (recommend 512, 768, 1024, min 512)", value=512, minimum=512)
|
438 |
+
|
439 |
+
generate_btn = gr.Button("Generate")
|
440 |
+
|
441 |
+
with gr.Column():
|
442 |
+
output_image = gr.Image(label="Generated Image",type='pil',image_mode='RGB')
|
443 |
+
output_random_seed = gr.Textbox(label="Used Seed", lines=5)
|
444 |
+
from functools import partial
|
445 |
+
generate_btn.click(
|
446 |
+
fn=partial(infer_func=inference_func, prompt=prompt, ref_images=init_image, seed=random_seed, size_level=size_level),
|
447 |
+
inputs=[
|
448 |
+
init_image,
|
449 |
+
prompt,
|
450 |
+
random_seed,
|
451 |
+
size_level,
|
452 |
+
],
|
453 |
+
outputs=[output_image, output_random_seed],
|
454 |
+
)
|
455 |
+
|
456 |
+
return demo
|
457 |
+
|
458 |
+
|
459 |
+
if __name__ == "__main__":
|
460 |
+
demo = create_demo()
|
461 |
+
demo.launch()
|
get_flash_attn.py
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import platform
|
2 |
+
import sys
|
3 |
+
|
4 |
+
import torch
|
5 |
+
|
6 |
+
|
7 |
+
def get_cuda_version():
|
8 |
+
if torch.cuda.is_available():
|
9 |
+
cuda_version = torch.version.cuda
|
10 |
+
return f"cu{cuda_version.replace('.', '')[:2]}" # 例如:cu121
|
11 |
+
return "cpu"
|
12 |
+
|
13 |
+
|
14 |
+
def get_torch_version():
|
15 |
+
return f"torch{torch.__version__.split('+')[0]}"[:-2] # 例如:torch2.2
|
16 |
+
|
17 |
+
|
18 |
+
def get_python_version():
|
19 |
+
version = sys.version_info
|
20 |
+
return f"cp{version.major}{version.minor}" # 例如:cp310
|
21 |
+
|
22 |
+
|
23 |
+
def get_abi_flag():
|
24 |
+
return "abiTRUE" if torch._C._GLIBCXX_USE_CXX11_ABI else "abiFALSE"
|
25 |
+
|
26 |
+
|
27 |
+
def get_platform():
|
28 |
+
system = platform.system().lower()
|
29 |
+
machine = platform.machine().lower()
|
30 |
+
if system == "linux" and machine == "x86_64":
|
31 |
+
return "linux_x86_64"
|
32 |
+
elif system == "windows" and machine == "amd64":
|
33 |
+
return "win_amd64"
|
34 |
+
elif system == "darwin" and machine == "x86_64":
|
35 |
+
return "macosx_x86_64"
|
36 |
+
else:
|
37 |
+
raise ValueError(f"Unsupported platform: {system}_{machine}")
|
38 |
+
|
39 |
+
|
40 |
+
def generate_flash_attn_filename(flash_attn_version="2.7.2.post1"):
|
41 |
+
cuda_version = get_cuda_version()
|
42 |
+
torch_version = get_torch_version()
|
43 |
+
python_version = get_python_version()
|
44 |
+
abi_flag = get_abi_flag()
|
45 |
+
platform_tag = get_platform()
|
46 |
+
|
47 |
+
filename = (
|
48 |
+
f"flash_attn-{flash_attn_version}+{cuda_version}{torch_version}cxx11{abi_flag}-"
|
49 |
+
f"{python_version}-{python_version}-{platform_tag}.whl"
|
50 |
+
)
|
51 |
+
return filename
|
52 |
+
|
53 |
+
|
54 |
+
if __name__ == "__main__":
|
55 |
+
try:
|
56 |
+
filename = generate_flash_attn_filename()
|
57 |
+
print(filename)
|
58 |
+
except Exception as e:
|
59 |
+
print("Error generating filename:", e)
|
modules/__init__.py
ADDED
File without changes
|
modules/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (135 Bytes). View file
|
|
modules/__pycache__/attention.cpython-310.pyc
ADDED
Binary file (3.13 kB). View file
|
|
modules/__pycache__/autoencoder.cpython-310.pyc
ADDED
Binary file (8.78 kB). View file
|
|
modules/__pycache__/conditioner.cpython-310.pyc
ADDED
Binary file (4.96 kB). View file
|
|
modules/__pycache__/connector_edit.cpython-310.pyc
ADDED
Binary file (11.8 kB). View file
|
|
modules/__pycache__/layers.cpython-310.pyc
ADDED
Binary file (19.4 kB). View file
|
|
modules/__pycache__/model_edit.cpython-310.pyc
ADDED
Binary file (4.22 kB). View file
|
|
modules/attention.py
ADDED
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn.functional as F
|
5 |
+
|
6 |
+
|
7 |
+
try:
|
8 |
+
import flash_attn
|
9 |
+
from flash_attn.flash_attn_interface import (
|
10 |
+
_flash_attn_forward,
|
11 |
+
flash_attn_func,
|
12 |
+
flash_attn_varlen_func,
|
13 |
+
)
|
14 |
+
except ImportError:
|
15 |
+
flash_attn = None
|
16 |
+
flash_attn_varlen_func = None
|
17 |
+
_flash_attn_forward = None
|
18 |
+
flash_attn_func = None
|
19 |
+
|
20 |
+
MEMORY_LAYOUT = {
|
21 |
+
# flash模式:
|
22 |
+
# 预处理: 输入 [batch_size, seq_len, num_heads, head_dim]
|
23 |
+
# 后处理: 保持形状不变
|
24 |
+
"flash": (
|
25 |
+
lambda x: x, # 保持形状
|
26 |
+
lambda x: x, # 保持形状
|
27 |
+
),
|
28 |
+
# torch/vanilla模式:
|
29 |
+
# 预处理: 交换序列和注意力头的维度 [B,S,A,D] -> [B,A,S,D]
|
30 |
+
# 后处理: 交换回原始维度 [B,A,S,D] -> [B,S,A,D]
|
31 |
+
"torch": (
|
32 |
+
lambda x: x.transpose(1, 2), # (B,S,A,D) -> (B,A,S,D)
|
33 |
+
lambda x: x.transpose(1, 2), # (B,A,S,D) -> (B,S,A,D)
|
34 |
+
),
|
35 |
+
"vanilla": (
|
36 |
+
lambda x: x.transpose(1, 2),
|
37 |
+
lambda x: x.transpose(1, 2),
|
38 |
+
),
|
39 |
+
}
|
40 |
+
|
41 |
+
|
42 |
+
def attention(
|
43 |
+
q,
|
44 |
+
k,
|
45 |
+
v,
|
46 |
+
mode="flash",
|
47 |
+
drop_rate=0,
|
48 |
+
attn_mask=None,
|
49 |
+
causal=False,
|
50 |
+
):
|
51 |
+
"""
|
52 |
+
执行QKV自注意力计算
|
53 |
+
|
54 |
+
Args:
|
55 |
+
q (torch.Tensor): 查询张量,形状 [batch_size, seq_len, num_heads, head_dim]
|
56 |
+
k (torch.Tensor): 键张量,形状 [batch_size, seq_len_kv, num_heads, head_dim]
|
57 |
+
v (torch.Tensor): 值张量,形状 [batch_size, seq_len_kv, num_heads, head_dim]
|
58 |
+
mode (str): 注意力模式,可选 'flash', 'torch', 'vanilla'
|
59 |
+
drop_rate (float): 注意力矩阵的dropout概率
|
60 |
+
attn_mask (torch.Tensor): 注意力掩码,形状根据模式不同而变化
|
61 |
+
causal (bool): 是否使用因果注意力(仅关注前面位置)
|
62 |
+
|
63 |
+
Returns:
|
64 |
+
torch.Tensor: 注意力输出,形状 [batch_size, seq_len, num_heads * head_dim]
|
65 |
+
"""
|
66 |
+
# 获取预处理和后处理函数
|
67 |
+
pre_attn_layout, post_attn_layout = MEMORY_LAYOUT[mode]
|
68 |
+
|
69 |
+
# 应用预处理变换
|
70 |
+
q = pre_attn_layout(q) # 形状根据模式变化
|
71 |
+
k = pre_attn_layout(k)
|
72 |
+
v = pre_attn_layout(v)
|
73 |
+
|
74 |
+
if mode == "torch":
|
75 |
+
# 使用PyTorch原生的scaled_dot_product_attention
|
76 |
+
if attn_mask is not None and attn_mask.dtype != torch.bool:
|
77 |
+
attn_mask = attn_mask.to(q.dtype)
|
78 |
+
x = F.scaled_dot_product_attention(
|
79 |
+
q, k, v, attn_mask=attn_mask, dropout_p=drop_rate, is_causal=causal
|
80 |
+
)
|
81 |
+
elif mode == "flash":
|
82 |
+
assert flash_attn_func is not None, "flash_attn_func未定义"
|
83 |
+
assert attn_mask is None, "不支持的注意力掩码"
|
84 |
+
x: torch.Tensor = flash_attn_func(
|
85 |
+
q, k, v, dropout_p=drop_rate, causal=causal, softmax_scale=None
|
86 |
+
) # type: ignore
|
87 |
+
elif mode == "vanilla":
|
88 |
+
# 手动实现注意力机制
|
89 |
+
scale_factor = 1 / math.sqrt(q.size(-1)) # 缩放因子 1/sqrt(d_k)
|
90 |
+
|
91 |
+
b, a, s, _ = q.shape # 获取形状参数
|
92 |
+
s1 = k.size(2) # 键值序列长度
|
93 |
+
|
94 |
+
# 初始化注意力偏置
|
95 |
+
attn_bias = torch.zeros(b, a, s, s1, dtype=q.dtype, device=q.device)
|
96 |
+
|
97 |
+
# 处理因果掩码
|
98 |
+
if causal:
|
99 |
+
assert attn_mask is None, "因果掩码和注意力掩码不能同时使用"
|
100 |
+
# 生成下三角因果掩码
|
101 |
+
temp_mask = torch.ones(b, a, s, s, dtype=torch.bool, device=q.device).tril(
|
102 |
+
diagonal=0
|
103 |
+
)
|
104 |
+
attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
|
105 |
+
attn_bias = attn_bias.to(q.dtype)
|
106 |
+
|
107 |
+
# 处理自定义注意力掩码
|
108 |
+
if attn_mask is not None:
|
109 |
+
if attn_mask.dtype == torch.bool:
|
110 |
+
attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
|
111 |
+
else:
|
112 |
+
attn_bias += attn_mask # 允许类似ALiBi的位置偏置
|
113 |
+
|
114 |
+
# 计算注意力矩阵
|
115 |
+
attn = (q @ k.transpose(-2, -1)) * scale_factor # [B,A,S,S1]
|
116 |
+
attn += attn_bias
|
117 |
+
|
118 |
+
# softmax和dropout
|
119 |
+
attn = attn.softmax(dim=-1)
|
120 |
+
attn = torch.dropout(attn, p=drop_rate, train=True)
|
121 |
+
|
122 |
+
# 计算输出
|
123 |
+
x = attn @ v # [B,A,S,D]
|
124 |
+
else:
|
125 |
+
raise NotImplementedError(f"不支持的注意力模式: {mode}")
|
126 |
+
|
127 |
+
# 应用后处理变换
|
128 |
+
x = post_attn_layout(x) # 恢复原始维度顺序
|
129 |
+
|
130 |
+
# 合并注意力头维度
|
131 |
+
b, s, a, d = x.shape
|
132 |
+
out = x.reshape(b, s, -1) # [B,S,A*D]
|
133 |
+
return out
|
modules/autoencoder.py
ADDED
@@ -0,0 +1,326 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Modified from Flux
|
2 |
+
#
|
3 |
+
# Copyright 2024 Black Forest Labs
|
4 |
+
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
#
|
17 |
+
# This source code is licensed under the license found in the
|
18 |
+
# LICENSE file in the root directory of this source tree.
|
19 |
+
import torch
|
20 |
+
from einops import rearrange
|
21 |
+
from torch import Tensor, nn
|
22 |
+
|
23 |
+
|
24 |
+
def swish(x: Tensor) -> Tensor:
|
25 |
+
return x * torch.sigmoid(x)
|
26 |
+
|
27 |
+
|
28 |
+
class AttnBlock(nn.Module):
|
29 |
+
def __init__(self, in_channels: int):
|
30 |
+
super().__init__()
|
31 |
+
self.in_channels = in_channels
|
32 |
+
|
33 |
+
self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
34 |
+
|
35 |
+
self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1)
|
36 |
+
self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1)
|
37 |
+
self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1)
|
38 |
+
self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1)
|
39 |
+
|
40 |
+
def attention(self, h_: Tensor) -> Tensor:
|
41 |
+
h_ = self.norm(h_)
|
42 |
+
q = self.q(h_)
|
43 |
+
k = self.k(h_)
|
44 |
+
v = self.v(h_)
|
45 |
+
|
46 |
+
b, c, h, w = q.shape
|
47 |
+
q = rearrange(q, "b c h w -> b 1 (h w) c").contiguous()
|
48 |
+
k = rearrange(k, "b c h w -> b 1 (h w) c").contiguous()
|
49 |
+
v = rearrange(v, "b c h w -> b 1 (h w) c").contiguous()
|
50 |
+
h_ = nn.functional.scaled_dot_product_attention(q, k, v)
|
51 |
+
|
52 |
+
return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b)
|
53 |
+
|
54 |
+
def forward(self, x: Tensor) -> Tensor:
|
55 |
+
return x + self.proj_out(self.attention(x))
|
56 |
+
|
57 |
+
|
58 |
+
class ResnetBlock(nn.Module):
|
59 |
+
def __init__(self, in_channels: int, out_channels: int):
|
60 |
+
super().__init__()
|
61 |
+
self.in_channels = in_channels
|
62 |
+
out_channels = in_channels if out_channels is None else out_channels
|
63 |
+
self.out_channels = out_channels
|
64 |
+
|
65 |
+
self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
66 |
+
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
67 |
+
self.norm2 = nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True)
|
68 |
+
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
69 |
+
if self.in_channels != self.out_channels:
|
70 |
+
self.nin_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
|
71 |
+
|
72 |
+
def forward(self, x):
|
73 |
+
h = x
|
74 |
+
h = self.norm1(h)
|
75 |
+
h = swish(h)
|
76 |
+
h = self.conv1(h)
|
77 |
+
|
78 |
+
h = self.norm2(h)
|
79 |
+
h = swish(h)
|
80 |
+
h = self.conv2(h)
|
81 |
+
|
82 |
+
if self.in_channels != self.out_channels:
|
83 |
+
x = self.nin_shortcut(x)
|
84 |
+
|
85 |
+
return x + h
|
86 |
+
|
87 |
+
|
88 |
+
class Downsample(nn.Module):
|
89 |
+
def __init__(self, in_channels: int):
|
90 |
+
super().__init__()
|
91 |
+
# no asymmetric padding in torch conv, must do it ourselves
|
92 |
+
self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
|
93 |
+
|
94 |
+
def forward(self, x: Tensor):
|
95 |
+
pad = (0, 1, 0, 1)
|
96 |
+
x = nn.functional.pad(x, pad, mode="constant", value=0)
|
97 |
+
x = self.conv(x)
|
98 |
+
return x
|
99 |
+
|
100 |
+
|
101 |
+
class Upsample(nn.Module):
|
102 |
+
def __init__(self, in_channels: int):
|
103 |
+
super().__init__()
|
104 |
+
self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
|
105 |
+
|
106 |
+
def forward(self, x: Tensor):
|
107 |
+
x = nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
|
108 |
+
x = self.conv(x)
|
109 |
+
return x
|
110 |
+
|
111 |
+
|
112 |
+
class Encoder(nn.Module):
|
113 |
+
def __init__(
|
114 |
+
self,
|
115 |
+
resolution: int,
|
116 |
+
in_channels: int,
|
117 |
+
ch: int,
|
118 |
+
ch_mult: list[int],
|
119 |
+
num_res_blocks: int,
|
120 |
+
z_channels: int,
|
121 |
+
):
|
122 |
+
super().__init__()
|
123 |
+
self.ch = ch
|
124 |
+
self.num_resolutions = len(ch_mult)
|
125 |
+
self.num_res_blocks = num_res_blocks
|
126 |
+
self.resolution = resolution
|
127 |
+
self.in_channels = in_channels
|
128 |
+
# downsampling
|
129 |
+
self.conv_in = nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1)
|
130 |
+
|
131 |
+
curr_res = resolution
|
132 |
+
in_ch_mult = (1, *tuple(ch_mult))
|
133 |
+
self.in_ch_mult = in_ch_mult
|
134 |
+
self.down = nn.ModuleList()
|
135 |
+
block_in = self.ch
|
136 |
+
for i_level in range(self.num_resolutions):
|
137 |
+
block = nn.ModuleList()
|
138 |
+
attn = nn.ModuleList()
|
139 |
+
block_in = ch * in_ch_mult[i_level]
|
140 |
+
block_out = ch * ch_mult[i_level]
|
141 |
+
for _ in range(self.num_res_blocks):
|
142 |
+
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
|
143 |
+
block_in = block_out
|
144 |
+
down = nn.Module()
|
145 |
+
down.block = block
|
146 |
+
down.attn = attn
|
147 |
+
if i_level != self.num_resolutions - 1:
|
148 |
+
down.downsample = Downsample(block_in)
|
149 |
+
curr_res = curr_res // 2
|
150 |
+
self.down.append(down)
|
151 |
+
|
152 |
+
# middle
|
153 |
+
self.mid = nn.Module()
|
154 |
+
self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
|
155 |
+
self.mid.attn_1 = AttnBlock(block_in)
|
156 |
+
self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
|
157 |
+
|
158 |
+
# end
|
159 |
+
self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
|
160 |
+
self.conv_out = nn.Conv2d(block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1)
|
161 |
+
|
162 |
+
def forward(self, x: Tensor) -> Tensor:
|
163 |
+
# downsampling
|
164 |
+
hs = [self.conv_in(x)]
|
165 |
+
for i_level in range(self.num_resolutions):
|
166 |
+
for i_block in range(self.num_res_blocks):
|
167 |
+
h = self.down[i_level].block[i_block](hs[-1])
|
168 |
+
if len(self.down[i_level].attn) > 0:
|
169 |
+
h = self.down[i_level].attn[i_block](h)
|
170 |
+
hs.append(h)
|
171 |
+
if i_level != self.num_resolutions - 1:
|
172 |
+
hs.append(self.down[i_level].downsample(hs[-1]))
|
173 |
+
|
174 |
+
# middle
|
175 |
+
h = hs[-1]
|
176 |
+
h = self.mid.block_1(h)
|
177 |
+
h = self.mid.attn_1(h)
|
178 |
+
h = self.mid.block_2(h)
|
179 |
+
# end
|
180 |
+
h = self.norm_out(h)
|
181 |
+
h = swish(h)
|
182 |
+
h = self.conv_out(h)
|
183 |
+
return h
|
184 |
+
|
185 |
+
|
186 |
+
class Decoder(nn.Module):
|
187 |
+
def __init__(
|
188 |
+
self,
|
189 |
+
ch: int,
|
190 |
+
out_ch: int,
|
191 |
+
ch_mult: list[int],
|
192 |
+
num_res_blocks: int,
|
193 |
+
in_channels: int,
|
194 |
+
resolution: int,
|
195 |
+
z_channels: int,
|
196 |
+
):
|
197 |
+
super().__init__()
|
198 |
+
self.ch = ch
|
199 |
+
self.num_resolutions = len(ch_mult)
|
200 |
+
self.num_res_blocks = num_res_blocks
|
201 |
+
self.resolution = resolution
|
202 |
+
self.in_channels = in_channels
|
203 |
+
self.ffactor = 2 ** (self.num_resolutions - 1)
|
204 |
+
|
205 |
+
# compute in_ch_mult, block_in and curr_res at lowest res
|
206 |
+
block_in = ch * ch_mult[self.num_resolutions - 1]
|
207 |
+
curr_res = resolution // 2 ** (self.num_resolutions - 1)
|
208 |
+
self.z_shape = (1, z_channels, curr_res, curr_res)
|
209 |
+
|
210 |
+
# z to block_in
|
211 |
+
self.conv_in = nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1)
|
212 |
+
|
213 |
+
# middle
|
214 |
+
self.mid = nn.Module()
|
215 |
+
self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
|
216 |
+
self.mid.attn_1 = AttnBlock(block_in)
|
217 |
+
self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
|
218 |
+
|
219 |
+
# upsampling
|
220 |
+
self.up = nn.ModuleList()
|
221 |
+
for i_level in reversed(range(self.num_resolutions)):
|
222 |
+
block = nn.ModuleList()
|
223 |
+
attn = nn.ModuleList()
|
224 |
+
block_out = ch * ch_mult[i_level]
|
225 |
+
for _ in range(self.num_res_blocks + 1):
|
226 |
+
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
|
227 |
+
block_in = block_out
|
228 |
+
up = nn.Module()
|
229 |
+
up.block = block
|
230 |
+
up.attn = attn
|
231 |
+
if i_level != 0:
|
232 |
+
up.upsample = Upsample(block_in)
|
233 |
+
curr_res = curr_res * 2
|
234 |
+
self.up.insert(0, up) # prepend to get consistent order
|
235 |
+
|
236 |
+
# end
|
237 |
+
self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
|
238 |
+
self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
|
239 |
+
|
240 |
+
def forward(self, z: Tensor) -> Tensor:
|
241 |
+
# z to block_in
|
242 |
+
h = self.conv_in(z)
|
243 |
+
|
244 |
+
# middle
|
245 |
+
h = self.mid.block_1(h)
|
246 |
+
h = self.mid.attn_1(h)
|
247 |
+
h = self.mid.block_2(h)
|
248 |
+
|
249 |
+
# upsampling
|
250 |
+
for i_level in reversed(range(self.num_resolutions)):
|
251 |
+
for i_block in range(self.num_res_blocks + 1):
|
252 |
+
h = self.up[i_level].block[i_block](h)
|
253 |
+
if len(self.up[i_level].attn) > 0:
|
254 |
+
h = self.up[i_level].attn[i_block](h)
|
255 |
+
if i_level != 0:
|
256 |
+
h = self.up[i_level].upsample(h)
|
257 |
+
|
258 |
+
# end
|
259 |
+
h = self.norm_out(h)
|
260 |
+
h = swish(h)
|
261 |
+
h = self.conv_out(h)
|
262 |
+
return h
|
263 |
+
|
264 |
+
|
265 |
+
class DiagonalGaussian(nn.Module):
|
266 |
+
def __init__(self, sample: bool = True, chunk_dim: int = 1):
|
267 |
+
super().__init__()
|
268 |
+
self.sample = sample
|
269 |
+
self.chunk_dim = chunk_dim
|
270 |
+
|
271 |
+
def forward(self, z: Tensor) -> Tensor:
|
272 |
+
mean, logvar = torch.chunk(z, 2, dim=self.chunk_dim)
|
273 |
+
if self.sample:
|
274 |
+
std = torch.exp(0.5 * logvar)
|
275 |
+
return mean + std * torch.randn_like(mean)
|
276 |
+
else:
|
277 |
+
return mean
|
278 |
+
|
279 |
+
|
280 |
+
class AutoEncoder(nn.Module):
|
281 |
+
def __init__(
|
282 |
+
self,
|
283 |
+
resolution: int,
|
284 |
+
in_channels: int,
|
285 |
+
ch: int,
|
286 |
+
out_ch: int,
|
287 |
+
ch_mult: list[int],
|
288 |
+
num_res_blocks: int,
|
289 |
+
z_channels: int,
|
290 |
+
scale_factor: float,
|
291 |
+
shift_factor: float,
|
292 |
+
):
|
293 |
+
super().__init__()
|
294 |
+
self.encoder = Encoder(
|
295 |
+
resolution=resolution,
|
296 |
+
in_channels=in_channels,
|
297 |
+
ch=ch,
|
298 |
+
ch_mult=ch_mult,
|
299 |
+
num_res_blocks=num_res_blocks,
|
300 |
+
z_channels=z_channels,
|
301 |
+
)
|
302 |
+
self.decoder = Decoder(
|
303 |
+
resolution=resolution,
|
304 |
+
in_channels=in_channels,
|
305 |
+
ch=ch,
|
306 |
+
out_ch=out_ch,
|
307 |
+
ch_mult=ch_mult,
|
308 |
+
num_res_blocks=num_res_blocks,
|
309 |
+
z_channels=z_channels,
|
310 |
+
)
|
311 |
+
self.reg = DiagonalGaussian()
|
312 |
+
|
313 |
+
self.scale_factor = scale_factor
|
314 |
+
self.shift_factor = shift_factor
|
315 |
+
|
316 |
+
def encode(self, x: Tensor) -> Tensor:
|
317 |
+
z = self.reg(self.encoder(x))
|
318 |
+
z = self.scale_factor * (z - self.shift_factor)
|
319 |
+
return z
|
320 |
+
|
321 |
+
def decode(self, z: Tensor) -> Tensor:
|
322 |
+
z = z / self.scale_factor + self.shift_factor
|
323 |
+
return self.decoder(z)
|
324 |
+
|
325 |
+
def forward(self, x: Tensor) -> Tensor:
|
326 |
+
return self.decode(self.encode(x))
|
modules/conditioner.py
ADDED
@@ -0,0 +1,216 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from qwen_vl_utils import process_vision_info
|
3 |
+
from transformers import (
|
4 |
+
AutoProcessor,
|
5 |
+
Qwen2VLForConditionalGeneration,
|
6 |
+
Qwen2_5_VLForConditionalGeneration,
|
7 |
+
)
|
8 |
+
from torchvision.transforms import ToPILImage
|
9 |
+
|
10 |
+
to_pil = ToPILImage()
|
11 |
+
|
12 |
+
Qwen25VL_7b_PREFIX = '''Given a user prompt, generate an "Enhanced prompt" that provides detailed visual descriptions suitable for image generation. Evaluate the level of detail in the user prompt:
|
13 |
+
- If the prompt is simple, focus on adding specifics about colors, shapes, sizes, textures, and spatial relationships to create vivid and concrete scenes.
|
14 |
+
- If the prompt is already detailed, refine and enhance the existing details slightly without overcomplicating.\n
|
15 |
+
Here are examples of how to transform or refine prompts:
|
16 |
+
- User Prompt: A cat sleeping -> Enhanced: A small, fluffy white cat curled up in a round shape, sleeping peacefully on a warm sunny windowsill, surrounded by pots of blooming red flowers.
|
17 |
+
- User Prompt: A busy city street -> Enhanced: A bustling city street scene at dusk, featuring glowing street lamps, a diverse crowd of people in colorful clothing, and a double-decker bus passing by towering glass skyscrapers.\n
|
18 |
+
Please generate only the enhanced description for the prompt below and avoid including any additional commentary or evaluations:
|
19 |
+
User Prompt:'''
|
20 |
+
|
21 |
+
|
22 |
+
def split_string(s):
|
23 |
+
# 将中文引号替换为英文引号
|
24 |
+
s = s.replace("“", '"').replace("”", '"') # use english quotes
|
25 |
+
result = []
|
26 |
+
# 标记是否在引号内
|
27 |
+
in_quotes = False
|
28 |
+
temp = ""
|
29 |
+
|
30 |
+
# 遍历字符串中的每个字符及其索引
|
31 |
+
for idx, char in enumerate(s):
|
32 |
+
# 如果字符是引号且索引大于 155
|
33 |
+
if char == '"' and idx > 155:
|
34 |
+
# 将引号添加到临时字符串
|
35 |
+
temp += char
|
36 |
+
# 如果不在引号内
|
37 |
+
if not in_quotes:
|
38 |
+
# 将临时字符串添加到结果列表
|
39 |
+
result.append(temp)
|
40 |
+
# 清空临时字符串
|
41 |
+
temp = ""
|
42 |
+
|
43 |
+
# 切换引号状态
|
44 |
+
in_quotes = not in_quotes
|
45 |
+
continue
|
46 |
+
# 如果在引号内
|
47 |
+
if in_quotes:
|
48 |
+
# 如果字符是空格
|
49 |
+
if char.isspace():
|
50 |
+
pass # have space token
|
51 |
+
|
52 |
+
# 将字符用中文引号包裹后添加到结果列表
|
53 |
+
result.append("“" + char + "”")
|
54 |
+
else:
|
55 |
+
# 将字符添加到临时字符串
|
56 |
+
temp += char
|
57 |
+
|
58 |
+
# 如果临时字符串不为空
|
59 |
+
if temp:
|
60 |
+
# 将临时字符串添加到结果列表
|
61 |
+
result.append(temp)
|
62 |
+
|
63 |
+
return result
|
64 |
+
|
65 |
+
|
66 |
+
class Qwen25VL_7b_Embedder(torch.nn.Module):
|
67 |
+
def __init__(self, model_path, max_length=640, dtype=torch.bfloat16, device="cuda"):
|
68 |
+
super(Qwen25VL_7b_Embedder, self).__init__()
|
69 |
+
self.max_length = max_length
|
70 |
+
self.dtype = dtype
|
71 |
+
self.device = device
|
72 |
+
|
73 |
+
self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
74 |
+
model_path,
|
75 |
+
torch_dtype=dtype,
|
76 |
+
attn_implementation="flash_attention_2",
|
77 |
+
).to(torch.cuda.current_device())
|
78 |
+
|
79 |
+
self.model.requires_grad_(False)
|
80 |
+
self.processor = AutoProcessor.from_pretrained(
|
81 |
+
model_path, min_pixels=256 * 28 * 28, max_pixels=324 * 28 * 28
|
82 |
+
)
|
83 |
+
|
84 |
+
self.prefix = Qwen25VL_7b_PREFIX
|
85 |
+
|
86 |
+
def forward(self, caption, ref_images):
|
87 |
+
text_list = caption
|
88 |
+
embs = torch.zeros(
|
89 |
+
len(text_list),
|
90 |
+
self.max_length,
|
91 |
+
self.model.config.hidden_size,
|
92 |
+
dtype=torch.bfloat16,
|
93 |
+
device=torch.cuda.current_device(),
|
94 |
+
)
|
95 |
+
hidden_states = torch.zeros(
|
96 |
+
len(text_list),
|
97 |
+
self.max_length,
|
98 |
+
self.model.config.hidden_size,
|
99 |
+
dtype=torch.bfloat16,
|
100 |
+
device=torch.cuda.current_device(),
|
101 |
+
)
|
102 |
+
masks = torch.zeros(
|
103 |
+
len(text_list),
|
104 |
+
self.max_length,
|
105 |
+
dtype=torch.long,
|
106 |
+
device=torch.cuda.current_device(),
|
107 |
+
)
|
108 |
+
input_ids_list = []
|
109 |
+
attention_mask_list = []
|
110 |
+
emb_list = []
|
111 |
+
|
112 |
+
def split_string(s):
|
113 |
+
s = s.replace("“", '"').replace("”", '"').replace("'", '''"''') # use english quotes
|
114 |
+
result = []
|
115 |
+
in_quotes = False
|
116 |
+
temp = ""
|
117 |
+
|
118 |
+
for idx,char in enumerate(s):
|
119 |
+
if char == '"' and idx>155:
|
120 |
+
temp += char
|
121 |
+
if not in_quotes:
|
122 |
+
result.append(temp)
|
123 |
+
temp = ""
|
124 |
+
|
125 |
+
in_quotes = not in_quotes
|
126 |
+
continue
|
127 |
+
if in_quotes:
|
128 |
+
if char.isspace():
|
129 |
+
pass # have space token
|
130 |
+
|
131 |
+
result.append("“" + char + "”")
|
132 |
+
else:
|
133 |
+
temp += char
|
134 |
+
|
135 |
+
if temp:
|
136 |
+
result.append(temp)
|
137 |
+
|
138 |
+
return result
|
139 |
+
|
140 |
+
for idx, (txt, imgs) in enumerate(zip(text_list, ref_images)):
|
141 |
+
|
142 |
+
messages = [{"role": "user", "content": []}]
|
143 |
+
|
144 |
+
messages[0]["content"].append({"type": "text", "text": f"{self.prefix}"})
|
145 |
+
|
146 |
+
messages[0]["content"].append({"type": "image", "image": to_pil(imgs)})
|
147 |
+
|
148 |
+
# 再添加 text
|
149 |
+
messages[0]["content"].append({"type": "text", "text": f"{txt}"})
|
150 |
+
|
151 |
+
# Preparation for inference
|
152 |
+
text = self.processor.apply_chat_template(
|
153 |
+
messages, tokenize=False, add_generation_prompt=True, add_vision_id=True
|
154 |
+
)
|
155 |
+
|
156 |
+
image_inputs, video_inputs = process_vision_info(messages)
|
157 |
+
|
158 |
+
inputs = self.processor(
|
159 |
+
text=[text],
|
160 |
+
images=image_inputs,
|
161 |
+
padding=True,
|
162 |
+
return_tensors="pt",
|
163 |
+
)
|
164 |
+
|
165 |
+
old_inputs_ids = inputs.input_ids
|
166 |
+
text_split_list = split_string(text)
|
167 |
+
|
168 |
+
token_list = []
|
169 |
+
for text_each in text_split_list:
|
170 |
+
txt_inputs = self.processor(
|
171 |
+
text=text_each,
|
172 |
+
images=None,
|
173 |
+
videos=None,
|
174 |
+
padding=True,
|
175 |
+
return_tensors="pt",
|
176 |
+
)
|
177 |
+
token_each = txt_inputs.input_ids
|
178 |
+
if token_each[0][0] == 2073 and token_each[0][-1] == 854:
|
179 |
+
token_each = token_each[:, 1:-1]
|
180 |
+
token_list.append(token_each)
|
181 |
+
else:
|
182 |
+
token_list.append(token_each)
|
183 |
+
|
184 |
+
new_txt_ids = torch.cat(token_list, dim=1).to("cuda")
|
185 |
+
|
186 |
+
new_txt_ids = new_txt_ids.to(old_inputs_ids.device)
|
187 |
+
|
188 |
+
idx1 = (old_inputs_ids == 151653).nonzero(as_tuple=True)[1][0]
|
189 |
+
idx2 = (new_txt_ids == 151653).nonzero(as_tuple=True)[1][0]
|
190 |
+
inputs.input_ids = (
|
191 |
+
torch.cat([old_inputs_ids[0, :idx1], new_txt_ids[0, idx2:]], dim=0)
|
192 |
+
.unsqueeze(0)
|
193 |
+
.to("cuda")
|
194 |
+
)
|
195 |
+
inputs.attention_mask = (inputs.input_ids > 0).long().to("cuda")
|
196 |
+
outputs = self.model(
|
197 |
+
input_ids=inputs.input_ids,
|
198 |
+
attention_mask=inputs.attention_mask,
|
199 |
+
pixel_values=inputs.pixel_values.to("cuda"),
|
200 |
+
image_grid_thw=inputs.image_grid_thw.to("cuda"),
|
201 |
+
output_hidden_states=True,
|
202 |
+
)
|
203 |
+
|
204 |
+
emb = outputs["hidden_states"][-1]
|
205 |
+
|
206 |
+
embs[idx, : min(self.max_length, emb.shape[1] - 217)] = emb[0, 217:][
|
207 |
+
: self.max_length
|
208 |
+
]
|
209 |
+
|
210 |
+
masks[idx, : min(self.max_length, emb.shape[1] - 217)] = torch.ones(
|
211 |
+
(min(self.max_length, emb.shape[1] - 217)),
|
212 |
+
dtype=torch.long,
|
213 |
+
device=torch.cuda.current_device(),
|
214 |
+
)
|
215 |
+
|
216 |
+
return embs, masks
|
modules/connector_edit.py
ADDED
@@ -0,0 +1,486 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn
|
5 |
+
from einops import rearrange
|
6 |
+
from torch import nn
|
7 |
+
|
8 |
+
from .layers import MLP, TextProjection, TimestepEmbedder, apply_gate, attention
|
9 |
+
|
10 |
+
|
11 |
+
class RMSNorm(nn.Module):
|
12 |
+
def __init__(
|
13 |
+
self,
|
14 |
+
dim: int,
|
15 |
+
elementwise_affine=True,
|
16 |
+
eps: float = 1e-6,
|
17 |
+
device=None,
|
18 |
+
dtype=None,
|
19 |
+
):
|
20 |
+
"""
|
21 |
+
Initialize the RMSNorm normalization layer.
|
22 |
+
|
23 |
+
Args:
|
24 |
+
dim (int): The dimension of the input tensor.
|
25 |
+
eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
|
26 |
+
|
27 |
+
Attributes:
|
28 |
+
eps (float): A small value added to the denominator for numerical stability.
|
29 |
+
weight (nn.Parameter): Learnable scaling parameter.
|
30 |
+
|
31 |
+
"""
|
32 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
33 |
+
super().__init__()
|
34 |
+
self.eps = eps
|
35 |
+
if elementwise_affine:
|
36 |
+
self.weight = nn.Parameter(torch.ones(dim, **factory_kwargs))
|
37 |
+
|
38 |
+
def _norm(self, x):
|
39 |
+
"""
|
40 |
+
Apply the RMSNorm normalization to the input tensor.
|
41 |
+
|
42 |
+
Args:
|
43 |
+
x (torch.Tensor): The input tensor.
|
44 |
+
|
45 |
+
Returns:
|
46 |
+
torch.Tensor: The normalized tensor.
|
47 |
+
|
48 |
+
"""
|
49 |
+
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
50 |
+
|
51 |
+
def forward(self, x):
|
52 |
+
"""
|
53 |
+
Forward pass through the RMSNorm layer.
|
54 |
+
|
55 |
+
Args:
|
56 |
+
x (torch.Tensor): The input tensor.
|
57 |
+
|
58 |
+
Returns:
|
59 |
+
torch.Tensor: The output tensor after applying RMSNorm.
|
60 |
+
|
61 |
+
"""
|
62 |
+
output = self._norm(x.float()).type_as(x)
|
63 |
+
if hasattr(self, "weight"):
|
64 |
+
output = output * self.weight
|
65 |
+
return output
|
66 |
+
|
67 |
+
|
68 |
+
def get_norm_layer(norm_layer):
|
69 |
+
"""
|
70 |
+
Get the normalization layer.
|
71 |
+
|
72 |
+
Args:
|
73 |
+
norm_layer (str): The type of normalization layer.
|
74 |
+
|
75 |
+
Returns:
|
76 |
+
norm_layer (nn.Module): The normalization layer.
|
77 |
+
"""
|
78 |
+
if norm_layer == "layer":
|
79 |
+
return nn.LayerNorm
|
80 |
+
elif norm_layer == "rms":
|
81 |
+
return RMSNorm
|
82 |
+
else:
|
83 |
+
raise NotImplementedError(f"Norm layer {norm_layer} is not implemented")
|
84 |
+
|
85 |
+
|
86 |
+
def get_activation_layer(act_type):
|
87 |
+
"""get activation layer
|
88 |
+
|
89 |
+
Args:
|
90 |
+
act_type (str): the activation type
|
91 |
+
|
92 |
+
Returns:
|
93 |
+
torch.nn.functional: the activation layer
|
94 |
+
"""
|
95 |
+
if act_type == "gelu":
|
96 |
+
return lambda: nn.GELU()
|
97 |
+
elif act_type == "gelu_tanh":
|
98 |
+
return lambda: nn.GELU(approximate="tanh")
|
99 |
+
elif act_type == "relu":
|
100 |
+
return nn.ReLU
|
101 |
+
elif act_type == "silu":
|
102 |
+
return nn.SiLU
|
103 |
+
else:
|
104 |
+
raise ValueError(f"Unknown activation type: {act_type}")
|
105 |
+
|
106 |
+
class IndividualTokenRefinerBlock(torch.nn.Module):
|
107 |
+
def __init__(
|
108 |
+
self,
|
109 |
+
hidden_size,
|
110 |
+
heads_num,
|
111 |
+
mlp_width_ratio: str = 4.0,
|
112 |
+
mlp_drop_rate: float = 0.0,
|
113 |
+
act_type: str = "silu",
|
114 |
+
qk_norm: bool = False,
|
115 |
+
qk_norm_type: str = "layer",
|
116 |
+
qkv_bias: bool = True,
|
117 |
+
need_CA: bool = False,
|
118 |
+
dtype: Optional[torch.dtype] = None,
|
119 |
+
device: Optional[torch.device] = None,
|
120 |
+
):
|
121 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
122 |
+
super().__init__()
|
123 |
+
self.need_CA = need_CA
|
124 |
+
self.heads_num = heads_num
|
125 |
+
head_dim = hidden_size // heads_num
|
126 |
+
mlp_hidden_dim = int(hidden_size * mlp_width_ratio)
|
127 |
+
|
128 |
+
self.norm1 = nn.LayerNorm(
|
129 |
+
hidden_size, elementwise_affine=True, eps=1e-6, **factory_kwargs
|
130 |
+
)
|
131 |
+
self.self_attn_qkv = nn.Linear(
|
132 |
+
hidden_size, hidden_size * 3, bias=qkv_bias, **factory_kwargs
|
133 |
+
)
|
134 |
+
qk_norm_layer = get_norm_layer(qk_norm_type)
|
135 |
+
self.self_attn_q_norm = (
|
136 |
+
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
|
137 |
+
if qk_norm
|
138 |
+
else nn.Identity()
|
139 |
+
)
|
140 |
+
self.self_attn_k_norm = (
|
141 |
+
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
|
142 |
+
if qk_norm
|
143 |
+
else nn.Identity()
|
144 |
+
)
|
145 |
+
self.self_attn_proj = nn.Linear(
|
146 |
+
hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs
|
147 |
+
)
|
148 |
+
|
149 |
+
self.norm2 = nn.LayerNorm(
|
150 |
+
hidden_size, elementwise_affine=True, eps=1e-6, **factory_kwargs
|
151 |
+
)
|
152 |
+
act_layer = get_activation_layer(act_type)
|
153 |
+
self.mlp = MLP(
|
154 |
+
in_channels=hidden_size,
|
155 |
+
hidden_channels=mlp_hidden_dim,
|
156 |
+
act_layer=act_layer,
|
157 |
+
drop=mlp_drop_rate,
|
158 |
+
**factory_kwargs,
|
159 |
+
)
|
160 |
+
|
161 |
+
self.adaLN_modulation = nn.Sequential(
|
162 |
+
act_layer(),
|
163 |
+
nn.Linear(hidden_size, 2 * hidden_size, bias=True, **factory_kwargs),
|
164 |
+
)
|
165 |
+
|
166 |
+
if self.need_CA:
|
167 |
+
self.cross_attnblock=CrossAttnBlock(hidden_size=hidden_size,
|
168 |
+
heads_num=heads_num,
|
169 |
+
mlp_width_ratio=mlp_width_ratio,
|
170 |
+
mlp_drop_rate=mlp_drop_rate,
|
171 |
+
act_type=act_type,
|
172 |
+
qk_norm=qk_norm,
|
173 |
+
qk_norm_type=qk_norm_type,
|
174 |
+
qkv_bias=qkv_bias,
|
175 |
+
**factory_kwargs,)
|
176 |
+
# Zero-initialize the modulation
|
177 |
+
nn.init.zeros_(self.adaLN_modulation[1].weight)
|
178 |
+
nn.init.zeros_(self.adaLN_modulation[1].bias)
|
179 |
+
|
180 |
+
def forward(
|
181 |
+
self,
|
182 |
+
x: torch.Tensor,
|
183 |
+
c: torch.Tensor, # timestep_aware_representations + context_aware_representations
|
184 |
+
attn_mask: torch.Tensor = None,
|
185 |
+
y: torch.Tensor = None,
|
186 |
+
):
|
187 |
+
gate_msa, gate_mlp = self.adaLN_modulation(c).chunk(2, dim=1)
|
188 |
+
|
189 |
+
norm_x = self.norm1(x)
|
190 |
+
qkv = self.self_attn_qkv(norm_x)
|
191 |
+
q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num)
|
192 |
+
# Apply QK-Norm if needed
|
193 |
+
q = self.self_attn_q_norm(q).to(v)
|
194 |
+
k = self.self_attn_k_norm(k).to(v)
|
195 |
+
|
196 |
+
# Self-Attention
|
197 |
+
attn = attention(q, k, v, mode="torch", attn_mask=attn_mask)
|
198 |
+
|
199 |
+
x = x + apply_gate(self.self_attn_proj(attn), gate_msa)
|
200 |
+
|
201 |
+
if self.need_CA:
|
202 |
+
x = self.cross_attnblock(x, c, attn_mask, y)
|
203 |
+
|
204 |
+
# FFN Layer
|
205 |
+
x = x + apply_gate(self.mlp(self.norm2(x)), gate_mlp)
|
206 |
+
|
207 |
+
return x
|
208 |
+
|
209 |
+
|
210 |
+
|
211 |
+
|
212 |
+
class CrossAttnBlock(torch.nn.Module):
|
213 |
+
def __init__(
|
214 |
+
self,
|
215 |
+
hidden_size,
|
216 |
+
heads_num,
|
217 |
+
mlp_width_ratio: str = 4.0,
|
218 |
+
mlp_drop_rate: float = 0.0,
|
219 |
+
act_type: str = "silu",
|
220 |
+
qk_norm: bool = False,
|
221 |
+
qk_norm_type: str = "layer",
|
222 |
+
qkv_bias: bool = True,
|
223 |
+
dtype: Optional[torch.dtype] = None,
|
224 |
+
device: Optional[torch.device] = None,
|
225 |
+
):
|
226 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
227 |
+
super().__init__()
|
228 |
+
self.heads_num = heads_num
|
229 |
+
head_dim = hidden_size // heads_num
|
230 |
+
|
231 |
+
self.norm1 = nn.LayerNorm(
|
232 |
+
hidden_size, elementwise_affine=True, eps=1e-6, **factory_kwargs
|
233 |
+
)
|
234 |
+
self.norm1_2 = nn.LayerNorm(
|
235 |
+
hidden_size, elementwise_affine=True, eps=1e-6, **factory_kwargs
|
236 |
+
)
|
237 |
+
self.self_attn_q = nn.Linear(
|
238 |
+
hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs
|
239 |
+
)
|
240 |
+
self.self_attn_kv = nn.Linear(
|
241 |
+
hidden_size, hidden_size*2, bias=qkv_bias, **factory_kwargs
|
242 |
+
)
|
243 |
+
qk_norm_layer = get_norm_layer(qk_norm_type)
|
244 |
+
self.self_attn_q_norm = (
|
245 |
+
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
|
246 |
+
if qk_norm
|
247 |
+
else nn.Identity()
|
248 |
+
)
|
249 |
+
self.self_attn_k_norm = (
|
250 |
+
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
|
251 |
+
if qk_norm
|
252 |
+
else nn.Identity()
|
253 |
+
)
|
254 |
+
self.self_attn_proj = nn.Linear(
|
255 |
+
hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs
|
256 |
+
)
|
257 |
+
|
258 |
+
self.norm2 = nn.LayerNorm(
|
259 |
+
hidden_size, elementwise_affine=True, eps=1e-6, **factory_kwargs
|
260 |
+
)
|
261 |
+
act_layer = get_activation_layer(act_type)
|
262 |
+
|
263 |
+
self.adaLN_modulation = nn.Sequential(
|
264 |
+
act_layer(),
|
265 |
+
nn.Linear(hidden_size, 2 * hidden_size, bias=True, **factory_kwargs),
|
266 |
+
)
|
267 |
+
# Zero-initialize the modulation
|
268 |
+
nn.init.zeros_(self.adaLN_modulation[1].weight)
|
269 |
+
nn.init.zeros_(self.adaLN_modulation[1].bias)
|
270 |
+
|
271 |
+
def forward(
|
272 |
+
self,
|
273 |
+
x: torch.Tensor,
|
274 |
+
c: torch.Tensor, # timestep_aware_representations + context_aware_representations
|
275 |
+
attn_mask: torch.Tensor = None,
|
276 |
+
y: torch.Tensor=None,
|
277 |
+
|
278 |
+
):
|
279 |
+
gate_msa, gate_mlp = self.adaLN_modulation(c).chunk(2, dim=1)
|
280 |
+
|
281 |
+
norm_x = self.norm1(x)
|
282 |
+
norm_y = self.norm1_2(y)
|
283 |
+
q = self.self_attn_q(norm_x)
|
284 |
+
q = rearrange(q, "B L (H D) -> B L H D", H=self.heads_num)
|
285 |
+
kv = self.self_attn_kv(norm_y)
|
286 |
+
k, v = rearrange(kv, "B L (K H D) -> K B L H D", K=2, H=self.heads_num)
|
287 |
+
# Apply QK-Norm if needed
|
288 |
+
q = self.self_attn_q_norm(q).to(v)
|
289 |
+
k = self.self_attn_k_norm(k).to(v)
|
290 |
+
|
291 |
+
# Self-Attention
|
292 |
+
attn = attention(q, k, v, mode="torch", attn_mask=attn_mask)
|
293 |
+
|
294 |
+
x = x + apply_gate(self.self_attn_proj(attn), gate_msa)
|
295 |
+
|
296 |
+
return x
|
297 |
+
|
298 |
+
|
299 |
+
|
300 |
+
class IndividualTokenRefiner(torch.nn.Module):
|
301 |
+
def __init__(
|
302 |
+
self,
|
303 |
+
hidden_size,
|
304 |
+
heads_num,
|
305 |
+
depth,
|
306 |
+
mlp_width_ratio: float = 4.0,
|
307 |
+
mlp_drop_rate: float = 0.0,
|
308 |
+
act_type: str = "silu",
|
309 |
+
qk_norm: bool = False,
|
310 |
+
qk_norm_type: str = "layer",
|
311 |
+
qkv_bias: bool = True,
|
312 |
+
need_CA:bool=False,
|
313 |
+
dtype: Optional[torch.dtype] = None,
|
314 |
+
device: Optional[torch.device] = None,
|
315 |
+
):
|
316 |
+
|
317 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
318 |
+
super().__init__()
|
319 |
+
self.need_CA = need_CA
|
320 |
+
self.blocks = nn.ModuleList(
|
321 |
+
[
|
322 |
+
IndividualTokenRefinerBlock(
|
323 |
+
hidden_size=hidden_size,
|
324 |
+
heads_num=heads_num,
|
325 |
+
mlp_width_ratio=mlp_width_ratio,
|
326 |
+
mlp_drop_rate=mlp_drop_rate,
|
327 |
+
act_type=act_type,
|
328 |
+
qk_norm=qk_norm,
|
329 |
+
qk_norm_type=qk_norm_type,
|
330 |
+
qkv_bias=qkv_bias,
|
331 |
+
need_CA=self.need_CA,
|
332 |
+
**factory_kwargs,
|
333 |
+
)
|
334 |
+
for _ in range(depth)
|
335 |
+
]
|
336 |
+
)
|
337 |
+
|
338 |
+
|
339 |
+
def forward(
|
340 |
+
self,
|
341 |
+
x: torch.Tensor,
|
342 |
+
c: torch.LongTensor,
|
343 |
+
mask: Optional[torch.Tensor] = None,
|
344 |
+
y:torch.Tensor=None,
|
345 |
+
):
|
346 |
+
self_attn_mask = None
|
347 |
+
if mask is not None:
|
348 |
+
batch_size = mask.shape[0]
|
349 |
+
seq_len = mask.shape[1]
|
350 |
+
mask = mask.to(x.device)
|
351 |
+
# batch_size x 1 x seq_len x seq_len
|
352 |
+
self_attn_mask_1 = mask.view(batch_size, 1, 1, seq_len).repeat(
|
353 |
+
1, 1, seq_len, 1
|
354 |
+
)
|
355 |
+
# batch_size x 1 x seq_len x seq_len
|
356 |
+
self_attn_mask_2 = self_attn_mask_1.transpose(2, 3)
|
357 |
+
# batch_size x 1 x seq_len x seq_len, 1 for broadcasting of heads_num
|
358 |
+
self_attn_mask = (self_attn_mask_1 & self_attn_mask_2).bool()
|
359 |
+
# avoids self-attention weight being NaN for padding tokens
|
360 |
+
self_attn_mask[:, :, :, 0] = True
|
361 |
+
|
362 |
+
|
363 |
+
for block in self.blocks:
|
364 |
+
x = block(x, c, self_attn_mask,y)
|
365 |
+
|
366 |
+
return x
|
367 |
+
|
368 |
+
|
369 |
+
class SingleTokenRefiner(torch.nn.Module):
|
370 |
+
"""
|
371 |
+
A single token refiner block for llm text embedding refine.
|
372 |
+
"""
|
373 |
+
def __init__(
|
374 |
+
self,
|
375 |
+
in_channels,
|
376 |
+
hidden_size,
|
377 |
+
heads_num,
|
378 |
+
depth,
|
379 |
+
mlp_width_ratio: float = 4.0,
|
380 |
+
mlp_drop_rate: float = 0.0,
|
381 |
+
act_type: str = "silu",
|
382 |
+
qk_norm: bool = False,
|
383 |
+
qk_norm_type: str = "layer",
|
384 |
+
qkv_bias: bool = True,
|
385 |
+
need_CA:bool=False,
|
386 |
+
attn_mode: str = "torch",
|
387 |
+
dtype: Optional[torch.dtype] = None,
|
388 |
+
device: Optional[torch.device] = None,
|
389 |
+
):
|
390 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
391 |
+
super().__init__()
|
392 |
+
self.attn_mode = attn_mode
|
393 |
+
self.need_CA = need_CA
|
394 |
+
assert self.attn_mode == "torch", "Only support 'torch' mode for token refiner."
|
395 |
+
|
396 |
+
self.input_embedder = nn.Linear(
|
397 |
+
in_channels, hidden_size, bias=True, **factory_kwargs
|
398 |
+
)
|
399 |
+
if self.need_CA:
|
400 |
+
self.input_embedder_CA = nn.Linear(
|
401 |
+
in_channels, hidden_size, bias=True, **factory_kwargs
|
402 |
+
)
|
403 |
+
|
404 |
+
act_layer = get_activation_layer(act_type)
|
405 |
+
# Build timestep embedding layer
|
406 |
+
self.t_embedder = TimestepEmbedder(hidden_size, act_layer, **factory_kwargs)
|
407 |
+
# Build context embedding layer
|
408 |
+
self.c_embedder = TextProjection(
|
409 |
+
in_channels, hidden_size, act_layer, **factory_kwargs
|
410 |
+
)
|
411 |
+
|
412 |
+
self.individual_token_refiner = IndividualTokenRefiner(
|
413 |
+
hidden_size=hidden_size,
|
414 |
+
heads_num=heads_num,
|
415 |
+
depth=depth,
|
416 |
+
mlp_width_ratio=mlp_width_ratio,
|
417 |
+
mlp_drop_rate=mlp_drop_rate,
|
418 |
+
act_type=act_type,
|
419 |
+
qk_norm=qk_norm,
|
420 |
+
qk_norm_type=qk_norm_type,
|
421 |
+
qkv_bias=qkv_bias,
|
422 |
+
need_CA=need_CA,
|
423 |
+
**factory_kwargs,
|
424 |
+
)
|
425 |
+
|
426 |
+
def forward(
|
427 |
+
self,
|
428 |
+
x: torch.Tensor,
|
429 |
+
t: torch.LongTensor,
|
430 |
+
mask: Optional[torch.LongTensor] = None,
|
431 |
+
y: torch.LongTensor=None,
|
432 |
+
):
|
433 |
+
timestep_aware_representations = self.t_embedder(t)
|
434 |
+
|
435 |
+
if mask is None:
|
436 |
+
context_aware_representations = x.mean(dim=1)
|
437 |
+
else:
|
438 |
+
mask_float = mask.unsqueeze(-1) # [b, s1, 1]
|
439 |
+
context_aware_representations = (x * mask_float).sum(
|
440 |
+
dim=1
|
441 |
+
) / mask_float.sum(dim=1)
|
442 |
+
context_aware_representations = self.c_embedder(context_aware_representations)
|
443 |
+
c = timestep_aware_representations + context_aware_representations
|
444 |
+
|
445 |
+
x = self.input_embedder(x)
|
446 |
+
if self.need_CA:
|
447 |
+
y = self.input_embedder_CA(y)
|
448 |
+
x = self.individual_token_refiner(x, c, mask, y)
|
449 |
+
else:
|
450 |
+
x = self.individual_token_refiner(x, c, mask)
|
451 |
+
|
452 |
+
return x
|
453 |
+
|
454 |
+
|
455 |
+
|
456 |
+
class Qwen2Connector(torch.nn.Module):
|
457 |
+
def __init__(
|
458 |
+
self,
|
459 |
+
# biclip_dim=1024,
|
460 |
+
in_channels=3584,
|
461 |
+
hidden_size=4096,
|
462 |
+
heads_num=32,
|
463 |
+
depth=2,
|
464 |
+
need_CA=False,
|
465 |
+
device=None,
|
466 |
+
dtype=torch.bfloat16,
|
467 |
+
):
|
468 |
+
super().__init__()
|
469 |
+
factory_kwargs = {"device": device, "dtype":dtype}
|
470 |
+
|
471 |
+
self.S =SingleTokenRefiner(in_channels=in_channels,hidden_size=hidden_size,heads_num=heads_num,depth=depth,need_CA=need_CA,**factory_kwargs)
|
472 |
+
self.global_proj_out=nn.Linear(in_channels,768)
|
473 |
+
|
474 |
+
self.scale_factor = nn.Parameter(torch.zeros(1))
|
475 |
+
with torch.no_grad():
|
476 |
+
self.scale_factor.data += -(1 - 0.09)
|
477 |
+
|
478 |
+
def forward(self, x,t,mask):
|
479 |
+
mask_float = mask.unsqueeze(-1) # [b, s1, 1]
|
480 |
+
x_mean = (x * mask_float).sum(
|
481 |
+
dim=1
|
482 |
+
) / mask_float.sum(dim=1) * (1 + self.scale_factor)
|
483 |
+
|
484 |
+
global_out=self.global_proj_out(x_mean)
|
485 |
+
encoder_hidden_states = self.S(x,t,mask)
|
486 |
+
return encoder_hidden_states,global_out
|
modules/layers.py
ADDED
@@ -0,0 +1,639 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Modified from Flux
|
2 |
+
#
|
3 |
+
# Copyright 2024 Black Forest Labs
|
4 |
+
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
#
|
17 |
+
# This source code is licensed under the license found in the
|
18 |
+
# LICENSE file in the root directory of this source tree.
|
19 |
+
|
20 |
+
import math # noqa: I001
|
21 |
+
from dataclasses import dataclass
|
22 |
+
from functools import partial
|
23 |
+
|
24 |
+
import torch
|
25 |
+
import torch.nn.functional as F
|
26 |
+
from einops import rearrange
|
27 |
+
from liger_kernel.ops.rms_norm import LigerRMSNormFunction
|
28 |
+
from torch import Tensor, nn
|
29 |
+
|
30 |
+
|
31 |
+
try:
|
32 |
+
import flash_attn
|
33 |
+
from flash_attn.flash_attn_interface import (
|
34 |
+
_flash_attn_forward,
|
35 |
+
flash_attn_varlen_func,
|
36 |
+
)
|
37 |
+
except ImportError:
|
38 |
+
flash_attn = None
|
39 |
+
flash_attn_varlen_func = None
|
40 |
+
_flash_attn_forward = None
|
41 |
+
|
42 |
+
|
43 |
+
MEMORY_LAYOUT = {
|
44 |
+
"flash": (
|
45 |
+
lambda x: x.view(x.shape[0] * x.shape[1], *x.shape[2:]),
|
46 |
+
lambda x: x,
|
47 |
+
),
|
48 |
+
"torch": (
|
49 |
+
lambda x: x.transpose(1, 2),
|
50 |
+
lambda x: x.transpose(1, 2),
|
51 |
+
),
|
52 |
+
"vanilla": (
|
53 |
+
lambda x: x.transpose(1, 2),
|
54 |
+
lambda x: x.transpose(1, 2),
|
55 |
+
),
|
56 |
+
}
|
57 |
+
|
58 |
+
|
59 |
+
def attention(
|
60 |
+
q,
|
61 |
+
k,
|
62 |
+
v,
|
63 |
+
mode="flash",
|
64 |
+
drop_rate=0,
|
65 |
+
attn_mask=None,
|
66 |
+
causal=False,
|
67 |
+
cu_seqlens_q=None,
|
68 |
+
cu_seqlens_kv=None,
|
69 |
+
max_seqlen_q=None,
|
70 |
+
max_seqlen_kv=None,
|
71 |
+
batch_size=1,
|
72 |
+
):
|
73 |
+
"""
|
74 |
+
Perform QKV self attention.
|
75 |
+
|
76 |
+
Args:
|
77 |
+
q (torch.Tensor): Query tensor with shape [b, s, a, d], where a is the number of heads.
|
78 |
+
k (torch.Tensor): Key tensor with shape [b, s1, a, d]
|
79 |
+
v (torch.Tensor): Value tensor with shape [b, s1, a, d]
|
80 |
+
mode (str): Attention mode. Choose from 'self_flash', 'cross_flash', 'torch', and 'vanilla'.
|
81 |
+
drop_rate (float): Dropout rate in attention map. (default: 0)
|
82 |
+
attn_mask (torch.Tensor): Attention mask with shape [b, s1] (cross_attn), or [b, a, s, s1] (torch or vanilla).
|
83 |
+
(default: None)
|
84 |
+
causal (bool): Whether to use causal attention. (default: False)
|
85 |
+
cu_seqlens_q (torch.Tensor): dtype torch.int32. The cumulative sequence lengths of the sequences in the batch,
|
86 |
+
used to index into q.
|
87 |
+
cu_seqlens_kv (torch.Tensor): dtype torch.int32. The cumulative sequence lengths of the sequences in the batch,
|
88 |
+
used to index into kv.
|
89 |
+
max_seqlen_q (int): The maximum sequence length in the batch of q.
|
90 |
+
max_seqlen_kv (int): The maximum sequence length in the batch of k and v.
|
91 |
+
|
92 |
+
Returns:
|
93 |
+
torch.Tensor: Output tensor after self attention with shape [b, s, ad]
|
94 |
+
"""
|
95 |
+
pre_attn_layout, post_attn_layout = MEMORY_LAYOUT[mode]
|
96 |
+
q = pre_attn_layout(q)
|
97 |
+
k = pre_attn_layout(k)
|
98 |
+
v = pre_attn_layout(v)
|
99 |
+
|
100 |
+
if mode == "torch":
|
101 |
+
if attn_mask is not None and attn_mask.dtype != torch.bool:
|
102 |
+
attn_mask = attn_mask.to(q.dtype)
|
103 |
+
x = F.scaled_dot_product_attention(
|
104 |
+
q, k, v, attn_mask=attn_mask, dropout_p=drop_rate, is_causal=causal
|
105 |
+
)
|
106 |
+
elif mode == "flash":
|
107 |
+
assert flash_attn_varlen_func is not None
|
108 |
+
x: torch.Tensor = flash_attn_varlen_func(
|
109 |
+
q,
|
110 |
+
k,
|
111 |
+
v,
|
112 |
+
cu_seqlens_q,
|
113 |
+
cu_seqlens_kv,
|
114 |
+
max_seqlen_q,
|
115 |
+
max_seqlen_kv,
|
116 |
+
) # type: ignore
|
117 |
+
# x with shape [(bxs), a, d]
|
118 |
+
x = x.view(batch_size, max_seqlen_q, x.shape[-2], x.shape[-1]) # type: ignore # reshape x to [b, s, a, d]
|
119 |
+
elif mode == "vanilla":
|
120 |
+
scale_factor = 1 / math.sqrt(q.size(-1))
|
121 |
+
|
122 |
+
b, a, s, _ = q.shape
|
123 |
+
s1 = k.size(2)
|
124 |
+
attn_bias = torch.zeros(b, a, s, s1, dtype=q.dtype, device=q.device)
|
125 |
+
if causal:
|
126 |
+
# Only applied to self attention
|
127 |
+
assert attn_mask is None, (
|
128 |
+
"Causal mask and attn_mask cannot be used together"
|
129 |
+
)
|
130 |
+
temp_mask = torch.ones(b, a, s, s, dtype=torch.bool, device=q.device).tril(
|
131 |
+
diagonal=0
|
132 |
+
)
|
133 |
+
attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
|
134 |
+
attn_bias.to(q.dtype)
|
135 |
+
|
136 |
+
if attn_mask is not None:
|
137 |
+
if attn_mask.dtype == torch.bool:
|
138 |
+
attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
|
139 |
+
else:
|
140 |
+
attn_bias += attn_mask
|
141 |
+
|
142 |
+
# TODO: Maybe force q and k to be float32 to avoid numerical overflow
|
143 |
+
attn = (q @ k.transpose(-2, -1)) * scale_factor
|
144 |
+
attn += attn_bias
|
145 |
+
attn = attn.softmax(dim=-1)
|
146 |
+
attn = torch.dropout(attn, p=drop_rate, train=True)
|
147 |
+
x = attn @ v
|
148 |
+
else:
|
149 |
+
raise NotImplementedError(f"Unsupported attention mode: {mode}")
|
150 |
+
|
151 |
+
x = post_attn_layout(x)
|
152 |
+
b, s, a, d = x.shape
|
153 |
+
out = x.reshape(b, s, -1)
|
154 |
+
return out
|
155 |
+
|
156 |
+
|
157 |
+
def apply_gate(x, gate=None, tanh=False):
|
158 |
+
"""AI is creating summary for apply_gate
|
159 |
+
|
160 |
+
Args:
|
161 |
+
x (torch.Tensor): input tensor.
|
162 |
+
gate (torch.Tensor, optional): gate tensor. Defaults to None.
|
163 |
+
tanh (bool, optional): whether to use tanh function. Defaults to False.
|
164 |
+
|
165 |
+
Returns:
|
166 |
+
torch.Tensor: the output tensor after apply gate.
|
167 |
+
"""
|
168 |
+
if gate is None:
|
169 |
+
return x
|
170 |
+
if tanh:
|
171 |
+
return x * gate.unsqueeze(1).tanh()
|
172 |
+
else:
|
173 |
+
return x * gate.unsqueeze(1)
|
174 |
+
|
175 |
+
|
176 |
+
class MLP(nn.Module):
|
177 |
+
"""MLP as used in Vision Transformer, MLP-Mixer and related networks"""
|
178 |
+
|
179 |
+
def __init__(
|
180 |
+
self,
|
181 |
+
in_channels,
|
182 |
+
hidden_channels=None,
|
183 |
+
out_features=None,
|
184 |
+
act_layer=nn.GELU,
|
185 |
+
norm_layer=None,
|
186 |
+
bias=True,
|
187 |
+
drop=0.0,
|
188 |
+
use_conv=False,
|
189 |
+
device=None,
|
190 |
+
dtype=None,
|
191 |
+
):
|
192 |
+
super().__init__()
|
193 |
+
out_features = out_features or in_channels
|
194 |
+
hidden_channels = hidden_channels or in_channels
|
195 |
+
bias = (bias, bias)
|
196 |
+
drop_probs = (drop, drop)
|
197 |
+
linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear
|
198 |
+
|
199 |
+
self.fc1 = linear_layer(
|
200 |
+
in_channels, hidden_channels, bias=bias[0], device=device, dtype=dtype
|
201 |
+
)
|
202 |
+
self.act = act_layer()
|
203 |
+
self.drop1 = nn.Dropout(drop_probs[0])
|
204 |
+
self.norm = (
|
205 |
+
norm_layer(hidden_channels, device=device, dtype=dtype)
|
206 |
+
if norm_layer is not None
|
207 |
+
else nn.Identity()
|
208 |
+
)
|
209 |
+
self.fc2 = linear_layer(
|
210 |
+
hidden_channels, out_features, bias=bias[1], device=device, dtype=dtype
|
211 |
+
)
|
212 |
+
self.drop2 = nn.Dropout(drop_probs[1])
|
213 |
+
|
214 |
+
def forward(self, x):
|
215 |
+
x = self.fc1(x)
|
216 |
+
x = self.act(x)
|
217 |
+
x = self.drop1(x)
|
218 |
+
x = self.norm(x)
|
219 |
+
x = self.fc2(x)
|
220 |
+
x = self.drop2(x)
|
221 |
+
return x
|
222 |
+
|
223 |
+
|
224 |
+
class TextProjection(nn.Module):
|
225 |
+
"""
|
226 |
+
Projects text embeddings. Also handles dropout for classifier-free guidance.
|
227 |
+
|
228 |
+
Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py
|
229 |
+
"""
|
230 |
+
|
231 |
+
def __init__(self, in_channels, hidden_size, act_layer, dtype=None, device=None):
|
232 |
+
factory_kwargs = {"dtype": dtype, "device": device}
|
233 |
+
super().__init__()
|
234 |
+
self.linear_1 = nn.Linear(
|
235 |
+
in_features=in_channels,
|
236 |
+
out_features=hidden_size,
|
237 |
+
bias=True,
|
238 |
+
**factory_kwargs,
|
239 |
+
)
|
240 |
+
self.act_1 = act_layer()
|
241 |
+
self.linear_2 = nn.Linear(
|
242 |
+
in_features=hidden_size,
|
243 |
+
out_features=hidden_size,
|
244 |
+
bias=True,
|
245 |
+
**factory_kwargs,
|
246 |
+
)
|
247 |
+
|
248 |
+
def forward(self, caption):
|
249 |
+
hidden_states = self.linear_1(caption)
|
250 |
+
hidden_states = self.act_1(hidden_states)
|
251 |
+
hidden_states = self.linear_2(hidden_states)
|
252 |
+
return hidden_states
|
253 |
+
|
254 |
+
|
255 |
+
class TimestepEmbedder(nn.Module):
|
256 |
+
"""
|
257 |
+
Embeds scalar timesteps into vector representations.
|
258 |
+
"""
|
259 |
+
|
260 |
+
def __init__(
|
261 |
+
self,
|
262 |
+
hidden_size,
|
263 |
+
act_layer,
|
264 |
+
frequency_embedding_size=256,
|
265 |
+
max_period=10000,
|
266 |
+
out_size=None,
|
267 |
+
dtype=None,
|
268 |
+
device=None,
|
269 |
+
):
|
270 |
+
factory_kwargs = {"dtype": dtype, "device": device}
|
271 |
+
super().__init__()
|
272 |
+
self.frequency_embedding_size = frequency_embedding_size
|
273 |
+
self.max_period = max_period
|
274 |
+
if out_size is None:
|
275 |
+
out_size = hidden_size
|
276 |
+
|
277 |
+
self.mlp = nn.Sequential(
|
278 |
+
nn.Linear(
|
279 |
+
frequency_embedding_size, hidden_size, bias=True, **factory_kwargs
|
280 |
+
),
|
281 |
+
act_layer(),
|
282 |
+
nn.Linear(hidden_size, out_size, bias=True, **factory_kwargs),
|
283 |
+
)
|
284 |
+
nn.init.normal_(self.mlp[0].weight, std=0.02) # type: ignore
|
285 |
+
nn.init.normal_(self.mlp[2].weight, std=0.02) # type: ignore
|
286 |
+
|
287 |
+
@staticmethod
|
288 |
+
def timestep_embedding(t, dim, max_period=10000):
|
289 |
+
"""
|
290 |
+
Create sinusoidal timestep embeddings.
|
291 |
+
|
292 |
+
Args:
|
293 |
+
t (torch.Tensor): a 1-D Tensor of N indices, one per batch element. These may be fractional.
|
294 |
+
dim (int): the dimension of the output.
|
295 |
+
max_period (int): controls the minimum frequency of the embeddings.
|
296 |
+
|
297 |
+
Returns:
|
298 |
+
embedding (torch.Tensor): An (N, D) Tensor of positional embeddings.
|
299 |
+
|
300 |
+
.. ref_link: https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
|
301 |
+
"""
|
302 |
+
half = dim // 2
|
303 |
+
freqs = torch.exp(
|
304 |
+
-math.log(max_period)
|
305 |
+
* torch.arange(start=0, end=half, dtype=torch.float32)
|
306 |
+
/ half
|
307 |
+
).to(device=t.device)
|
308 |
+
args = t[:, None].float() * freqs[None]
|
309 |
+
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
310 |
+
if dim % 2:
|
311 |
+
embedding = torch.cat(
|
312 |
+
[embedding, torch.zeros_like(embedding[:, :1])], dim=-1
|
313 |
+
)
|
314 |
+
return embedding
|
315 |
+
|
316 |
+
def forward(self, t):
|
317 |
+
t_freq = self.timestep_embedding(
|
318 |
+
t, self.frequency_embedding_size, self.max_period
|
319 |
+
).type(self.mlp[0].weight.dtype) # type: ignore
|
320 |
+
t_emb = self.mlp(t_freq)
|
321 |
+
return t_emb
|
322 |
+
|
323 |
+
|
324 |
+
class EmbedND(nn.Module):
|
325 |
+
def __init__(self, dim: int, theta: int, axes_dim: list[int]):
|
326 |
+
super().__init__()
|
327 |
+
self.dim = dim
|
328 |
+
self.theta = theta
|
329 |
+
self.axes_dim = axes_dim
|
330 |
+
|
331 |
+
def forward(self, ids: Tensor) -> Tensor:
|
332 |
+
n_axes = ids.shape[-1]
|
333 |
+
emb = torch.cat(
|
334 |
+
[rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
|
335 |
+
dim=-3,
|
336 |
+
)
|
337 |
+
|
338 |
+
return emb.unsqueeze(1)
|
339 |
+
|
340 |
+
|
341 |
+
class MLPEmbedder(nn.Module):
|
342 |
+
def __init__(self, in_dim: int, hidden_dim: int):
|
343 |
+
super().__init__()
|
344 |
+
self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True)
|
345 |
+
self.silu = nn.SiLU()
|
346 |
+
self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True)
|
347 |
+
|
348 |
+
def forward(self, x: Tensor) -> Tensor:
|
349 |
+
return self.out_layer(self.silu(self.in_layer(x)))
|
350 |
+
|
351 |
+
|
352 |
+
def rope(pos, dim: int, theta: int):
|
353 |
+
assert dim % 2 == 0
|
354 |
+
scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
|
355 |
+
omega = 1.0 / (theta**scale)
|
356 |
+
out = torch.einsum("...n,d->...nd", pos, omega)
|
357 |
+
out = torch.stack(
|
358 |
+
[torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1
|
359 |
+
)
|
360 |
+
out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
|
361 |
+
return out.float()
|
362 |
+
|
363 |
+
|
364 |
+
def attention_after_rope(q, k, v, pe):
|
365 |
+
q, k = apply_rope(q, k, pe)
|
366 |
+
|
367 |
+
from .attention import attention
|
368 |
+
|
369 |
+
x = attention(q, k, v, mode="flash")
|
370 |
+
return x
|
371 |
+
|
372 |
+
|
373 |
+
@torch.compile(mode="max-autotune-no-cudagraphs", dynamic=True)
|
374 |
+
def apply_rope(xq, xk, freqs_cis):
|
375 |
+
# 将 num_heads 和 seq_len 的维度交换回原函数的处理顺序
|
376 |
+
xq = xq.transpose(1, 2) # [batch, num_heads, seq_len, head_dim]
|
377 |
+
xk = xk.transpose(1, 2)
|
378 |
+
|
379 |
+
# 将 head_dim 拆分为复数部分(实部和虚部)
|
380 |
+
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
|
381 |
+
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
|
382 |
+
|
383 |
+
# 应用旋转位置编码(复数乘法)
|
384 |
+
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
|
385 |
+
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
|
386 |
+
|
387 |
+
# 恢复张量形状并转置回目标维度顺序
|
388 |
+
xq_out = xq_out.reshape(*xq.shape).type_as(xq).transpose(1, 2)
|
389 |
+
xk_out = xk_out.reshape(*xk.shape).type_as(xk).transpose(1, 2)
|
390 |
+
|
391 |
+
return xq_out, xk_out
|
392 |
+
|
393 |
+
|
394 |
+
@torch.compile(mode="max-autotune-no-cudagraphs", dynamic=True)
|
395 |
+
def scale_add_residual(
|
396 |
+
x: torch.Tensor, scale: torch.Tensor, residual: torch.Tensor
|
397 |
+
) -> torch.Tensor:
|
398 |
+
return x * scale + residual
|
399 |
+
|
400 |
+
|
401 |
+
@torch.compile(mode="max-autotune-no-cudagraphs", dynamic=True)
|
402 |
+
def layernorm_and_scale_shift(
|
403 |
+
x: torch.Tensor, scale: torch.Tensor, shift: torch.Tensor
|
404 |
+
) -> torch.Tensor:
|
405 |
+
return torch.nn.functional.layer_norm(x, (x.size(-1),)) * (scale + 1) + shift
|
406 |
+
|
407 |
+
|
408 |
+
class SelfAttention(nn.Module):
|
409 |
+
def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False):
|
410 |
+
super().__init__()
|
411 |
+
self.num_heads = num_heads
|
412 |
+
head_dim = dim // num_heads
|
413 |
+
|
414 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
415 |
+
self.norm = QKNorm(head_dim)
|
416 |
+
self.proj = nn.Linear(dim, dim)
|
417 |
+
|
418 |
+
def forward(self, x: Tensor, pe: Tensor) -> Tensor:
|
419 |
+
qkv = self.qkv(x)
|
420 |
+
q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.num_heads)
|
421 |
+
q, k = self.norm(q, k, v)
|
422 |
+
x = attention_after_rope(q, k, v, pe=pe)
|
423 |
+
x = self.proj(x)
|
424 |
+
return x
|
425 |
+
|
426 |
+
|
427 |
+
@dataclass
|
428 |
+
class ModulationOut:
|
429 |
+
shift: Tensor
|
430 |
+
scale: Tensor
|
431 |
+
gate: Tensor
|
432 |
+
|
433 |
+
|
434 |
+
class RMSNorm(torch.nn.Module):
|
435 |
+
def __init__(self, dim: int):
|
436 |
+
super().__init__()
|
437 |
+
self.scale = nn.Parameter(torch.ones(dim))
|
438 |
+
|
439 |
+
@staticmethod
|
440 |
+
def rms_norm_fast(x, weight, eps):
|
441 |
+
return LigerRMSNormFunction.apply(
|
442 |
+
x,
|
443 |
+
weight,
|
444 |
+
eps,
|
445 |
+
0.0,
|
446 |
+
"gemma",
|
447 |
+
True,
|
448 |
+
)
|
449 |
+
|
450 |
+
@staticmethod
|
451 |
+
def rms_norm(x, weight, eps):
|
452 |
+
x_dtype = x.dtype
|
453 |
+
x = x.float()
|
454 |
+
rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + eps)
|
455 |
+
return (x * rrms).to(dtype=x_dtype) * weight
|
456 |
+
|
457 |
+
def forward(self, x: Tensor):
|
458 |
+
return self.rms_norm_fast(x, self.scale, 1e-6)
|
459 |
+
|
460 |
+
|
461 |
+
class QKNorm(torch.nn.Module):
|
462 |
+
def __init__(self, dim: int):
|
463 |
+
super().__init__()
|
464 |
+
self.query_norm = RMSNorm(dim)
|
465 |
+
self.key_norm = RMSNorm(dim)
|
466 |
+
|
467 |
+
def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple[Tensor, Tensor]:
|
468 |
+
q = self.query_norm(q)
|
469 |
+
k = self.key_norm(k)
|
470 |
+
return q.to(v), k.to(v)
|
471 |
+
|
472 |
+
|
473 |
+
class Modulation(nn.Module):
|
474 |
+
def __init__(self, dim: int, double: bool):
|
475 |
+
super().__init__()
|
476 |
+
self.is_double = double
|
477 |
+
self.multiplier = 6 if double else 3
|
478 |
+
self.lin = nn.Linear(dim, self.multiplier * dim, bias=True)
|
479 |
+
|
480 |
+
def forward(self, vec: Tensor) -> tuple[ModulationOut, ModulationOut | None]:
|
481 |
+
out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(
|
482 |
+
self.multiplier, dim=-1
|
483 |
+
)
|
484 |
+
|
485 |
+
return (
|
486 |
+
ModulationOut(*out[:3]),
|
487 |
+
ModulationOut(*out[3:]) if self.is_double else None,
|
488 |
+
)
|
489 |
+
|
490 |
+
|
491 |
+
class DoubleStreamBlock(nn.Module):
|
492 |
+
def __init__(
|
493 |
+
self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False
|
494 |
+
):
|
495 |
+
super().__init__()
|
496 |
+
|
497 |
+
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
498 |
+
self.num_heads = num_heads
|
499 |
+
self.hidden_size = hidden_size
|
500 |
+
self.img_mod = Modulation(hidden_size, double=True)
|
501 |
+
self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
502 |
+
self.img_attn = SelfAttention(
|
503 |
+
dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias
|
504 |
+
)
|
505 |
+
|
506 |
+
self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
507 |
+
self.img_mlp = nn.Sequential(
|
508 |
+
nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
|
509 |
+
nn.GELU(approximate="tanh"),
|
510 |
+
nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
|
511 |
+
)
|
512 |
+
|
513 |
+
self.txt_mod = Modulation(hidden_size, double=True)
|
514 |
+
self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
515 |
+
self.txt_attn = SelfAttention(
|
516 |
+
dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias
|
517 |
+
)
|
518 |
+
|
519 |
+
self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
520 |
+
self.txt_mlp = nn.Sequential(
|
521 |
+
nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
|
522 |
+
nn.GELU(approximate="tanh"),
|
523 |
+
nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
|
524 |
+
)
|
525 |
+
|
526 |
+
def forward(
|
527 |
+
self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor
|
528 |
+
) -> tuple[Tensor, Tensor]:
|
529 |
+
img_mod1, img_mod2 = self.img_mod(vec)
|
530 |
+
txt_mod1, txt_mod2 = self.txt_mod(vec)
|
531 |
+
|
532 |
+
# prepare image for attention
|
533 |
+
img_modulated = self.img_norm1(img)
|
534 |
+
img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
|
535 |
+
img_qkv = self.img_attn.qkv(img_modulated)
|
536 |
+
img_q, img_k, img_v = rearrange(
|
537 |
+
img_qkv, "B L (K H D) -> K B L H D", K=3, H=self.num_heads
|
538 |
+
)
|
539 |
+
img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
|
540 |
+
|
541 |
+
# prepare txt for attention
|
542 |
+
txt_modulated = self.txt_norm1(txt)
|
543 |
+
txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
|
544 |
+
txt_qkv = self.txt_attn.qkv(txt_modulated)
|
545 |
+
txt_q, txt_k, txt_v = rearrange(
|
546 |
+
txt_qkv, "B L (K H D) -> K B L H D", K=3, H=self.num_heads
|
547 |
+
)
|
548 |
+
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
|
549 |
+
|
550 |
+
# run actual attention
|
551 |
+
q = torch.cat((txt_q, img_q), dim=1)
|
552 |
+
k = torch.cat((txt_k, img_k), dim=1)
|
553 |
+
v = torch.cat((txt_v, img_v), dim=1)
|
554 |
+
|
555 |
+
attn = attention_after_rope(q, k, v, pe=pe)
|
556 |
+
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
|
557 |
+
|
558 |
+
# calculate the img bloks
|
559 |
+
img = img + img_mod1.gate * self.img_attn.proj(img_attn)
|
560 |
+
img_mlp = self.img_mlp(
|
561 |
+
(1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift
|
562 |
+
)
|
563 |
+
img = scale_add_residual(img_mlp, img_mod2.gate, img)
|
564 |
+
|
565 |
+
# calculate the txt bloks
|
566 |
+
txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn)
|
567 |
+
txt_mlp = self.txt_mlp(
|
568 |
+
(1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift
|
569 |
+
)
|
570 |
+
txt = scale_add_residual(txt_mlp, txt_mod2.gate, txt)
|
571 |
+
return img, txt
|
572 |
+
|
573 |
+
|
574 |
+
class SingleStreamBlock(nn.Module):
|
575 |
+
"""
|
576 |
+
A DiT block with parallel linear layers as described in
|
577 |
+
https://arxiv.org/abs/2302.05442 and adapted modulation interface.
|
578 |
+
"""
|
579 |
+
|
580 |
+
def __init__(
|
581 |
+
self,
|
582 |
+
hidden_size: int,
|
583 |
+
num_heads: int,
|
584 |
+
mlp_ratio: float = 4.0,
|
585 |
+
qk_scale: float | None = None,
|
586 |
+
):
|
587 |
+
super().__init__()
|
588 |
+
self.hidden_dim = hidden_size
|
589 |
+
self.num_heads = num_heads
|
590 |
+
head_dim = hidden_size // num_heads
|
591 |
+
self.scale = qk_scale or head_dim**-0.5
|
592 |
+
|
593 |
+
self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
594 |
+
# qkv and mlp_in
|
595 |
+
self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim)
|
596 |
+
# proj and mlp_out
|
597 |
+
self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size)
|
598 |
+
|
599 |
+
self.norm = QKNorm(head_dim)
|
600 |
+
|
601 |
+
self.hidden_size = hidden_size
|
602 |
+
self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
603 |
+
|
604 |
+
self.mlp_act = nn.GELU(approximate="tanh")
|
605 |
+
self.modulation = Modulation(hidden_size, double=False)
|
606 |
+
|
607 |
+
def forward(self, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor:
|
608 |
+
mod, _ = self.modulation(vec)
|
609 |
+
x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
|
610 |
+
qkv, mlp = torch.split(
|
611 |
+
self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1
|
612 |
+
)
|
613 |
+
|
614 |
+
q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.num_heads)
|
615 |
+
q, k = self.norm(q, k, v)
|
616 |
+
|
617 |
+
# compute attention
|
618 |
+
attn = attention_after_rope(q, k, v, pe=pe)
|
619 |
+
# compute activation in mlp stream, cat again and run second linear layer
|
620 |
+
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
|
621 |
+
return scale_add_residual(output, mod.gate, x)
|
622 |
+
|
623 |
+
|
624 |
+
class LastLayer(nn.Module):
|
625 |
+
def __init__(self, hidden_size: int, patch_size: int, out_channels: int):
|
626 |
+
super().__init__()
|
627 |
+
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
628 |
+
self.linear = nn.Linear(
|
629 |
+
hidden_size, patch_size * patch_size * out_channels, bias=True
|
630 |
+
)
|
631 |
+
self.adaLN_modulation = nn.Sequential(
|
632 |
+
nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True)
|
633 |
+
)
|
634 |
+
|
635 |
+
def forward(self, x: Tensor, vec: Tensor) -> Tensor:
|
636 |
+
shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1)
|
637 |
+
x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
|
638 |
+
x = self.linear(x)
|
639 |
+
return x
|
modules/model_edit.py
ADDED
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from dataclasses import dataclass
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
from torch import Tensor, nn
|
7 |
+
|
8 |
+
from .connector_edit import Qwen2Connector
|
9 |
+
from .layers import DoubleStreamBlock, EmbedND, LastLayer, MLPEmbedder, SingleStreamBlock
|
10 |
+
|
11 |
+
|
12 |
+
@dataclass
|
13 |
+
class Step1XParams:
|
14 |
+
in_channels: int
|
15 |
+
out_channels: int
|
16 |
+
vec_in_dim: int
|
17 |
+
context_in_dim: int
|
18 |
+
hidden_size: int
|
19 |
+
mlp_ratio: float
|
20 |
+
num_heads: int
|
21 |
+
depth: int
|
22 |
+
depth_single_blocks: int
|
23 |
+
axes_dim: list[int]
|
24 |
+
theta: int
|
25 |
+
qkv_bias: bool
|
26 |
+
|
27 |
+
|
28 |
+
class Step1XEdit(nn.Module):
|
29 |
+
"""
|
30 |
+
Transformer model for flow matching on sequences.
|
31 |
+
"""
|
32 |
+
|
33 |
+
def __init__(self, params: Step1XParams):
|
34 |
+
super().__init__()
|
35 |
+
|
36 |
+
self.params = params
|
37 |
+
self.in_channels = params.in_channels
|
38 |
+
self.out_channels = params.out_channels
|
39 |
+
if params.hidden_size % params.num_heads != 0:
|
40 |
+
raise ValueError(
|
41 |
+
f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}"
|
42 |
+
)
|
43 |
+
pe_dim = params.hidden_size // params.num_heads
|
44 |
+
if sum(params.axes_dim) != pe_dim:
|
45 |
+
raise ValueError(
|
46 |
+
f"Got {params.axes_dim} but expected positional dim {pe_dim}"
|
47 |
+
)
|
48 |
+
self.hidden_size = params.hidden_size
|
49 |
+
self.num_heads = params.num_heads
|
50 |
+
self.pe_embedder = EmbedND(
|
51 |
+
dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim
|
52 |
+
)
|
53 |
+
self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True)
|
54 |
+
self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
|
55 |
+
self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size)
|
56 |
+
self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size)
|
57 |
+
|
58 |
+
self.double_blocks = nn.ModuleList(
|
59 |
+
[
|
60 |
+
DoubleStreamBlock(
|
61 |
+
self.hidden_size,
|
62 |
+
self.num_heads,
|
63 |
+
mlp_ratio=params.mlp_ratio,
|
64 |
+
qkv_bias=params.qkv_bias,
|
65 |
+
)
|
66 |
+
for _ in range(params.depth)
|
67 |
+
]
|
68 |
+
)
|
69 |
+
|
70 |
+
self.single_blocks = nn.ModuleList(
|
71 |
+
[
|
72 |
+
SingleStreamBlock(
|
73 |
+
self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio
|
74 |
+
)
|
75 |
+
for _ in range(params.depth_single_blocks)
|
76 |
+
]
|
77 |
+
)
|
78 |
+
|
79 |
+
self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels)
|
80 |
+
|
81 |
+
self.connector = Qwen2Connector()
|
82 |
+
|
83 |
+
@staticmethod
|
84 |
+
def timestep_embedding(
|
85 |
+
t: Tensor, dim, max_period=10000, time_factor: float = 1000.0
|
86 |
+
):
|
87 |
+
"""
|
88 |
+
Create sinusoidal timestep embeddings.
|
89 |
+
:param t: a 1-D Tensor of N indices, one per batch element.
|
90 |
+
These may be fractional.
|
91 |
+
:param dim: the dimension of the output.
|
92 |
+
:param max_period: controls the minimum frequency of the embeddings.
|
93 |
+
:return: an (N, D) Tensor of positional embeddings.
|
94 |
+
"""
|
95 |
+
t = time_factor * t
|
96 |
+
half = dim // 2
|
97 |
+
freqs = torch.exp(
|
98 |
+
-math.log(max_period)
|
99 |
+
* torch.arange(start=0, end=half, dtype=torch.float32)
|
100 |
+
/ half
|
101 |
+
).to(t.device)
|
102 |
+
|
103 |
+
args = t[:, None].float() * freqs[None]
|
104 |
+
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
105 |
+
if dim % 2:
|
106 |
+
embedding = torch.cat(
|
107 |
+
[embedding, torch.zeros_like(embedding[:, :1])], dim=-1
|
108 |
+
)
|
109 |
+
if torch.is_floating_point(t):
|
110 |
+
embedding = embedding.to(t)
|
111 |
+
return embedding
|
112 |
+
|
113 |
+
def forward(
|
114 |
+
self,
|
115 |
+
img: Tensor,
|
116 |
+
img_ids: Tensor,
|
117 |
+
txt: Tensor,
|
118 |
+
txt_ids: Tensor,
|
119 |
+
timesteps: Tensor,
|
120 |
+
y: Tensor,
|
121 |
+
) -> Tensor:
|
122 |
+
if img.ndim != 3 or txt.ndim != 3:
|
123 |
+
raise ValueError("Input img and txt tensors must have 3 dimensions.")
|
124 |
+
|
125 |
+
img = self.img_in(img)
|
126 |
+
vec = self.time_in(self.timestep_embedding(timesteps, 256))
|
127 |
+
|
128 |
+
vec = vec + self.vector_in(y)
|
129 |
+
txt = self.txt_in(txt)
|
130 |
+
|
131 |
+
ids = torch.cat((txt_ids, img_ids), dim=1)
|
132 |
+
pe = self.pe_embedder(ids)
|
133 |
+
|
134 |
+
for block in self.double_blocks:
|
135 |
+
img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
|
136 |
+
|
137 |
+
img = torch.cat((txt, img), 1)
|
138 |
+
for block in self.single_blocks:
|
139 |
+
img = block(img, vec=vec, pe=pe)
|
140 |
+
img = img[:, txt.shape[1] :, ...]
|
141 |
+
|
142 |
+
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
|
143 |
+
return img
|
requirements.txt
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch==2.3.1
|
2 |
+
liger_kernel==0.5.4
|
3 |
+
einops==0.8.1
|
4 |
+
transformers==4.49.0
|
5 |
+
qwen_vl_utils==0.0.10
|
6 |
+
safetensors==0.4.5
|
7 |
+
pillow==11.1.0
|
8 |
+
huggingface_hub
|
sampling.py
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from collections.abc import Callable
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from torch import Tensor
|
6 |
+
|
7 |
+
|
8 |
+
def get_noise(num_samples: int, height: int, width: int, device: torch.device, dtype: torch.dtype, seed: int):
|
9 |
+
return torch.randn(
|
10 |
+
num_samples,
|
11 |
+
16,
|
12 |
+
# allow for packing
|
13 |
+
2 * math.ceil(height / 16),
|
14 |
+
2 * math.ceil(width / 16),
|
15 |
+
device=device,
|
16 |
+
dtype=dtype,
|
17 |
+
generator=torch.Generator(device=device).manual_seed(seed),
|
18 |
+
)
|
19 |
+
|
20 |
+
|
21 |
+
def time_shift(mu: float, sigma: float, t: Tensor):
|
22 |
+
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
|
23 |
+
|
24 |
+
|
25 |
+
def get_lin_function(x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15) -> Callable[[float], float]:
|
26 |
+
m = (y2 - y1) / (x2 - x1)
|
27 |
+
b = y1 - m * x1
|
28 |
+
return lambda x: m * x + b
|
29 |
+
|
30 |
+
|
31 |
+
def get_schedule(
|
32 |
+
num_steps: int,
|
33 |
+
image_seq_len: int,
|
34 |
+
base_shift: float = 0.5,
|
35 |
+
max_shift: float = 1.15,
|
36 |
+
shift: bool = True,
|
37 |
+
) -> list[float]:
|
38 |
+
# extra step for zero
|
39 |
+
timesteps = torch.linspace(1, 0, num_steps + 1)
|
40 |
+
|
41 |
+
# shifting the schedule to favor high timesteps for higher signal images
|
42 |
+
if shift:
|
43 |
+
# estimate mu based on linear estimation between two points
|
44 |
+
mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len)
|
45 |
+
timesteps = time_shift(mu, 1.0, timesteps)
|
46 |
+
|
47 |
+
return timesteps.tolist()
|