listen2you commited on
Commit
c41b22c
·
1 Parent(s): dacfebb
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ __pycache/
2
+ */__pycache/
Dockerfile ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM nvidia/cuda:12.1.1-cudnn8-devel-ubuntu22.04
2
+
3
+ ARG DEBIAN_FRONTEND=noninteractive
4
+
5
+ ENV PYTHONUNBUFFERED=1
6
+
7
+ RUN apt-get update && apt-get install --no-install-recommends -y \
8
+ build-essential \
9
+ python3.10 \
10
+ python3-pip \
11
+ git \
12
+ ffmpeg \
13
+ && apt-get clean && rm -rf /var/lib/apt/lists/*
14
+
15
+ WORKDIR /code
16
+
17
+ COPY ./requirements.txt /code/requirements.txt
18
+ COPY ./get_flash_attn.py /code/get_flash_attn.py
19
+
20
+ # Set up a new user named "user" with user ID 1000
21
+ RUN useradd -m -u 1000 user
22
+ # Switch to the "user" user
23
+ USER user
24
+ # Set home to the user's home directory
25
+ ENV HOME=/home/user \
26
+ PATH=/home/user/.local/bin:$PATH \
27
+ PYTHONPATH=$HOME/app \
28
+ PYTHONUNBUFFERED=1 \
29
+ GRADIO_ALLOW_FLAGGING=never \
30
+ GRADIO_NUM_PORTS=1 \
31
+ GRADIO_SERVER_NAME=0.0.0.0 \
32
+ GRADIO_THEME=huggingface \
33
+ SYSTEM=spaces
34
+
35
+ RUN pip3 install torch==2.3.1 torchvision==0.18.1 torchaudio==2.3.1 --index-url https://download.pytorch.org/whl/cu121
36
+ RUN pip3 install --no-cache-dir --upgrade -r /code/requirements.txt
37
+ ARG DYNAMIC_PARAMS=$(python3 /code/get_flash_attn.py)
38
+ RUN pip3 install https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.2.post1/${DYNAMIC_PARAMS}
39
+
40
+ # Set the working directory to the user's home directory
41
+ WORKDIR $HOME/app
42
+
43
+ # Copy the current directory contents into the container at $HOME/app setting the owner to the user
44
+ COPY --chown=user . $HOME/app
45
+
46
+ CMD ["python3", "app.py"]
README.md CHANGED
@@ -3,8 +3,7 @@ title: Test
3
  emoji: 🚀
4
  colorFrom: indigo
5
  colorTo: pink
6
- sdk: gradio
7
- sdk_version: 5.26.0
8
  app_file: app.py
9
  pinned: false
10
  license: mit
 
3
  emoji: 🚀
4
  colorFrom: indigo
5
  colorTo: pink
6
+ sdk: docker
 
7
  app_file: app.py
8
  pinned: false
9
  license: mit
app.py ADDED
@@ -0,0 +1,461 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import datetime
3
+ import json
4
+ import itertools
5
+ import math
6
+ import os
7
+ import time
8
+ from pathlib import Path
9
+
10
+
11
+ import gradio as gr
12
+ import numpy as np
13
+ import torch
14
+ from einops import rearrange, repeat
15
+ from huggingface_hub import snapshot_download
16
+ from PIL import Image, ImageOps
17
+ from safetensors.torch import load_file
18
+ from torchvision.transforms import functional as F
19
+ from tqdm import tqdm
20
+
21
+ import sampling
22
+ from modules.autoencoder import AutoEncoder
23
+ from modules.conditioner import Qwen25VL_7b_Embedder as Qwen2VLEmbedder
24
+ from modules.model_edit import Step1XParams, Step1XEdit
25
+
26
+ print("TORCH_CUDA", torch.cuda.is_available())
27
+
28
+ def load_state_dict(model, ckpt_path, device="cuda", strict=False, assign=True):
29
+ if Path(ckpt_path).suffix == ".safetensors":
30
+ state_dict = load_file(ckpt_path, device)
31
+ else:
32
+ state_dict = torch.load(ckpt_path, map_location="cpu")
33
+
34
+ missing, unexpected = model.load_state_dict(
35
+ state_dict, strict=strict, assign=assign
36
+ )
37
+ if len(missing) > 0 and len(unexpected) > 0:
38
+ print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing))
39
+ print("\n" + "-" * 79 + "\n")
40
+ print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected))
41
+ elif len(missing) > 0:
42
+ print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing))
43
+ elif len(unexpected) > 0:
44
+ print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected))
45
+ return model
46
+
47
+
48
+ def load_models(
49
+ dit_path=None,
50
+ ae_path=None,
51
+ qwen2vl_model_path=None,
52
+ device="cuda",
53
+ max_length=256,
54
+ dtype=torch.bfloat16,
55
+ ):
56
+ qwen2vl_encoder = Qwen2VLEmbedder(
57
+ qwen2vl_model_path,
58
+ device=device,
59
+ max_length=max_length,
60
+ dtype=dtype,
61
+ )
62
+
63
+ with torch.device("meta"):
64
+ ae = AutoEncoder(
65
+ resolution=256,
66
+ in_channels=3,
67
+ ch=128,
68
+ out_ch=3,
69
+ ch_mult=[1, 2, 4, 4],
70
+ num_res_blocks=2,
71
+ z_channels=16,
72
+ scale_factor=0.3611,
73
+ shift_factor=0.1159,
74
+ )
75
+
76
+ step1x_params = Step1XParams(
77
+ in_channels=64,
78
+ out_channels=64,
79
+ vec_in_dim=768,
80
+ context_in_dim=4096,
81
+ hidden_size=3072,
82
+ mlp_ratio=4.0,
83
+ num_heads=24,
84
+ depth=19,
85
+ depth_single_blocks=38,
86
+ axes_dim=[16, 56, 56],
87
+ theta=10_000,
88
+ qkv_bias=True,
89
+ )
90
+ dit = Step1XEdit(step1x_params)
91
+
92
+ ae = load_state_dict(ae, ae_path)
93
+ dit = load_state_dict(
94
+ dit, dit_path
95
+ )
96
+
97
+ dit = dit.to(device=device, dtype=dtype)
98
+ ae = ae.to(device=device, dtype=torch.float32)
99
+
100
+ return ae, dit, qwen2vl_encoder
101
+
102
+
103
+ class ImageGenerator:
104
+ def __init__(
105
+ self,
106
+ dit_path=None,
107
+ ae_path=None,
108
+ qwen2vl_model_path=None,
109
+ device="cuda",
110
+ max_length=640,
111
+ dtype=torch.bfloat16,
112
+ ) -> None:
113
+ self.device = torch.device(device)
114
+ self.ae, self.dit, self.llm_encoder = load_models(
115
+ dit_path=dit_path,
116
+ ae_path=ae_path,
117
+ qwen2vl_model_path=qwen2vl_model_path,
118
+ max_length=max_length,
119
+ dtype=dtype,
120
+ )
121
+
122
+ def prepare(self, prompt, img, ref_image, ref_image_raw):
123
+ bs, _, h, w = img.shape
124
+ bs, _, ref_h, ref_w = ref_image.shape
125
+
126
+ assert h == ref_h and w == ref_w
127
+
128
+ if bs == 1 and not isinstance(prompt, str):
129
+ bs = len(prompt)
130
+ elif bs >= 1 and isinstance(prompt, str):
131
+ prompt = [prompt] * bs
132
+
133
+ img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
134
+ ref_img = rearrange(ref_image, "b c (ref_h ph) (ref_w pw) -> b (ref_h ref_w) (c ph pw)", ph=2, pw=2)
135
+ if img.shape[0] == 1 and bs > 1:
136
+ img = repeat(img, "1 ... -> bs ...", bs=bs)
137
+ ref_img = repeat(ref_img, "1 ... -> bs ...", bs=bs)
138
+
139
+ img_ids = torch.zeros(h // 2, w // 2, 3)
140
+
141
+ img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None]
142
+ img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :]
143
+ img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
144
+
145
+ ref_img_ids = torch.zeros(ref_h // 2, ref_w // 2, 3)
146
+
147
+ ref_img_ids[..., 1] = ref_img_ids[..., 1] + torch.arange(ref_h // 2)[:, None]
148
+ ref_img_ids[..., 2] = ref_img_ids[..., 2] + torch.arange(ref_w // 2)[None, :]
149
+ ref_img_ids = repeat(ref_img_ids, "ref_h ref_w c -> b (ref_h ref_w) c", b=bs)
150
+
151
+ if isinstance(prompt, str):
152
+ prompt = [prompt]
153
+
154
+ txt, mask = self.llm_encoder(prompt, ref_image_raw)
155
+
156
+ txt_ids = torch.zeros(bs, txt.shape[1], 3)
157
+
158
+ img = torch.cat([img, ref_img.to(device=img.device, dtype=img.dtype)], dim=-2)
159
+ img_ids = torch.cat([img_ids, ref_img_ids], dim=-2)
160
+
161
+
162
+ return {
163
+ "img": img,
164
+ "mask": mask,
165
+ "img_ids": img_ids.to(img.device),
166
+ "llm_embedding": txt.to(img.device),
167
+ "txt_ids": txt_ids.to(img.device),
168
+ }
169
+
170
+ @staticmethod
171
+ def process_diff_norm(diff_norm, k):
172
+ pow_result = torch.pow(diff_norm, k)
173
+
174
+ result = torch.where(
175
+ diff_norm > 1.0,
176
+ pow_result,
177
+ torch.where(diff_norm < 1.0, torch.ones_like(diff_norm), diff_norm),
178
+ )
179
+ return result
180
+
181
+ def denoise(
182
+ self,
183
+ img: torch.Tensor,
184
+ img_ids: torch.Tensor,
185
+ llm_embedding: torch.Tensor,
186
+ txt_ids: torch.Tensor,
187
+ timesteps: list[float],
188
+ cfg_guidance: float = 4.5,
189
+ mask=None,
190
+ show_progress=False,
191
+ timesteps_truncate=1.0,
192
+ ):
193
+ if show_progress:
194
+ pbar = tqdm(itertools.pairwise(timesteps), desc='denoising...')
195
+ else:
196
+ pbar = itertools.pairwise(timesteps)
197
+ for t_curr, t_prev in pbar:
198
+ if img.shape[0] == 1 and cfg_guidance != -1:
199
+ img = torch.cat([img, img], dim=0)
200
+ t_vec = torch.full(
201
+ (img.shape[0],), t_curr, dtype=img.dtype, device=img.device
202
+ )
203
+
204
+ txt, vec = self.dit.connector(llm_embedding, t_vec, mask)
205
+
206
+
207
+ pred = self.dit(
208
+ img=img,
209
+ img_ids=img_ids,
210
+ txt=txt,
211
+ txt_ids=txt_ids,
212
+ y=vec,
213
+ timesteps=t_vec,
214
+ )
215
+
216
+ if cfg_guidance != -1:
217
+ cond, uncond = (
218
+ pred[0 : pred.shape[0] // 2, :],
219
+ pred[pred.shape[0] // 2 :, :],
220
+ )
221
+ if t_curr > timesteps_truncate:
222
+ diff = cond - uncond
223
+ diff_norm = torch.norm(diff, dim=(2), keepdim=True)
224
+ pred = uncond + cfg_guidance * (
225
+ cond - uncond
226
+ ) / self.process_diff_norm(diff_norm, k=0.4)
227
+ else:
228
+ pred = uncond + cfg_guidance * (cond - uncond)
229
+ tem_img = img[0 : img.shape[0] // 2, :] + (t_prev - t_curr) * pred
230
+ img_input_length = img.shape[1] // 2
231
+ img = torch.cat(
232
+ [
233
+ tem_img[:, :img_input_length],
234
+ img[ : img.shape[0] // 2, img_input_length:],
235
+ ], dim=1
236
+ )
237
+
238
+ return img[:, :img.shape[1] // 2]
239
+
240
+ @staticmethod
241
+ def unpack(x: torch.Tensor, height: int, width: int) -> torch.Tensor:
242
+ return rearrange(
243
+ x,
244
+ "b (h w) (c ph pw) -> b c (h ph) (w pw)",
245
+ h=math.ceil(height / 16),
246
+ w=math.ceil(width / 16),
247
+ ph=2,
248
+ pw=2,
249
+ )
250
+
251
+ @staticmethod
252
+ def load_image(image):
253
+ from PIL import Image
254
+
255
+ if isinstance(image, np.ndarray):
256
+ image = torch.from_numpy(image).permute(2, 0, 1).float() / 255.0
257
+ image = image.unsqueeze(0)
258
+ return image
259
+ elif isinstance(image, Image.Image):
260
+ image = F.to_tensor(image.convert("RGB"))
261
+ image = image.unsqueeze(0)
262
+ return image
263
+ elif isinstance(image, torch.Tensor):
264
+ return image
265
+ elif isinstance(image, str):
266
+ image = F.to_tensor(Image.open(image).convert("RGB"))
267
+ image = image.unsqueeze(0)
268
+ return image
269
+ else:
270
+ raise ValueError(f"Unsupported image type: {type(image)}")
271
+
272
+ def output_process_image(self, resize_img, image_size):
273
+ res_image = resize_img.resize(image_size)
274
+ return res_image
275
+
276
+ def input_process_image(self, img, img_size=512):
277
+ # 1. 打开图片
278
+ w, h = img.size
279
+ r = w / h
280
+
281
+ if w > h:
282
+ w_new = math.ceil(math.sqrt(img_size * img_size * r))
283
+ h_new = math.ceil(w_new / r)
284
+ else:
285
+ h_new = math.ceil(math.sqrt(img_size * img_size / r))
286
+ w_new = math.ceil(h_new * r)
287
+ h_new = math.ceil(h_new) // 16 * 16
288
+ w_new = math.ceil(w_new) // 16 * 16
289
+
290
+ img_resized = img.resize((w_new, h_new))
291
+ return img_resized, img.size
292
+
293
+ @torch.inference_mode()
294
+ def generate_image(
295
+ self,
296
+ prompt,
297
+ negative_prompt,
298
+ ref_images,
299
+ num_steps,
300
+ cfg_guidance,
301
+ seed,
302
+ num_samples=1,
303
+ init_image=None,
304
+ image2image_strength=0.0,
305
+ show_progress=False,
306
+ size_level=512,
307
+ ):
308
+ assert num_samples == 1, "num_samples > 1 is not supported yet."
309
+ ref_images_raw, img_info = self.input_process_image(ref_images, img_size=size_level)
310
+
311
+ width, height = ref_images_raw.width, ref_images_raw.height
312
+
313
+
314
+ ref_images_raw = self.load_image(ref_images_raw)
315
+ ref_images_raw = ref_images_raw.to(self.device)
316
+ ref_images = self.ae.encode(ref_images_raw.to(self.device) * 2 - 1)
317
+
318
+ seed = int(seed)
319
+ seed = torch.Generator(device="cpu").seed() if seed < 0 else seed
320
+
321
+ t0 = time.perf_counter()
322
+
323
+ if init_image is not None:
324
+ init_image = self.load_image(init_image)
325
+ init_image = init_image.to(self.device)
326
+ init_image = torch.nn.functional.interpolate(init_image, (height, width))
327
+ init_image = self.ae.encode(init_image.to() * 2 - 1)
328
+
329
+ x = torch.randn(
330
+ num_samples,
331
+ 16,
332
+ height // 8,
333
+ width // 8,
334
+ device=self.device,
335
+ dtype=torch.bfloat16,
336
+ generator=torch.Generator(device=self.device).manual_seed(seed),
337
+ )
338
+
339
+ timesteps = sampling.get_schedule(
340
+ num_steps, x.shape[-1] * x.shape[-2] // 4, shift=True
341
+ )
342
+
343
+ if init_image is not None:
344
+ t_idx = int((1 - image2image_strength) * num_steps)
345
+ t = timesteps[t_idx]
346
+ timesteps = timesteps[t_idx:]
347
+ x = t * x + (1.0 - t) * init_image.to(x.dtype)
348
+
349
+ x = torch.cat([x, x], dim=0)
350
+ ref_images = torch.cat([ref_images, ref_images], dim=0)
351
+ ref_images_raw = torch.cat([ref_images_raw, ref_images_raw], dim=0)
352
+ inputs = self.prepare([prompt, negative_prompt], x, ref_image=ref_images, ref_image_raw=ref_images_raw)
353
+
354
+ x = self.denoise(
355
+ **inputs,
356
+ cfg_guidance=cfg_guidance,
357
+ timesteps=timesteps,
358
+ show_progress=show_progress,
359
+ timesteps_truncate=1.0,
360
+ )
361
+ x = self.unpack(x.float(), height, width)
362
+ with torch.autocast(device_type=self.device.type, dtype=torch.bfloat16):
363
+ x = self.ae.decode(x)
364
+ x = x.clamp(-1, 1)
365
+ x = x.mul(0.5).add(0.5)
366
+
367
+ t1 = time.perf_counter()
368
+ print(f"Done in {t1 - t0:.1f}s.")
369
+ images_list = []
370
+ for img in x.float():
371
+ images_list.append(self.output_process_image(F.to_pil_image(img), img_info))
372
+ return images_list
373
+
374
+
375
+ def prepare_infer_func():
376
+ # 模型仓库ID(如:"bert-base-uncased")
377
+ model_repo = "stepfun-ai/Step1X-Edit"
378
+ # 本地保存路径
379
+ model_path = "./model_weights"
380
+ os.makedirs(model_path, exist_ok=True)
381
+
382
+
383
+ # 下载模型(包括所有文件)
384
+ snapshot_download(
385
+ repo_id=model_repo,
386
+ local_dir=model_path,
387
+ local_dir_use_symlinks=False # 避免使用符号链接
388
+ )
389
+
390
+
391
+ image_edit = ImageGenerator(
392
+ ae_path=os.path.join(model_path, 'vae.safetensors'),
393
+ dit_path=os.path.join(model_path, "step1x-edit-i1258.safetensors"),
394
+ qwen2vl_model_path='Qwen/Qwen2.5-VL-7B-Instruct',
395
+ max_length=640,
396
+ )
397
+
398
+ return image_edit.generate_image
399
+
400
+ def inference(infer_func, prompt, ref_images, seed, size_level):
401
+ start_time = time.time()
402
+
403
+ image = infer_func(
404
+ prompt,
405
+ negative_prompt="",
406
+ ref_images=ref_images,
407
+ num_samples=1,
408
+ num_steps=28,
409
+ cfg_guidance=6.0,
410
+ seed=seed,
411
+ show_progress=True,
412
+ size_level=size_level,
413
+ )[0]
414
+
415
+ print(f"Time taken: {time.time() - start_time:.2f} seconds")
416
+ return image
417
+
418
+
419
+ def create_demo():
420
+ inference_func = prepare_infer_func()
421
+ with gr.Blocks() as demo:
422
+ gr.Markdown(
423
+ """
424
+ # Step1X-Edit
425
+ """
426
+ )
427
+ with gr.Row():
428
+ with gr.Column():
429
+ prompt = gr.Textbox(
430
+ label="编辑指令",
431
+ value='Remove the person from the image.',
432
+ )
433
+ init_image = gr.Image(label="Input Image", type='pil')
434
+
435
+ random_seed = gr.Number(label="Random Seed", value=-1, minimum=-1)
436
+
437
+ size_level = gr.Number(label="size level (recommend 512, 768, 1024, min 512)", value=512, minimum=512)
438
+
439
+ generate_btn = gr.Button("Generate")
440
+
441
+ with gr.Column():
442
+ output_image = gr.Image(label="Generated Image",type='pil',image_mode='RGB')
443
+ output_random_seed = gr.Textbox(label="Used Seed", lines=5)
444
+ from functools import partial
445
+ generate_btn.click(
446
+ fn=partial(infer_func=inference_func, prompt=prompt, ref_images=init_image, seed=random_seed, size_level=size_level),
447
+ inputs=[
448
+ init_image,
449
+ prompt,
450
+ random_seed,
451
+ size_level,
452
+ ],
453
+ outputs=[output_image, output_random_seed],
454
+ )
455
+
456
+ return demo
457
+
458
+
459
+ if __name__ == "__main__":
460
+ demo = create_demo()
461
+ demo.launch()
get_flash_attn.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import platform
2
+ import sys
3
+
4
+ import torch
5
+
6
+
7
+ def get_cuda_version():
8
+ if torch.cuda.is_available():
9
+ cuda_version = torch.version.cuda
10
+ return f"cu{cuda_version.replace('.', '')[:2]}" # 例如:cu121
11
+ return "cpu"
12
+
13
+
14
+ def get_torch_version():
15
+ return f"torch{torch.__version__.split('+')[0]}"[:-2] # 例如:torch2.2
16
+
17
+
18
+ def get_python_version():
19
+ version = sys.version_info
20
+ return f"cp{version.major}{version.minor}" # 例如:cp310
21
+
22
+
23
+ def get_abi_flag():
24
+ return "abiTRUE" if torch._C._GLIBCXX_USE_CXX11_ABI else "abiFALSE"
25
+
26
+
27
+ def get_platform():
28
+ system = platform.system().lower()
29
+ machine = platform.machine().lower()
30
+ if system == "linux" and machine == "x86_64":
31
+ return "linux_x86_64"
32
+ elif system == "windows" and machine == "amd64":
33
+ return "win_amd64"
34
+ elif system == "darwin" and machine == "x86_64":
35
+ return "macosx_x86_64"
36
+ else:
37
+ raise ValueError(f"Unsupported platform: {system}_{machine}")
38
+
39
+
40
+ def generate_flash_attn_filename(flash_attn_version="2.7.2.post1"):
41
+ cuda_version = get_cuda_version()
42
+ torch_version = get_torch_version()
43
+ python_version = get_python_version()
44
+ abi_flag = get_abi_flag()
45
+ platform_tag = get_platform()
46
+
47
+ filename = (
48
+ f"flash_attn-{flash_attn_version}+{cuda_version}{torch_version}cxx11{abi_flag}-"
49
+ f"{python_version}-{python_version}-{platform_tag}.whl"
50
+ )
51
+ return filename
52
+
53
+
54
+ if __name__ == "__main__":
55
+ try:
56
+ filename = generate_flash_attn_filename()
57
+ print(filename)
58
+ except Exception as e:
59
+ print("Error generating filename:", e)
modules/__init__.py ADDED
File without changes
modules/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (135 Bytes). View file
 
modules/__pycache__/attention.cpython-310.pyc ADDED
Binary file (3.13 kB). View file
 
modules/__pycache__/autoencoder.cpython-310.pyc ADDED
Binary file (8.78 kB). View file
 
modules/__pycache__/conditioner.cpython-310.pyc ADDED
Binary file (4.96 kB). View file
 
modules/__pycache__/connector_edit.cpython-310.pyc ADDED
Binary file (11.8 kB). View file
 
modules/__pycache__/layers.cpython-310.pyc ADDED
Binary file (19.4 kB). View file
 
modules/__pycache__/model_edit.cpython-310.pyc ADDED
Binary file (4.22 kB). View file
 
modules/attention.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+
6
+
7
+ try:
8
+ import flash_attn
9
+ from flash_attn.flash_attn_interface import (
10
+ _flash_attn_forward,
11
+ flash_attn_func,
12
+ flash_attn_varlen_func,
13
+ )
14
+ except ImportError:
15
+ flash_attn = None
16
+ flash_attn_varlen_func = None
17
+ _flash_attn_forward = None
18
+ flash_attn_func = None
19
+
20
+ MEMORY_LAYOUT = {
21
+ # flash模式:
22
+ # 预处理: 输入 [batch_size, seq_len, num_heads, head_dim]
23
+ # 后处理: 保持形状不变
24
+ "flash": (
25
+ lambda x: x, # 保持形状
26
+ lambda x: x, # 保持形状
27
+ ),
28
+ # torch/vanilla模式:
29
+ # 预处理: 交换序列和注意力头的维度 [B,S,A,D] -> [B,A,S,D]
30
+ # 后处理: 交换回原始维度 [B,A,S,D] -> [B,S,A,D]
31
+ "torch": (
32
+ lambda x: x.transpose(1, 2), # (B,S,A,D) -> (B,A,S,D)
33
+ lambda x: x.transpose(1, 2), # (B,A,S,D) -> (B,S,A,D)
34
+ ),
35
+ "vanilla": (
36
+ lambda x: x.transpose(1, 2),
37
+ lambda x: x.transpose(1, 2),
38
+ ),
39
+ }
40
+
41
+
42
+ def attention(
43
+ q,
44
+ k,
45
+ v,
46
+ mode="flash",
47
+ drop_rate=0,
48
+ attn_mask=None,
49
+ causal=False,
50
+ ):
51
+ """
52
+ 执行QKV自注意力计算
53
+
54
+ Args:
55
+ q (torch.Tensor): 查询张量,形状 [batch_size, seq_len, num_heads, head_dim]
56
+ k (torch.Tensor): 键张量,形状 [batch_size, seq_len_kv, num_heads, head_dim]
57
+ v (torch.Tensor): 值张量,形状 [batch_size, seq_len_kv, num_heads, head_dim]
58
+ mode (str): 注意力模式,可选 'flash', 'torch', 'vanilla'
59
+ drop_rate (float): 注意力矩阵的dropout概率
60
+ attn_mask (torch.Tensor): 注意力掩码,形状根据模式不同而变化
61
+ causal (bool): 是否使用因果注意力(仅关注前面位置)
62
+
63
+ Returns:
64
+ torch.Tensor: 注意力输出,形状 [batch_size, seq_len, num_heads * head_dim]
65
+ """
66
+ # 获取预处理和后处理函数
67
+ pre_attn_layout, post_attn_layout = MEMORY_LAYOUT[mode]
68
+
69
+ # 应用预处理变换
70
+ q = pre_attn_layout(q) # 形状根据模式变化
71
+ k = pre_attn_layout(k)
72
+ v = pre_attn_layout(v)
73
+
74
+ if mode == "torch":
75
+ # 使用PyTorch原生的scaled_dot_product_attention
76
+ if attn_mask is not None and attn_mask.dtype != torch.bool:
77
+ attn_mask = attn_mask.to(q.dtype)
78
+ x = F.scaled_dot_product_attention(
79
+ q, k, v, attn_mask=attn_mask, dropout_p=drop_rate, is_causal=causal
80
+ )
81
+ elif mode == "flash":
82
+ assert flash_attn_func is not None, "flash_attn_func未定义"
83
+ assert attn_mask is None, "不支持的注意力掩码"
84
+ x: torch.Tensor = flash_attn_func(
85
+ q, k, v, dropout_p=drop_rate, causal=causal, softmax_scale=None
86
+ ) # type: ignore
87
+ elif mode == "vanilla":
88
+ # 手动实现注意力机制
89
+ scale_factor = 1 / math.sqrt(q.size(-1)) # 缩放因子 1/sqrt(d_k)
90
+
91
+ b, a, s, _ = q.shape # 获取形状参数
92
+ s1 = k.size(2) # 键值序列长度
93
+
94
+ # 初始化注意力偏置
95
+ attn_bias = torch.zeros(b, a, s, s1, dtype=q.dtype, device=q.device)
96
+
97
+ # 处理因果掩码
98
+ if causal:
99
+ assert attn_mask is None, "因果掩码和注意力掩码不能同时使用"
100
+ # 生成下三角因果掩码
101
+ temp_mask = torch.ones(b, a, s, s, dtype=torch.bool, device=q.device).tril(
102
+ diagonal=0
103
+ )
104
+ attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
105
+ attn_bias = attn_bias.to(q.dtype)
106
+
107
+ # 处理自定义注意力掩码
108
+ if attn_mask is not None:
109
+ if attn_mask.dtype == torch.bool:
110
+ attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
111
+ else:
112
+ attn_bias += attn_mask # 允许类似ALiBi的位置偏置
113
+
114
+ # 计算注意力矩阵
115
+ attn = (q @ k.transpose(-2, -1)) * scale_factor # [B,A,S,S1]
116
+ attn += attn_bias
117
+
118
+ # softmax和dropout
119
+ attn = attn.softmax(dim=-1)
120
+ attn = torch.dropout(attn, p=drop_rate, train=True)
121
+
122
+ # 计算输出
123
+ x = attn @ v # [B,A,S,D]
124
+ else:
125
+ raise NotImplementedError(f"不支持的注意力模式: {mode}")
126
+
127
+ # 应用后处理变换
128
+ x = post_attn_layout(x) # 恢复原始维度顺序
129
+
130
+ # 合并注意力头维度
131
+ b, s, a, d = x.shape
132
+ out = x.reshape(b, s, -1) # [B,S,A*D]
133
+ return out
modules/autoencoder.py ADDED
@@ -0,0 +1,326 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from Flux
2
+ #
3
+ # Copyright 2024 Black Forest Labs
4
+
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ #
17
+ # This source code is licensed under the license found in the
18
+ # LICENSE file in the root directory of this source tree.
19
+ import torch
20
+ from einops import rearrange
21
+ from torch import Tensor, nn
22
+
23
+
24
+ def swish(x: Tensor) -> Tensor:
25
+ return x * torch.sigmoid(x)
26
+
27
+
28
+ class AttnBlock(nn.Module):
29
+ def __init__(self, in_channels: int):
30
+ super().__init__()
31
+ self.in_channels = in_channels
32
+
33
+ self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
34
+
35
+ self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1)
36
+ self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1)
37
+ self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1)
38
+ self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1)
39
+
40
+ def attention(self, h_: Tensor) -> Tensor:
41
+ h_ = self.norm(h_)
42
+ q = self.q(h_)
43
+ k = self.k(h_)
44
+ v = self.v(h_)
45
+
46
+ b, c, h, w = q.shape
47
+ q = rearrange(q, "b c h w -> b 1 (h w) c").contiguous()
48
+ k = rearrange(k, "b c h w -> b 1 (h w) c").contiguous()
49
+ v = rearrange(v, "b c h w -> b 1 (h w) c").contiguous()
50
+ h_ = nn.functional.scaled_dot_product_attention(q, k, v)
51
+
52
+ return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b)
53
+
54
+ def forward(self, x: Tensor) -> Tensor:
55
+ return x + self.proj_out(self.attention(x))
56
+
57
+
58
+ class ResnetBlock(nn.Module):
59
+ def __init__(self, in_channels: int, out_channels: int):
60
+ super().__init__()
61
+ self.in_channels = in_channels
62
+ out_channels = in_channels if out_channels is None else out_channels
63
+ self.out_channels = out_channels
64
+
65
+ self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
66
+ self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
67
+ self.norm2 = nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True)
68
+ self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
69
+ if self.in_channels != self.out_channels:
70
+ self.nin_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
71
+
72
+ def forward(self, x):
73
+ h = x
74
+ h = self.norm1(h)
75
+ h = swish(h)
76
+ h = self.conv1(h)
77
+
78
+ h = self.norm2(h)
79
+ h = swish(h)
80
+ h = self.conv2(h)
81
+
82
+ if self.in_channels != self.out_channels:
83
+ x = self.nin_shortcut(x)
84
+
85
+ return x + h
86
+
87
+
88
+ class Downsample(nn.Module):
89
+ def __init__(self, in_channels: int):
90
+ super().__init__()
91
+ # no asymmetric padding in torch conv, must do it ourselves
92
+ self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
93
+
94
+ def forward(self, x: Tensor):
95
+ pad = (0, 1, 0, 1)
96
+ x = nn.functional.pad(x, pad, mode="constant", value=0)
97
+ x = self.conv(x)
98
+ return x
99
+
100
+
101
+ class Upsample(nn.Module):
102
+ def __init__(self, in_channels: int):
103
+ super().__init__()
104
+ self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
105
+
106
+ def forward(self, x: Tensor):
107
+ x = nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
108
+ x = self.conv(x)
109
+ return x
110
+
111
+
112
+ class Encoder(nn.Module):
113
+ def __init__(
114
+ self,
115
+ resolution: int,
116
+ in_channels: int,
117
+ ch: int,
118
+ ch_mult: list[int],
119
+ num_res_blocks: int,
120
+ z_channels: int,
121
+ ):
122
+ super().__init__()
123
+ self.ch = ch
124
+ self.num_resolutions = len(ch_mult)
125
+ self.num_res_blocks = num_res_blocks
126
+ self.resolution = resolution
127
+ self.in_channels = in_channels
128
+ # downsampling
129
+ self.conv_in = nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1)
130
+
131
+ curr_res = resolution
132
+ in_ch_mult = (1, *tuple(ch_mult))
133
+ self.in_ch_mult = in_ch_mult
134
+ self.down = nn.ModuleList()
135
+ block_in = self.ch
136
+ for i_level in range(self.num_resolutions):
137
+ block = nn.ModuleList()
138
+ attn = nn.ModuleList()
139
+ block_in = ch * in_ch_mult[i_level]
140
+ block_out = ch * ch_mult[i_level]
141
+ for _ in range(self.num_res_blocks):
142
+ block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
143
+ block_in = block_out
144
+ down = nn.Module()
145
+ down.block = block
146
+ down.attn = attn
147
+ if i_level != self.num_resolutions - 1:
148
+ down.downsample = Downsample(block_in)
149
+ curr_res = curr_res // 2
150
+ self.down.append(down)
151
+
152
+ # middle
153
+ self.mid = nn.Module()
154
+ self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
155
+ self.mid.attn_1 = AttnBlock(block_in)
156
+ self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
157
+
158
+ # end
159
+ self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
160
+ self.conv_out = nn.Conv2d(block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1)
161
+
162
+ def forward(self, x: Tensor) -> Tensor:
163
+ # downsampling
164
+ hs = [self.conv_in(x)]
165
+ for i_level in range(self.num_resolutions):
166
+ for i_block in range(self.num_res_blocks):
167
+ h = self.down[i_level].block[i_block](hs[-1])
168
+ if len(self.down[i_level].attn) > 0:
169
+ h = self.down[i_level].attn[i_block](h)
170
+ hs.append(h)
171
+ if i_level != self.num_resolutions - 1:
172
+ hs.append(self.down[i_level].downsample(hs[-1]))
173
+
174
+ # middle
175
+ h = hs[-1]
176
+ h = self.mid.block_1(h)
177
+ h = self.mid.attn_1(h)
178
+ h = self.mid.block_2(h)
179
+ # end
180
+ h = self.norm_out(h)
181
+ h = swish(h)
182
+ h = self.conv_out(h)
183
+ return h
184
+
185
+
186
+ class Decoder(nn.Module):
187
+ def __init__(
188
+ self,
189
+ ch: int,
190
+ out_ch: int,
191
+ ch_mult: list[int],
192
+ num_res_blocks: int,
193
+ in_channels: int,
194
+ resolution: int,
195
+ z_channels: int,
196
+ ):
197
+ super().__init__()
198
+ self.ch = ch
199
+ self.num_resolutions = len(ch_mult)
200
+ self.num_res_blocks = num_res_blocks
201
+ self.resolution = resolution
202
+ self.in_channels = in_channels
203
+ self.ffactor = 2 ** (self.num_resolutions - 1)
204
+
205
+ # compute in_ch_mult, block_in and curr_res at lowest res
206
+ block_in = ch * ch_mult[self.num_resolutions - 1]
207
+ curr_res = resolution // 2 ** (self.num_resolutions - 1)
208
+ self.z_shape = (1, z_channels, curr_res, curr_res)
209
+
210
+ # z to block_in
211
+ self.conv_in = nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1)
212
+
213
+ # middle
214
+ self.mid = nn.Module()
215
+ self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
216
+ self.mid.attn_1 = AttnBlock(block_in)
217
+ self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
218
+
219
+ # upsampling
220
+ self.up = nn.ModuleList()
221
+ for i_level in reversed(range(self.num_resolutions)):
222
+ block = nn.ModuleList()
223
+ attn = nn.ModuleList()
224
+ block_out = ch * ch_mult[i_level]
225
+ for _ in range(self.num_res_blocks + 1):
226
+ block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
227
+ block_in = block_out
228
+ up = nn.Module()
229
+ up.block = block
230
+ up.attn = attn
231
+ if i_level != 0:
232
+ up.upsample = Upsample(block_in)
233
+ curr_res = curr_res * 2
234
+ self.up.insert(0, up) # prepend to get consistent order
235
+
236
+ # end
237
+ self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
238
+ self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
239
+
240
+ def forward(self, z: Tensor) -> Tensor:
241
+ # z to block_in
242
+ h = self.conv_in(z)
243
+
244
+ # middle
245
+ h = self.mid.block_1(h)
246
+ h = self.mid.attn_1(h)
247
+ h = self.mid.block_2(h)
248
+
249
+ # upsampling
250
+ for i_level in reversed(range(self.num_resolutions)):
251
+ for i_block in range(self.num_res_blocks + 1):
252
+ h = self.up[i_level].block[i_block](h)
253
+ if len(self.up[i_level].attn) > 0:
254
+ h = self.up[i_level].attn[i_block](h)
255
+ if i_level != 0:
256
+ h = self.up[i_level].upsample(h)
257
+
258
+ # end
259
+ h = self.norm_out(h)
260
+ h = swish(h)
261
+ h = self.conv_out(h)
262
+ return h
263
+
264
+
265
+ class DiagonalGaussian(nn.Module):
266
+ def __init__(self, sample: bool = True, chunk_dim: int = 1):
267
+ super().__init__()
268
+ self.sample = sample
269
+ self.chunk_dim = chunk_dim
270
+
271
+ def forward(self, z: Tensor) -> Tensor:
272
+ mean, logvar = torch.chunk(z, 2, dim=self.chunk_dim)
273
+ if self.sample:
274
+ std = torch.exp(0.5 * logvar)
275
+ return mean + std * torch.randn_like(mean)
276
+ else:
277
+ return mean
278
+
279
+
280
+ class AutoEncoder(nn.Module):
281
+ def __init__(
282
+ self,
283
+ resolution: int,
284
+ in_channels: int,
285
+ ch: int,
286
+ out_ch: int,
287
+ ch_mult: list[int],
288
+ num_res_blocks: int,
289
+ z_channels: int,
290
+ scale_factor: float,
291
+ shift_factor: float,
292
+ ):
293
+ super().__init__()
294
+ self.encoder = Encoder(
295
+ resolution=resolution,
296
+ in_channels=in_channels,
297
+ ch=ch,
298
+ ch_mult=ch_mult,
299
+ num_res_blocks=num_res_blocks,
300
+ z_channels=z_channels,
301
+ )
302
+ self.decoder = Decoder(
303
+ resolution=resolution,
304
+ in_channels=in_channels,
305
+ ch=ch,
306
+ out_ch=out_ch,
307
+ ch_mult=ch_mult,
308
+ num_res_blocks=num_res_blocks,
309
+ z_channels=z_channels,
310
+ )
311
+ self.reg = DiagonalGaussian()
312
+
313
+ self.scale_factor = scale_factor
314
+ self.shift_factor = shift_factor
315
+
316
+ def encode(self, x: Tensor) -> Tensor:
317
+ z = self.reg(self.encoder(x))
318
+ z = self.scale_factor * (z - self.shift_factor)
319
+ return z
320
+
321
+ def decode(self, z: Tensor) -> Tensor:
322
+ z = z / self.scale_factor + self.shift_factor
323
+ return self.decoder(z)
324
+
325
+ def forward(self, x: Tensor) -> Tensor:
326
+ return self.decode(self.encode(x))
modules/conditioner.py ADDED
@@ -0,0 +1,216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from qwen_vl_utils import process_vision_info
3
+ from transformers import (
4
+ AutoProcessor,
5
+ Qwen2VLForConditionalGeneration,
6
+ Qwen2_5_VLForConditionalGeneration,
7
+ )
8
+ from torchvision.transforms import ToPILImage
9
+
10
+ to_pil = ToPILImage()
11
+
12
+ Qwen25VL_7b_PREFIX = '''Given a user prompt, generate an "Enhanced prompt" that provides detailed visual descriptions suitable for image generation. Evaluate the level of detail in the user prompt:
13
+ - If the prompt is simple, focus on adding specifics about colors, shapes, sizes, textures, and spatial relationships to create vivid and concrete scenes.
14
+ - If the prompt is already detailed, refine and enhance the existing details slightly without overcomplicating.\n
15
+ Here are examples of how to transform or refine prompts:
16
+ - User Prompt: A cat sleeping -> Enhanced: A small, fluffy white cat curled up in a round shape, sleeping peacefully on a warm sunny windowsill, surrounded by pots of blooming red flowers.
17
+ - User Prompt: A busy city street -> Enhanced: A bustling city street scene at dusk, featuring glowing street lamps, a diverse crowd of people in colorful clothing, and a double-decker bus passing by towering glass skyscrapers.\n
18
+ Please generate only the enhanced description for the prompt below and avoid including any additional commentary or evaluations:
19
+ User Prompt:'''
20
+
21
+
22
+ def split_string(s):
23
+ # 将中文引号替换为英文引号
24
+ s = s.replace("“", '"').replace("”", '"') # use english quotes
25
+ result = []
26
+ # 标记是否在引号内
27
+ in_quotes = False
28
+ temp = ""
29
+
30
+ # 遍历字符串中的每个字符及其索引
31
+ for idx, char in enumerate(s):
32
+ # 如果字符是引号且索引大于 155
33
+ if char == '"' and idx > 155:
34
+ # 将引号添加到临时字符串
35
+ temp += char
36
+ # 如果不在引号内
37
+ if not in_quotes:
38
+ # 将临时字符串添加到结果列表
39
+ result.append(temp)
40
+ # 清空临时字符串
41
+ temp = ""
42
+
43
+ # 切换引号状态
44
+ in_quotes = not in_quotes
45
+ continue
46
+ # 如果在引号内
47
+ if in_quotes:
48
+ # 如果字符是空格
49
+ if char.isspace():
50
+ pass # have space token
51
+
52
+ # 将字符用中文引号包裹后添加到结果列表
53
+ result.append("“" + char + "”")
54
+ else:
55
+ # 将字符添加到临时字符串
56
+ temp += char
57
+
58
+ # 如果临时字符串不为空
59
+ if temp:
60
+ # 将临时字符串添加到结果列表
61
+ result.append(temp)
62
+
63
+ return result
64
+
65
+
66
+ class Qwen25VL_7b_Embedder(torch.nn.Module):
67
+ def __init__(self, model_path, max_length=640, dtype=torch.bfloat16, device="cuda"):
68
+ super(Qwen25VL_7b_Embedder, self).__init__()
69
+ self.max_length = max_length
70
+ self.dtype = dtype
71
+ self.device = device
72
+
73
+ self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
74
+ model_path,
75
+ torch_dtype=dtype,
76
+ attn_implementation="flash_attention_2",
77
+ ).to(torch.cuda.current_device())
78
+
79
+ self.model.requires_grad_(False)
80
+ self.processor = AutoProcessor.from_pretrained(
81
+ model_path, min_pixels=256 * 28 * 28, max_pixels=324 * 28 * 28
82
+ )
83
+
84
+ self.prefix = Qwen25VL_7b_PREFIX
85
+
86
+ def forward(self, caption, ref_images):
87
+ text_list = caption
88
+ embs = torch.zeros(
89
+ len(text_list),
90
+ self.max_length,
91
+ self.model.config.hidden_size,
92
+ dtype=torch.bfloat16,
93
+ device=torch.cuda.current_device(),
94
+ )
95
+ hidden_states = torch.zeros(
96
+ len(text_list),
97
+ self.max_length,
98
+ self.model.config.hidden_size,
99
+ dtype=torch.bfloat16,
100
+ device=torch.cuda.current_device(),
101
+ )
102
+ masks = torch.zeros(
103
+ len(text_list),
104
+ self.max_length,
105
+ dtype=torch.long,
106
+ device=torch.cuda.current_device(),
107
+ )
108
+ input_ids_list = []
109
+ attention_mask_list = []
110
+ emb_list = []
111
+
112
+ def split_string(s):
113
+ s = s.replace("“", '"').replace("”", '"').replace("'", '''"''') # use english quotes
114
+ result = []
115
+ in_quotes = False
116
+ temp = ""
117
+
118
+ for idx,char in enumerate(s):
119
+ if char == '"' and idx>155:
120
+ temp += char
121
+ if not in_quotes:
122
+ result.append(temp)
123
+ temp = ""
124
+
125
+ in_quotes = not in_quotes
126
+ continue
127
+ if in_quotes:
128
+ if char.isspace():
129
+ pass # have space token
130
+
131
+ result.append("“" + char + "”")
132
+ else:
133
+ temp += char
134
+
135
+ if temp:
136
+ result.append(temp)
137
+
138
+ return result
139
+
140
+ for idx, (txt, imgs) in enumerate(zip(text_list, ref_images)):
141
+
142
+ messages = [{"role": "user", "content": []}]
143
+
144
+ messages[0]["content"].append({"type": "text", "text": f"{self.prefix}"})
145
+
146
+ messages[0]["content"].append({"type": "image", "image": to_pil(imgs)})
147
+
148
+ # 再添加 text
149
+ messages[0]["content"].append({"type": "text", "text": f"{txt}"})
150
+
151
+ # Preparation for inference
152
+ text = self.processor.apply_chat_template(
153
+ messages, tokenize=False, add_generation_prompt=True, add_vision_id=True
154
+ )
155
+
156
+ image_inputs, video_inputs = process_vision_info(messages)
157
+
158
+ inputs = self.processor(
159
+ text=[text],
160
+ images=image_inputs,
161
+ padding=True,
162
+ return_tensors="pt",
163
+ )
164
+
165
+ old_inputs_ids = inputs.input_ids
166
+ text_split_list = split_string(text)
167
+
168
+ token_list = []
169
+ for text_each in text_split_list:
170
+ txt_inputs = self.processor(
171
+ text=text_each,
172
+ images=None,
173
+ videos=None,
174
+ padding=True,
175
+ return_tensors="pt",
176
+ )
177
+ token_each = txt_inputs.input_ids
178
+ if token_each[0][0] == 2073 and token_each[0][-1] == 854:
179
+ token_each = token_each[:, 1:-1]
180
+ token_list.append(token_each)
181
+ else:
182
+ token_list.append(token_each)
183
+
184
+ new_txt_ids = torch.cat(token_list, dim=1).to("cuda")
185
+
186
+ new_txt_ids = new_txt_ids.to(old_inputs_ids.device)
187
+
188
+ idx1 = (old_inputs_ids == 151653).nonzero(as_tuple=True)[1][0]
189
+ idx2 = (new_txt_ids == 151653).nonzero(as_tuple=True)[1][0]
190
+ inputs.input_ids = (
191
+ torch.cat([old_inputs_ids[0, :idx1], new_txt_ids[0, idx2:]], dim=0)
192
+ .unsqueeze(0)
193
+ .to("cuda")
194
+ )
195
+ inputs.attention_mask = (inputs.input_ids > 0).long().to("cuda")
196
+ outputs = self.model(
197
+ input_ids=inputs.input_ids,
198
+ attention_mask=inputs.attention_mask,
199
+ pixel_values=inputs.pixel_values.to("cuda"),
200
+ image_grid_thw=inputs.image_grid_thw.to("cuda"),
201
+ output_hidden_states=True,
202
+ )
203
+
204
+ emb = outputs["hidden_states"][-1]
205
+
206
+ embs[idx, : min(self.max_length, emb.shape[1] - 217)] = emb[0, 217:][
207
+ : self.max_length
208
+ ]
209
+
210
+ masks[idx, : min(self.max_length, emb.shape[1] - 217)] = torch.ones(
211
+ (min(self.max_length, emb.shape[1] - 217)),
212
+ dtype=torch.long,
213
+ device=torch.cuda.current_device(),
214
+ )
215
+
216
+ return embs, masks
modules/connector_edit.py ADDED
@@ -0,0 +1,486 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ import torch
4
+ import torch.nn
5
+ from einops import rearrange
6
+ from torch import nn
7
+
8
+ from .layers import MLP, TextProjection, TimestepEmbedder, apply_gate, attention
9
+
10
+
11
+ class RMSNorm(nn.Module):
12
+ def __init__(
13
+ self,
14
+ dim: int,
15
+ elementwise_affine=True,
16
+ eps: float = 1e-6,
17
+ device=None,
18
+ dtype=None,
19
+ ):
20
+ """
21
+ Initialize the RMSNorm normalization layer.
22
+
23
+ Args:
24
+ dim (int): The dimension of the input tensor.
25
+ eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
26
+
27
+ Attributes:
28
+ eps (float): A small value added to the denominator for numerical stability.
29
+ weight (nn.Parameter): Learnable scaling parameter.
30
+
31
+ """
32
+ factory_kwargs = {"device": device, "dtype": dtype}
33
+ super().__init__()
34
+ self.eps = eps
35
+ if elementwise_affine:
36
+ self.weight = nn.Parameter(torch.ones(dim, **factory_kwargs))
37
+
38
+ def _norm(self, x):
39
+ """
40
+ Apply the RMSNorm normalization to the input tensor.
41
+
42
+ Args:
43
+ x (torch.Tensor): The input tensor.
44
+
45
+ Returns:
46
+ torch.Tensor: The normalized tensor.
47
+
48
+ """
49
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
50
+
51
+ def forward(self, x):
52
+ """
53
+ Forward pass through the RMSNorm layer.
54
+
55
+ Args:
56
+ x (torch.Tensor): The input tensor.
57
+
58
+ Returns:
59
+ torch.Tensor: The output tensor after applying RMSNorm.
60
+
61
+ """
62
+ output = self._norm(x.float()).type_as(x)
63
+ if hasattr(self, "weight"):
64
+ output = output * self.weight
65
+ return output
66
+
67
+
68
+ def get_norm_layer(norm_layer):
69
+ """
70
+ Get the normalization layer.
71
+
72
+ Args:
73
+ norm_layer (str): The type of normalization layer.
74
+
75
+ Returns:
76
+ norm_layer (nn.Module): The normalization layer.
77
+ """
78
+ if norm_layer == "layer":
79
+ return nn.LayerNorm
80
+ elif norm_layer == "rms":
81
+ return RMSNorm
82
+ else:
83
+ raise NotImplementedError(f"Norm layer {norm_layer} is not implemented")
84
+
85
+
86
+ def get_activation_layer(act_type):
87
+ """get activation layer
88
+
89
+ Args:
90
+ act_type (str): the activation type
91
+
92
+ Returns:
93
+ torch.nn.functional: the activation layer
94
+ """
95
+ if act_type == "gelu":
96
+ return lambda: nn.GELU()
97
+ elif act_type == "gelu_tanh":
98
+ return lambda: nn.GELU(approximate="tanh")
99
+ elif act_type == "relu":
100
+ return nn.ReLU
101
+ elif act_type == "silu":
102
+ return nn.SiLU
103
+ else:
104
+ raise ValueError(f"Unknown activation type: {act_type}")
105
+
106
+ class IndividualTokenRefinerBlock(torch.nn.Module):
107
+ def __init__(
108
+ self,
109
+ hidden_size,
110
+ heads_num,
111
+ mlp_width_ratio: str = 4.0,
112
+ mlp_drop_rate: float = 0.0,
113
+ act_type: str = "silu",
114
+ qk_norm: bool = False,
115
+ qk_norm_type: str = "layer",
116
+ qkv_bias: bool = True,
117
+ need_CA: bool = False,
118
+ dtype: Optional[torch.dtype] = None,
119
+ device: Optional[torch.device] = None,
120
+ ):
121
+ factory_kwargs = {"device": device, "dtype": dtype}
122
+ super().__init__()
123
+ self.need_CA = need_CA
124
+ self.heads_num = heads_num
125
+ head_dim = hidden_size // heads_num
126
+ mlp_hidden_dim = int(hidden_size * mlp_width_ratio)
127
+
128
+ self.norm1 = nn.LayerNorm(
129
+ hidden_size, elementwise_affine=True, eps=1e-6, **factory_kwargs
130
+ )
131
+ self.self_attn_qkv = nn.Linear(
132
+ hidden_size, hidden_size * 3, bias=qkv_bias, **factory_kwargs
133
+ )
134
+ qk_norm_layer = get_norm_layer(qk_norm_type)
135
+ self.self_attn_q_norm = (
136
+ qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
137
+ if qk_norm
138
+ else nn.Identity()
139
+ )
140
+ self.self_attn_k_norm = (
141
+ qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
142
+ if qk_norm
143
+ else nn.Identity()
144
+ )
145
+ self.self_attn_proj = nn.Linear(
146
+ hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs
147
+ )
148
+
149
+ self.norm2 = nn.LayerNorm(
150
+ hidden_size, elementwise_affine=True, eps=1e-6, **factory_kwargs
151
+ )
152
+ act_layer = get_activation_layer(act_type)
153
+ self.mlp = MLP(
154
+ in_channels=hidden_size,
155
+ hidden_channels=mlp_hidden_dim,
156
+ act_layer=act_layer,
157
+ drop=mlp_drop_rate,
158
+ **factory_kwargs,
159
+ )
160
+
161
+ self.adaLN_modulation = nn.Sequential(
162
+ act_layer(),
163
+ nn.Linear(hidden_size, 2 * hidden_size, bias=True, **factory_kwargs),
164
+ )
165
+
166
+ if self.need_CA:
167
+ self.cross_attnblock=CrossAttnBlock(hidden_size=hidden_size,
168
+ heads_num=heads_num,
169
+ mlp_width_ratio=mlp_width_ratio,
170
+ mlp_drop_rate=mlp_drop_rate,
171
+ act_type=act_type,
172
+ qk_norm=qk_norm,
173
+ qk_norm_type=qk_norm_type,
174
+ qkv_bias=qkv_bias,
175
+ **factory_kwargs,)
176
+ # Zero-initialize the modulation
177
+ nn.init.zeros_(self.adaLN_modulation[1].weight)
178
+ nn.init.zeros_(self.adaLN_modulation[1].bias)
179
+
180
+ def forward(
181
+ self,
182
+ x: torch.Tensor,
183
+ c: torch.Tensor, # timestep_aware_representations + context_aware_representations
184
+ attn_mask: torch.Tensor = None,
185
+ y: torch.Tensor = None,
186
+ ):
187
+ gate_msa, gate_mlp = self.adaLN_modulation(c).chunk(2, dim=1)
188
+
189
+ norm_x = self.norm1(x)
190
+ qkv = self.self_attn_qkv(norm_x)
191
+ q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num)
192
+ # Apply QK-Norm if needed
193
+ q = self.self_attn_q_norm(q).to(v)
194
+ k = self.self_attn_k_norm(k).to(v)
195
+
196
+ # Self-Attention
197
+ attn = attention(q, k, v, mode="torch", attn_mask=attn_mask)
198
+
199
+ x = x + apply_gate(self.self_attn_proj(attn), gate_msa)
200
+
201
+ if self.need_CA:
202
+ x = self.cross_attnblock(x, c, attn_mask, y)
203
+
204
+ # FFN Layer
205
+ x = x + apply_gate(self.mlp(self.norm2(x)), gate_mlp)
206
+
207
+ return x
208
+
209
+
210
+
211
+
212
+ class CrossAttnBlock(torch.nn.Module):
213
+ def __init__(
214
+ self,
215
+ hidden_size,
216
+ heads_num,
217
+ mlp_width_ratio: str = 4.0,
218
+ mlp_drop_rate: float = 0.0,
219
+ act_type: str = "silu",
220
+ qk_norm: bool = False,
221
+ qk_norm_type: str = "layer",
222
+ qkv_bias: bool = True,
223
+ dtype: Optional[torch.dtype] = None,
224
+ device: Optional[torch.device] = None,
225
+ ):
226
+ factory_kwargs = {"device": device, "dtype": dtype}
227
+ super().__init__()
228
+ self.heads_num = heads_num
229
+ head_dim = hidden_size // heads_num
230
+
231
+ self.norm1 = nn.LayerNorm(
232
+ hidden_size, elementwise_affine=True, eps=1e-6, **factory_kwargs
233
+ )
234
+ self.norm1_2 = nn.LayerNorm(
235
+ hidden_size, elementwise_affine=True, eps=1e-6, **factory_kwargs
236
+ )
237
+ self.self_attn_q = nn.Linear(
238
+ hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs
239
+ )
240
+ self.self_attn_kv = nn.Linear(
241
+ hidden_size, hidden_size*2, bias=qkv_bias, **factory_kwargs
242
+ )
243
+ qk_norm_layer = get_norm_layer(qk_norm_type)
244
+ self.self_attn_q_norm = (
245
+ qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
246
+ if qk_norm
247
+ else nn.Identity()
248
+ )
249
+ self.self_attn_k_norm = (
250
+ qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
251
+ if qk_norm
252
+ else nn.Identity()
253
+ )
254
+ self.self_attn_proj = nn.Linear(
255
+ hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs
256
+ )
257
+
258
+ self.norm2 = nn.LayerNorm(
259
+ hidden_size, elementwise_affine=True, eps=1e-6, **factory_kwargs
260
+ )
261
+ act_layer = get_activation_layer(act_type)
262
+
263
+ self.adaLN_modulation = nn.Sequential(
264
+ act_layer(),
265
+ nn.Linear(hidden_size, 2 * hidden_size, bias=True, **factory_kwargs),
266
+ )
267
+ # Zero-initialize the modulation
268
+ nn.init.zeros_(self.adaLN_modulation[1].weight)
269
+ nn.init.zeros_(self.adaLN_modulation[1].bias)
270
+
271
+ def forward(
272
+ self,
273
+ x: torch.Tensor,
274
+ c: torch.Tensor, # timestep_aware_representations + context_aware_representations
275
+ attn_mask: torch.Tensor = None,
276
+ y: torch.Tensor=None,
277
+
278
+ ):
279
+ gate_msa, gate_mlp = self.adaLN_modulation(c).chunk(2, dim=1)
280
+
281
+ norm_x = self.norm1(x)
282
+ norm_y = self.norm1_2(y)
283
+ q = self.self_attn_q(norm_x)
284
+ q = rearrange(q, "B L (H D) -> B L H D", H=self.heads_num)
285
+ kv = self.self_attn_kv(norm_y)
286
+ k, v = rearrange(kv, "B L (K H D) -> K B L H D", K=2, H=self.heads_num)
287
+ # Apply QK-Norm if needed
288
+ q = self.self_attn_q_norm(q).to(v)
289
+ k = self.self_attn_k_norm(k).to(v)
290
+
291
+ # Self-Attention
292
+ attn = attention(q, k, v, mode="torch", attn_mask=attn_mask)
293
+
294
+ x = x + apply_gate(self.self_attn_proj(attn), gate_msa)
295
+
296
+ return x
297
+
298
+
299
+
300
+ class IndividualTokenRefiner(torch.nn.Module):
301
+ def __init__(
302
+ self,
303
+ hidden_size,
304
+ heads_num,
305
+ depth,
306
+ mlp_width_ratio: float = 4.0,
307
+ mlp_drop_rate: float = 0.0,
308
+ act_type: str = "silu",
309
+ qk_norm: bool = False,
310
+ qk_norm_type: str = "layer",
311
+ qkv_bias: bool = True,
312
+ need_CA:bool=False,
313
+ dtype: Optional[torch.dtype] = None,
314
+ device: Optional[torch.device] = None,
315
+ ):
316
+
317
+ factory_kwargs = {"device": device, "dtype": dtype}
318
+ super().__init__()
319
+ self.need_CA = need_CA
320
+ self.blocks = nn.ModuleList(
321
+ [
322
+ IndividualTokenRefinerBlock(
323
+ hidden_size=hidden_size,
324
+ heads_num=heads_num,
325
+ mlp_width_ratio=mlp_width_ratio,
326
+ mlp_drop_rate=mlp_drop_rate,
327
+ act_type=act_type,
328
+ qk_norm=qk_norm,
329
+ qk_norm_type=qk_norm_type,
330
+ qkv_bias=qkv_bias,
331
+ need_CA=self.need_CA,
332
+ **factory_kwargs,
333
+ )
334
+ for _ in range(depth)
335
+ ]
336
+ )
337
+
338
+
339
+ def forward(
340
+ self,
341
+ x: torch.Tensor,
342
+ c: torch.LongTensor,
343
+ mask: Optional[torch.Tensor] = None,
344
+ y:torch.Tensor=None,
345
+ ):
346
+ self_attn_mask = None
347
+ if mask is not None:
348
+ batch_size = mask.shape[0]
349
+ seq_len = mask.shape[1]
350
+ mask = mask.to(x.device)
351
+ # batch_size x 1 x seq_len x seq_len
352
+ self_attn_mask_1 = mask.view(batch_size, 1, 1, seq_len).repeat(
353
+ 1, 1, seq_len, 1
354
+ )
355
+ # batch_size x 1 x seq_len x seq_len
356
+ self_attn_mask_2 = self_attn_mask_1.transpose(2, 3)
357
+ # batch_size x 1 x seq_len x seq_len, 1 for broadcasting of heads_num
358
+ self_attn_mask = (self_attn_mask_1 & self_attn_mask_2).bool()
359
+ # avoids self-attention weight being NaN for padding tokens
360
+ self_attn_mask[:, :, :, 0] = True
361
+
362
+
363
+ for block in self.blocks:
364
+ x = block(x, c, self_attn_mask,y)
365
+
366
+ return x
367
+
368
+
369
+ class SingleTokenRefiner(torch.nn.Module):
370
+ """
371
+ A single token refiner block for llm text embedding refine.
372
+ """
373
+ def __init__(
374
+ self,
375
+ in_channels,
376
+ hidden_size,
377
+ heads_num,
378
+ depth,
379
+ mlp_width_ratio: float = 4.0,
380
+ mlp_drop_rate: float = 0.0,
381
+ act_type: str = "silu",
382
+ qk_norm: bool = False,
383
+ qk_norm_type: str = "layer",
384
+ qkv_bias: bool = True,
385
+ need_CA:bool=False,
386
+ attn_mode: str = "torch",
387
+ dtype: Optional[torch.dtype] = None,
388
+ device: Optional[torch.device] = None,
389
+ ):
390
+ factory_kwargs = {"device": device, "dtype": dtype}
391
+ super().__init__()
392
+ self.attn_mode = attn_mode
393
+ self.need_CA = need_CA
394
+ assert self.attn_mode == "torch", "Only support 'torch' mode for token refiner."
395
+
396
+ self.input_embedder = nn.Linear(
397
+ in_channels, hidden_size, bias=True, **factory_kwargs
398
+ )
399
+ if self.need_CA:
400
+ self.input_embedder_CA = nn.Linear(
401
+ in_channels, hidden_size, bias=True, **factory_kwargs
402
+ )
403
+
404
+ act_layer = get_activation_layer(act_type)
405
+ # Build timestep embedding layer
406
+ self.t_embedder = TimestepEmbedder(hidden_size, act_layer, **factory_kwargs)
407
+ # Build context embedding layer
408
+ self.c_embedder = TextProjection(
409
+ in_channels, hidden_size, act_layer, **factory_kwargs
410
+ )
411
+
412
+ self.individual_token_refiner = IndividualTokenRefiner(
413
+ hidden_size=hidden_size,
414
+ heads_num=heads_num,
415
+ depth=depth,
416
+ mlp_width_ratio=mlp_width_ratio,
417
+ mlp_drop_rate=mlp_drop_rate,
418
+ act_type=act_type,
419
+ qk_norm=qk_norm,
420
+ qk_norm_type=qk_norm_type,
421
+ qkv_bias=qkv_bias,
422
+ need_CA=need_CA,
423
+ **factory_kwargs,
424
+ )
425
+
426
+ def forward(
427
+ self,
428
+ x: torch.Tensor,
429
+ t: torch.LongTensor,
430
+ mask: Optional[torch.LongTensor] = None,
431
+ y: torch.LongTensor=None,
432
+ ):
433
+ timestep_aware_representations = self.t_embedder(t)
434
+
435
+ if mask is None:
436
+ context_aware_representations = x.mean(dim=1)
437
+ else:
438
+ mask_float = mask.unsqueeze(-1) # [b, s1, 1]
439
+ context_aware_representations = (x * mask_float).sum(
440
+ dim=1
441
+ ) / mask_float.sum(dim=1)
442
+ context_aware_representations = self.c_embedder(context_aware_representations)
443
+ c = timestep_aware_representations + context_aware_representations
444
+
445
+ x = self.input_embedder(x)
446
+ if self.need_CA:
447
+ y = self.input_embedder_CA(y)
448
+ x = self.individual_token_refiner(x, c, mask, y)
449
+ else:
450
+ x = self.individual_token_refiner(x, c, mask)
451
+
452
+ return x
453
+
454
+
455
+
456
+ class Qwen2Connector(torch.nn.Module):
457
+ def __init__(
458
+ self,
459
+ # biclip_dim=1024,
460
+ in_channels=3584,
461
+ hidden_size=4096,
462
+ heads_num=32,
463
+ depth=2,
464
+ need_CA=False,
465
+ device=None,
466
+ dtype=torch.bfloat16,
467
+ ):
468
+ super().__init__()
469
+ factory_kwargs = {"device": device, "dtype":dtype}
470
+
471
+ self.S =SingleTokenRefiner(in_channels=in_channels,hidden_size=hidden_size,heads_num=heads_num,depth=depth,need_CA=need_CA,**factory_kwargs)
472
+ self.global_proj_out=nn.Linear(in_channels,768)
473
+
474
+ self.scale_factor = nn.Parameter(torch.zeros(1))
475
+ with torch.no_grad():
476
+ self.scale_factor.data += -(1 - 0.09)
477
+
478
+ def forward(self, x,t,mask):
479
+ mask_float = mask.unsqueeze(-1) # [b, s1, 1]
480
+ x_mean = (x * mask_float).sum(
481
+ dim=1
482
+ ) / mask_float.sum(dim=1) * (1 + self.scale_factor)
483
+
484
+ global_out=self.global_proj_out(x_mean)
485
+ encoder_hidden_states = self.S(x,t,mask)
486
+ return encoder_hidden_states,global_out
modules/layers.py ADDED
@@ -0,0 +1,639 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from Flux
2
+ #
3
+ # Copyright 2024 Black Forest Labs
4
+
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ #
17
+ # This source code is licensed under the license found in the
18
+ # LICENSE file in the root directory of this source tree.
19
+
20
+ import math # noqa: I001
21
+ from dataclasses import dataclass
22
+ from functools import partial
23
+
24
+ import torch
25
+ import torch.nn.functional as F
26
+ from einops import rearrange
27
+ from liger_kernel.ops.rms_norm import LigerRMSNormFunction
28
+ from torch import Tensor, nn
29
+
30
+
31
+ try:
32
+ import flash_attn
33
+ from flash_attn.flash_attn_interface import (
34
+ _flash_attn_forward,
35
+ flash_attn_varlen_func,
36
+ )
37
+ except ImportError:
38
+ flash_attn = None
39
+ flash_attn_varlen_func = None
40
+ _flash_attn_forward = None
41
+
42
+
43
+ MEMORY_LAYOUT = {
44
+ "flash": (
45
+ lambda x: x.view(x.shape[0] * x.shape[1], *x.shape[2:]),
46
+ lambda x: x,
47
+ ),
48
+ "torch": (
49
+ lambda x: x.transpose(1, 2),
50
+ lambda x: x.transpose(1, 2),
51
+ ),
52
+ "vanilla": (
53
+ lambda x: x.transpose(1, 2),
54
+ lambda x: x.transpose(1, 2),
55
+ ),
56
+ }
57
+
58
+
59
+ def attention(
60
+ q,
61
+ k,
62
+ v,
63
+ mode="flash",
64
+ drop_rate=0,
65
+ attn_mask=None,
66
+ causal=False,
67
+ cu_seqlens_q=None,
68
+ cu_seqlens_kv=None,
69
+ max_seqlen_q=None,
70
+ max_seqlen_kv=None,
71
+ batch_size=1,
72
+ ):
73
+ """
74
+ Perform QKV self attention.
75
+
76
+ Args:
77
+ q (torch.Tensor): Query tensor with shape [b, s, a, d], where a is the number of heads.
78
+ k (torch.Tensor): Key tensor with shape [b, s1, a, d]
79
+ v (torch.Tensor): Value tensor with shape [b, s1, a, d]
80
+ mode (str): Attention mode. Choose from 'self_flash', 'cross_flash', 'torch', and 'vanilla'.
81
+ drop_rate (float): Dropout rate in attention map. (default: 0)
82
+ attn_mask (torch.Tensor): Attention mask with shape [b, s1] (cross_attn), or [b, a, s, s1] (torch or vanilla).
83
+ (default: None)
84
+ causal (bool): Whether to use causal attention. (default: False)
85
+ cu_seqlens_q (torch.Tensor): dtype torch.int32. The cumulative sequence lengths of the sequences in the batch,
86
+ used to index into q.
87
+ cu_seqlens_kv (torch.Tensor): dtype torch.int32. The cumulative sequence lengths of the sequences in the batch,
88
+ used to index into kv.
89
+ max_seqlen_q (int): The maximum sequence length in the batch of q.
90
+ max_seqlen_kv (int): The maximum sequence length in the batch of k and v.
91
+
92
+ Returns:
93
+ torch.Tensor: Output tensor after self attention with shape [b, s, ad]
94
+ """
95
+ pre_attn_layout, post_attn_layout = MEMORY_LAYOUT[mode]
96
+ q = pre_attn_layout(q)
97
+ k = pre_attn_layout(k)
98
+ v = pre_attn_layout(v)
99
+
100
+ if mode == "torch":
101
+ if attn_mask is not None and attn_mask.dtype != torch.bool:
102
+ attn_mask = attn_mask.to(q.dtype)
103
+ x = F.scaled_dot_product_attention(
104
+ q, k, v, attn_mask=attn_mask, dropout_p=drop_rate, is_causal=causal
105
+ )
106
+ elif mode == "flash":
107
+ assert flash_attn_varlen_func is not None
108
+ x: torch.Tensor = flash_attn_varlen_func(
109
+ q,
110
+ k,
111
+ v,
112
+ cu_seqlens_q,
113
+ cu_seqlens_kv,
114
+ max_seqlen_q,
115
+ max_seqlen_kv,
116
+ ) # type: ignore
117
+ # x with shape [(bxs), a, d]
118
+ x = x.view(batch_size, max_seqlen_q, x.shape[-2], x.shape[-1]) # type: ignore # reshape x to [b, s, a, d]
119
+ elif mode == "vanilla":
120
+ scale_factor = 1 / math.sqrt(q.size(-1))
121
+
122
+ b, a, s, _ = q.shape
123
+ s1 = k.size(2)
124
+ attn_bias = torch.zeros(b, a, s, s1, dtype=q.dtype, device=q.device)
125
+ if causal:
126
+ # Only applied to self attention
127
+ assert attn_mask is None, (
128
+ "Causal mask and attn_mask cannot be used together"
129
+ )
130
+ temp_mask = torch.ones(b, a, s, s, dtype=torch.bool, device=q.device).tril(
131
+ diagonal=0
132
+ )
133
+ attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
134
+ attn_bias.to(q.dtype)
135
+
136
+ if attn_mask is not None:
137
+ if attn_mask.dtype == torch.bool:
138
+ attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
139
+ else:
140
+ attn_bias += attn_mask
141
+
142
+ # TODO: Maybe force q and k to be float32 to avoid numerical overflow
143
+ attn = (q @ k.transpose(-2, -1)) * scale_factor
144
+ attn += attn_bias
145
+ attn = attn.softmax(dim=-1)
146
+ attn = torch.dropout(attn, p=drop_rate, train=True)
147
+ x = attn @ v
148
+ else:
149
+ raise NotImplementedError(f"Unsupported attention mode: {mode}")
150
+
151
+ x = post_attn_layout(x)
152
+ b, s, a, d = x.shape
153
+ out = x.reshape(b, s, -1)
154
+ return out
155
+
156
+
157
+ def apply_gate(x, gate=None, tanh=False):
158
+ """AI is creating summary for apply_gate
159
+
160
+ Args:
161
+ x (torch.Tensor): input tensor.
162
+ gate (torch.Tensor, optional): gate tensor. Defaults to None.
163
+ tanh (bool, optional): whether to use tanh function. Defaults to False.
164
+
165
+ Returns:
166
+ torch.Tensor: the output tensor after apply gate.
167
+ """
168
+ if gate is None:
169
+ return x
170
+ if tanh:
171
+ return x * gate.unsqueeze(1).tanh()
172
+ else:
173
+ return x * gate.unsqueeze(1)
174
+
175
+
176
+ class MLP(nn.Module):
177
+ """MLP as used in Vision Transformer, MLP-Mixer and related networks"""
178
+
179
+ def __init__(
180
+ self,
181
+ in_channels,
182
+ hidden_channels=None,
183
+ out_features=None,
184
+ act_layer=nn.GELU,
185
+ norm_layer=None,
186
+ bias=True,
187
+ drop=0.0,
188
+ use_conv=False,
189
+ device=None,
190
+ dtype=None,
191
+ ):
192
+ super().__init__()
193
+ out_features = out_features or in_channels
194
+ hidden_channels = hidden_channels or in_channels
195
+ bias = (bias, bias)
196
+ drop_probs = (drop, drop)
197
+ linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear
198
+
199
+ self.fc1 = linear_layer(
200
+ in_channels, hidden_channels, bias=bias[0], device=device, dtype=dtype
201
+ )
202
+ self.act = act_layer()
203
+ self.drop1 = nn.Dropout(drop_probs[0])
204
+ self.norm = (
205
+ norm_layer(hidden_channels, device=device, dtype=dtype)
206
+ if norm_layer is not None
207
+ else nn.Identity()
208
+ )
209
+ self.fc2 = linear_layer(
210
+ hidden_channels, out_features, bias=bias[1], device=device, dtype=dtype
211
+ )
212
+ self.drop2 = nn.Dropout(drop_probs[1])
213
+
214
+ def forward(self, x):
215
+ x = self.fc1(x)
216
+ x = self.act(x)
217
+ x = self.drop1(x)
218
+ x = self.norm(x)
219
+ x = self.fc2(x)
220
+ x = self.drop2(x)
221
+ return x
222
+
223
+
224
+ class TextProjection(nn.Module):
225
+ """
226
+ Projects text embeddings. Also handles dropout for classifier-free guidance.
227
+
228
+ Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py
229
+ """
230
+
231
+ def __init__(self, in_channels, hidden_size, act_layer, dtype=None, device=None):
232
+ factory_kwargs = {"dtype": dtype, "device": device}
233
+ super().__init__()
234
+ self.linear_1 = nn.Linear(
235
+ in_features=in_channels,
236
+ out_features=hidden_size,
237
+ bias=True,
238
+ **factory_kwargs,
239
+ )
240
+ self.act_1 = act_layer()
241
+ self.linear_2 = nn.Linear(
242
+ in_features=hidden_size,
243
+ out_features=hidden_size,
244
+ bias=True,
245
+ **factory_kwargs,
246
+ )
247
+
248
+ def forward(self, caption):
249
+ hidden_states = self.linear_1(caption)
250
+ hidden_states = self.act_1(hidden_states)
251
+ hidden_states = self.linear_2(hidden_states)
252
+ return hidden_states
253
+
254
+
255
+ class TimestepEmbedder(nn.Module):
256
+ """
257
+ Embeds scalar timesteps into vector representations.
258
+ """
259
+
260
+ def __init__(
261
+ self,
262
+ hidden_size,
263
+ act_layer,
264
+ frequency_embedding_size=256,
265
+ max_period=10000,
266
+ out_size=None,
267
+ dtype=None,
268
+ device=None,
269
+ ):
270
+ factory_kwargs = {"dtype": dtype, "device": device}
271
+ super().__init__()
272
+ self.frequency_embedding_size = frequency_embedding_size
273
+ self.max_period = max_period
274
+ if out_size is None:
275
+ out_size = hidden_size
276
+
277
+ self.mlp = nn.Sequential(
278
+ nn.Linear(
279
+ frequency_embedding_size, hidden_size, bias=True, **factory_kwargs
280
+ ),
281
+ act_layer(),
282
+ nn.Linear(hidden_size, out_size, bias=True, **factory_kwargs),
283
+ )
284
+ nn.init.normal_(self.mlp[0].weight, std=0.02) # type: ignore
285
+ nn.init.normal_(self.mlp[2].weight, std=0.02) # type: ignore
286
+
287
+ @staticmethod
288
+ def timestep_embedding(t, dim, max_period=10000):
289
+ """
290
+ Create sinusoidal timestep embeddings.
291
+
292
+ Args:
293
+ t (torch.Tensor): a 1-D Tensor of N indices, one per batch element. These may be fractional.
294
+ dim (int): the dimension of the output.
295
+ max_period (int): controls the minimum frequency of the embeddings.
296
+
297
+ Returns:
298
+ embedding (torch.Tensor): An (N, D) Tensor of positional embeddings.
299
+
300
+ .. ref_link: https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
301
+ """
302
+ half = dim // 2
303
+ freqs = torch.exp(
304
+ -math.log(max_period)
305
+ * torch.arange(start=0, end=half, dtype=torch.float32)
306
+ / half
307
+ ).to(device=t.device)
308
+ args = t[:, None].float() * freqs[None]
309
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
310
+ if dim % 2:
311
+ embedding = torch.cat(
312
+ [embedding, torch.zeros_like(embedding[:, :1])], dim=-1
313
+ )
314
+ return embedding
315
+
316
+ def forward(self, t):
317
+ t_freq = self.timestep_embedding(
318
+ t, self.frequency_embedding_size, self.max_period
319
+ ).type(self.mlp[0].weight.dtype) # type: ignore
320
+ t_emb = self.mlp(t_freq)
321
+ return t_emb
322
+
323
+
324
+ class EmbedND(nn.Module):
325
+ def __init__(self, dim: int, theta: int, axes_dim: list[int]):
326
+ super().__init__()
327
+ self.dim = dim
328
+ self.theta = theta
329
+ self.axes_dim = axes_dim
330
+
331
+ def forward(self, ids: Tensor) -> Tensor:
332
+ n_axes = ids.shape[-1]
333
+ emb = torch.cat(
334
+ [rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
335
+ dim=-3,
336
+ )
337
+
338
+ return emb.unsqueeze(1)
339
+
340
+
341
+ class MLPEmbedder(nn.Module):
342
+ def __init__(self, in_dim: int, hidden_dim: int):
343
+ super().__init__()
344
+ self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True)
345
+ self.silu = nn.SiLU()
346
+ self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True)
347
+
348
+ def forward(self, x: Tensor) -> Tensor:
349
+ return self.out_layer(self.silu(self.in_layer(x)))
350
+
351
+
352
+ def rope(pos, dim: int, theta: int):
353
+ assert dim % 2 == 0
354
+ scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
355
+ omega = 1.0 / (theta**scale)
356
+ out = torch.einsum("...n,d->...nd", pos, omega)
357
+ out = torch.stack(
358
+ [torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1
359
+ )
360
+ out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
361
+ return out.float()
362
+
363
+
364
+ def attention_after_rope(q, k, v, pe):
365
+ q, k = apply_rope(q, k, pe)
366
+
367
+ from .attention import attention
368
+
369
+ x = attention(q, k, v, mode="flash")
370
+ return x
371
+
372
+
373
+ @torch.compile(mode="max-autotune-no-cudagraphs", dynamic=True)
374
+ def apply_rope(xq, xk, freqs_cis):
375
+ # 将 num_heads 和 seq_len 的维度交换回原函数的处理顺序
376
+ xq = xq.transpose(1, 2) # [batch, num_heads, seq_len, head_dim]
377
+ xk = xk.transpose(1, 2)
378
+
379
+ # 将 head_dim 拆分为复数部分(实部和虚部)
380
+ xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
381
+ xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
382
+
383
+ # 应用旋转位置编码(复数乘法)
384
+ xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
385
+ xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
386
+
387
+ # 恢复张量形状并转置回目标维度顺序
388
+ xq_out = xq_out.reshape(*xq.shape).type_as(xq).transpose(1, 2)
389
+ xk_out = xk_out.reshape(*xk.shape).type_as(xk).transpose(1, 2)
390
+
391
+ return xq_out, xk_out
392
+
393
+
394
+ @torch.compile(mode="max-autotune-no-cudagraphs", dynamic=True)
395
+ def scale_add_residual(
396
+ x: torch.Tensor, scale: torch.Tensor, residual: torch.Tensor
397
+ ) -> torch.Tensor:
398
+ return x * scale + residual
399
+
400
+
401
+ @torch.compile(mode="max-autotune-no-cudagraphs", dynamic=True)
402
+ def layernorm_and_scale_shift(
403
+ x: torch.Tensor, scale: torch.Tensor, shift: torch.Tensor
404
+ ) -> torch.Tensor:
405
+ return torch.nn.functional.layer_norm(x, (x.size(-1),)) * (scale + 1) + shift
406
+
407
+
408
+ class SelfAttention(nn.Module):
409
+ def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False):
410
+ super().__init__()
411
+ self.num_heads = num_heads
412
+ head_dim = dim // num_heads
413
+
414
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
415
+ self.norm = QKNorm(head_dim)
416
+ self.proj = nn.Linear(dim, dim)
417
+
418
+ def forward(self, x: Tensor, pe: Tensor) -> Tensor:
419
+ qkv = self.qkv(x)
420
+ q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.num_heads)
421
+ q, k = self.norm(q, k, v)
422
+ x = attention_after_rope(q, k, v, pe=pe)
423
+ x = self.proj(x)
424
+ return x
425
+
426
+
427
+ @dataclass
428
+ class ModulationOut:
429
+ shift: Tensor
430
+ scale: Tensor
431
+ gate: Tensor
432
+
433
+
434
+ class RMSNorm(torch.nn.Module):
435
+ def __init__(self, dim: int):
436
+ super().__init__()
437
+ self.scale = nn.Parameter(torch.ones(dim))
438
+
439
+ @staticmethod
440
+ def rms_norm_fast(x, weight, eps):
441
+ return LigerRMSNormFunction.apply(
442
+ x,
443
+ weight,
444
+ eps,
445
+ 0.0,
446
+ "gemma",
447
+ True,
448
+ )
449
+
450
+ @staticmethod
451
+ def rms_norm(x, weight, eps):
452
+ x_dtype = x.dtype
453
+ x = x.float()
454
+ rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + eps)
455
+ return (x * rrms).to(dtype=x_dtype) * weight
456
+
457
+ def forward(self, x: Tensor):
458
+ return self.rms_norm_fast(x, self.scale, 1e-6)
459
+
460
+
461
+ class QKNorm(torch.nn.Module):
462
+ def __init__(self, dim: int):
463
+ super().__init__()
464
+ self.query_norm = RMSNorm(dim)
465
+ self.key_norm = RMSNorm(dim)
466
+
467
+ def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple[Tensor, Tensor]:
468
+ q = self.query_norm(q)
469
+ k = self.key_norm(k)
470
+ return q.to(v), k.to(v)
471
+
472
+
473
+ class Modulation(nn.Module):
474
+ def __init__(self, dim: int, double: bool):
475
+ super().__init__()
476
+ self.is_double = double
477
+ self.multiplier = 6 if double else 3
478
+ self.lin = nn.Linear(dim, self.multiplier * dim, bias=True)
479
+
480
+ def forward(self, vec: Tensor) -> tuple[ModulationOut, ModulationOut | None]:
481
+ out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(
482
+ self.multiplier, dim=-1
483
+ )
484
+
485
+ return (
486
+ ModulationOut(*out[:3]),
487
+ ModulationOut(*out[3:]) if self.is_double else None,
488
+ )
489
+
490
+
491
+ class DoubleStreamBlock(nn.Module):
492
+ def __init__(
493
+ self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False
494
+ ):
495
+ super().__init__()
496
+
497
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
498
+ self.num_heads = num_heads
499
+ self.hidden_size = hidden_size
500
+ self.img_mod = Modulation(hidden_size, double=True)
501
+ self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
502
+ self.img_attn = SelfAttention(
503
+ dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias
504
+ )
505
+
506
+ self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
507
+ self.img_mlp = nn.Sequential(
508
+ nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
509
+ nn.GELU(approximate="tanh"),
510
+ nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
511
+ )
512
+
513
+ self.txt_mod = Modulation(hidden_size, double=True)
514
+ self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
515
+ self.txt_attn = SelfAttention(
516
+ dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias
517
+ )
518
+
519
+ self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
520
+ self.txt_mlp = nn.Sequential(
521
+ nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
522
+ nn.GELU(approximate="tanh"),
523
+ nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
524
+ )
525
+
526
+ def forward(
527
+ self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor
528
+ ) -> tuple[Tensor, Tensor]:
529
+ img_mod1, img_mod2 = self.img_mod(vec)
530
+ txt_mod1, txt_mod2 = self.txt_mod(vec)
531
+
532
+ # prepare image for attention
533
+ img_modulated = self.img_norm1(img)
534
+ img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
535
+ img_qkv = self.img_attn.qkv(img_modulated)
536
+ img_q, img_k, img_v = rearrange(
537
+ img_qkv, "B L (K H D) -> K B L H D", K=3, H=self.num_heads
538
+ )
539
+ img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
540
+
541
+ # prepare txt for attention
542
+ txt_modulated = self.txt_norm1(txt)
543
+ txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
544
+ txt_qkv = self.txt_attn.qkv(txt_modulated)
545
+ txt_q, txt_k, txt_v = rearrange(
546
+ txt_qkv, "B L (K H D) -> K B L H D", K=3, H=self.num_heads
547
+ )
548
+ txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
549
+
550
+ # run actual attention
551
+ q = torch.cat((txt_q, img_q), dim=1)
552
+ k = torch.cat((txt_k, img_k), dim=1)
553
+ v = torch.cat((txt_v, img_v), dim=1)
554
+
555
+ attn = attention_after_rope(q, k, v, pe=pe)
556
+ txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
557
+
558
+ # calculate the img bloks
559
+ img = img + img_mod1.gate * self.img_attn.proj(img_attn)
560
+ img_mlp = self.img_mlp(
561
+ (1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift
562
+ )
563
+ img = scale_add_residual(img_mlp, img_mod2.gate, img)
564
+
565
+ # calculate the txt bloks
566
+ txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn)
567
+ txt_mlp = self.txt_mlp(
568
+ (1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift
569
+ )
570
+ txt = scale_add_residual(txt_mlp, txt_mod2.gate, txt)
571
+ return img, txt
572
+
573
+
574
+ class SingleStreamBlock(nn.Module):
575
+ """
576
+ A DiT block with parallel linear layers as described in
577
+ https://arxiv.org/abs/2302.05442 and adapted modulation interface.
578
+ """
579
+
580
+ def __init__(
581
+ self,
582
+ hidden_size: int,
583
+ num_heads: int,
584
+ mlp_ratio: float = 4.0,
585
+ qk_scale: float | None = None,
586
+ ):
587
+ super().__init__()
588
+ self.hidden_dim = hidden_size
589
+ self.num_heads = num_heads
590
+ head_dim = hidden_size // num_heads
591
+ self.scale = qk_scale or head_dim**-0.5
592
+
593
+ self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
594
+ # qkv and mlp_in
595
+ self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim)
596
+ # proj and mlp_out
597
+ self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size)
598
+
599
+ self.norm = QKNorm(head_dim)
600
+
601
+ self.hidden_size = hidden_size
602
+ self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
603
+
604
+ self.mlp_act = nn.GELU(approximate="tanh")
605
+ self.modulation = Modulation(hidden_size, double=False)
606
+
607
+ def forward(self, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor:
608
+ mod, _ = self.modulation(vec)
609
+ x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
610
+ qkv, mlp = torch.split(
611
+ self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1
612
+ )
613
+
614
+ q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.num_heads)
615
+ q, k = self.norm(q, k, v)
616
+
617
+ # compute attention
618
+ attn = attention_after_rope(q, k, v, pe=pe)
619
+ # compute activation in mlp stream, cat again and run second linear layer
620
+ output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
621
+ return scale_add_residual(output, mod.gate, x)
622
+
623
+
624
+ class LastLayer(nn.Module):
625
+ def __init__(self, hidden_size: int, patch_size: int, out_channels: int):
626
+ super().__init__()
627
+ self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
628
+ self.linear = nn.Linear(
629
+ hidden_size, patch_size * patch_size * out_channels, bias=True
630
+ )
631
+ self.adaLN_modulation = nn.Sequential(
632
+ nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True)
633
+ )
634
+
635
+ def forward(self, x: Tensor, vec: Tensor) -> Tensor:
636
+ shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1)
637
+ x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
638
+ x = self.linear(x)
639
+ return x
modules/model_edit.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from dataclasses import dataclass
3
+
4
+ import numpy as np
5
+ import torch
6
+ from torch import Tensor, nn
7
+
8
+ from .connector_edit import Qwen2Connector
9
+ from .layers import DoubleStreamBlock, EmbedND, LastLayer, MLPEmbedder, SingleStreamBlock
10
+
11
+
12
+ @dataclass
13
+ class Step1XParams:
14
+ in_channels: int
15
+ out_channels: int
16
+ vec_in_dim: int
17
+ context_in_dim: int
18
+ hidden_size: int
19
+ mlp_ratio: float
20
+ num_heads: int
21
+ depth: int
22
+ depth_single_blocks: int
23
+ axes_dim: list[int]
24
+ theta: int
25
+ qkv_bias: bool
26
+
27
+
28
+ class Step1XEdit(nn.Module):
29
+ """
30
+ Transformer model for flow matching on sequences.
31
+ """
32
+
33
+ def __init__(self, params: Step1XParams):
34
+ super().__init__()
35
+
36
+ self.params = params
37
+ self.in_channels = params.in_channels
38
+ self.out_channels = params.out_channels
39
+ if params.hidden_size % params.num_heads != 0:
40
+ raise ValueError(
41
+ f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}"
42
+ )
43
+ pe_dim = params.hidden_size // params.num_heads
44
+ if sum(params.axes_dim) != pe_dim:
45
+ raise ValueError(
46
+ f"Got {params.axes_dim} but expected positional dim {pe_dim}"
47
+ )
48
+ self.hidden_size = params.hidden_size
49
+ self.num_heads = params.num_heads
50
+ self.pe_embedder = EmbedND(
51
+ dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim
52
+ )
53
+ self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True)
54
+ self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
55
+ self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size)
56
+ self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size)
57
+
58
+ self.double_blocks = nn.ModuleList(
59
+ [
60
+ DoubleStreamBlock(
61
+ self.hidden_size,
62
+ self.num_heads,
63
+ mlp_ratio=params.mlp_ratio,
64
+ qkv_bias=params.qkv_bias,
65
+ )
66
+ for _ in range(params.depth)
67
+ ]
68
+ )
69
+
70
+ self.single_blocks = nn.ModuleList(
71
+ [
72
+ SingleStreamBlock(
73
+ self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio
74
+ )
75
+ for _ in range(params.depth_single_blocks)
76
+ ]
77
+ )
78
+
79
+ self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels)
80
+
81
+ self.connector = Qwen2Connector()
82
+
83
+ @staticmethod
84
+ def timestep_embedding(
85
+ t: Tensor, dim, max_period=10000, time_factor: float = 1000.0
86
+ ):
87
+ """
88
+ Create sinusoidal timestep embeddings.
89
+ :param t: a 1-D Tensor of N indices, one per batch element.
90
+ These may be fractional.
91
+ :param dim: the dimension of the output.
92
+ :param max_period: controls the minimum frequency of the embeddings.
93
+ :return: an (N, D) Tensor of positional embeddings.
94
+ """
95
+ t = time_factor * t
96
+ half = dim // 2
97
+ freqs = torch.exp(
98
+ -math.log(max_period)
99
+ * torch.arange(start=0, end=half, dtype=torch.float32)
100
+ / half
101
+ ).to(t.device)
102
+
103
+ args = t[:, None].float() * freqs[None]
104
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
105
+ if dim % 2:
106
+ embedding = torch.cat(
107
+ [embedding, torch.zeros_like(embedding[:, :1])], dim=-1
108
+ )
109
+ if torch.is_floating_point(t):
110
+ embedding = embedding.to(t)
111
+ return embedding
112
+
113
+ def forward(
114
+ self,
115
+ img: Tensor,
116
+ img_ids: Tensor,
117
+ txt: Tensor,
118
+ txt_ids: Tensor,
119
+ timesteps: Tensor,
120
+ y: Tensor,
121
+ ) -> Tensor:
122
+ if img.ndim != 3 or txt.ndim != 3:
123
+ raise ValueError("Input img and txt tensors must have 3 dimensions.")
124
+
125
+ img = self.img_in(img)
126
+ vec = self.time_in(self.timestep_embedding(timesteps, 256))
127
+
128
+ vec = vec + self.vector_in(y)
129
+ txt = self.txt_in(txt)
130
+
131
+ ids = torch.cat((txt_ids, img_ids), dim=1)
132
+ pe = self.pe_embedder(ids)
133
+
134
+ for block in self.double_blocks:
135
+ img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
136
+
137
+ img = torch.cat((txt, img), 1)
138
+ for block in self.single_blocks:
139
+ img = block(img, vec=vec, pe=pe)
140
+ img = img[:, txt.shape[1] :, ...]
141
+
142
+ img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
143
+ return img
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ torch==2.3.1
2
+ liger_kernel==0.5.4
3
+ einops==0.8.1
4
+ transformers==4.49.0
5
+ qwen_vl_utils==0.0.10
6
+ safetensors==0.4.5
7
+ pillow==11.1.0
8
+ huggingface_hub
sampling.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from collections.abc import Callable
3
+
4
+ import torch
5
+ from torch import Tensor
6
+
7
+
8
+ def get_noise(num_samples: int, height: int, width: int, device: torch.device, dtype: torch.dtype, seed: int):
9
+ return torch.randn(
10
+ num_samples,
11
+ 16,
12
+ # allow for packing
13
+ 2 * math.ceil(height / 16),
14
+ 2 * math.ceil(width / 16),
15
+ device=device,
16
+ dtype=dtype,
17
+ generator=torch.Generator(device=device).manual_seed(seed),
18
+ )
19
+
20
+
21
+ def time_shift(mu: float, sigma: float, t: Tensor):
22
+ return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
23
+
24
+
25
+ def get_lin_function(x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15) -> Callable[[float], float]:
26
+ m = (y2 - y1) / (x2 - x1)
27
+ b = y1 - m * x1
28
+ return lambda x: m * x + b
29
+
30
+
31
+ def get_schedule(
32
+ num_steps: int,
33
+ image_seq_len: int,
34
+ base_shift: float = 0.5,
35
+ max_shift: float = 1.15,
36
+ shift: bool = True,
37
+ ) -> list[float]:
38
+ # extra step for zero
39
+ timesteps = torch.linspace(1, 0, num_steps + 1)
40
+
41
+ # shifting the schedule to favor high timesteps for higher signal images
42
+ if shift:
43
+ # estimate mu based on linear estimation between two points
44
+ mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len)
45
+ timesteps = time_shift(mu, 1.0, timesteps)
46
+
47
+ return timesteps.tolist()