WYBar commited on
Commit
8fe62ee
·
1 Parent(s): 7dfddd5

finish with token

Browse files
README.md CHANGED
@@ -1,8 +1,8 @@
1
  ---
2
  title: ART V1.0
3
- emoji:
4
  colorFrom: gray
5
- colorTo: red
6
  sdk: gradio
7
  sdk_version: 5.20.0
8
  app_file: app.py
 
1
  ---
2
  title: ART V1.0
3
+ emoji: 📊
4
  colorFrom: gray
5
+ colorTo: indigo
6
  sdk: gradio
7
  sdk_version: 5.20.0
8
  app_file: app.py
app.py ADDED
@@ -0,0 +1,684 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import spaces
3
+
4
+ import ast
5
+ import numpy as np
6
+ from functools import partial
7
+
8
+ import torch
9
+ import torch.utils.checkpoint
10
+
11
+ from PIL import Image
12
+ import xml.etree.cElementTree as ET
13
+ from io import BytesIO
14
+ import base64
15
+ import json
16
+
17
+ import gradio as gr
18
+ from functools import partial
19
+ import requests
20
+ import base64
21
+ import os
22
+ import time
23
+ import re
24
+
25
+ from transformers import (
26
+ AutoTokenizer,
27
+ set_seed
28
+ )
29
+ from typing import List
30
+
31
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
32
+ from transformers.generation.stopping_criteria import StoppingCriteria, StoppingCriteriaList, \
33
+ STOPPING_CRITERIA_INPUTS_DOCSTRING, add_start_docstrings
34
+ class StopAtSpecificTokenCriteria(StoppingCriteria):
35
+ def __init__(self, token_id_list: List[int] = None):
36
+ self.token_id_list = token_id_list
37
+ @add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING)
38
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
39
+ return input_ids[0][-1].detach().cpu().numpy() in self.token_id_list
40
+
41
+ def ensure_space_after_period(input_string):
42
+ # 去除多余的空格
43
+ output_string = re.sub(r'\.\s*', '. ', input_string)
44
+ return output_string
45
+
46
+ def generate_unique_filename():
47
+ # 生成一个基于时间戳和随机数的唯一文件名
48
+ timestamp = int(time.time() * 1000) # 时间戳,毫秒级
49
+ # random_num = random.randint(1000, 9999) # 随机数
50
+ unique_filename = f"{timestamp}"
51
+ return unique_filename
52
+
53
+ def upload_to_github(file_path,
54
+ repo='WYBar/gradiodemo_svg',
55
+ branch='main',
56
+ token='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'):
57
+ if not os.path.isfile(file_path):
58
+ print(f"File not found: {file_path}")
59
+ return None
60
+ with open(file_path, 'rb') as file:
61
+ content = file.read()
62
+ encoded_content = base64.b64encode(content).decode('utf-8')
63
+ unique_filename = generate_unique_filename()
64
+ url = f"https://api.github.com/repos/{repo}/contents/{unique_filename}.svg"
65
+ headers = {
66
+ "Authorization": f"token {token}"
67
+ }
68
+ response = requests.get(url, headers=headers)
69
+
70
+ sha = None
71
+ if response.status_code == 200:
72
+ sha = response.json()['sha']
73
+ elif response.status_code == 404:
74
+ # 文件不存在,不需要SHA
75
+ pass
76
+ else:
77
+ print(f"Failed to get file status: {response.status_code}")
78
+ # print(response.text)
79
+ return None
80
+
81
+ headers = {
82
+ "Authorization": f"token {token}",
83
+ "Content-Type": "application/json"
84
+ }
85
+ data = {
86
+ "message": "upload svg file",
87
+ "content": encoded_content,
88
+ "branch": branch
89
+ }
90
+
91
+ if sha:
92
+ # 文件存在,更新文件
93
+ # print('sha exists, update the old one')
94
+ data["sha"] = sha
95
+ response = requests.put(url, headers=headers, json=data)
96
+ else:
97
+ # 文件不存在,创建新文件
98
+ print("sha not exist, need to create a new one")
99
+ response = requests.put(url, headers=headers, json=data)
100
+
101
+ # print(response.status_code)
102
+ # print(response.text)
103
+ if response.status_code in [200, 201]:
104
+ # print(response.json()['content']['download_url'])
105
+ return response.json()['content']['download_url'], unique_filename
106
+ else:
107
+ print("None")
108
+ return None
109
+
110
+ def calculate_iou(box1, box2):
111
+ # 计算两个框的交集
112
+ x1 = max(box1[0], box2[0])
113
+ y1 = max(box1[1], box2[1])
114
+ x2 = min(box1[2], box2[2])
115
+ y2 = min(box1[3], box2[3])
116
+
117
+ intersection_area = max(0, x2 - x1) * max(0, y2 - y1)
118
+
119
+ # 计算两个框的并集
120
+ box1_area = (box1[2] - box1[0]) * (box1[3] - box1[1])
121
+ box2_area = (box2[2] - box2[0]) * (box2[3] - box2[1])
122
+
123
+ union_area = box1_area + box2_area - intersection_area
124
+
125
+ # 计算IOU
126
+ iou = intersection_area / union_area
127
+ return iou
128
+
129
+ def adjust_coordinates(box):
130
+ size = 32
131
+ (x1, y1, x2, y2) = box
132
+ if x1 % size != 0:
133
+ x1 = (x1 // size) * size
134
+ if x2 % size != 0:
135
+ x2 = (x2 // size + 1) * size
136
+
137
+ if y1 % size != 0:
138
+ y1 = (y1 // size) * size
139
+ if y2 % size != 0:
140
+ y2 = (y2 // size + 1) * size
141
+ return (x1, y1, x2, y2)
142
+
143
+ def adjust_validation_box(validation_box):
144
+ return [adjust_coordinates(box) for box in validation_box]
145
+
146
+ def get_list_layer_box(list_png_images):
147
+ list_layer_box = []
148
+ for img in list_png_images:
149
+ img_np = np.array(img)
150
+ alpha_channel = img_np[:, :, -1]
151
+
152
+ # Step 1: Find the non-zero indices
153
+ rows, cols = np.nonzero(alpha_channel)
154
+
155
+ if (len(rows) == 0) or (len(cols) == 0):
156
+ # If there are no non-zero indices, we can skip this layer
157
+ list_layer_box.append((0, 0, 0, 0))
158
+ continue
159
+
160
+ # Step 2: Get the minimum and maximum indices for rows and columns
161
+ min_row, max_row = rows.min().item(), rows.max().item()
162
+ min_col, max_col = cols.min().item(), cols.max().item()
163
+
164
+ # Step 3: Quantize the minimum values down to the nearest multiple of 8
165
+ quantized_min_row = (min_row // 8) * 8
166
+ quantized_min_col = (min_col // 8) * 8
167
+
168
+ # Step 4: Quantize the maximum values up to the nearest multiple of 8 outside of the max
169
+ quantized_max_row = ((max_row // 8) + 1) * 8
170
+ quantized_max_col = ((max_col // 8) + 1) * 8
171
+ list_layer_box.append(
172
+ (quantized_min_col, quantized_min_row, quantized_max_col, quantized_max_row)
173
+ )
174
+ return list_layer_box
175
+
176
+ def pngs_to_svg(list_png_images):
177
+ list_layer_box = get_list_layer_box(list_png_images)
178
+ assert(len(list_png_images) == len(list_layer_box))
179
+ width, height = list_png_images[0].width, list_png_images[0].height
180
+ img_svg = ET.Element(
181
+ 'svg',
182
+ {
183
+ "width": str(width),
184
+ "height": str(height),
185
+ "xmlns": "http://www.w3.org/2000/svg",
186
+ "xmlns:svg": "http://www.w3.org/2000/svg",
187
+ "xmlns:xlink":"http://www.w3.org/1999/xlink"
188
+ }
189
+ )
190
+ for img, box in zip(list_png_images, list_layer_box):
191
+ x, y, w, h = box[0], box[1], box[2]-box[0], box[3]-box[1]
192
+ if (w == 0 or h == 0):
193
+ continue
194
+ img = img.crop((x, y, x+w, y+h))
195
+ buffer = BytesIO()
196
+ img.save(buffer, format='PNG')
197
+ img_str = base64.b64encode(buffer.getvalue())
198
+ ET.SubElement(
199
+ img_svg,
200
+ "image",
201
+ {
202
+ "x": str(x),
203
+ "y": str(y),
204
+ "width": str(w),
205
+ "height": str(h),
206
+ "xlink:href": "data:image/png;base64,"+img_str.decode('utf-8')
207
+ }
208
+ )
209
+ return ET.tostring(img_svg, encoding='utf-8').decode('utf-8')
210
+
211
+ def calculate_iou(box1, box2):
212
+ # 计算两个框的交集
213
+ x1 = max(box1[0], box2[0])
214
+ y1 = max(box1[1], box2[1])
215
+ x2 = min(box1[2], box2[2])
216
+ y2 = min(box1[3], box2[3])
217
+
218
+ intersection_area = max(0, x2 - x1) * max(0, y2 - y1)
219
+
220
+ # 计算两个框的并集
221
+ box1_area = (box1[2] - box1[0]) * (box1[3] - box1[1])
222
+ box2_area = (box2[2] - box2[0]) * (box2[3] - box2[1])
223
+
224
+ union_area = box1_area + box2_area - intersection_area
225
+
226
+ # 计算IOU
227
+ iou = intersection_area / union_area
228
+ return iou
229
+
230
+ # @spaces.GPU(enable_queue=True, duration=60)
231
+ def buildmodel(**kwargs):
232
+ from modeling_crello import CrelloModel, CrelloModelConfig
233
+ from quantizer import get_quantizer
234
+ # seed / input model / resume
235
+ resume = kwargs.get('resume', None)
236
+ seed = kwargs.get('seed', None)
237
+ input_model = kwargs.get('input_model', None)
238
+ quantizer_version = kwargs.get('quantizer_version', 'v4')
239
+ device = "cuda"
240
+
241
+ set_seed(seed)
242
+ # old_tokenizer = AutoTokenizer.from_pretrained(input_model, trust_remote_code=True)
243
+ old_tokenizer = AutoTokenizer.from_pretrained(
244
+ "WYBar/LLM_For_Layout_Planning", # 仓库路径
245
+ subfolder="Meta-Llama-3-8B", # 子目录对应模型文件夹
246
+ trust_remote_code=True,
247
+ # cache_dir="/openseg_blob/v-yanbin/GradioDemo/cache_dir",
248
+ )
249
+ old_vocab_size = len(old_tokenizer)
250
+ # tokenizer = AutoTokenizer.from_pretrained(resume, trust_remote_code=True)
251
+ tokenizer = AutoTokenizer.from_pretrained(
252
+ "WYBar/LLM_For_Layout_Planning",
253
+ subfolder="checkpoint-26000", # 检查点所在子目录
254
+ trust_remote_code=True,
255
+ # cache_dir="/openseg_blob/v-yanbin/GradioDemo/cache_dir",
256
+ )
257
+
258
+ quantizer = get_quantizer(
259
+ quantizer_version,
260
+ update_vocab = False,
261
+ decimal_quantize_types = kwargs.get('decimal_quantize_types'),
262
+ mask_values = kwargs['mask_values'],
263
+ width = kwargs['width'],
264
+ height = kwargs['height'],
265
+ simplify_json = False,
266
+ num_mask_tokens = 0,
267
+ mask_type = kwargs.get('mask_type'),
268
+ )
269
+ quantizer.setup_tokenizer(tokenizer)
270
+
271
+ model_args = CrelloModelConfig(
272
+ old_vocab_size = old_vocab_size,
273
+ vocab_size=len(tokenizer),
274
+ pad_token_id=tokenizer.pad_token_id,
275
+ ignore_ids=tokenizer.convert_tokens_to_ids(quantizer.ignore_tokens),
276
+ )
277
+ model_args.freeze_lm = True
278
+ model_args.opt_version = "WYBar/LLM_For_Layout_Planning"
279
+ model_args.use_lora = False
280
+ model_args.load_in_4bit = kwargs.get('load_in_4bit', False)
281
+ # model = CrelloModel.from_pretrained(
282
+ # resume,
283
+ # config=model_args
284
+ # ).to(device)
285
+ # model = CrelloModel.from_pretrained(
286
+ # "WYBar/LLM_For_Layout_Planning",
287
+ # subfolder="checkpoint-26000", # 加载检查点目录
288
+ # config=model_args,
289
+ # # cache_dir="/openseg_blob/v-yanbin/GradioDemo/cache_dir",
290
+ # )
291
+ model = CrelloModel(config=model_args)
292
+ print("before .to(device)")
293
+ model = model.to(device)
294
+ print("after .to(device)")
295
+ model = model.bfloat16()
296
+ model.eval()
297
+
298
+ tokenizer.add_special_tokens({"mask_token": "<mask>"})
299
+ quantizer.additional_special_tokens.add("<mask>")
300
+ added_special_tokens_list = ["<layout>", "<position>", "<wholecaption>"]
301
+ tokenizer.add_special_tokens({"additional_special_tokens": added_special_tokens_list}, replace_additional_special_tokens=False)
302
+ for token in added_special_tokens_list:
303
+ quantizer.additional_special_tokens.add(token)
304
+
305
+ return model, quantizer, tokenizer
306
+
307
+ def construction_layout():
308
+ params_dict = {
309
+ # 需要修改
310
+ "input_model": "WYBar/LLM_For_Layout_Planning",
311
+ "resume": "WYBar/LLM_For_Layout_Planning",
312
+
313
+ "seed": 0,
314
+ "mask_values": False,
315
+ "quantizer_version": 'v4',
316
+ "mask_type": 'cm3',
317
+ "decimal_quantize_types": [],
318
+ "num_mask_tokens": 0,
319
+ "width": 512,
320
+ "height": 512,
321
+ "device": 0,
322
+ }
323
+ device = "cuda"
324
+ # Init model
325
+ model, quantizer, tokenizer = buildmodel(**params_dict)
326
+
327
+ print('resize token embeddings to match the tokenizer', 129423)
328
+ model.lm.resize_token_embeddings(129423)
329
+ model.input_embeddings = model.lm.get_input_embeddings()
330
+ print('after token embeddings to match the tokenizer', 129423)
331
+ return model, quantizer, tokenizer, params_dict["width"], params_dict["height"], device
332
+
333
+ @torch.no_grad()
334
+ @spaces.GPU(enable_queue=True, duration=60)
335
+ def evaluate_v1(inputs, model, quantizer, tokenizer, width, height, device, do_sample=False, temperature=1.0, top_p=1.0, top_k=50):
336
+ json_example = inputs
337
+ input_intension = '{"wholecaption":"' + json_example["wholecaption"] + '","layout":[{"layer":'
338
+ inputs = tokenizer(
339
+ input_intension, return_tensors="pt"
340
+ ).to(device)
341
+
342
+ stopping_criteria = StoppingCriteriaList()
343
+ stopping_criteria.append(StopAtSpecificTokenCriteria(token_id_list=[128000]))
344
+
345
+ outputs = model.lm.generate(**inputs, use_cache=True, max_length=8000, stopping_criteria=stopping_criteria, do_sample=do_sample, temperature=temperature, top_p=top_p, top_k=top_k)
346
+ inputs_length = inputs['input_ids'].shape[1]
347
+ outputs = outputs[:, inputs_length:]
348
+
349
+ outputs_word = tokenizer.batch_decode(outputs)[0]
350
+ split_word = outputs_word.split('}]}')[0]+"}]}"
351
+ split_word = '{"wholecaption":"' + json_example["wholecaption"].replace('\n', '\\n').replace('"', '\\"') + '","layout":[{"layer":' + split_word
352
+ map_dict = quantizer.construct_map_dict()
353
+
354
+ for key ,value in map_dict.items():
355
+ split_word = split_word.replace(key, value)
356
+ try:
357
+ pred_json_example = json.loads(split_word)
358
+ for layer in pred_json_example["layout"]:
359
+ layer['x'] = round(int(width)*layer['x'])
360
+ layer['y'] = round(int(height)*layer['y'])
361
+ layer['width'] = round(int(width)*layer['width'])
362
+ layer['height'] = round(int(height)*layer['height'])
363
+ except Exception as e:
364
+ print(e)
365
+ pred_json_example = None
366
+ return pred_json_example
367
+
368
+ def inference(generate_method, intention, model, quantizer, tokenizer, width, height, device, do_sample=True, temperature=1.0, top_p=1.0, top_k=50):
369
+ def FormulateInput(intension: str):
370
+ resdict = {}
371
+ resdict["wholecaption"] = intension
372
+ resdict["layout"] = []
373
+ return resdict
374
+
375
+ rawdata = FormulateInput(intention)
376
+
377
+ if generate_method == 'v1':
378
+ max_try_time = 5
379
+ preddata = None
380
+ while preddata is None and max_try_time > 0:
381
+ preddata = evaluate_v1(rawdata, model, quantizer, tokenizer, width, height, device, do_sample=do_sample, temperature=temperature, top_p=top_p, top_k=top_k)
382
+ max_try_time -= 1
383
+ else:
384
+ print("Please input correct generate method")
385
+ preddata = None
386
+
387
+ return preddata
388
+
389
+ # @spaces.GPU(enable_queue=True, duration=60)
390
+ def construction():
391
+ from custom_model_mmdit import CustomFluxTransformer2DModel
392
+ from custom_model_transp_vae import AutoencoderKLTransformerTraining as CustomVAE
393
+ from custom_pipeline import CustomFluxPipelineCfg
394
+
395
+ transformer = CustomFluxTransformer2DModel.from_pretrained(
396
+ "WYBar/ART_test_weights",
397
+ subfolder="fused_transformer",
398
+ torch_dtype=torch.bfloat16,
399
+ # cache_dir="/openseg_blob/v-yanbin/GradioDemo/cache_dir"
400
+ )
401
+
402
+ transp_vae = CustomVAE.from_pretrained(
403
+ "WYBar/ART_test_weights",
404
+ subfolder="custom_vae",
405
+ torch_dtype=torch.float32,
406
+ # cache_dir="/openseg_blob/v-yanbin/GradioDemo/cache_dir"
407
+ )
408
+
409
+ token = os.environ.get("HF_TOKEN")
410
+ pipeline = CustomFluxPipelineCfg.from_pretrained(
411
+ "black-forest-labs/FLUX.1-dev",
412
+ transformer=transformer,
413
+ torch_dtype=torch.bfloat16,
414
+ token=token,
415
+ # cache_dir="/openseg_blob/v-yanbin/GradioDemo/cache_dir"
416
+ ).to("cuda")
417
+ pipeline.enable_model_cpu_offload(gpu_id=0) # Save GPU memory
418
+
419
+ return pipeline, transp_vae
420
+
421
+ @spaces.GPU(enable_queue=True, duration=60)
422
+ def test_one_sample(validation_box, validation_prompt, true_gs, inference_steps, pipeline, generator, transp_vae):
423
+ print(validation_box)
424
+ output, rgba_output, _, _ = pipeline(
425
+ prompt=validation_prompt,
426
+ validation_box=validation_box,
427
+ generator=generator,
428
+ height=512,
429
+ width=512,
430
+ num_layers=len(validation_box),
431
+ guidance_scale=4.0,
432
+ num_inference_steps=inference_steps,
433
+ transparent_decoder=transp_vae,
434
+ true_gs=true_gs
435
+ )
436
+ images = output.images # list of PIL, len=layers
437
+ rgba_images = [Image.fromarray(arr, 'RGBA') for arr in rgba_output]
438
+
439
+ output_gradio = []
440
+ merged_pil = images[1].convert('RGBA')
441
+ for frame_idx, frame_pil in enumerate(rgba_images):
442
+ if frame_idx < 2:
443
+ frame_pil = images[frame_idx].convert('RGBA') # merged and background
444
+ else:
445
+ merged_pil = Image.alpha_composite(merged_pil, frame_pil)
446
+ output_gradio.append(frame_pil)
447
+
448
+ return output_gradio
449
+
450
+ def svg_test_one_sample(validation_prompt, validation_box_str, seed, true_gs, inference_steps, pipeline, transp_vae):
451
+ generator = torch.Generator().manual_seed(seed)
452
+ try:
453
+ validation_box = ast.literal_eval(validation_box_str)
454
+ except Exception as e:
455
+ return [f"Error parsing validation_box: {e}"]
456
+ if not isinstance(validation_box, list) or not all(isinstance(t, tuple) and len(t) == 4 for t in validation_box):
457
+ return ["validation_box must be a list of tuples, each of length 4."]
458
+
459
+ validation_box = adjust_validation_box(validation_box)
460
+
461
+ result_images = test_one_sample(validation_box, validation_prompt, true_gs, inference_steps, pipeline, generator, transp_vae)
462
+
463
+ svg_img = pngs_to_svg(result_images[1:])
464
+
465
+ svg_file_path = './image.svg'
466
+ os.makedirs(os.path.dirname(svg_file_path), exist_ok=True)
467
+ with open(svg_file_path, 'w', encoding='utf-8') as f:
468
+ f.write(svg_img)
469
+
470
+ return result_images, svg_file_path
471
+
472
+ def main():
473
+ model, quantizer, tokenizer, width, height, device = construction_layout()
474
+
475
+ inference_partial = partial(
476
+ inference,
477
+ model=model,
478
+ quantizer=quantizer,
479
+ tokenizer=tokenizer,
480
+ width=width,
481
+ height=height,
482
+ device=device
483
+ )
484
+
485
+ def process_preddate(intention, temperature, top_p, generate_method='v1'):
486
+ intention = intention.replace('\n', '').replace('\r', '').replace('\\', '')
487
+ intention = ensure_space_after_period(intention)
488
+ if temperature == 0.0:
489
+ # print("looking for greedy decoding strategies, set `do_sample=False`.")
490
+ preddata = inference_partial(generate_method, intention, do_sample=False)
491
+ else:
492
+ preddata = inference_partial(generate_method, intention, temperature=temperature, top_p=top_p)
493
+ # wholecaption = preddata["wholecaption"]
494
+ layouts = preddata["layout"]
495
+ list_box = []
496
+ for i, layout in enumerate(layouts):
497
+ x, y = layout["x"], layout["y"]
498
+ width, height = layout["width"], layout["height"]
499
+ if i == 0:
500
+ list_box.append((0, 0, width, height))
501
+ list_box.append((0, 0, width, height))
502
+ else:
503
+ left = x - width // 2
504
+ top = y - height // 2
505
+ right = x + width // 2
506
+ bottom = y + height // 2
507
+ list_box.append((left, top, right, bottom))
508
+
509
+ # print(list_box)
510
+ filtered_boxes = list_box[:2]
511
+ for i in range(2, len(list_box)):
512
+ keep = True
513
+ for j in range(1, len(filtered_boxes)):
514
+ iou = calculate_iou(list_box[i], filtered_boxes[j])
515
+ if iou > 0.65:
516
+ print(list_box[i], filtered_boxes[j])
517
+ keep = False
518
+ break
519
+ if keep:
520
+ filtered_boxes.append(list_box[i])
521
+
522
+ return str(filtered_boxes), intention, str(filtered_boxes)
523
+
524
+ # def process_preddate(intention, generate_method='v1'):
525
+ # list_box = [(0, 0, 512, 512), (0, 0, 512, 512), (136, 184, 512, 512), (144, 0, 512, 512), (0, 0, 328, 136), (160, 112, 512, 360), (168, 112, 512, 360), (40, 232, 112, 296), (32, 88, 248, 176), (48, 424, 144, 448), (48, 464, 144, 488), (240, 464, 352, 488), (384, 464, 488, 488), (48, 480, 144, 504), (240, 480, 360, 504), (456, 0, 512, 56), (0, 0, 56, 40), (440, 0, 512, 40), (0, 24, 48, 88), (48, 168, 168, 240)]
526
+ # wholecaption = "Design an engaging and vibrant recruitment advertisement for our company. The image should feature three animated characters in a modern cityscape, depicting a dynamic and collaborative work environment. Incorporate a light bulb graphic with a question mark, symbolizing innovation, creativity, and problem-solving. Use bold text to announce \"WE ARE RECRUITING\" and provide the company's social media handle \"@reallygreatsite\" and a contact phone number \"+123-456-7890\" for interested individuals. The overall design should be playful and youthful, attracting potential recruits who are innovative and eager to contribute to a lively team."
527
+ # json_file = "/home/wyb/openseg_blob/v-yanbin/GradioDemo/LLM-For-Layout-Planning/inference_test.json"
528
+ # return wholecaption, str(list_box), json_file
529
+
530
+ pipeline, transp_vae = construction()
531
+
532
+ gradio_test_one_sample_partial = partial(
533
+ svg_test_one_sample,
534
+ pipeline=pipeline,
535
+ transp_vae=transp_vae,
536
+ )
537
+
538
+ def process_svg(text_input, tuple_input, seed, true_gs, inference_steps):
539
+ result_images = []
540
+ result_images, svg_file_path = gradio_test_one_sample_partial(text_input, tuple_input, seed, true_gs, inference_steps)
541
+
542
+ url, unique_filename = upload_to_github(file_path=svg_file_path)
543
+ unique_filename = f'{unique_filename}'
544
+
545
+ if url != None:
546
+ print(f"File uploaded to: {url}")
547
+ svg_editor = f"""
548
+ <iframe src="https://svgedit.netlify.app/editor/index.html?\
549
+ storagePrompt=false&url={url}" \
550
+ width="100%", height="800px"></iframe>
551
+ """
552
+ else:
553
+ print('upload_to_github FAILED!')
554
+ svg_editor = f"""
555
+ <iframe src="https://svgedit.netlify.app/editor/index.html" \
556
+ width="100%", height="800px"></iframe>
557
+ """
558
+
559
+ return result_images, svg_file_path, svg_editor
560
+
561
+ def one_click_generate(intention_input, temperature, top_p, seed, true_gs, inference_steps):
562
+ # 首先调用process_preddate
563
+ list_box_output, intention_input, list_box_output = process_preddate(intention_input, temperature, top_p)
564
+
565
+ # 然后将process_preddate的输出作为process_svg的输入
566
+ result_images, svg_file, svg_editor = process_svg(intention_input, list_box_output, seed, true_gs, inference_steps)
567
+
568
+ # 返回两个函数的输出
569
+ return list_box_output, result_images, svg_file, svg_editor, intention_input, list_box_output
570
+
571
+ def clear_inputs1():
572
+ return "", ""
573
+
574
+ def clear_inputs2():
575
+ return "", ""
576
+
577
+ def transfer_inputs(intention, list_box):
578
+ return intention, list_box
579
+
580
+ theme = gr.themes.Soft(
581
+ radius_size="lg",
582
+ ).set(
583
+ block_background_fill='*primary_50',
584
+ block_border_color='*primary_200',
585
+ block_border_width='1px',
586
+ block_border_width_dark='100px',
587
+ block_info_text_color='*primary_950',
588
+ block_label_border_color='*primary_200',
589
+ block_radius='*radius_lg'
590
+ )
591
+
592
+ with gr.Blocks(theme=theme) as demo:
593
+ gr.HTML("<h1 style='text-align: center;'>ART: Anonymous Region Transformer for Variable Multi-Layer Transparent Image Generation</h1>")
594
+ gr.HTML("<h2>Anonymous Region Layout Planner</h2>")
595
+
596
+ with gr.Row():
597
+ with gr.Column():
598
+ intention_input = gr.Textbox(lines=15, placeholder="Enter intention", label="Prompt")
599
+ with gr.Row():
600
+ temperature_input=gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label="Temperature", value=0.0)
601
+ top_p_input=gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label="Top P", value=0.0)
602
+ with gr.Row():
603
+ clear_btn1 = gr.Button("Clear")
604
+ model_btn1 = gr.Button("Commit", variant='primary')
605
+ transfer_btn1 = gr.Button("Export to below")
606
+
607
+ one_click_btn = gr.Button("One Click Generate ALL", variant='primary')
608
+
609
+ with gr.Column():
610
+ list_box_output = gr.Textbox(lines=10, placeholder="Validation Box", label="Validation Box")
611
+
612
+ examples = gr.Examples(
613
+ examples=[
614
+ ['The image is a graphic design with a celebratory theme. At the top, there is a banner with the text \"Happy Anniversary\" in a bold, sans-serif font. Below this banner, there is a circular frame containing a photograph of a couple. The man has short, dark hair and is wearing a light-colored sweater, while the woman has long blonde hair and is also wearing a light-colored sweater. They are both smiling and appear to be embracing each other.Surrounding the circular frame are decorative elements such as pink flowers and green leaves, which add a festive touch to the design. Below the circular frame, there is a text that reads "Isabel & Morgan" in a cursive, elegant font, suggesting that the couple\'s names are Isabel and Morgan.At the bottom of the image, there is a banner with a message that says "Happy Anniversary! Cheers to another year of love, laughter, and cherished memories together.\" This text is in a smaller, sans-serif font and is placed against a solid background, providing a clear message of celebration and well-wishes for the couple.The overall style of the image is warm and celebratory, with a color scheme that includes shades of pink, green, and white, which contribute to a joyful and romantic atmosphere.'],
615
+ ['The image is a digital illustration with a light blue background. At the top, there is a logo consisting of a snake wrapped around a staff, which is a common symbol in healthcare. Below the logo, the text "International Nurses Day" is prominently displayed in white, with the date "12 May 20xx" in smaller font size.The central part of the image features two stylized characters. On the left, there is a female character with dark hair, wearing a white nurse\'s uniform with a cap. She is holding a clipboard and appears to be speaking or gesturing, as indicated by a speech bubble with the word "OK" in it. On the right, there is a male character with light brown hair, wearing a light blue shirt with a white collar and a white apron. He is holding a stethoscope to his ear, suggesting he is a doctor or a healthcare professional.The characters are depicted in a friendly and approachable manner, with smiles on their faces. Around them, there are small blue plus signs, which are often associated with healthcare and medical services. The overall style of the image is clean, modern, and appears to be designed to celebrate International Nurses Day.'],
616
+ ['The image features a graphic design with a festive theme. At the top, there is a decorative border with a wavy pattern. Below this border, the text "WINTER SEASON SPECIAL COOKIES" is prominently displayed in a bold, sans-serif font. The text is black with a slight shadow effect, giving it a three-dimensional appearance.In the center of the image, there are three illustrated gingerbread cookies. Each cookie has a smiling face with eyes, a nose, and a mouth, and they are colored in a warm, brown hue. The cookies are arranged in a staggered formation, with the middle cookie slightly higher than the others, creating a sense of depth.At the bottom of the image, there is a call to action that reads "ORDER.NOW" in a large, bold, sans-serif font. The text is colored in a darker shade of brown, contrasting with the lighter background. The overall style of the image suggests it is an advertisement or promotional graphic for a winter-themed cookie special.']
617
+ ],
618
+ inputs=[intention_input]
619
+ )
620
+
621
+ gr.HTML("<h2>Anonymous Region Transformer</h2>")
622
+ with gr.Row():
623
+ with gr.Column():
624
+ text_input = gr.Textbox(lines=10, placeholder="Enter prompt text", label="Prompt")
625
+ tuple_input = gr.Textbox(lines=5, placeholder="Enter list of tuples, e.g., [(1, 2, 3, 4), (5, 6, 7, 8)]", label="Validation Box")
626
+ with gr.Row():
627
+ true_gs_input=gr.Slider(minimum=3.0, maximum=5.0, step=0.1, label="true_gs", value=3.5)
628
+ inference_steps_input=gr.Slider(minimum=5, maximum=50, step=1, label="inference_steps", value=28)
629
+ with gr.Row():
630
+ seed_input = gr.Number(label="Seed", value=42)
631
+ with gr.Row():
632
+ transfer_btn2 = gr.Button("Import from above")
633
+ with gr.Row():
634
+ clear_btn2 = gr.Button("Clear")
635
+ model_btn2 = gr.Button("Commit", variant='primary')
636
+
637
+ with gr.Column():
638
+ result_images = gr.Gallery(label="Result Images", columns=5, height='auto')
639
+
640
+ gr.HTML("<h1>SVG Image</h1>")
641
+ svg_file = gr.File(label="Download SVG Image")
642
+ svg_editor = gr.HTML(label="Editable SVG Editor")
643
+
644
+ model_btn1.click(
645
+ fn=process_preddate,
646
+ inputs=[intention_input, temperature_input, top_p_input],
647
+ outputs=[list_box_output, text_input, tuple_input],
648
+ api_name="process_preddate"
649
+ )
650
+ clear_btn1.click(
651
+ fn=clear_inputs1,
652
+ inputs=[],
653
+ outputs=[intention_input, list_box_output]
654
+ )
655
+ model_btn2.click(
656
+ fn=process_svg,
657
+ inputs=[text_input, tuple_input, seed_input, true_gs_input, inference_steps_input],
658
+ outputs=[result_images, svg_file, svg_editor],
659
+ api_name="process_svg"
660
+ )
661
+ clear_btn2.click(
662
+ fn=clear_inputs2,
663
+ inputs=[],
664
+ outputs=[text_input, tuple_input]
665
+ )
666
+ transfer_btn1.click(
667
+ fn=transfer_inputs,
668
+ inputs=[intention_input, list_box_output],
669
+ outputs=[text_input, tuple_input]
670
+ )
671
+ transfer_btn2.click(
672
+ fn=transfer_inputs,
673
+ inputs=[intention_input, list_box_output],
674
+ outputs=[text_input, tuple_input]
675
+ )
676
+ one_click_btn.click(
677
+ fn=one_click_generate,
678
+ inputs=[intention_input, temperature_input, top_p_input, seed_input, true_gs_input, inference_steps_input],
679
+ outputs=[list_box_output, result_images, svg_file, svg_editor, text_input, tuple_input]
680
+ )
681
+ demo.launch()
682
+
683
+ if __name__ == "__main__":
684
+ main()
app_test.py ADDED
@@ -0,0 +1,684 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ # import spaces
3
+
4
+ import ast
5
+ import numpy as np
6
+ from functools import partial
7
+
8
+ import torch
9
+ import torch.utils.checkpoint
10
+
11
+ from PIL import Image
12
+ import xml.etree.cElementTree as ET
13
+ from io import BytesIO
14
+ import base64
15
+ import json
16
+
17
+ import gradio as gr
18
+ from functools import partial
19
+ import requests
20
+ import base64
21
+ import os
22
+ import time
23
+ import re
24
+
25
+ from transformers import (
26
+ AutoTokenizer,
27
+ set_seed
28
+ )
29
+ from typing import List
30
+
31
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
32
+ from transformers.generation.stopping_criteria import StoppingCriteria, StoppingCriteriaList, \
33
+ STOPPING_CRITERIA_INPUTS_DOCSTRING, add_start_docstrings
34
+ class StopAtSpecificTokenCriteria(StoppingCriteria):
35
+ def __init__(self, token_id_list: List[int] = None):
36
+ self.token_id_list = token_id_list
37
+ @add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING)
38
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
39
+ return input_ids[0][-1].detach().cpu().numpy() in self.token_id_list
40
+
41
+ def ensure_space_after_period(input_string):
42
+ # 去除多余的空格
43
+ output_string = re.sub(r'\.\s*', '. ', input_string)
44
+ return output_string
45
+
46
+ def generate_unique_filename():
47
+ # 生成一个基于时间戳和随机数的唯一文件名
48
+ timestamp = int(time.time() * 1000) # 时间戳,毫秒级
49
+ # random_num = random.randint(1000, 9999) # 随机数
50
+ unique_filename = f"{timestamp}"
51
+ return unique_filename
52
+
53
+ def upload_to_github(file_path,
54
+ repo='WYBar/gradiodemo_svg',
55
+ branch='main',
56
+ token='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'):
57
+ if not os.path.isfile(file_path):
58
+ print(f"File not found: {file_path}")
59
+ return None
60
+ with open(file_path, 'rb') as file:
61
+ content = file.read()
62
+ encoded_content = base64.b64encode(content).decode('utf-8')
63
+ unique_filename = generate_unique_filename()
64
+ url = f"https://api.github.com/repos/{repo}/contents/{unique_filename}.svg"
65
+ headers = {
66
+ "Authorization": f"token {token}"
67
+ }
68
+ response = requests.get(url, headers=headers)
69
+
70
+ sha = None
71
+ if response.status_code == 200:
72
+ sha = response.json()['sha']
73
+ elif response.status_code == 404:
74
+ # 文件不存在,不需要SHA
75
+ pass
76
+ else:
77
+ print(f"Failed to get file status: {response.status_code}")
78
+ # print(response.text)
79
+ return None
80
+
81
+ headers = {
82
+ "Authorization": f"token {token}",
83
+ "Content-Type": "application/json"
84
+ }
85
+ data = {
86
+ "message": "upload svg file",
87
+ "content": encoded_content,
88
+ "branch": branch
89
+ }
90
+
91
+ if sha:
92
+ # 文件存在,更新文件
93
+ # print('sha exists, update the old one')
94
+ data["sha"] = sha
95
+ response = requests.put(url, headers=headers, json=data)
96
+ else:
97
+ # 文件不存在,创建新文件
98
+ print("sha not exist, need to create a new one")
99
+ response = requests.put(url, headers=headers, json=data)
100
+
101
+ # print(response.status_code)
102
+ # print(response.text)
103
+ if response.status_code in [200, 201]:
104
+ # print(response.json()['content']['download_url'])
105
+ return response.json()['content']['download_url'], unique_filename
106
+ else:
107
+ print("None")
108
+ return None
109
+
110
+ def calculate_iou(box1, box2):
111
+ # 计算两个框的交集
112
+ x1 = max(box1[0], box2[0])
113
+ y1 = max(box1[1], box2[1])
114
+ x2 = min(box1[2], box2[2])
115
+ y2 = min(box1[3], box2[3])
116
+
117
+ intersection_area = max(0, x2 - x1) * max(0, y2 - y1)
118
+
119
+ # 计算两个框的并集
120
+ box1_area = (box1[2] - box1[0]) * (box1[3] - box1[1])
121
+ box2_area = (box2[2] - box2[0]) * (box2[3] - box2[1])
122
+
123
+ union_area = box1_area + box2_area - intersection_area
124
+
125
+ # 计算IOU
126
+ iou = intersection_area / union_area
127
+ return iou
128
+
129
+ def adjust_coordinates(box):
130
+ size = 32
131
+ (x1, y1, x2, y2) = box
132
+ if x1 % size != 0:
133
+ x1 = (x1 // size) * size
134
+ if x2 % size != 0:
135
+ x2 = (x2 // size + 1) * size
136
+
137
+ if y1 % size != 0:
138
+ y1 = (y1 // size) * size
139
+ if y2 % size != 0:
140
+ y2 = (y2 // size + 1) * size
141
+ return (x1, y1, x2, y2)
142
+
143
+ def adjust_validation_box(validation_box):
144
+ return [adjust_coordinates(box) for box in validation_box]
145
+
146
+ def get_list_layer_box(list_png_images):
147
+ list_layer_box = []
148
+ for img in list_png_images:
149
+ img_np = np.array(img)
150
+ alpha_channel = img_np[:, :, -1]
151
+
152
+ # Step 1: Find the non-zero indices
153
+ rows, cols = np.nonzero(alpha_channel)
154
+
155
+ if (len(rows) == 0) or (len(cols) == 0):
156
+ # If there are no non-zero indices, we can skip this layer
157
+ list_layer_box.append((0, 0, 0, 0))
158
+ continue
159
+
160
+ # Step 2: Get the minimum and maximum indices for rows and columns
161
+ min_row, max_row = rows.min().item(), rows.max().item()
162
+ min_col, max_col = cols.min().item(), cols.max().item()
163
+
164
+ # Step 3: Quantize the minimum values down to the nearest multiple of 8
165
+ quantized_min_row = (min_row // 8) * 8
166
+ quantized_min_col = (min_col // 8) * 8
167
+
168
+ # Step 4: Quantize the maximum values up to the nearest multiple of 8 outside of the max
169
+ quantized_max_row = ((max_row // 8) + 1) * 8
170
+ quantized_max_col = ((max_col // 8) + 1) * 8
171
+ list_layer_box.append(
172
+ (quantized_min_col, quantized_min_row, quantized_max_col, quantized_max_row)
173
+ )
174
+ return list_layer_box
175
+
176
+ def pngs_to_svg(list_png_images):
177
+ list_layer_box = get_list_layer_box(list_png_images)
178
+ assert(len(list_png_images) == len(list_layer_box))
179
+ width, height = list_png_images[0].width, list_png_images[0].height
180
+ img_svg = ET.Element(
181
+ 'svg',
182
+ {
183
+ "width": str(width),
184
+ "height": str(height),
185
+ "xmlns": "http://www.w3.org/2000/svg",
186
+ "xmlns:svg": "http://www.w3.org/2000/svg",
187
+ "xmlns:xlink":"http://www.w3.org/1999/xlink"
188
+ }
189
+ )
190
+ for img, box in zip(list_png_images, list_layer_box):
191
+ x, y, w, h = box[0], box[1], box[2]-box[0], box[3]-box[1]
192
+ if (w == 0 or h == 0):
193
+ continue
194
+ img = img.crop((x, y, x+w, y+h))
195
+ buffer = BytesIO()
196
+ img.save(buffer, format='PNG')
197
+ img_str = base64.b64encode(buffer.getvalue())
198
+ ET.SubElement(
199
+ img_svg,
200
+ "image",
201
+ {
202
+ "x": str(x),
203
+ "y": str(y),
204
+ "width": str(w),
205
+ "height": str(h),
206
+ "xlink:href": "data:image/png;base64,"+img_str.decode('utf-8')
207
+ }
208
+ )
209
+ return ET.tostring(img_svg, encoding='utf-8').decode('utf-8')
210
+
211
+ def calculate_iou(box1, box2):
212
+ # 计算两个框的交集
213
+ x1 = max(box1[0], box2[0])
214
+ y1 = max(box1[1], box2[1])
215
+ x2 = min(box1[2], box2[2])
216
+ y2 = min(box1[3], box2[3])
217
+
218
+ intersection_area = max(0, x2 - x1) * max(0, y2 - y1)
219
+
220
+ # 计算两个框的并集
221
+ box1_area = (box1[2] - box1[0]) * (box1[3] - box1[1])
222
+ box2_area = (box2[2] - box2[0]) * (box2[3] - box2[1])
223
+
224
+ union_area = box1_area + box2_area - intersection_area
225
+
226
+ # 计算IOU
227
+ iou = intersection_area / union_area
228
+ return iou
229
+
230
+ # @spaces.GPU(enable_queue=True, duration=60)
231
+ def buildmodel(**kwargs):
232
+ from modeling_crello import CrelloModel, CrelloModelConfig
233
+ from quantizer import get_quantizer
234
+ # seed / input model / resume
235
+ resume = kwargs.get('resume', None)
236
+ seed = kwargs.get('seed', None)
237
+ input_model = kwargs.get('input_model', None)
238
+ quantizer_version = kwargs.get('quantizer_version', 'v4')
239
+ device = "cuda"
240
+
241
+ set_seed(seed)
242
+ # old_tokenizer = AutoTokenizer.from_pretrained(input_model, trust_remote_code=True)
243
+ old_tokenizer = AutoTokenizer.from_pretrained(
244
+ "WYBar/LLM_For_Layout_Planning", # 仓库路径
245
+ subfolder="Meta-Llama-3-8B", # 子目录对应模型文件夹
246
+ trust_remote_code=True,
247
+ cache_dir="/openseg_blob/v-yanbin/GradioDemo/cache_dir",
248
+ )
249
+ old_vocab_size = len(old_tokenizer)
250
+ # tokenizer = AutoTokenizer.from_pretrained(resume, trust_remote_code=True)
251
+ tokenizer = AutoTokenizer.from_pretrained(
252
+ "WYBar/LLM_For_Layout_Planning",
253
+ subfolder="checkpoint-26000", # 检查点所在子目录
254
+ trust_remote_code=True,
255
+ cache_dir="/openseg_blob/v-yanbin/GradioDemo/cache_dir",
256
+ )
257
+
258
+ quantizer = get_quantizer(
259
+ quantizer_version,
260
+ update_vocab = False,
261
+ decimal_quantize_types = kwargs.get('decimal_quantize_types'),
262
+ mask_values = kwargs['mask_values'],
263
+ width = kwargs['width'],
264
+ height = kwargs['height'],
265
+ simplify_json = False,
266
+ num_mask_tokens = 0,
267
+ mask_type = kwargs.get('mask_type'),
268
+ )
269
+ quantizer.setup_tokenizer(tokenizer)
270
+
271
+ model_args = CrelloModelConfig(
272
+ old_vocab_size = old_vocab_size,
273
+ vocab_size=len(tokenizer),
274
+ pad_token_id=tokenizer.pad_token_id,
275
+ ignore_ids=tokenizer.convert_tokens_to_ids(quantizer.ignore_tokens),
276
+ )
277
+ model_args.freeze_lm = True
278
+ model_args.opt_version = "WYBar/LLM_For_Layout_Planning"
279
+ model_args.use_lora = False
280
+ model_args.load_in_4bit = kwargs.get('load_in_4bit', False)
281
+ # model = CrelloModel.from_pretrained(
282
+ # resume,
283
+ # config=model_args
284
+ # ).to(device)
285
+ # model = CrelloModel.from_pretrained(
286
+ # "WYBar/LLM_For_Layout_Planning",
287
+ # subfolder="checkpoint-26000", # 加载检查点目录
288
+ # config=model_args,
289
+ # # cache_dir="/openseg_blob/v-yanbin/GradioDemo/cache_dir",
290
+ # )
291
+ model = CrelloModel(config=model_args)
292
+ print("before .to(device)")
293
+ model = model.to(device)
294
+ print("after .to(device)")
295
+ model = model.bfloat16()
296
+ model.eval()
297
+
298
+ tokenizer.add_special_tokens({"mask_token": "<mask>"})
299
+ quantizer.additional_special_tokens.add("<mask>")
300
+ added_special_tokens_list = ["<layout>", "<position>", "<wholecaption>"]
301
+ tokenizer.add_special_tokens({"additional_special_tokens": added_special_tokens_list}, replace_additional_special_tokens=False)
302
+ for token in added_special_tokens_list:
303
+ quantizer.additional_special_tokens.add(token)
304
+
305
+ return model, quantizer, tokenizer
306
+
307
+ def construction_layout():
308
+ params_dict = {
309
+ # 需要修改
310
+ "input_model": "WYBar/LLM_For_Layout_Planning",
311
+ "resume": "WYBar/LLM_For_Layout_Planning",
312
+
313
+ "seed": 0,
314
+ "mask_values": False,
315
+ "quantizer_version": 'v4',
316
+ "mask_type": 'cm3',
317
+ "decimal_quantize_types": [],
318
+ "num_mask_tokens": 0,
319
+ "width": 512,
320
+ "height": 512,
321
+ "device": 0,
322
+ }
323
+ device = "cuda"
324
+ # Init model
325
+ model, quantizer, tokenizer = buildmodel(**params_dict)
326
+
327
+ print('resize token embeddings to match the tokenizer', 129423)
328
+ model.lm.resize_token_embeddings(129423)
329
+ model.input_embeddings = model.lm.get_input_embeddings()
330
+ print('after token embeddings to match the tokenizer', 129423)
331
+ return model, quantizer, tokenizer, params_dict["width"], params_dict["height"], device
332
+
333
+ @torch.no_grad()
334
+ # @spaces.GPU(enable_queue=True, duration=60)
335
+ def evaluate_v1(inputs, model, quantizer, tokenizer, width, height, device, do_sample=False, temperature=1.0, top_p=1.0, top_k=50):
336
+ json_example = inputs
337
+ input_intension = '{"wholecaption":"' + json_example["wholecaption"] + '","layout":[{"layer":'
338
+ inputs = tokenizer(
339
+ input_intension, return_tensors="pt"
340
+ ).to(device)
341
+
342
+ stopping_criteria = StoppingCriteriaList()
343
+ stopping_criteria.append(StopAtSpecificTokenCriteria(token_id_list=[128000]))
344
+
345
+ outputs = model.lm.generate(**inputs, use_cache=True, max_length=8000, stopping_criteria=stopping_criteria, do_sample=do_sample, temperature=temperature, top_p=top_p, top_k=top_k)
346
+ inputs_length = inputs['input_ids'].shape[1]
347
+ outputs = outputs[:, inputs_length:]
348
+
349
+ outputs_word = tokenizer.batch_decode(outputs)[0]
350
+ split_word = outputs_word.split('}]}')[0]+"}]}"
351
+ split_word = '{"wholecaption":"' + json_example["wholecaption"].replace('\n', '\\n').replace('"', '\\"') + '","layout":[{"layer":' + split_word
352
+ map_dict = quantizer.construct_map_dict()
353
+
354
+ for key ,value in map_dict.items():
355
+ split_word = split_word.replace(key, value)
356
+ try:
357
+ pred_json_example = json.loads(split_word)
358
+ for layer in pred_json_example["layout"]:
359
+ layer['x'] = round(int(width)*layer['x'])
360
+ layer['y'] = round(int(height)*layer['y'])
361
+ layer['width'] = round(int(width)*layer['width'])
362
+ layer['height'] = round(int(height)*layer['height'])
363
+ except Exception as e:
364
+ print(e)
365
+ pred_json_example = None
366
+ return pred_json_example
367
+
368
+ def inference(generate_method, intention, model, quantizer, tokenizer, width, height, device, do_sample=True, temperature=1.0, top_p=1.0, top_k=50):
369
+ def FormulateInput(intension: str):
370
+ resdict = {}
371
+ resdict["wholecaption"] = intension
372
+ resdict["layout"] = []
373
+ return resdict
374
+
375
+ rawdata = FormulateInput(intention)
376
+
377
+ if generate_method == 'v1':
378
+ max_try_time = 5
379
+ preddata = None
380
+ while preddata is None and max_try_time > 0:
381
+ preddata = evaluate_v1(rawdata, model, quantizer, tokenizer, width, height, device, do_sample=do_sample, temperature=temperature, top_p=top_p, top_k=top_k)
382
+ max_try_time -= 1
383
+ else:
384
+ print("Please input correct generate method")
385
+ preddata = None
386
+
387
+ return preddata
388
+
389
+ # @spaces.GPU(enable_queue=True, duration=60)
390
+ def construction():
391
+ from custom_model_mmdit import CustomFluxTransformer2DModel
392
+ from custom_model_transp_vae import AutoencoderKLTransformerTraining as CustomVAE
393
+ from custom_pipeline import CustomFluxPipelineCfg
394
+
395
+ transformer = CustomFluxTransformer2DModel.from_pretrained(
396
+ "WYBar/ART_test_weights",
397
+ subfolder="fused_transformer",
398
+ torch_dtype=torch.bfloat16,
399
+ cache_dir="/openseg_blob/v-yanbin/GradioDemo/cache_dir"
400
+ )
401
+
402
+ transp_vae = CustomVAE.from_pretrained(
403
+ "WYBar/ART_test_weights",
404
+ subfolder="custom_vae",
405
+ torch_dtype=torch.float32,
406
+ cache_dir="/openseg_blob/v-yanbin/GradioDemo/cache_dir"
407
+ )
408
+
409
+ token = os.environ.get("HF_TOKEN")
410
+ pipeline = CustomFluxPipelineCfg.from_pretrained(
411
+ "black-forest-labs/FLUX.1-dev",
412
+ transformer=transformer,
413
+ torch_dtype=torch.bfloat16,
414
+ token=token,
415
+ cache_dir="/openseg_blob/v-yanbin/GradioDemo/cache_dir"
416
+ ).to("cuda")
417
+ pipeline.enable_model_cpu_offload(gpu_id=0) # Save GPU memory
418
+
419
+ return pipeline, transp_vae
420
+
421
+ # @spaces.GPU(enable_queue=True, duration=60)
422
+ def test_one_sample(validation_box, validation_prompt, true_gs, inference_steps, pipeline, generator, transp_vae):
423
+ print(validation_box)
424
+ output, rgba_output, _, _ = pipeline(
425
+ prompt=validation_prompt,
426
+ validation_box=validation_box,
427
+ generator=generator,
428
+ height=512,
429
+ width=512,
430
+ num_layers=len(validation_box),
431
+ guidance_scale=4.0,
432
+ num_inference_steps=inference_steps,
433
+ transparent_decoder=transp_vae,
434
+ true_gs=true_gs
435
+ )
436
+ images = output.images # list of PIL, len=layers
437
+ rgba_images = [Image.fromarray(arr, 'RGBA') for arr in rgba_output]
438
+
439
+ output_gradio = []
440
+ merged_pil = images[1].convert('RGBA')
441
+ for frame_idx, frame_pil in enumerate(rgba_images):
442
+ if frame_idx < 2:
443
+ frame_pil = images[frame_idx].convert('RGBA') # merged and background
444
+ else:
445
+ merged_pil = Image.alpha_composite(merged_pil, frame_pil)
446
+ output_gradio.append(frame_pil)
447
+
448
+ return output_gradio
449
+
450
+ def svg_test_one_sample(validation_prompt, validation_box_str, seed, true_gs, inference_steps, pipeline, transp_vae):
451
+ generator = torch.Generator().manual_seed(seed)
452
+ try:
453
+ validation_box = ast.literal_eval(validation_box_str)
454
+ except Exception as e:
455
+ return [f"Error parsing validation_box: {e}"]
456
+ if not isinstance(validation_box, list) or not all(isinstance(t, tuple) and len(t) == 4 for t in validation_box):
457
+ return ["validation_box must be a list of tuples, each of length 4."]
458
+
459
+ validation_box = adjust_validation_box(validation_box)
460
+
461
+ result_images = test_one_sample(validation_box, validation_prompt, true_gs, inference_steps, pipeline, generator, transp_vae)
462
+
463
+ svg_img = pngs_to_svg(result_images[1:])
464
+
465
+ svg_file_path = './image.svg'
466
+ os.makedirs(os.path.dirname(svg_file_path), exist_ok=True)
467
+ with open(svg_file_path, 'w', encoding='utf-8') as f:
468
+ f.write(svg_img)
469
+
470
+ return result_images, svg_file_path
471
+
472
+ def main():
473
+ model, quantizer, tokenizer, width, height, device = construction_layout()
474
+
475
+ inference_partial = partial(
476
+ inference,
477
+ model=model,
478
+ quantizer=quantizer,
479
+ tokenizer=tokenizer,
480
+ width=width,
481
+ height=height,
482
+ device=device
483
+ )
484
+
485
+ def process_preddate(intention, temperature, top_p, generate_method='v1'):
486
+ intention = intention.replace('\n', '').replace('\r', '').replace('\\', '')
487
+ intention = ensure_space_after_period(intention)
488
+ if temperature == 0.0:
489
+ # print("looking for greedy decoding strategies, set `do_sample=False`.")
490
+ preddata = inference_partial(generate_method, intention, do_sample=False)
491
+ else:
492
+ preddata = inference_partial(generate_method, intention, temperature=temperature, top_p=top_p)
493
+ # wholecaption = preddata["wholecaption"]
494
+ layouts = preddata["layout"]
495
+ list_box = []
496
+ for i, layout in enumerate(layouts):
497
+ x, y = layout["x"], layout["y"]
498
+ width, height = layout["width"], layout["height"]
499
+ if i == 0:
500
+ list_box.append((0, 0, width, height))
501
+ list_box.append((0, 0, width, height))
502
+ else:
503
+ left = x - width // 2
504
+ top = y - height // 2
505
+ right = x + width // 2
506
+ bottom = y + height // 2
507
+ list_box.append((left, top, right, bottom))
508
+
509
+ # print(list_box)
510
+ filtered_boxes = list_box[:2]
511
+ for i in range(2, len(list_box)):
512
+ keep = True
513
+ for j in range(1, len(filtered_boxes)):
514
+ iou = calculate_iou(list_box[i], filtered_boxes[j])
515
+ if iou > 0.65:
516
+ print(list_box[i], filtered_boxes[j])
517
+ keep = False
518
+ break
519
+ if keep:
520
+ filtered_boxes.append(list_box[i])
521
+
522
+ return str(filtered_boxes), intention, str(filtered_boxes)
523
+
524
+ # def process_preddate(intention, generate_method='v1'):
525
+ # list_box = [(0, 0, 512, 512), (0, 0, 512, 512), (136, 184, 512, 512), (144, 0, 512, 512), (0, 0, 328, 136), (160, 112, 512, 360), (168, 112, 512, 360), (40, 232, 112, 296), (32, 88, 248, 176), (48, 424, 144, 448), (48, 464, 144, 488), (240, 464, 352, 488), (384, 464, 488, 488), (48, 480, 144, 504), (240, 480, 360, 504), (456, 0, 512, 56), (0, 0, 56, 40), (440, 0, 512, 40), (0, 24, 48, 88), (48, 168, 168, 240)]
526
+ # wholecaption = "Design an engaging and vibrant recruitment advertisement for our company. The image should feature three animated characters in a modern cityscape, depicting a dynamic and collaborative work environment. Incorporate a light bulb graphic with a question mark, symbolizing innovation, creativity, and problem-solving. Use bold text to announce \"WE ARE RECRUITING\" and provide the company's social media handle \"@reallygreatsite\" and a contact phone number \"+123-456-7890\" for interested individuals. The overall design should be playful and youthful, attracting potential recruits who are innovative and eager to contribute to a lively team."
527
+ # json_file = "/home/wyb/openseg_blob/v-yanbin/GradioDemo/LLM-For-Layout-Planning/inference_test.json"
528
+ # return wholecaption, str(list_box), json_file
529
+
530
+ pipeline, transp_vae = construction()
531
+
532
+ gradio_test_one_sample_partial = partial(
533
+ svg_test_one_sample,
534
+ pipeline=pipeline,
535
+ transp_vae=transp_vae,
536
+ )
537
+
538
+ def process_svg(text_input, tuple_input, seed, true_gs, inference_steps):
539
+ result_images = []
540
+ result_images, svg_file_path = gradio_test_one_sample_partial(text_input, tuple_input, seed, true_gs, inference_steps)
541
+
542
+ url, unique_filename = upload_to_github(file_path=svg_file_path)
543
+ unique_filename = f'{unique_filename}'
544
+
545
+ if url != None:
546
+ print(f"File uploaded to: {url}")
547
+ svg_editor = f"""
548
+ <iframe src="https://svgedit.netlify.app/editor/index.html?\
549
+ storagePrompt=false&url={url}" \
550
+ width="100%", height="800px"></iframe>
551
+ """
552
+ else:
553
+ print('upload_to_github FAILED!')
554
+ svg_editor = f"""
555
+ <iframe src="https://svgedit.netlify.app/editor/index.html" \
556
+ width="100%", height="800px"></iframe>
557
+ """
558
+
559
+ return result_images, svg_file_path, svg_editor
560
+
561
+ def one_click_generate(intention_input, temperature, top_p, seed, true_gs, inference_steps):
562
+ # 首先调用process_preddate
563
+ list_box_output, intention_input, list_box_output = process_preddate(intention_input, temperature, top_p)
564
+
565
+ # 然后将process_preddate的输出作为process_svg的输入
566
+ result_images, svg_file, svg_editor = process_svg(intention_input, list_box_output, seed, true_gs, inference_steps)
567
+
568
+ # 返回两个函数的输出
569
+ return list_box_output, result_images, svg_file, svg_editor, intention_input, list_box_output
570
+
571
+ def clear_inputs1():
572
+ return "", ""
573
+
574
+ def clear_inputs2():
575
+ return "", ""
576
+
577
+ def transfer_inputs(intention, list_box):
578
+ return intention, list_box
579
+
580
+ theme = gr.themes.Soft(
581
+ radius_size="lg",
582
+ ).set(
583
+ block_background_fill='*primary_50',
584
+ block_border_color='*primary_200',
585
+ block_border_width='1px',
586
+ block_border_width_dark='100px',
587
+ block_info_text_color='*primary_950',
588
+ block_label_border_color='*primary_200',
589
+ block_radius='*radius_lg'
590
+ )
591
+
592
+ with gr.Blocks(theme=theme) as demo:
593
+ gr.HTML("<h1 style='text-align: center;'>ART: Anonymous Region Transformer for Variable Multi-Layer Transparent Image Generation</h1>")
594
+ gr.HTML("<h2>Anonymous Region Layout Planner</h2>")
595
+
596
+ with gr.Row():
597
+ with gr.Column():
598
+ intention_input = gr.Textbox(lines=15, placeholder="Enter intention", label="Prompt")
599
+ with gr.Row():
600
+ temperature_input=gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label="Temperature", value=0.0)
601
+ top_p_input=gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label="Top P", value=0.0)
602
+ with gr.Row():
603
+ clear_btn1 = gr.Button("Clear")
604
+ model_btn1 = gr.Button("Commit", variant='primary')
605
+ transfer_btn1 = gr.Button("Export to below")
606
+
607
+ one_click_btn = gr.Button("One Click Generate ALL", variant='primary')
608
+
609
+ with gr.Column():
610
+ list_box_output = gr.Textbox(lines=10, placeholder="Validation Box", label="Validation Box")
611
+
612
+ examples = gr.Examples(
613
+ examples=[
614
+ ['The image is a graphic design with a celebratory theme. At the top, there is a banner with the text \"Happy Anniversary\" in a bold, sans-serif font. Below this banner, there is a circular frame containing a photograph of a couple. The man has short, dark hair and is wearing a light-colored sweater, while the woman has long blonde hair and is also wearing a light-colored sweater. They are both smiling and appear to be embracing each other.Surrounding the circular frame are decorative elements such as pink flowers and green leaves, which add a festive touch to the design. Below the circular frame, there is a text that reads "Isabel & Morgan" in a cursive, elegant font, suggesting that the couple\'s names are Isabel and Morgan.At the bottom of the image, there is a banner with a message that says "Happy Anniversary! Cheers to another year of love, laughter, and cherished memories together.\" This text is in a smaller, sans-serif font and is placed against a solid background, providing a clear message of celebration and well-wishes for the couple.The overall style of the image is warm and celebratory, with a color scheme that includes shades of pink, green, and white, which contribute to a joyful and romantic atmosphere.'],
615
+ ['The image is a digital illustration with a light blue background. At the top, there is a logo consisting of a snake wrapped around a staff, which is a common symbol in healthcare. Below the logo, the text "International Nurses Day" is prominently displayed in white, with the date "12 May 20xx" in smaller font size.The central part of the image features two stylized characters. On the left, there is a female character with dark hair, wearing a white nurse\'s uniform with a cap. She is holding a clipboard and appears to be speaking or gesturing, as indicated by a speech bubble with the word "OK" in it. On the right, there is a male character with light brown hair, wearing a light blue shirt with a white collar and a white apron. He is holding a stethoscope to his ear, suggesting he is a doctor or a healthcare professional.The characters are depicted in a friendly and approachable manner, with smiles on their faces. Around them, there are small blue plus signs, which are often associated with healthcare and medical services. The overall style of the image is clean, modern, and appears to be designed to celebrate International Nurses Day.'],
616
+ ['The image features a graphic design with a festive theme. At the top, there is a decorative border with a wavy pattern. Below this border, the text "WINTER SEASON SPECIAL COOKIES" is prominently displayed in a bold, sans-serif font. The text is black with a slight shadow effect, giving it a three-dimensional appearance.In the center of the image, there are three illustrated gingerbread cookies. Each cookie has a smiling face with eyes, a nose, and a mouth, and they are colored in a warm, brown hue. The cookies are arranged in a staggered formation, with the middle cookie slightly higher than the others, creating a sense of depth.At the bottom of the image, there is a call to action that reads "ORDER.NOW" in a large, bold, sans-serif font. The text is colored in a darker shade of brown, contrasting with the lighter background. The overall style of the image suggests it is an advertisement or promotional graphic for a winter-themed cookie special.']
617
+ ],
618
+ inputs=[intention_input]
619
+ )
620
+
621
+ gr.HTML("<h2>Anonymous Region Transformer</h2>")
622
+ with gr.Row():
623
+ with gr.Column():
624
+ text_input = gr.Textbox(lines=10, placeholder="Enter prompt text", label="Prompt")
625
+ tuple_input = gr.Textbox(lines=5, placeholder="Enter list of tuples, e.g., [(1, 2, 3, 4), (5, 6, 7, 8)]", label="Validation Box")
626
+ with gr.Row():
627
+ true_gs_input=gr.Slider(minimum=3.0, maximum=5.0, step=0.1, label="true_gs", value=3.5)
628
+ inference_steps_input=gr.Slider(minimum=5, maximum=50, step=1, label="inference_steps", value=28)
629
+ with gr.Row():
630
+ seed_input = gr.Number(label="Seed", value=42)
631
+ with gr.Row():
632
+ transfer_btn2 = gr.Button("Import from above")
633
+ with gr.Row():
634
+ clear_btn2 = gr.Button("Clear")
635
+ model_btn2 = gr.Button("Commit", variant='primary')
636
+
637
+ with gr.Column():
638
+ result_images = gr.Gallery(label="Result Images", columns=5, height='auto')
639
+
640
+ gr.HTML("<h1>SVG Image</h1>")
641
+ svg_file = gr.File(label="Download SVG Image")
642
+ svg_editor = gr.HTML(label="Editable SVG Editor")
643
+
644
+ model_btn1.click(
645
+ fn=process_preddate,
646
+ inputs=[intention_input, temperature_input, top_p_input],
647
+ outputs=[list_box_output, text_input, tuple_input],
648
+ api_name="process_preddate"
649
+ )
650
+ clear_btn1.click(
651
+ fn=clear_inputs1,
652
+ inputs=[],
653
+ outputs=[intention_input, list_box_output]
654
+ )
655
+ model_btn2.click(
656
+ fn=process_svg,
657
+ inputs=[text_input, tuple_input, seed_input, true_gs_input, inference_steps_input],
658
+ outputs=[result_images, svg_file, svg_editor],
659
+ api_name="process_svg"
660
+ )
661
+ clear_btn2.click(
662
+ fn=clear_inputs2,
663
+ inputs=[],
664
+ outputs=[text_input, tuple_input]
665
+ )
666
+ transfer_btn1.click(
667
+ fn=transfer_inputs,
668
+ inputs=[intention_input, list_box_output],
669
+ outputs=[text_input, tuple_input]
670
+ )
671
+ transfer_btn2.click(
672
+ fn=transfer_inputs,
673
+ inputs=[intention_input, list_box_output],
674
+ outputs=[text_input, tuple_input]
675
+ )
676
+ one_click_btn.click(
677
+ fn=one_click_generate,
678
+ inputs=[intention_input, temperature_input, top_p_input, seed_input, true_gs_input, inference_steps_input],
679
+ outputs=[list_box_output, result_images, svg_file, svg_editor, text_input, tuple_input]
680
+ )
681
+ demo.launch()
682
+
683
+ if __name__ == "__main__":
684
+ main()
config/base.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ### Model Settings
2
+ pretrained_model_name_or_path = "black-forest-labs/FLUX.1-dev"
3
+ revision = None
4
+ variant = None
5
+ cache_dir = None
6
+
7
+ ### Training Settings
8
+ seed = 42
9
+ report_to = "wandb"
10
+ tracker_project_name = "multilayer"
11
+ wandb_job_name = "YOU_FORGET_TO_SET"
12
+ logging_dir = "logs"
13
+ max_train_steps = None
14
+ checkpoints_total_limit = None
15
+
16
+ # gpu
17
+ allow_tf32 = True
18
+ gradient_checkpointing = True
19
+ mixed_precision = "bf16"
20
+
21
+ ### Validation Settings
22
+ num_validation_images = 1
23
+ validation_steps = 5
24
+ validation_prompts = [
25
+ "The image features a simple, flat design with a solid pink background. On the left side, there is a stylized depiction of a decorated egg with a pattern of alternating white and light blue stripes. The egg has a smooth, oval shape and is outlined with a thin line. In the center of the image, there is a floral arrangement consisting of a large, white flower with a green center and several smaller white flowers with green centers. The flowers are connected by thin green stems and leaves, creating a small bouquet. On the right side of the image, there is another egg similar to the one on the left, with the same pattern of stripes. This egg is also outlined with a thin line and has a smooth, oval shape. The overall style of the image is clean and modern, with a limited color palette and a focus on geometric shapes and simple patterns. There are no texts or additional elements in the image.",
26
+ "The image features a cartoon-style illustration with three characters against a blue background. On the left side, there is a green, goblin-like creature with large, expressive eyes and a wide grin. It has a small body and is standing upright with its arms raised in a welcoming or excited gesture. In the center, there is a large, white, egg-shaped object that appears to be floating or resting on the surface. It has a smooth, rounded shape and is the largest object in the image. On the right side, there is a purple dinosaur with a friendly expression. It has a small head, large eyes, and a wide mouth that seems to be smiling. The dinosaur is standing on all fours and appears to be looking towards the viewer. The overall style of the image is playful and whimsical, with a clear emphasis on the characters rather than any specific background details.",
27
+ "The image features a collection of Christmas-themed objects against a solid green background. On the left side, there is a red Christmas ornament with a white pattern, resembling a traditional Christmas ball. Next to it, there is a red and white striped stocking with a small white cuff at the top. On the right side, there is a cartoon-style depiction of Santa Claus' face, with a white beard, red cheeks, and a smiling expression. The Santa face is stylized with simple lines and shapes, giving it a friendly and festive appearance. The overall style of the image is flat and graphic, with a clear focus on holiday-related items.",
28
+ "The image depicts a stylized illustration of a rocket launch. The rocket, which is the central focus of the image, is depicted in a simplified, cartoon-like style with a white body and a pointed nose cone. It is shown ascending into a dark background, which is likely meant to represent the night sky. Above the rocket, there are several small, golden stars scattered across the sky, adding a sense of motion and direction to the rocket's ascent. The stars are of varying sizes and are positioned at different heights, creating a sense of depth and distance. The overall style of the image is minimalist and modern, with a limited color palette that emphasizes the rocket and the stars against the dark background. The image does not contain any text or additional elements that would provide context or narrative beyond the depiction of the rocket launch.",
29
+ "The image features a stylized, cartoon-like depiction of a bear. The bear is predominantly pink with a lighter pink nose and a small black dot for an eye. It has two small ears and a small black line for a mouth. The bear is standing upright and appears to be holding a yellow object, possibly a piece of paper or a card, in its right paw. To the right of the bear, there is a purple background with a large, heart-shaped doodle. The overall style of the image is simplistic and child-friendly, with a limited color palette and a clear, uncluttered composition.",
30
+ "The image features three ice cream cones against a pink background. Each cone is filled with a different flavor of ice cream: the leftmost cone has chocolate ice cream, the middle cone has vanilla ice cream, and the rightmost cone has strawberry ice cream. The ice cream is topped with a drizzle of the respective flavor's syrup, and each cone is adorned with a small, round, chocolate-covered piece of candy. The image also contains text that reads 'Sprinkle Sunday Ice Cream Factory East Avenue, CA 13154' and a phone number '+799-2324-9890'. Additionally, there is a website address 'www.sprinklesunday.com'. The style of the image is illustrative and appears to be designed for advertising or promotional purposes.",
31
+ ]
config/v04sv03_lora_r64_upto50layers_bs1_lr1_prodigy_800k_wds_512_filtered_10ep_none_8gpu.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _base_ = "./base.py"
2
+
3
+ ### path & device settings
4
+ img_tar_path = "/openseg_blob/puyifan/shared_data/CANVA_802000_resolution512max21760tokens/"
5
+ output_path_base = "/openseg_blob/zhaoym/multi_layer_sd3/work_dirs/"
6
+ cache_dir = "/openseg_blob/zhaoym/pretrained/flux"
7
+ # transformer_varient = "ashen0209/Flux-Dev2Pro"
8
+ pretrained_lora_dir = "/openseg_blob/zhaoym/sd3/work_dirs/canva500k_mix100k_sft_flux"
9
+ total_gpu_num = 8
10
+
11
+ ### wandb settings
12
+ wandb_job_name = "flux_" + '{{fileBasenameNoExtension}}'
13
+
14
+ ### Dataset Settings
15
+ resolution = 512
16
+ dataloader_pin_memory = True
17
+ dataloader_num_workers = 16
18
+ train_batch_size = 1
19
+ dataset_cfg = dict(
20
+ img_tar_path=img_tar_path,
21
+ num_train_examples=802000,
22
+ per_gpu_batch_size=train_batch_size,
23
+ global_batch_size=(train_batch_size * total_gpu_num),
24
+ num_workers=dataloader_num_workers,
25
+ resolution=resolution,
26
+ center_crop=True,
27
+ random_flip=False,
28
+ shuffle_buffer_size=1000,
29
+ pin_memory=dataloader_pin_memory,
30
+ persistent_workers=True,
31
+ )
32
+
33
+ ### Model Settings
34
+ rank = 64
35
+ text_encoder_rank = 64
36
+ train_text_encoder = False
37
+ max_layer_num = 50 + 2
38
+ learnable_proj = True
39
+
40
+ ### Training Settings
41
+ weighting_scheme = "none"
42
+ logit_mean = 0.0
43
+ logit_std = 1.0
44
+ mode_scale = 1.29
45
+ guidance_scale = 1.0 ###IMPORTANT
46
+ layer_weighting = 5.0
47
+
48
+ # steps
49
+ # train_batch_size = 1
50
+ num_train_epochs = 1
51
+ max_train_steps = None
52
+ checkpointing_steps = 2000
53
+ resume_from_checkpoint = "latest"
54
+ gradient_accumulation_steps = 1
55
+
56
+ # lr
57
+ optimizer = "prodigy"
58
+ learning_rate = 1.0
59
+ scale_lr = False
60
+ lr_scheduler = "constant"
61
+ lr_warmup_steps = 0
62
+ lr_num_cycles = 1
63
+ lr_power = 1.0
64
+
65
+ # optim
66
+ adam_beta1 = 0.9
67
+ adam_beta2 = 0.999
68
+ adam_weight_decay = 1e-3
69
+ adam_epsilon = 1e-8
70
+ prodigy_beta3 = None
71
+ prodigy_decouple = True
72
+ prodigy_use_bias_correction = True
73
+ prodigy_safeguard_warmup = True
74
+ max_grad_norm = 1.0
75
+
76
+ # logging
77
+ tracker_task_name = '{{fileBasenameNoExtension}}'
78
+ output_dir = output_path_base + "{{fileBasenameNoExtension}}"
79
+
80
+ ### Validation Settings
81
+ num_validation_images = 1
82
+ validation_steps = 2000
83
+ validation_prompts = [
84
+ 'The image features a background with a soft, pastel color gradient that transitions from pink to purple. There are abstract floral elements scattered throughout the background, with some appearing to be in full bloom and others in a more delicate, bud-like state. The flowers have a watercolor effect, with soft edges that blend into the background.\n\nCentered in the image is a quote in a serif font that reads, "You\'re free to be different." The text is black, which stands out against the lighter background. The overall style of the image is artistic and inspirational, with a motivational message that encourages individuality and self-expression. The image could be used for motivational purposes, as a background for a blog or social media post, or as part of a personal development or self-help theme.',
85
+ 'The image features a logo for a company named "Bull Head Party Adventure." The logo is stylized with a cartoon-like depiction of a bull\'s head, which is the central element of the design. The bull has prominent horns and a fierce expression, with its mouth slightly open as if it\'s snarling or roaring. The color scheme of the bull is a mix of brown and beige tones, with the horns highlighted in a lighter shade.\n\nBelow the bull\'s head, the company name is written in a bold, sans-serif font. The text is arranged in two lines, with "Bull Head" on the top line and "Party Adventure" on the bottom line. The font color matches the color of the bull, creating a cohesive look. The overall style of the image is playful and energetic, suggesting that the company may offer exciting or adventurous party experiences.',
86
+ 'The image features a festive and colorful illustration with a theme related to the Islamic holiday of Eid al-Fitr. At the center of the image is a large, ornate crescent moon with intricate patterns and decorations. Surrounding the moon are several smaller stars and crescents, also adorned with decorative elements. These smaller celestial motifs are suspended from the moon, creating a sense of depth and dimension.\n\nBelow the central moon, there is a banner with the text "Eid Mubarak" in a stylized, elegant font. The text is in a bold, dark color that stands out against the lighter background. The background itself is a gradient of light to dark green, which complements the golden and white hues of the celestial motifs.\n\nThe overall style of the image is celebratory and decorative, with a focus on the traditional symbols associated with Eid al-Fitr. The use of gold and white gives the image a luxurious and festive feel, while the green background is a color often associated with Islam. The image appears to be a digital artwork or graphic design, possibly intended for use as a greeting card or a festive decoration.',
87
+ 'The image is a festive graphic with a dark background. At the center, there is a large, bold text that reads "Happy New Year 2023" in a combination of white and gold colors. The text is surrounded by numerous white balloons with gold ribbons, giving the impression of a celebratory atmosphere. The balloons are scattered around the text, creating a sense of depth and movement. Additionally, there are small gold sparkles and confetti-like elements that add to the celebratory theme. The overall design suggests a New Year\'s celebration, with the year 2023 being the focal point.',
88
+ 'The image is a stylized illustration with a flat design aesthetic. It depicts a scene related to healthcare or medical care. In the center, there is a hospital bed with a patient lying down, appearing to be resting or possibly receiving treatment. The patient is surrounded by three individuals who seem to be healthcare professionals or caregivers. They are standing around the bed, with one on each side and one at the foot of the bed. The person at the foot of the bed is holding a clipboard, suggesting they might be taking notes or reviewing medical records.\n\nThe room has a window with curtains partially drawn, allowing some light to enter. The color palette is soft, with pastel tones dominating the scene. The text "INTERNATIONAL CANCER DAY" is prominently displayed at the top of the image, indicating that the illustration is related to this event. The overall impression is one of care and support, with a focus on the patient\'s well-being.',
89
+ 'The image features a stylized illustration of a man with a beard and a tank top, drinking from a can. The man is depicted in a simplified, cartoon-like style with a limited color palette. Above him, there is a text that reads "Happy Eating, Friends" in a bold, friendly font. Below the illustration, there is another line of text that states "Food is a Necessity That is Not Prioritized," which is also in a bold, sans-serif font. The background of the image is a gradient of light to dark blue, giving the impression of a sky or a calm, serene environment. The overall style of the image is casual and approachable, with a focus on the message conveyed by the text.',
90
+ 'The image is a digital illustration with a pastel pink background. At the top, there is a text that reads "Sending you my Easter wishes" in a simple, sans-serif font. Below this, a larger text states "May Your Heart be Happy!" in a more decorative, serif font. Underneath this main message, there is a smaller text that says "Let the miracle of the season fill you with hope and love."\n\nThe illustration features three stylized flowers with smiling faces. On the left, there is a purple flower with a yellow center. In the center, there is a blue flower with a green center. On the right, there is a pink flower with a yellow center. Each flower has a pair of eyes and a mouth, giving them a friendly appearance. The flowers are drawn with a cartoon-like style, using solid colors and simple shapes.\n\nThe overall style of the image is cheerful and whimsical, with a clear Easter theme suggested by the text and the presence of flowers, which are often associated with spring and new beginnings.',
91
+ 'The image is a vibrant and colorful graphic with a pink background. In the center, there is a photograph of a man and a woman embracing each other. The man is wearing a white shirt, and the woman is wearing a patterned top. They are both smiling and appear to be in a joyful mood.\n\nSurrounding the photograph are various elements that suggest a festive or celebratory theme. There are three hot air balloons in the background, each with a different design: one with a heart, one with a gift box, and one with a basket. These balloons are floating against a clear sky.\n\nAdditionally, there are two gift boxes with ribbons, one on the left and one on the right side of the image. These gift boxes are stylized with a glossy finish and are placed at different heights, creating a sense of depth.\n\nAt the bottom of the image, there is a large red heart, which is a common symbol associated with love and Valentine\'s Day.\n\nFinally, at the very bottom of the image, there is a text that reads "Happy Valentine\'s Day," which confirms the theme of the image as a Valentine\'s Day greeting. The text is in a playful, cursive font that matches the overall cheerful and romantic tone of the image.',
92
+ 'The image depicts a stylized illustration of two women sitting on stools, engaged in conversation. They are wearing traditional attire, with headscarves and patterned dresses. The woman on the left is wearing a brown dress with a purple pattern, while the woman on the right is wearing a purple dress with a brown pattern. Between them is a purple flower. Above the women, the text "INTERNATIONAL WOMEN\'S DAY" is written in bold, uppercase letters. The background is a soft, pastel pink, and there are abstract, swirling lines in a darker shade of pink above the women. The overall style of the image is simplistic and cartoonish, with a warm and friendly tone.',
93
+ 'The image is a digital graphic with a clean, minimalist design. It features a light blue background with a subtle floral pattern at the bottom. On the left side, there is a large, bold text that reads "Our Global Idea." The text is in a serif font and is colored in a darker shade of blue, creating a contrast against the lighter background.\n\nOn the right side, there is a smaller text in a sans-serif font that provides information about utilizing the Live Q&A feature of Canva. The text suggests using this feature to engage an audience more effectively, such as asking about their opinions on certain topics and themes. The text is in a lighter shade of blue, which matches the background, and it is enclosed within a decorative border that includes a floral motif, mirroring the design at the bottom of the image.\n\nThe overall style of the image is professional and modern, with a focus on typography and a simple color scheme. The design elements are well-balanced, with the text and decorative elements complementing each other without overwhelming the viewer.',
94
+ 'The image is a stylized illustration with a warm, peach-colored background. At the center, there is a vintage-style radio with a prominent dial and antenna. The radio is emitting a blue, star-like burst of light or energy from its top. Surrounding the radio are various objects and elements that seem to be floating or suspended in the air. These include a brown, cone-shaped object, a blue, star-like shape, and a brown, wavy, abstract shape that could be interpreted as a flower or a wave.\n\nAt the top of the image, there is text that reads "World Radio Day" in a bold, serif font. Below this, in a smaller, sans-serif font, is the date "13 February 2022." The overall style of the image is playful and cartoonish, with a clear focus on celebrating World Radio Day.',
95
+ 'The image is a graphic design of a baby shower invitation. The central focus is a cute, cartoon-style teddy bear with a friendly expression, sitting upright. The bear is colored in a soft, light brown hue. Above the bear, there is a bold text that reads "YOU\'RE INVITED" in a playful, sans-serif font. Below this, the words "BABY SHOWER" are prominently displayed in a larger, more decorative font, suggesting the theme of the event.\n\nThe background of the invitation is a soft, light pink color, which adds to the gentle and welcoming atmosphere of the design. At the bottom of the image, there is additional text providing specific details about the event. It reads "27 January, 2022 - 8:00 PM" followed by "FAUGET INDUSTRIES CAFE," indicating the date, time, and location of the baby shower.\n\nThe overall style of the image is warm, inviting, and child-friendly, with a clear focus on the theme of a baby shower celebration. The use of a teddy bear as the central image reinforces the baby-related theme. The design is simple yet effective, with a clear hierarchy of information that guides the viewer\'s attention from the top to the bottom of the invitation.',
96
+ ]
97
+
98
+ validation_boxes = [
99
+ [(0, 0, 512, 512), (0, 0, 512, 512), (368, 0, 512, 272), (0, 272, 112, 512), (160, 208, 352, 304)],
100
+ [(0, 0, 512, 512), (0, 0, 512, 512), (128, 128, 384, 304), (96, 288, 416, 336), (128, 336, 384, 368)],
101
+ [(0, 0, 512, 512), (0, 0, 512, 512), (112, 48, 400, 368), (0, 48, 96, 176), (128, 336, 384, 384), (240, 384, 384, 432)],
102
+ [(0, 0, 512, 512), (0, 0, 512, 512), (32, 32, 480, 480), (80, 176, 432, 368), (64, 176, 448, 224), (144, 96, 368, 224)],
103
+ [(0, 0, 512, 512), (0, 0, 512, 512), (0, 64, 176, 272), (0, 400, 512, 512), (16, 160, 496, 512), (224, 48, 464, 112), (208, 96, 464, 160)],
104
+ [(0, 0, 512, 512), (0, 0, 512, 512), (112, 224, 512, 512), (0, 0, 240, 160), (144, 144, 512, 512), (48, 64, 432, 208), (48, 400, 256, 448)],
105
+ [(0, 0, 512, 512), (0, 0, 512, 512), (160, 48, 352, 80), (64, 80, 448, 192), (128, 208, 384, 240), (320, 240, 512, 512), (80, 272, 368, 512), (0, 224, 192, 512)],
106
+ [(0, 0, 512, 512), (0, 0, 512, 512), (48, 0, 464, 304), (128, 144, 384, 400), (288, 288, 384, 368), (336, 304, 400, 368), (176, 432, 336, 480), (224, 400, 288, 432)],
107
+ [(0, 0, 512, 512), (0, 0, 512, 512), (32, 288, 448, 512), (144, 176, 336, 400), (224, 208, 272, 256), (160, 128, 336, 192), (192, 368, 304, 400), (368, 80, 448, 224), (48, 160, 128, 256)],
108
+ [(0, 0, 512, 512), (0, 0, 512, 512), (0, 112, 112, 240), (400, 272, 512, 416), (400, 112, 512, 240), (0, 272, 112, 400), (64, 192, 176, 320), (224, 192, 432, 320), (224, 304, 448, 368)],
109
+ [(0, 0, 512, 512), (0, 0, 512, 512), (0, 352, 512, 512), (112, 176, 368, 432), (48, 176, 128, 256), (48, 368, 128, 448), (384, 192, 480, 272), (384, 336, 432, 384), (80, 80, 432, 128), (176, 128, 336, 160)],
110
+ [(0, 0, 512, 512), (0, 0, 512, 512), (0, 0, 512, 352), (144, 384, 368, 448), (160, 192, 352, 432), (368, 0, 512, 144), (0, 0, 144, 144), (128, 80, 384, 208), (128, 448, 384, 496), (176, 48, 336, 80)],
111
+ ]
custom_model_mmdit.py ADDED
@@ -0,0 +1,334 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from typing import Any, Dict, List, Optional, Union, Tuple
4
+
5
+ from accelerate.utils import set_module_tensor_to_device
6
+ from diffusers.models.modeling_outputs import Transformer2DModelOutput
7
+ from diffusers.models.normalization import AdaLayerNormContinuous
8
+ from diffusers.models.embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, FluxPosEmbed
9
+ from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel, FluxTransformerBlock, FluxSingleTransformerBlock
10
+
11
+ from diffusers.configuration_utils import register_to_config
12
+ from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
13
+
14
+
15
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
16
+
17
+
18
+ class CustomFluxTransformer2DModel(FluxTransformer2DModel):
19
+ """
20
+ The Transformer model introduced in Flux.
21
+
22
+ Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
23
+
24
+ Parameters:
25
+ patch_size (`int`): Patch size to turn the input data into small patches.
26
+ in_channels (`int`, *optional*, defaults to 16): The number of channels in the input.
27
+ num_layers (`int`, *optional*, defaults to 18): The number of layers of MMDiT blocks to use.
28
+ num_single_layers (`int`, *optional*, defaults to 18): The number of layers of single DiT blocks to use.
29
+ attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head.
30
+ num_attention_heads (`int`, *optional*, defaults to 18): The number of heads to use for multi-head attention.
31
+ joint_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
32
+ pooled_projection_dim (`int`): Number of dimensions to use when projecting the `pooled_projections`.
33
+ guidance_embeds (`bool`, defaults to False): Whether to use guidance embeddings.
34
+ """
35
+
36
+ @register_to_config
37
+ def __init__(
38
+ self,
39
+ patch_size: int = 1,
40
+ in_channels: int = 64,
41
+ num_layers: int = 19,
42
+ num_single_layers: int = 38,
43
+ attention_head_dim: int = 128,
44
+ num_attention_heads: int = 24,
45
+ joint_attention_dim: int = 4096,
46
+ pooled_projection_dim: int = 768,
47
+ guidance_embeds: bool = False,
48
+ axes_dims_rope: Tuple[int] = (16, 56, 56),
49
+ max_layer_num: int = 10,
50
+ ):
51
+ super(FluxTransformer2DModel, self).__init__()
52
+ self.out_channels = in_channels
53
+ self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim
54
+
55
+ self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope)
56
+
57
+ text_time_guidance_cls = (
58
+ CombinedTimestepGuidanceTextProjEmbeddings if guidance_embeds else CombinedTimestepTextProjEmbeddings
59
+ )
60
+ self.time_text_embed = text_time_guidance_cls(
61
+ embedding_dim=self.inner_dim, pooled_projection_dim=self.config.pooled_projection_dim
62
+ )
63
+
64
+ self.context_embedder = nn.Linear(self.config.joint_attention_dim, self.inner_dim)
65
+ self.x_embedder = torch.nn.Linear(self.config.in_channels, self.inner_dim)
66
+
67
+ self.transformer_blocks = nn.ModuleList(
68
+ [
69
+ FluxTransformerBlock(
70
+ dim=self.inner_dim,
71
+ num_attention_heads=self.config.num_attention_heads,
72
+ attention_head_dim=self.config.attention_head_dim,
73
+ )
74
+ for i in range(self.config.num_layers)
75
+ ]
76
+ )
77
+
78
+ self.single_transformer_blocks = nn.ModuleList(
79
+ [
80
+ FluxSingleTransformerBlock(
81
+ dim=self.inner_dim,
82
+ num_attention_heads=self.config.num_attention_heads,
83
+ attention_head_dim=self.config.attention_head_dim,
84
+ )
85
+ for i in range(self.config.num_single_layers)
86
+ ]
87
+ )
88
+
89
+ self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6)
90
+ self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)
91
+
92
+ self.gradient_checkpointing = False
93
+
94
+ self.max_layer_num = max_layer_num
95
+
96
+ # the following process ensures self.layer_pe is not created as a meta tensor
97
+ self.layer_pe = nn.Parameter(torch.empty(1, self.max_layer_num, 1, 1, self.inner_dim))
98
+ nn.init.trunc_normal_(self.layer_pe, mean=0.0, std=0.02, a=-2.0, b=2.0)
99
+ # layer_pe_value = nn.init.trunc_normal_(
100
+ # nn.Parameter(torch.zeros(
101
+ # 1, self.max_layer_num, 1, 1, self.inner_dim,
102
+ # )),
103
+ # mean=0.0, std=0.02, a=-2.0, b=2.0,
104
+ # ).data.detach()
105
+ # self.layer_pe = nn.Parameter(layer_pe_value)
106
+ # set_module_tensor_to_device(
107
+ # self,
108
+ # 'layer_pe',
109
+ # device='cpu',
110
+ # value=layer_pe_value,
111
+ # dtype=layer_pe_value.dtype,
112
+ # )
113
+
114
+ @classmethod
115
+ def from_pretrained(cls, *args, **kwarg):
116
+ model = super().from_pretrained(*args, **kwarg)
117
+ for name, para in model.named_parameters():
118
+ if name != 'layer_pe':
119
+ device = para.device
120
+ break
121
+ model.layer_pe.to(device)
122
+ return model
123
+
124
+ def crop_each_layer(self, hidden_states, list_layer_box):
125
+ """
126
+ hidden_states: [1, n_layers, h, w, inner_dim]
127
+ list_layer_box: List, length=n_layers, each element is a Tuple of 4 elements (x1, y1, x2, y2)
128
+ """
129
+ token_list = []
130
+ for layer_idx in range(hidden_states.shape[1]):
131
+ if list_layer_box[layer_idx] == None:
132
+ continue
133
+ else:
134
+ x1, y1, x2, y2 = list_layer_box[layer_idx]
135
+ x1, y1, x2, y2 = x1 // 16, y1 // 16, x2 // 16, y2 // 16
136
+ layer_token = hidden_states[:, layer_idx, y1:y2, x1:x2, :]
137
+ bs, h, w, c = layer_token.shape
138
+ layer_token = layer_token.reshape(bs, -1, c)
139
+ token_list.append(layer_token)
140
+ result = torch.cat(token_list, dim=1)
141
+ return result
142
+
143
+ def fill_in_processed_tokens(self, hidden_states, full_hidden_states, list_layer_box):
144
+ """
145
+ hidden_states: [1, h1xw1 + h2xw2 + ... + hlxwl , inner_dim]
146
+ full_hidden_states: [1, n_layers, h, w, inner_dim]
147
+ list_layer_box: List, length=n_layers, each element is a Tuple of 4 elements (x1, y1, x2, y2)
148
+ """
149
+ used_token_len = 0
150
+ bs = hidden_states.shape[0]
151
+ for layer_idx in range(full_hidden_states.shape[1]):
152
+ if list_layer_box[layer_idx] == None:
153
+ continue
154
+ else:
155
+ x1, y1, x2, y2 = list_layer_box[layer_idx]
156
+ x1, y1, x2, y2 = x1 // 16, y1 // 16, x2 // 16, y2 // 16
157
+ full_hidden_states[:, layer_idx, y1:y2, x1:x2, :] = hidden_states[:, used_token_len: used_token_len + (y2-y1) * (x2-x1), :].reshape(bs, y2-y1, x2-x1, -1)
158
+ used_token_len = used_token_len + (y2-y1) * (x2-x1)
159
+ return full_hidden_states
160
+
161
+ def forward(
162
+ self,
163
+ hidden_states: torch.Tensor,
164
+ list_layer_box: List[Tuple] = None,
165
+ encoder_hidden_states: torch.Tensor = None,
166
+ pooled_projections: torch.Tensor = None,
167
+ timestep: torch.LongTensor = None,
168
+ img_ids: torch.Tensor = None,
169
+ txt_ids: torch.Tensor = None,
170
+ guidance: torch.Tensor = None,
171
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
172
+ return_dict: bool = True,
173
+ ) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
174
+ """
175
+ The [`FluxTransformer2DModel`] forward method.
176
+
177
+ Args:
178
+ hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
179
+ Input `hidden_states`.
180
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
181
+ Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
182
+ pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
183
+ from the embeddings of input conditions.
184
+ timestep ( `torch.LongTensor`):
185
+ Used to indicate denoising step.
186
+ block_controlnet_hidden_states: (`list` of `torch.Tensor`):
187
+ A list of tensors that if specified are added to the residuals of transformer blocks.
188
+ joint_attention_kwargs (`dict`, *optional*):
189
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
190
+ `self.processor` in
191
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
192
+ return_dict (`bool`, *optional*, defaults to `True`):
193
+ Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
194
+ tuple.
195
+
196
+ Returns:
197
+ If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
198
+ `tuple` where the first element is the sample tensor.
199
+ """
200
+ if joint_attention_kwargs is not None:
201
+ joint_attention_kwargs = joint_attention_kwargs.copy()
202
+ lora_scale = joint_attention_kwargs.pop("scale", 1.0)
203
+ else:
204
+ lora_scale = 1.0
205
+
206
+ if USE_PEFT_BACKEND:
207
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
208
+ scale_lora_layers(self, lora_scale)
209
+ else:
210
+ if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
211
+ logger.warning(
212
+ "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
213
+ )
214
+
215
+ bs, n_layers, channel_latent, height, width = hidden_states.shape # [bs, n_layers, c_latent, h, w]
216
+
217
+ hidden_states = hidden_states.view(bs, n_layers, channel_latent, height // 2, 2, width // 2, 2) # [bs, n_layers, c_latent, h/2, 2, w/2, 2]
218
+ hidden_states = hidden_states.permute(0, 1, 3, 5, 2, 4, 6) # [bs, n_layers, h/2, w/2, c_latent, 2, 2]
219
+ hidden_states = hidden_states.reshape(bs, n_layers, height // 2, width // 2, channel_latent * 4) # [bs, n_layers, h/2, w/2, c_latent*4]
220
+ hidden_states = self.x_embedder(hidden_states) # [bs, n_layers, h/2, w/2, inner_dim]
221
+
222
+ full_hidden_states = torch.zeros_like(hidden_states) # [bs, n_layers, h/2, w/2, inner_dim]
223
+ layer_pe = self.layer_pe.view(1, self.max_layer_num, 1, 1, self.inner_dim) # [1, max_n_layers, 1, 1, inner_dim]
224
+ hidden_states = hidden_states + layer_pe[:, :n_layers] # [bs, n_layers, h/2, w/2, inner_dim] + [1, n_layers, 1, 1, inner_dim] --> [bs, f, h/2, w/2, inner_dim]
225
+ hidden_states = self.crop_each_layer(hidden_states, list_layer_box) # [bs, token_len, inner_dim]
226
+
227
+ timestep = timestep.to(hidden_states.dtype) * 1000
228
+ if guidance is not None:
229
+ guidance = guidance.to(hidden_states.dtype) * 1000
230
+ else:
231
+ guidance = None
232
+ temb = (
233
+ self.time_text_embed(timestep, pooled_projections)
234
+ if guidance is None
235
+ else self.time_text_embed(timestep, guidance, pooled_projections)
236
+ )
237
+ encoder_hidden_states = self.context_embedder(encoder_hidden_states)
238
+
239
+ if txt_ids.ndim == 3:
240
+ logger.warning(
241
+ "Passing `txt_ids` 3d torch.Tensor is deprecated."
242
+ "Please remove the batch dimension and pass it as a 2d torch Tensor"
243
+ )
244
+ txt_ids = txt_ids[0]
245
+ if img_ids.ndim == 3:
246
+ logger.warning(
247
+ "Passing `img_ids` 3d torch.Tensor is deprecated."
248
+ "Please remove the batch dimension and pass it as a 2d torch Tensor"
249
+ )
250
+ img_ids = img_ids[0]
251
+ ids = torch.cat((txt_ids, img_ids), dim=0)
252
+ image_rotary_emb = self.pos_embed(ids)
253
+
254
+ for index_block, block in enumerate(self.transformer_blocks):
255
+ if self.training and self.gradient_checkpointing:
256
+
257
+ def create_custom_forward(module, return_dict=None):
258
+ def custom_forward(*inputs):
259
+ if return_dict is not None:
260
+ return module(*inputs, return_dict=return_dict)
261
+ else:
262
+ return module(*inputs)
263
+
264
+ return custom_forward
265
+
266
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
267
+ encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
268
+ create_custom_forward(block),
269
+ hidden_states,
270
+ encoder_hidden_states,
271
+ temb,
272
+ image_rotary_emb,
273
+ **ckpt_kwargs,
274
+ )
275
+
276
+ else:
277
+ encoder_hidden_states, hidden_states = block(
278
+ hidden_states=hidden_states,
279
+ encoder_hidden_states=encoder_hidden_states,
280
+ temb=temb,
281
+ image_rotary_emb=image_rotary_emb,
282
+ )
283
+
284
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
285
+
286
+ for index_block, block in enumerate(self.single_transformer_blocks):
287
+ if self.training and self.gradient_checkpointing:
288
+
289
+ def create_custom_forward(module, return_dict=None):
290
+ def custom_forward(*inputs):
291
+ if return_dict is not None:
292
+ return module(*inputs, return_dict=return_dict)
293
+ else:
294
+ return module(*inputs)
295
+
296
+ return custom_forward
297
+
298
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
299
+ hidden_states = torch.utils.checkpoint.checkpoint(
300
+ create_custom_forward(block),
301
+ hidden_states,
302
+ temb,
303
+ image_rotary_emb,
304
+ **ckpt_kwargs,
305
+ )
306
+
307
+ else:
308
+ hidden_states = block(
309
+ hidden_states=hidden_states,
310
+ temb=temb,
311
+ image_rotary_emb=image_rotary_emb,
312
+ )
313
+
314
+ hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]
315
+
316
+ hidden_states = self.fill_in_processed_tokens(hidden_states, full_hidden_states, list_layer_box) # [bs, n_layers, h/2, w/2, inner_dim]
317
+ hidden_states = hidden_states.view(bs, -1, self.inner_dim) # [bs, n_layers * full_len, inner_dim]
318
+
319
+ hidden_states = self.norm_out(hidden_states, temb) # [bs, n_layers * full_len, inner_dim]
320
+ hidden_states = self.proj_out(hidden_states) # [bs, n_layers * full_len, c_latent*4]
321
+
322
+ # unpatchify
323
+ hidden_states = hidden_states.view(bs, n_layers, height//2, width//2, channel_latent, 2, 2) # [bs, n_layers, h/2, w/2, c_latent, 2, 2]
324
+ hidden_states = hidden_states.permute(0, 1, 4, 2, 5, 3, 6)
325
+ output = hidden_states.reshape(bs, n_layers, channel_latent, height, width) # [bs, n_layers, c_latent, h, w]
326
+
327
+ if USE_PEFT_BACKEND:
328
+ # remove `lora_scale` from each PEFT layer
329
+ unscale_lora_layers(self, lora_scale)
330
+
331
+ if not return_dict:
332
+ return (output,)
333
+
334
+ return Transformer2DModelOutput(sample=output)
custom_model_transp_vae.py ADDED
@@ -0,0 +1,331 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import einops
2
+ from collections import OrderedDict
3
+ from functools import partial
4
+ from typing import Callable
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torchvision
9
+ from torch.utils.checkpoint import checkpoint
10
+
11
+ from accelerate.utils import set_module_tensor_to_device
12
+ from diffusers.models.embeddings import apply_rotary_emb, FluxPosEmbed
13
+ from diffusers.models.modeling_utils import ModelMixin
14
+ from diffusers.configuration_utils import ConfigMixin
15
+ from diffusers.loaders import FromOriginalModelMixin
16
+
17
+
18
+ class MLPBlock(torchvision.ops.misc.MLP):
19
+ """Transformer MLP block."""
20
+
21
+ _version = 2
22
+
23
+ def __init__(self, in_dim: int, mlp_dim: int, dropout: float):
24
+ super().__init__(in_dim, [mlp_dim, in_dim], activation_layer=nn.GELU, inplace=None, dropout=dropout)
25
+
26
+ for m in self.modules():
27
+ if isinstance(m, nn.Linear):
28
+ nn.init.xavier_uniform_(m.weight)
29
+ if m.bias is not None:
30
+ nn.init.normal_(m.bias, std=1e-6)
31
+
32
+ def _load_from_state_dict(
33
+ self,
34
+ state_dict,
35
+ prefix,
36
+ local_metadata,
37
+ strict,
38
+ missing_keys,
39
+ unexpected_keys,
40
+ error_msgs,
41
+ ):
42
+ version = local_metadata.get("version", None)
43
+
44
+ if version is None or version < 2:
45
+ # Replacing legacy MLPBlock with MLP. See https://github.com/pytorch/vision/pull/6053
46
+ for i in range(2):
47
+ for type in ["weight", "bias"]:
48
+ old_key = f"{prefix}linear_{i+1}.{type}"
49
+ new_key = f"{prefix}{3*i}.{type}"
50
+ if old_key in state_dict:
51
+ state_dict[new_key] = state_dict.pop(old_key)
52
+
53
+ super()._load_from_state_dict(
54
+ state_dict,
55
+ prefix,
56
+ local_metadata,
57
+ strict,
58
+ missing_keys,
59
+ unexpected_keys,
60
+ error_msgs,
61
+ )
62
+
63
+
64
+ class EncoderBlock(nn.Module):
65
+ """Transformer encoder block."""
66
+
67
+ def __init__(
68
+ self,
69
+ num_heads: int,
70
+ hidden_dim: int,
71
+ mlp_dim: int,
72
+ dropout: float,
73
+ attention_dropout: float,
74
+ norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6),
75
+ ):
76
+ super().__init__()
77
+ self.num_heads = num_heads
78
+ self.hidden_dim = hidden_dim
79
+ self.num_heads = num_heads
80
+
81
+ # Attention block
82
+ self.ln_1 = norm_layer(hidden_dim)
83
+ self.self_attention = nn.MultiheadAttention(hidden_dim, num_heads, dropout=attention_dropout, batch_first=True)
84
+ self.dropout = nn.Dropout(dropout)
85
+
86
+ # MLP block
87
+ self.ln_2 = norm_layer(hidden_dim)
88
+ self.mlp = MLPBlock(hidden_dim, mlp_dim, dropout)
89
+
90
+ def forward(self, input: torch.Tensor, freqs_cis):
91
+ torch._assert(input.dim() == 3, f"Expected (batch_size, seq_length, hidden_dim) got {input.shape}")
92
+ B, L, C = input.shape
93
+ x = self.ln_1(input)
94
+ if freqs_cis is not None:
95
+ query = x.view(B, L, self.num_heads, self.hidden_dim // self.num_heads).transpose(1, 2)
96
+ query = apply_rotary_emb(query, freqs_cis)
97
+ query = query.transpose(1, 2).reshape(B, L, self.hidden_dim)
98
+ x, _ = self.self_attention(query, query, x, need_weights=False)
99
+ x = self.dropout(x)
100
+ x = x + input
101
+
102
+ y = self.ln_2(x)
103
+ y = self.mlp(y)
104
+ return x + y
105
+
106
+
107
+ class Encoder(nn.Module):
108
+ """Transformer Model Encoder for sequence to sequence translation."""
109
+
110
+ def __init__(
111
+ self,
112
+ seq_length: int,
113
+ num_layers: int,
114
+ num_heads: int,
115
+ hidden_dim: int,
116
+ mlp_dim: int,
117
+ dropout: float,
118
+ attention_dropout: float,
119
+ norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6),
120
+ ):
121
+ super().__init__()
122
+ # Note that batch_size is on the first dim because
123
+ # we have batch_first=True in nn.MultiAttention() by default
124
+ # self.pos_embedding = nn.Parameter(torch.empty(1, seq_length, hidden_dim).normal_(std=0.02)) # from BERT
125
+ self.dropout = nn.Dropout(dropout)
126
+ layers: OrderedDict[str, nn.Module] = OrderedDict()
127
+ for i in range(num_layers):
128
+ layers[f"encoder_layer_{i}"] = EncoderBlock(
129
+ num_heads,
130
+ hidden_dim,
131
+ mlp_dim,
132
+ dropout,
133
+ attention_dropout,
134
+ norm_layer,
135
+ )
136
+ self.layers = nn.Sequential(layers)
137
+ self.ln = norm_layer(hidden_dim)
138
+
139
+ def forward(self, input: torch.Tensor, freqs_cis):
140
+ torch._assert(input.dim() == 3, f"Expected (batch_size, seq_length, hidden_dim) got {input.shape}")
141
+ input = input # + self.pos_embedding
142
+ x = self.dropout(input)
143
+ for l in self.layers:
144
+ x = checkpoint(l, x, freqs_cis)
145
+ x = self.ln(x)
146
+ return x
147
+
148
+
149
+ class ViTEncoder(nn.Module):
150
+ def __init__(self, arch='vit-b/32'):
151
+ super().__init__()
152
+ self.arch = arch
153
+
154
+ if self.arch == 'vit-b/32':
155
+ ch = 768
156
+ layers = 12
157
+ heads = 12
158
+ elif self.arch == 'vit-h/14':
159
+ ch = 1280
160
+ layers = 32
161
+ heads = 16
162
+
163
+ self.encoder = Encoder(
164
+ seq_length=-1,
165
+ num_layers=layers,
166
+ num_heads=heads,
167
+ hidden_dim=ch,
168
+ mlp_dim=ch*4,
169
+ dropout=0.0,
170
+ attention_dropout=0.0,
171
+ )
172
+ self.fc_in = nn.Linear(16, ch)
173
+ self.fc_out = nn.Linear(ch, 256)
174
+
175
+ if self.arch == 'vit-b/32':
176
+ from torchvision.models.vision_transformer import vit_b_32, ViT_B_32_Weights
177
+ vit = vit_b_32(weights=ViT_B_32_Weights.DEFAULT)
178
+ elif self.arch == 'vit-h/14':
179
+ from torchvision.models.vision_transformer import vit_h_14, ViT_H_14_Weights
180
+ vit = vit_h_14(weights=ViT_H_14_Weights.IMAGENET1K_SWAG_E2E_V1)
181
+
182
+ missing_keys, unexpected_keys = self.encoder.load_state_dict(vit.encoder.state_dict(), strict=False)
183
+ if len(missing_keys) > 0 or len(unexpected_keys) > 0:
184
+ print(f"ViT Encoder Missing keys: {missing_keys}")
185
+ print(f"ViT Encoder Unexpected keys: {unexpected_keys}")
186
+ del vit
187
+
188
+ def forward(self, x, freqs_cis):
189
+ out = self.fc_in(x)
190
+ out = self.encoder(out, freqs_cis)
191
+ out = checkpoint(self.fc_out, out)
192
+ return out
193
+
194
+
195
+ def patchify(x, patch_size=8):
196
+ if len(x.shape) == 4:
197
+ bs, c, h, w = x.shape
198
+ x = einops.rearrange(x, "b c (h p1) (w p2) -> b (c p1 p2) h w", p1=patch_size, p2=patch_size)
199
+ elif len(x.shape) == 3:
200
+ c, h, w = x.shape
201
+ x = einops.rearrange(x, "c (h p1) (w p2) -> (c p1 p2) h w", p1=patch_size, p2=patch_size)
202
+ return x
203
+
204
+
205
+ def unpatchify(x, patch_size=8):
206
+ if len(x.shape) == 4:
207
+ bs, c, h, w = x.shape
208
+ x = einops.rearrange(x, "b (c p1 p2) h w -> b c (h p1) (w p2)", p1=patch_size, p2=patch_size)
209
+ elif len(x.shape) == 3:
210
+ c, h, w = x.shape
211
+ x = einops.rearrange(x, "(c p1 p2) h w -> c (h p1) (w p2)", p1=patch_size, p2=patch_size)
212
+ return x
213
+
214
+
215
+ def crop_each_layer(hidden_states, use_layers, list_layer_box, H, W, pos_embedding):
216
+ token_list = []
217
+ cos_list, sin_list = [], []
218
+ for layer_idx in range(hidden_states.shape[1]):
219
+ if list_layer_box[layer_idx] is None:
220
+ continue
221
+ else:
222
+ x1, y1, x2, y2 = list_layer_box[layer_idx]
223
+ x1, y1, x2, y2 = x1 // 8, y1 // 8, x2 // 8, y2 // 8
224
+ layer_token = hidden_states[:, layer_idx, y1:y2, x1:x2]
225
+ c, h, w = layer_token.shape
226
+ layer_token = layer_token.reshape(c, -1)
227
+ token_list.append(layer_token)
228
+ ids = prepare_latent_image_ids(-1, H * 2, W * 2, hidden_states.device, hidden_states.dtype)
229
+ ids[:, 0] = use_layers[layer_idx]
230
+ image_rotary_emb = pos_embedding(ids)
231
+ pos_cos, pos_sin = image_rotary_emb[0].reshape(H, W, -1), image_rotary_emb[1].reshape(H, W, -1)
232
+ cos_list.append(pos_cos[y1:y2, x1:x2].reshape(-1, 64))
233
+ sin_list.append(pos_sin[y1:y2, x1:x2].reshape(-1, 64))
234
+ token_list = torch.cat(token_list, dim=1).permute(1, 0)
235
+ cos_list = torch.cat(cos_list, dim=0)
236
+ sin_list = torch.cat(sin_list, dim=0)
237
+ return token_list, (cos_list, sin_list)
238
+
239
+
240
+ def prepare_latent_image_ids(batch_size, height, width, device, dtype):
241
+ latent_image_ids = torch.zeros(height // 2, width // 2, 3)
242
+ latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
243
+ latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]
244
+
245
+ latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
246
+
247
+ latent_image_ids = latent_image_ids.reshape(
248
+ latent_image_id_height * latent_image_id_width, latent_image_id_channels
249
+ )
250
+
251
+ return latent_image_ids.to(device=device, dtype=dtype)
252
+
253
+
254
+ class AutoencoderKLTransformerTraining(ModelMixin, ConfigMixin, FromOriginalModelMixin):
255
+ def __init__(self):
256
+ super().__init__()
257
+
258
+ self.decoder_arch = 'vit'
259
+ self.layer_embedding = 'rope'
260
+
261
+ self.decoder = ViTEncoder()
262
+ self.pos_embedding = FluxPosEmbed(theta=10000, axes_dim=(8, 28, 28))
263
+ if 'rel' in self.layer_embedding or 'abs' in self.layer_embedding:
264
+ self.layer_embedding = nn.Parameter(torch.empty(16, 2 + self.max_layers, 1, 1).normal_(std=0.02), requires_grad=True)
265
+
266
+ def zero_module(module):
267
+ """
268
+ Zero out the parameters of a module and return it.
269
+ """
270
+ for p in module.parameters():
271
+ p.detach().zero_()
272
+ return module
273
+
274
+ def encode(self, z_2d, box, use_layers):
275
+ B, C, T, H, W = z_2d.shape
276
+
277
+ z, freqs_cis = [], []
278
+ for b in range(B):
279
+ _z = z_2d[b]
280
+ if 'vit' in self.decoder_arch:
281
+ _use_layers = torch.tensor(use_layers[b], device=z_2d.device)
282
+ if 'rel' in self.layer_embedding:
283
+ _use_layers[_use_layers > 2] = 2
284
+ if 'rel' in self.layer_embedding or 'abs' in self.layer_embedding:
285
+ _z = _z + self.layer_embedding[:, _use_layers] # + self.pos_embedding
286
+ if 'rope' not in self.layer_embedding:
287
+ use_layers[b] = [0] * len(use_layers[b])
288
+ _z, cis = crop_each_layer(_z, use_layers[b], box[b], H, W, self.pos_embedding) ### modified
289
+ z.append(_z)
290
+ freqs_cis.append(cis)
291
+
292
+ return z, freqs_cis
293
+
294
+ def decode(self, z, freqs_cis, box, H, W):
295
+ B = len(z)
296
+ pad = torch.zeros(4, H, W, device=z[0].device, dtype=z[0].dtype)
297
+ pad[3, :, :] = -1
298
+ x = []
299
+ for b in range(B):
300
+ _x = []
301
+ _z = self.decoder(z[b].unsqueeze(0), freqs_cis[b]).squeeze(0)
302
+ current_index = 0
303
+ for layer_idx in range(len(box[b])):
304
+ if box[b][layer_idx] == None:
305
+ _x.append(pad.clone())
306
+ else:
307
+ x1, y1, x2, y2 = box[b][layer_idx]
308
+ x1_tok, y1_tok, x2_tok, y2_tok = x1 // 8, y1 // 8, x2 // 8, y2 // 8
309
+ token_length = (x2_tok - x1_tok) * (y2_tok - y1_tok)
310
+ tokens = _z[current_index:current_index + token_length]
311
+ pixels = einops.rearrange(tokens, "(h w) c -> c h w", h=y2_tok - y1_tok, w=x2_tok - x1_tok)
312
+ unpatched = unpatchify(pixels)
313
+ pixels = pad.clone()
314
+ pixels[:, y1:y2, x1:x2] = unpatched
315
+ _x.append(pixels)
316
+ current_index += token_length
317
+ _x = torch.stack(_x, dim=1)
318
+ x.append(_x)
319
+ x = torch.stack(x, dim=0)
320
+ return x
321
+
322
+ def forward(self, z_2d, box, use_layers=None):
323
+ z_2d = z_2d.transpose(0, 1).unsqueeze(0)
324
+ use_layers = use_layers or [list(range(z_2d.shape[2]))]
325
+ z, freqs_cis = self.encode(z_2d, box, use_layers)
326
+ H, W = z_2d.shape[-2:]
327
+ x_hat = self.decode(z, freqs_cis, box, H * 8, W * 8)
328
+ assert x_hat.shape[0] == 1, x_hat.shape
329
+ x_hat = einops.rearrange(x_hat[0], "c t h w -> t c h w")
330
+ x_hat_rgb, x_hat_alpha = x_hat[:, :3], x_hat[:, 3:]
331
+ return x_hat_rgb, x_hat_alpha
custom_pipeline.py ADDED
@@ -0,0 +1,845 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from typing import Any, Callable, Dict, List, Optional, Union
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+
7
+ from diffusers.utils.torch_utils import randn_tensor
8
+ from diffusers.utils import is_torch_xla_available, logging
9
+ from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput
10
+ from diffusers.pipelines.flux.pipeline_flux import calculate_shift, retrieve_timesteps, FluxPipeline
11
+
12
+ if is_torch_xla_available():
13
+ import torch_xla.core.xla_model as xm # type: ignore
14
+ XLA_AVAILABLE = True
15
+ else:
16
+ XLA_AVAILABLE = False
17
+
18
+
19
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
20
+
21
+
22
+ def _get_clip_prompt_embeds(
23
+ tokenizer,
24
+ text_encoder,
25
+ prompt: Union[str, List[str]],
26
+ num_images_per_prompt: int = 1,
27
+ device: Optional[torch.device] = None,
28
+ ):
29
+ device = device or text_encoder.device
30
+ dtype = text_encoder.dtype
31
+
32
+ prompt = [prompt] if isinstance(prompt, str) else prompt
33
+ batch_size = len(prompt)
34
+
35
+ text_inputs = tokenizer(
36
+ prompt,
37
+ padding="max_length",
38
+ max_length=text_encoder.config.max_position_embeddings,
39
+ truncation=True,
40
+ return_overflowing_tokens=False,
41
+ return_length=False,
42
+ return_tensors="pt",
43
+ )
44
+
45
+ text_input_ids = text_inputs.input_ids
46
+ prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=False)
47
+
48
+ # Use pooled output of CLIPTextModel
49
+ prompt_embeds = prompt_embeds.pooler_output
50
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
51
+
52
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
53
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
54
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
55
+
56
+ return prompt_embeds
57
+
58
+
59
+ def _get_t5_prompt_embeds(
60
+ tokenizer,
61
+ text_encoder,
62
+ prompt: Union[str, List[str]] = None,
63
+ num_images_per_prompt: int = 1,
64
+ max_sequence_length: int = 512,
65
+ device: Optional[torch.device] = None,
66
+ dtype: Optional[torch.dtype] = None,
67
+ ):
68
+ device = device or text_encoder.device
69
+ dtype = dtype or text_encoder.dtype
70
+
71
+ prompt = [prompt] if isinstance(prompt, str) else prompt
72
+ batch_size = len(prompt)
73
+
74
+ text_inputs = tokenizer(
75
+ prompt,
76
+ padding="max_length",
77
+ max_length=max_sequence_length,
78
+ truncation=True,
79
+ return_length=False,
80
+ return_overflowing_tokens=False,
81
+ return_tensors="pt",
82
+ )
83
+ text_input_ids = text_inputs.input_ids
84
+
85
+ prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=False)[0]
86
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
87
+
88
+ _, seq_len, _ = prompt_embeds.shape
89
+
90
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
91
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
92
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
93
+
94
+ return prompt_embeds
95
+
96
+
97
+ def encode_prompt(
98
+ tokenizers,
99
+ text_encoders,
100
+ prompt: Union[str, List[str]],
101
+ prompt_2: Union[str, List[str]] = None,
102
+ num_images_per_prompt: int = 1,
103
+ max_sequence_length: int = 512,
104
+ ):
105
+
106
+ tokenizer_1, tokenizer_2 = tokenizers
107
+ text_encoder_1, text_encoder_2 = text_encoders
108
+ device = text_encoder_1.device
109
+ dtype = text_encoder_1.dtype
110
+
111
+ prompt = [prompt] if isinstance(prompt, str) else prompt
112
+ prompt_2 = prompt_2 or prompt
113
+ prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
114
+
115
+ # We only use the pooled prompt output from the CLIPTextModel
116
+ pooled_prompt_embeds = _get_clip_prompt_embeds(
117
+ tokenizer=tokenizer_1,
118
+ text_encoder=text_encoder_1,
119
+ prompt=prompt,
120
+ num_images_per_prompt=num_images_per_prompt,
121
+ )
122
+ prompt_embeds = _get_t5_prompt_embeds(
123
+ tokenizer=tokenizer_2,
124
+ text_encoder=text_encoder_2,
125
+ prompt=prompt_2,
126
+ num_images_per_prompt=num_images_per_prompt,
127
+ max_sequence_length=max_sequence_length,
128
+ )
129
+
130
+ text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
131
+
132
+ return prompt_embeds, pooled_prompt_embeds, text_ids
133
+
134
+
135
+ class CustomFluxPipeline(FluxPipeline):
136
+
137
+ @staticmethod
138
+ def _prepare_latent_image_ids(height, width, list_layer_box, device, dtype):
139
+
140
+ latent_image_ids_list = []
141
+ for layer_idx in range(len(list_layer_box)):
142
+ if list_layer_box[layer_idx] == None:
143
+ continue
144
+ else:
145
+ latent_image_ids = torch.zeros(height // 2, width // 2, 3) # [h/2, w/2, 3]
146
+ latent_image_ids[..., 0] = layer_idx # use the first dimension for layer representation
147
+ latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
148
+ latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]
149
+
150
+ x1, y1, x2, y2 = list_layer_box[layer_idx]
151
+ x1, y1, x2, y2 = x1 // 16, y1 // 16, x2 // 16, y2 // 16
152
+ latent_image_ids = latent_image_ids[y1:y2, x1:x2, :]
153
+
154
+ latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
155
+ latent_image_ids = latent_image_ids.reshape(
156
+ latent_image_id_height * latent_image_id_width, latent_image_id_channels
157
+ )
158
+
159
+ latent_image_ids_list.append(latent_image_ids)
160
+
161
+ full_latent_image_ids = torch.cat(latent_image_ids_list, dim=0)
162
+
163
+ return full_latent_image_ids.to(device=device, dtype=dtype)
164
+
165
+ def prepare_latents(
166
+ self,
167
+ batch_size,
168
+ num_layers,
169
+ num_channels_latents,
170
+ height,
171
+ width,
172
+ list_layer_box,
173
+ dtype,
174
+ device,
175
+ generator,
176
+ latents=None,
177
+ ):
178
+ height = 2 * (int(height) // self.vae_scale_factor)
179
+ width = 2 * (int(width) // self.vae_scale_factor)
180
+
181
+ shape = (batch_size, num_layers, num_channels_latents, height, width)
182
+
183
+ if latents is not None:
184
+ latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype)
185
+ return latents.to(device=device, dtype=dtype), latent_image_ids
186
+
187
+ if isinstance(generator, list) and len(generator) != batch_size:
188
+ raise ValueError(
189
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
190
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
191
+ )
192
+
193
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) # [bs, f, c_latent, h, w]
194
+
195
+ latent_image_ids = self._prepare_latent_image_ids(height, width, list_layer_box, device, dtype)
196
+
197
+ return latents, latent_image_ids
198
+
199
+ @torch.no_grad()
200
+ def __call__(
201
+ self,
202
+ prompt: Union[str, List[str]] = None,
203
+ prompt_2: Optional[Union[str, List[str]]] = None,
204
+ validation_box: List[tuple] = None,
205
+ height: Optional[int] = None,
206
+ width: Optional[int] = None,
207
+ num_inference_steps: int = 28,
208
+ timesteps: List[int] = None,
209
+ guidance_scale: float = 3.5,
210
+ num_images_per_prompt: Optional[int] = 1,
211
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
212
+ latents: Optional[torch.FloatTensor] = None,
213
+ prompt_embeds: Optional[torch.FloatTensor] = None,
214
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
215
+ output_type: Optional[str] = "pil",
216
+ return_dict: bool = True,
217
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
218
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
219
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
220
+ max_sequence_length: int = 512,
221
+ num_layers: int = 5,
222
+ sdxl_vae: nn.Module = None,
223
+ transparent_decoder: nn.Module = None,
224
+ ):
225
+ r"""
226
+ Function invoked when calling the pipeline for generation.
227
+
228
+ Args:
229
+ prompt (`str` or `List[str]`, *optional*):
230
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
231
+ instead.
232
+ prompt_2 (`str` or `List[str]`, *optional*):
233
+ The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
234
+ will be used instead
235
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
236
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
237
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
238
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
239
+ num_inference_steps (`int`, *optional*, defaults to 50):
240
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
241
+ expense of slower inference.
242
+ timesteps (`List[int]`, *optional*):
243
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
244
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
245
+ passed will be used. Must be in descending order.
246
+ guidance_scale (`float`, *optional*, defaults to 7.0):
247
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
248
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
249
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
250
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
251
+ usually at the expense of lower image quality.
252
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
253
+ The number of images to generate per prompt.
254
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
255
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
256
+ to make generation deterministic.
257
+ latents (`torch.FloatTensor`, *optional*):
258
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
259
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
260
+ tensor will ge generated by sampling using the supplied random `generator`.
261
+ prompt_embeds (`torch.FloatTensor`, *optional*):
262
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
263
+ provided, text embeddings will be generated from `prompt` input argument.
264
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
265
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
266
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
267
+ output_type (`str`, *optional*, defaults to `"pil"`):
268
+ The output format of the generate image. Choose between
269
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
270
+ return_dict (`bool`, *optional*, defaults to `True`):
271
+ Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple.
272
+ joint_attention_kwargs (`dict`, *optional*):
273
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
274
+ `self.processor` in
275
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
276
+ callback_on_step_end (`Callable`, *optional*):
277
+ A function that calls at the end of each denoising steps during the inference. The function is called
278
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
279
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
280
+ `callback_on_step_end_tensor_inputs`.
281
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
282
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
283
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
284
+ `._callback_tensor_inputs` attribute of your pipeline class.
285
+ max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
286
+
287
+ Examples:
288
+
289
+ Returns:
290
+ [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict`
291
+ is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated
292
+ images.
293
+ """
294
+
295
+ height = height or self.default_sample_size * self.vae_scale_factor
296
+ width = width or self.default_sample_size * self.vae_scale_factor
297
+
298
+ # 1. Check inputs. Raise error if not correct
299
+ self.check_inputs(
300
+ prompt,
301
+ prompt_2,
302
+ height,
303
+ width,
304
+ prompt_embeds=prompt_embeds,
305
+ pooled_prompt_embeds=pooled_prompt_embeds,
306
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
307
+ max_sequence_length=max_sequence_length,
308
+ )
309
+
310
+ self._guidance_scale = guidance_scale
311
+ self._joint_attention_kwargs = joint_attention_kwargs
312
+ self._interrupt = False
313
+
314
+ # 2. Define call parameters
315
+ if prompt is not None and isinstance(prompt, str):
316
+ batch_size = 1
317
+ elif prompt is not None and isinstance(prompt, list):
318
+ batch_size = len(prompt)
319
+ else:
320
+ batch_size = prompt_embeds.shape[0]
321
+
322
+ device = self._execution_device
323
+
324
+ lora_scale = (
325
+ self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
326
+ )
327
+ (
328
+ prompt_embeds,
329
+ pooled_prompt_embeds,
330
+ text_ids,
331
+ ) = self.encode_prompt(
332
+ prompt=prompt,
333
+ prompt_2=prompt_2,
334
+ prompt_embeds=prompt_embeds,
335
+ pooled_prompt_embeds=pooled_prompt_embeds,
336
+ device=device,
337
+ num_images_per_prompt=num_images_per_prompt,
338
+ max_sequence_length=max_sequence_length,
339
+ lora_scale=lora_scale,
340
+ )
341
+
342
+ # 4. Prepare latent variables
343
+ num_channels_latents = self.transformer.config.in_channels // 4
344
+ latents, latent_image_ids = self.prepare_latents(
345
+ batch_size * num_images_per_prompt,
346
+ num_layers,
347
+ num_channels_latents,
348
+ height,
349
+ width,
350
+ validation_box,
351
+ prompt_embeds.dtype,
352
+ device,
353
+ generator,
354
+ latents,
355
+ )
356
+
357
+ # 5. Prepare timesteps
358
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
359
+ image_seq_len = latent_image_ids.shape[0] # ???
360
+ mu = calculate_shift(
361
+ image_seq_len,
362
+ self.scheduler.config.base_image_seq_len,
363
+ self.scheduler.config.max_image_seq_len,
364
+ self.scheduler.config.base_shift,
365
+ self.scheduler.config.max_shift,
366
+ )
367
+ timesteps, num_inference_steps = retrieve_timesteps(
368
+ self.scheduler,
369
+ num_inference_steps,
370
+ device,
371
+ timesteps,
372
+ sigmas,
373
+ mu=mu,
374
+ )
375
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
376
+ self._num_timesteps = len(timesteps)
377
+
378
+ # handle guidance
379
+ if self.transformer.config.guidance_embeds:
380
+ guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
381
+ guidance = guidance.expand(latents.shape[0])
382
+ else:
383
+ guidance = None
384
+
385
+ # 6. Denoising loop
386
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
387
+ for i, t in enumerate(timesteps):
388
+ if self.interrupt:
389
+ continue
390
+
391
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
392
+ timestep = t.expand(latents.shape[0]).to(latents.dtype)
393
+
394
+ noise_pred = self.transformer(
395
+ hidden_states=latents,
396
+ list_layer_box=validation_box,
397
+ timestep=timestep / 1000,
398
+ guidance=guidance,
399
+ pooled_projections=pooled_prompt_embeds,
400
+ encoder_hidden_states=prompt_embeds,
401
+ txt_ids=text_ids,
402
+ img_ids=latent_image_ids,
403
+ joint_attention_kwargs=self.joint_attention_kwargs,
404
+ return_dict=False,
405
+ )[0]
406
+
407
+ # compute the previous noisy sample x_t -> x_t-1
408
+ latents_dtype = latents.dtype
409
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
410
+
411
+ if latents.dtype != latents_dtype:
412
+ if torch.backends.mps.is_available():
413
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
414
+ latents = latents.to(latents_dtype)
415
+
416
+ if callback_on_step_end is not None:
417
+ callback_kwargs = {}
418
+ for k in callback_on_step_end_tensor_inputs:
419
+ callback_kwargs[k] = locals()[k]
420
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
421
+
422
+ latents = callback_outputs.pop("latents", latents)
423
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
424
+
425
+ # call the callback, if provided
426
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
427
+ progress_bar.update()
428
+
429
+ if XLA_AVAILABLE:
430
+ xm.mark_step()
431
+
432
+ # create a grey latent
433
+ bs, n_frames, channel_latent, height, width = latents.shape
434
+
435
+ pixel_grey = torch.zeros(size=(bs*n_frames, 3, height*8, width*8), device=latents.device, dtype=latents.dtype)
436
+ latent_grey = self.vae.encode(pixel_grey).latent_dist.sample()
437
+ latent_grey = (latent_grey - self.vae.config.shift_factor) * self.vae.config.scaling_factor
438
+ latent_grey = latent_grey.view(bs, n_frames, channel_latent, height, width) # [bs, f, c_latent, h, w]
439
+
440
+ # fill in the latents
441
+ for layer_idx in range(latent_grey.shape[1]):
442
+ x1, y1, x2, y2 = validation_box[layer_idx]
443
+ x1, y1, x2, y2 = x1 // 8, y1 // 8, x2 // 8, y2 // 8
444
+ latent_grey[:, layer_idx, :, y1:y2, x1:x2] = latents[:, layer_idx, :, y1:y2, x1:x2]
445
+ latents = latent_grey
446
+
447
+ if output_type == "latent":
448
+ image = latents
449
+
450
+ else:
451
+ latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
452
+ latents = latents.reshape(bs * n_frames, channel_latent, height, width)
453
+ image = self.vae.decode(latents, return_dict=False)[0]
454
+ if sdxl_vae is not None:
455
+ sdxl_vae = sdxl_vae.to(dtype=image.dtype, device=image.device)
456
+ sdxl_latents = sdxl_vae.encode(image).latent_dist.sample()
457
+ transparent_decoder = transparent_decoder.to(dtype=image.dtype, device=image.device)
458
+ result_list, vis_list = transparent_decoder(sdxl_vae, sdxl_latents)
459
+ else:
460
+ result_list, vis_list = None, None
461
+ image = self.image_processor.postprocess(image, output_type=output_type)
462
+
463
+ # Offload all models
464
+ self.maybe_free_model_hooks()
465
+
466
+ if not return_dict:
467
+ return (image, result_list, vis_list)
468
+
469
+ return FluxPipelineOutput(images=image), result_list, vis_list
470
+
471
+
472
+ class CustomFluxPipelineCfg(FluxPipeline):
473
+
474
+ @staticmethod
475
+ def _prepare_latent_image_ids(height, width, list_layer_box, device, dtype):
476
+
477
+ latent_image_ids_list = []
478
+ for layer_idx in range(len(list_layer_box)):
479
+ if list_layer_box[layer_idx] == None:
480
+ continue
481
+ else:
482
+ latent_image_ids = torch.zeros(height // 2, width // 2, 3) # [h/2, w/2, 3]
483
+ latent_image_ids[..., 0] = layer_idx # use the first dimension for layer representation
484
+ latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
485
+ latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]
486
+
487
+ x1, y1, x2, y2 = list_layer_box[layer_idx]
488
+ x1, y1, x2, y2 = x1 // 16, y1 // 16, x2 // 16, y2 // 16
489
+ latent_image_ids = latent_image_ids[y1:y2, x1:x2, :]
490
+
491
+ latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
492
+ latent_image_ids = latent_image_ids.reshape(
493
+ latent_image_id_height * latent_image_id_width, latent_image_id_channels
494
+ )
495
+
496
+ latent_image_ids_list.append(latent_image_ids)
497
+
498
+ full_latent_image_ids = torch.cat(latent_image_ids_list, dim=0)
499
+
500
+ return full_latent_image_ids.to(device=device, dtype=dtype)
501
+
502
+ def prepare_latents(
503
+ self,
504
+ batch_size,
505
+ num_layers,
506
+ num_channels_latents,
507
+ height,
508
+ width,
509
+ list_layer_box,
510
+ dtype,
511
+ device,
512
+ generator,
513
+ latents=None,
514
+ ):
515
+ height = 2 * (int(height) // self.vae_scale_factor)
516
+ width = 2 * (int(width) // self.vae_scale_factor)
517
+
518
+ shape = (batch_size, num_layers, num_channels_latents, height, width)
519
+
520
+ if latents is not None:
521
+ latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype)
522
+ return latents.to(device=device, dtype=dtype), latent_image_ids
523
+
524
+ if isinstance(generator, list) and len(generator) != batch_size:
525
+ raise ValueError(
526
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
527
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
528
+ )
529
+
530
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) # [bs, n_layers, c_latent, h, w]
531
+
532
+ latent_image_ids = self._prepare_latent_image_ids(height, width, list_layer_box, device, dtype)
533
+
534
+ return latents, latent_image_ids
535
+
536
+ @torch.no_grad()
537
+ def __call__(
538
+ self,
539
+ prompt: Union[str, List[str]] = None,
540
+ prompt_2: Optional[Union[str, List[str]]] = None,
541
+ validation_box: List[tuple] = None,
542
+ height: Optional[int] = None,
543
+ width: Optional[int] = None,
544
+ num_inference_steps: int = 28,
545
+ timesteps: List[int] = None,
546
+ guidance_scale: float = 3.5,
547
+ true_gs: float = 3.5,
548
+ num_images_per_prompt: Optional[int] = 1,
549
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
550
+ latents: Optional[torch.FloatTensor] = None,
551
+ prompt_embeds: Optional[torch.FloatTensor] = None,
552
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
553
+ output_type: Optional[str] = "pil",
554
+ return_dict: bool = True,
555
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
556
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
557
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
558
+ max_sequence_length: int = 512,
559
+ num_layers: int = 5,
560
+ transparent_decoder: nn.Module = None,
561
+ ):
562
+ r"""
563
+ Function invoked when calling the pipeline for generation.
564
+
565
+ Args:
566
+ prompt (`str` or `List[str]`, *optional*):
567
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
568
+ instead.
569
+ prompt_2 (`str` or `List[str]`, *optional*):
570
+ The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
571
+ will be used instead
572
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
573
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
574
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
575
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
576
+ num_inference_steps (`int`, *optional*, defaults to 50):
577
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
578
+ expense of slower inference.
579
+ timesteps (`List[int]`, *optional*):
580
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
581
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
582
+ passed will be used. Must be in descending order.
583
+ guidance_scale (`float`, *optional*, defaults to 7.0):
584
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
585
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
586
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
587
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
588
+ usually at the expense of lower image quality.
589
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
590
+ The number of images to generate per prompt.
591
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
592
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
593
+ to make generation deterministic.
594
+ latents (`torch.FloatTensor`, *optional*):
595
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
596
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
597
+ tensor will ge generated by sampling using the supplied random `generator`.
598
+ prompt_embeds (`torch.FloatTensor`, *optional*):
599
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
600
+ provided, text embeddings will be generated from `prompt` input argument.
601
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
602
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
603
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
604
+ output_type (`str`, *optional*, defaults to `"pil"`):
605
+ The output format of the generate image. Choose between
606
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
607
+ return_dict (`bool`, *optional*, defaults to `True`):
608
+ Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple.
609
+ joint_attention_kwargs (`dict`, *optional*):
610
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
611
+ `self.processor` in
612
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
613
+ callback_on_step_end (`Callable`, *optional*):
614
+ A function that calls at the end of each denoising steps during the inference. The function is called
615
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
616
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
617
+ `callback_on_step_end_tensor_inputs`.
618
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
619
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
620
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
621
+ `._callback_tensor_inputs` attribute of your pipeline class.
622
+ max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
623
+
624
+ Examples:
625
+
626
+ Returns:
627
+ [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict`
628
+ is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated
629
+ images.
630
+ """
631
+
632
+ height = height or self.default_sample_size * self.vae_scale_factor
633
+ width = width or self.default_sample_size * self.vae_scale_factor
634
+
635
+ # 1. Check inputs. Raise error if not correct
636
+ self.check_inputs(
637
+ prompt,
638
+ prompt_2,
639
+ height,
640
+ width,
641
+ prompt_embeds=prompt_embeds,
642
+ pooled_prompt_embeds=pooled_prompt_embeds,
643
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
644
+ max_sequence_length=max_sequence_length,
645
+ )
646
+
647
+ self._guidance_scale = guidance_scale
648
+ self._joint_attention_kwargs = joint_attention_kwargs
649
+ self._interrupt = False
650
+
651
+ # 2. Define call parameters
652
+ if prompt is not None and isinstance(prompt, str):
653
+ batch_size = 1
654
+ elif prompt is not None and isinstance(prompt, list):
655
+ batch_size = len(prompt)
656
+ else:
657
+ batch_size = prompt_embeds.shape[0]
658
+
659
+ device = self._execution_device
660
+
661
+ lora_scale = (
662
+ self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
663
+ )
664
+ (
665
+ prompt_embeds,
666
+ pooled_prompt_embeds,
667
+ text_ids,
668
+ ) = self.encode_prompt(
669
+ prompt=prompt,
670
+ prompt_2=prompt_2,
671
+ prompt_embeds=prompt_embeds,
672
+ pooled_prompt_embeds=pooled_prompt_embeds,
673
+ device=device,
674
+ num_images_per_prompt=num_images_per_prompt,
675
+ max_sequence_length=max_sequence_length,
676
+ lora_scale=lora_scale,
677
+ )
678
+ (
679
+ neg_prompt_embeds,
680
+ neg_pooled_prompt_embeds,
681
+ neg_text_ids,
682
+ ) = self.encode_prompt(
683
+ prompt="",
684
+ prompt_2=None,
685
+ device=device,
686
+ num_images_per_prompt=num_images_per_prompt,
687
+ max_sequence_length=max_sequence_length,
688
+ lora_scale=lora_scale,
689
+ )
690
+
691
+ # 4. Prepare latent variables
692
+ num_channels_latents = self.transformer.config.in_channels // 4
693
+ latents, latent_image_ids = self.prepare_latents(
694
+ batch_size * num_images_per_prompt,
695
+ num_layers,
696
+ num_channels_latents,
697
+ height,
698
+ width,
699
+ validation_box,
700
+ prompt_embeds.dtype,
701
+ device,
702
+ generator,
703
+ latents,
704
+ )
705
+
706
+ # 5. Prepare timesteps
707
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
708
+ image_seq_len = latent_image_ids.shape[0]
709
+ mu = calculate_shift(
710
+ image_seq_len,
711
+ self.scheduler.config.base_image_seq_len,
712
+ self.scheduler.config.max_image_seq_len,
713
+ self.scheduler.config.base_shift,
714
+ self.scheduler.config.max_shift,
715
+ )
716
+ timesteps, num_inference_steps = retrieve_timesteps(
717
+ self.scheduler,
718
+ num_inference_steps,
719
+ device,
720
+ timesteps,
721
+ sigmas,
722
+ mu=mu,
723
+ )
724
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
725
+ self._num_timesteps = len(timesteps)
726
+
727
+ # handle guidance
728
+ if self.transformer.config.guidance_embeds:
729
+ guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
730
+ guidance = guidance.expand(latents.shape[0])
731
+ else:
732
+ guidance = None
733
+
734
+ # 6. Denoising loop
735
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
736
+ for i, t in enumerate(timesteps):
737
+ if self.interrupt:
738
+ continue
739
+
740
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
741
+ timestep = t.expand(latents.shape[0]).to(latents.dtype)
742
+
743
+ noise_pred = self.transformer(
744
+ hidden_states=latents,
745
+ list_layer_box=validation_box,
746
+ timestep=timestep / 1000,
747
+ guidance=guidance,
748
+ pooled_projections=pooled_prompt_embeds,
749
+ encoder_hidden_states=prompt_embeds,
750
+ txt_ids=text_ids,
751
+ img_ids=latent_image_ids,
752
+ joint_attention_kwargs=self.joint_attention_kwargs,
753
+ return_dict=False,
754
+ )[0]
755
+
756
+ neg_noise_pred = self.transformer(
757
+ hidden_states=latents,
758
+ list_layer_box=validation_box,
759
+ timestep=timestep / 1000,
760
+ guidance=guidance,
761
+ pooled_projections=neg_pooled_prompt_embeds,
762
+ encoder_hidden_states=neg_prompt_embeds,
763
+ txt_ids=neg_text_ids,
764
+ img_ids=latent_image_ids,
765
+ joint_attention_kwargs=self.joint_attention_kwargs,
766
+ return_dict=False,
767
+ )[0]
768
+
769
+ noise_pred = neg_noise_pred + true_gs * (noise_pred - neg_noise_pred)
770
+
771
+ # compute the previous noisy sample x_t -> x_t-1
772
+ latents_dtype = latents.dtype
773
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
774
+
775
+ if latents.dtype != latents_dtype:
776
+ if torch.backends.mps.is_available():
777
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
778
+ latents = latents.to(latents_dtype)
779
+
780
+ if callback_on_step_end is not None:
781
+ callback_kwargs = {}
782
+ for k in callback_on_step_end_tensor_inputs:
783
+ callback_kwargs[k] = locals()[k]
784
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
785
+
786
+ latents = callback_outputs.pop("latents", latents)
787
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
788
+
789
+ # call the callback, if provided
790
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
791
+ progress_bar.update()
792
+
793
+ if XLA_AVAILABLE:
794
+ xm.mark_step()
795
+
796
+ # create a grey latent
797
+ bs, n_layers, channel_latent, height, width = latents.shape
798
+
799
+ pixel_grey = torch.zeros(size=(bs*n_layers, 3, height*8, width*8), device=latents.device, dtype=latents.dtype)
800
+ latent_grey = self.vae.encode(pixel_grey).latent_dist.sample()
801
+ latent_grey = (latent_grey - self.vae.config.shift_factor) * self.vae.config.scaling_factor
802
+ latent_grey = latent_grey.view(bs, n_layers, channel_latent, height, width) # [bs, n_layers, c_latent, h, w]
803
+
804
+ # fill in the latents
805
+ for layer_idx in range(latent_grey.shape[1]):
806
+ if validation_box[layer_idx] == None:
807
+ continue
808
+ x1, y1, x2, y2 = validation_box[layer_idx]
809
+ x1, y1, x2, y2 = x1 // 8, y1 // 8, x2 // 8, y2 // 8
810
+ latent_grey[:, layer_idx, :, y1:y2, x1:x2] = latents[:, layer_idx, :, y1:y2, x1:x2]
811
+ latents = latent_grey
812
+
813
+ if output_type == "latent":
814
+ image = latents
815
+
816
+ else:
817
+ latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
818
+ latents = latents.reshape(bs * n_layers, channel_latent, height, width)
819
+ latents_segs = torch.split(latents, 16, dim=0) ### split latents by 16 to avoid odd purple output
820
+ image_segs = [self.vae.decode(latents_seg, return_dict=False)[0] for latents_seg in latents_segs]
821
+ image = torch.cat(image_segs, dim=0)
822
+ if transparent_decoder is not None:
823
+ transparent_decoder = transparent_decoder.to(dtype=image.dtype, device=image.device)
824
+
825
+ decoded_fg, decoded_alpha = transparent_decoder(latents, [validation_box])
826
+ decoded_alpha = (decoded_alpha + 1.0) / 2.0
827
+ decoded_alpha = torch.clamp(decoded_alpha, min=0.0, max=1.0).permute(0, 2, 3, 1)
828
+
829
+ decoded_fg = (decoded_fg + 1.0) / 2.0
830
+ decoded_fg = torch.clamp(decoded_fg, min=0.0, max=1.0).permute(0, 2, 3, 1)
831
+
832
+ vis_list = None
833
+ png = torch.cat([decoded_fg, decoded_alpha], dim=3)
834
+ result_list = (png * 255.0).detach().cpu().float().numpy().clip(0, 255).astype(np.uint8)
835
+ else:
836
+ result_list, vis_list = None, None
837
+ image = self.image_processor.postprocess(image, output_type=output_type)
838
+
839
+ # Offload all models
840
+ self.maybe_free_model_hooks()
841
+
842
+ if not return_dict:
843
+ return (image, result_list, vis_list, latents)
844
+
845
+ return FluxPipelineOutput(images=image), result_list, vis_list, latents
modeling_crello.py ADDED
@@ -0,0 +1,235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import PreTrainedModel, PretrainedConfig, AutoModel, AutoModelForCausalLM, OPTForCausalLM
3
+ # from transformers import BitsAndBytesConfig
4
+ from torch import nn
5
+ import os
6
+ from typing import Optional, List
7
+ import os
8
+
9
+ def kmp_preprocess(pattern):
10
+ pattern_len = len(pattern)
11
+ prefix_suffix = [0] * pattern_len
12
+ j = 0
13
+
14
+ for i in range(1, pattern_len):
15
+ while j > 0 and pattern[i] != pattern[j]:
16
+ j = prefix_suffix[j - 1]
17
+
18
+ if pattern[i] == pattern[j]:
19
+ j += 1
20
+
21
+ prefix_suffix[i] = j
22
+
23
+ return prefix_suffix
24
+
25
+ def kmp_search(text, pattern):
26
+ text_len = len(text)
27
+ pattern_len = len(pattern)
28
+ prefix_suffix = kmp_preprocess(pattern)
29
+ matches = []
30
+
31
+ j = 0
32
+ for i in range(text_len):
33
+ while j > 0 and text[i] != pattern[j]:
34
+ j = prefix_suffix[j - 1]
35
+
36
+ if text[i] == pattern[j]:
37
+ j += 1
38
+
39
+ if j == pattern_len:
40
+ matches.append(i - j + 1)
41
+ j = prefix_suffix[j - 1]
42
+
43
+ return matches
44
+
45
+ class ModelWrapper:
46
+ def __init__(self, model):
47
+ self.model = model
48
+
49
+ def __getattr__(self, name):
50
+ return getattr(self.model, name)
51
+
52
+ @torch.no_grad()
53
+ def __call__(self, pixel_values):
54
+ return self.model(pixel_values)
55
+
56
+ def eval(self):
57
+ pass
58
+
59
+ def train(self):
60
+ pass
61
+
62
+
63
+ def parameters(self):
64
+ return self.model.parameters()
65
+
66
+
67
+ class CrelloModelConfig(PretrainedConfig):
68
+ def __init__(
69
+ self,
70
+ old_vocab_size: int = 32000,
71
+ vocab_size: int = 32000,
72
+ pad_token_id: int = 2,
73
+ ignore_ids: List[int] = [],
74
+
75
+ freeze_lm: bool = True, # lm.eval()
76
+ opt_version: str = 'facebook/opt-6.7b',
77
+
78
+ task: str = 'captioning',
79
+
80
+ use_lora: bool = False,
81
+ lora_alpha: int = 32,
82
+ lora_r: int = 8,
83
+ lora_dropout: float = 0.05,
84
+ lora_target_modules: str = r'.*\.(q_proj|v_proj)',
85
+
86
+ hidden_size: int = -1,
87
+ load_in_4bit: Optional[bool] = False,
88
+
89
+ **kwargs,
90
+ ):
91
+ super().__init__(**kwargs)
92
+ assert old_vocab_size > 0, 'old_vocab_size must be positive'
93
+ assert vocab_size > 0, 'vocab_size must be positive'
94
+
95
+ self.old_vocab_size = old_vocab_size
96
+ self.vocab_size = vocab_size
97
+ self.pad_token_id = pad_token_id
98
+ self.freeze_lm = freeze_lm
99
+ self.opt_version = opt_version
100
+ self.task = task
101
+ self.use_lora = use_lora
102
+ self.lora_alpha = lora_alpha
103
+ self.lora_r = lora_r
104
+ self.lora_dropout = lora_dropout
105
+ self.lora_target_modules = lora_target_modules
106
+ self.hidden_size = hidden_size
107
+ self.load_in_4bit = load_in_4bit
108
+ self.ignore_ids = ignore_ids
109
+
110
+
111
+ class CrelloModel(PreTrainedModel):
112
+ config_class = CrelloModelConfig
113
+ supports_gradient_checkpointing = True
114
+
115
+ def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None):
116
+ self.lm.gradient_checkpointing_enable()
117
+
118
+ def __init__(self, config: CrelloModelConfig): # 显示声明config类型
119
+ super().__init__(config)
120
+
121
+ self.pad_token_id = config.pad_token_id
122
+
123
+ self.args = config
124
+
125
+ opt_version = "WYBar/LLM_For_Layout_Planning"
126
+
127
+ print(f"Using {opt_version} for the language model.")
128
+
129
+ if 'facebook/opt' in opt_version:
130
+ self.lm = OPTForCausalLM.from_pretrained(opt_version)
131
+ word_embed_proj_dim = self.lm.config.word_embed_proj_dim
132
+ else:
133
+ if config.load_in_4bit:
134
+ print("\n would load_in_4bit")
135
+ quantization_config = None
136
+ # This means: fit the entire model on the GPU:0
137
+ local_rank = int(os.environ.get("LOCAL_RANK", 0))
138
+ device_map = {"": local_rank}
139
+ torch_dtype = torch.bfloat16
140
+ else:
141
+ print("\n wouldn't load_in_4bit")
142
+ quantization_config = None
143
+ device_map = None
144
+ torch_dtype = None
145
+
146
+ self.lm = AutoModelForCausalLM.from_pretrained(
147
+ "WYBar/LLM_For_Layout_Planning",
148
+ subfolder="Meta-Llama-3-8B",
149
+ # use_auth_token=use_auth_token,
150
+ # quantization_config=quantization_config,
151
+ # device_map=device_map,
152
+ trust_remote_code=True,
153
+ torch_dtype=torch.bfloat16,
154
+ # cache_dir="/openseg_blob/v-yanbin/GradioDemo/cache_dir",
155
+ )
156
+ word_embed_proj_dim = self.lm.config.hidden_size
157
+ self.config.hidden_size = self.lm.config.hidden_size
158
+ self.opt_version = opt_version
159
+
160
+ if self.args.freeze_lm:
161
+ self.lm.eval()
162
+ print("Freezing the LM.")
163
+ # for param in self.lm.parameters():
164
+ # param.requires_grad = False
165
+ else:
166
+ print("\n no freeze lm, so to train lm")
167
+ self.lm.train()
168
+ self.lm.config.gradient_checkpointing = True
169
+
170
+ # print('resize token embeddings to match the tokenizer', config.vocab_size)
171
+ # self.lm.resize_token_embeddings(config.vocab_size)
172
+ # self.input_embeddings = self.lm.get_input_embeddings()
173
+ # print('after token embeddings to match the tokenizer', config.vocab_size)
174
+
175
+ def train(self, mode=True):
176
+ super().train(mode=mode)
177
+ # Overwrite train() to ensure frozen models remain frozen.
178
+ if self.args.freeze_lm:
179
+ self.lm.eval()
180
+
181
+ def forward(
182
+ self,
183
+ labels: torch.LongTensor,
184
+ ):
185
+ batch_size = labels.shape[0]
186
+ full_labels = labels.detach().clone()
187
+
188
+ input_embs = self.input_embeddings(labels) # (N, T, D)
189
+ input_embs_norm = ((input_embs ** 2).sum(dim=-1) ** 0.5).mean()
190
+
191
+ for ignore_id in self.config.ignore_ids:
192
+ full_labels[full_labels == ignore_id] = -100
193
+
194
+ pad_idx = []
195
+ # 获取每一个batch的 seq 长度,取值为 max_len or padding_position,记录在pad_idx
196
+ # -100 is the ignore index for cross entropy loss. https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html
197
+ for label in full_labels:
198
+ for k, token in enumerate(label):
199
+ # Mask out pad tokens if they exist.
200
+ if token in [self.pad_token_id]:
201
+ label[k:] = -100 # 将后面的token都mask掉
202
+ pad_idx.append(k)
203
+ break
204
+ if k == len(label) - 1: # No padding found.
205
+ pad_idx.append(k + 1)
206
+ assert len(pad_idx) == batch_size, (len(pad_idx), batch_size)
207
+
208
+ output = self.lm( inputs_embeds=input_embs,
209
+ # input_ids=labels,
210
+ labels=full_labels,
211
+ output_hidden_states=True)
212
+
213
+ return output, full_labels, input_embs_norm
214
+
215
+ if __name__=="__main__":
216
+ config = CrelloModelConfig(
217
+ vocab_size=50265,
218
+ image_reg_token=50264,
219
+ image_gt_token=50263,
220
+ )
221
+ print("config: ",config)
222
+ model1 = CrelloModel(config)
223
+ print("\nmodel1: ",model1)
224
+ model1.save_pretrained('test')
225
+ model2 = CrelloModel.from_pretrained('test')
226
+ print("\nmodel2: ",model2)
227
+ # compare model1 and model2
228
+
229
+ state_dict1 = model1.state_dict()
230
+ state_dict2 = model2.state_dict()
231
+ assert set(state_dict1.keys()) == set(state_dict2.keys())
232
+ for k in state_dict1.keys():
233
+ assert torch.equal(state_dict1[k], state_dict2[k])
234
+ print('all parameters are equal')
235
+
quantizer.py ADDED
@@ -0,0 +1,552 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import copy
4
+ from collections import OrderedDict
5
+ import json
6
+ from datasets import ClassLabel
7
+ import random
8
+ import math
9
+ from functools import lru_cache
10
+ from matplotlib import font_manager
11
+ from colorama import Fore, Style, init
12
+
13
+
14
+ class BaseQuantizer:
15
+ @property
16
+ def ignore_tokens(self):
17
+ if self.num_mask_tokens > 0:
18
+ if self.mask_type == 'cm3':
19
+ return [self.predict_start_token] + self.mask_tokens
20
+ elif self.mask_type == 'mask_aug':
21
+ return [self.mask_aug_token]
22
+ else:
23
+ raise ValueError(f'Invalid mask type {self.mask_type}')
24
+ else:
25
+ return []
26
+
27
+ def __init__(self, simplify_json=False, mask_all=False,
28
+ num_mask_tokens=0, mask_type='cm3', **kwargs):
29
+ self.simplify_json=simplify_json
30
+ self.io_ignore_replace_tokens = ['<split-text>']
31
+ self.mask_all = mask_all
32
+ self.num_mask_tokens = num_mask_tokens
33
+ self.mask_type = mask_type
34
+ if self.mask_type == 'mask_aug':
35
+ self.mask_aug_token = '<mask-aug>'
36
+ elif self.mask_type == 'cm3':
37
+ self.predict_start_token = '<pred-start>'
38
+ else:
39
+ raise ValueError(f'Invalid mask type {self.mask_type}')
40
+
41
+ def get_additional_mask_tokens(self):
42
+ if self.mask_type == 'cm3': # 两种配置:1. ['<pred-start>'] + '<mask-%d>',数量和self.num_mask_tokens相关 2. ['<mask-aug>']
43
+ self.mask_tokens = ['<mask-%d>' % i for i in range(self.num_mask_tokens)]
44
+ return [self.predict_start_token] + self.mask_tokens
45
+ elif self.mask_type == 'mask_aug':
46
+ return [self.mask_aug_token]
47
+ else:
48
+ raise ValueError(f'Invalid mask type {self.mask_type}')
49
+
50
+ def dump2json(self, json_example):
51
+ if self.simplify_json: # 将 dict 转化为 str, 如果simplify_json is True,那么缩减空格和换行,删除token的双引号
52
+ content = json.dumps(json_example, separators=(',',':'))
53
+ for token in self.additional_special_tokens:
54
+ content = content.replace(f'"{token}"', token)
55
+ else:
56
+ content = json.dumps(json_example)
57
+ return content
58
+
59
+ def load_json(self, content): # 将str转化为json
60
+ replace_tokens = set(self.additional_special_tokens) - set(self.io_ignore_replace_tokens) # sirui change
61
+ if self.simplify_json:
62
+ for token in replace_tokens: # 如果simplify_json is True,那么为 token 添加双引号
63
+ content = content.replace(token, f'"{token}"')
64
+ return json.loads(content)
65
+
66
+ def apply_masking(self,
67
+ json_example,
68
+ mask_all=None,
69
+ return_meta=False,
70
+ target_keys=['width', 'height', 'left', 'top'],
71
+ target_element_types=None
72
+ ):
73
+ if mask_all is None:
74
+ mask_all = self.mask_all
75
+ json_example = copy.deepcopy(json_example)
76
+ target_keys = set(target_keys)
77
+ target_tokens = []
78
+ for shape_i, shape in enumerate(json_example['layers']['textlayer']):
79
+ # element_type = self.general_dequantize(shape['type'],'type',to_float=False)
80
+ # if target_element_types is not None:
81
+ # if element_type not in target_element_types:
82
+ # continue
83
+ for key_i, key in enumerate(shape.keys()):
84
+ if key in target_keys:
85
+ target_tokens.append((shape_i, key_i, key, shape[key]))
86
+ if not mask_all:
87
+ target_num_mask_tokens = random.randint(1, self.num_mask_tokens)
88
+ if len(target_tokens) > target_num_mask_tokens:
89
+ random.shuffle(target_tokens)
90
+ target_tokens = target_tokens[:target_num_mask_tokens]
91
+ # sort by shape_i and key_i
92
+ target_tokens = sorted(target_tokens, key=lambda x: x[0]*100+x[1])
93
+ else:
94
+ if len(target_tokens) > self.num_mask_tokens:
95
+ # 取最后面几个
96
+ target_tokens = target_tokens[-self.num_mask_tokens:]
97
+
98
+ tuples = []
99
+ meta_infos = []
100
+ for mask_i, (shape_i, key_i, key, value) in enumerate(target_tokens):
101
+ if self.mask_type == 'cm3':
102
+ mask_token = self.mask_tokens[mask_i]
103
+ elif self.mask_type == 'mask_aug':
104
+ mask_token = self.mask_aug_token
105
+ else:
106
+ raise ValueError(f'Invalid mask type {self.mask_type}')
107
+ # <one-1><decimal0-1><decimal1-2>
108
+ if '<' in value:
109
+ num_token = value.count('<')
110
+ else:
111
+ num_token = value.count(' ')
112
+ json_example['layers']['textlayer'][shape_i][key] = mask_token
113
+ tuples.append((mask_token, value, num_token))
114
+ meta_infos.append((shape_i,key))
115
+ if return_meta:
116
+ return json_example, tuples, meta_infos
117
+ else:
118
+ return json_example, tuples
119
+
120
+ def make_prediction_postfix(self, tuples):
121
+ postfix = self.predict_start_token
122
+ for mask_token, value, num_token in tuples:
123
+ postfix = postfix+ f'{mask_token}{value}'
124
+ return postfix
125
+
126
+ # specs={
127
+ # "width":"size",
128
+ # "height":"size",
129
+ # "left":"pos",
130
+ # "top":"pos",
131
+ # "x":"pos", # center x
132
+ # "y":"pos", # center y
133
+ # "opacity":"opacity",
134
+ # "color":"color",
135
+ # "angle":"angle",
136
+ # "font_size":"font_size",
137
+ # 'ratio':'ratio',
138
+ # 'letter_spacing': 'spacing',
139
+ # 'textlen': 'textlen'
140
+ # }
141
+
142
+ specs={
143
+ "width":"size",
144
+ "height":"size",
145
+ "x":"pos", # center x
146
+ "y":"pos", # center y
147
+ "color":"color",
148
+ "font":"font"
149
+ }
150
+
151
+ # TODO change min_max_bins
152
+ # min_max_bins = {
153
+ # 'size':(0,2,256),
154
+ # 'pos':(-1,1,256),
155
+ # # 'opacity':(0,1,8),
156
+ # 'opacity':(0,255,8),
157
+ # 'color':(0,255,32),
158
+ # 'angle':(0,2*np.pi,64),
159
+ # 'font_size':(2,200,100),
160
+ # 'spacing': (0,1,40),
161
+ # 'textlen': (1,20,20)
162
+ # }
163
+ min_max_bins = {
164
+ 'size': (0,1,256),
165
+ 'pos': (0,1,256),
166
+ 'color': (0,137,138),
167
+ 'font': (0,511,512)
168
+ }
169
+
170
+ import numpy as np
171
+
172
+ # pre 和 post 分别代表 10 的幂,分别对应大数和小数部分,参数代表位数
173
+ def get_keys_and_multipliers(pre_decimal=3, post_decimal=2):
174
+ pre_keys = ['one', 'ten', 'hundred', 'thousand']
175
+ pre_multiplers = [1, 10, 100, 1000]
176
+ assert pre_decimal <= len(pre_keys)
177
+ pre_keys = pre_keys[:pre_decimal][::-1]
178
+ pre_multiplers = pre_multiplers[:pre_decimal][::-1]
179
+
180
+ post_keys = [f'decimal{x}' for x in range(post_decimal)]
181
+ post_multiplers = [10 ** -(x+1) for x in range(post_decimal)]
182
+
183
+ keys = pre_keys + post_keys
184
+ multiplers = pre_multiplers + post_multiplers
185
+ return keys, multiplers
186
+
187
+ class DecimalQuantizer:
188
+ def __init__(self, max_pre_decimal=3, max_post_decimal=2):
189
+ self.max_pre_decimal = max_pre_decimal
190
+ self.max_post_decimal = max_post_decimal
191
+ self.keys, self.multiplers = get_keys_and_multipliers(max_pre_decimal, max_post_decimal)
192
+ self.symbols = {
193
+ -1: '<symbol-1>',
194
+ 1: '<symbol-0>',
195
+ }
196
+
197
+ def get_vocab(self):
198
+ special_tokens = [*self.symbols.values()] # ['<symbol-1>', '<symbol-0>']
199
+ for key in self.keys: # ['one', 'ten', 'hundred', 'thousand'] + ['decimal0', 'decimal1]
200
+ special_tokens.extend([f'<{key}-{i}>' for i in range(10)])
201
+ return special_tokens
202
+
203
+ def check_valid(self, token):
204
+ prefix = token.lstrip('<').split('-')[0] # '<symbol-1>' -> 'symbol-1>' -> ['symbol', '1>']
205
+ if prefix =='symbol' or prefix in self.keys:
206
+ return True
207
+ else:
208
+ return False
209
+
210
+ # 小数点后保留两位
211
+ def __call__(self, val, pre_decimal=None, post_decimal=None, need_symbol=False): # 100.00
212
+ if pre_decimal is None:
213
+ pre_decimal = self.max_pre_decimal
214
+ if post_decimal is None:
215
+ post_decimal = self.max_post_decimal
216
+
217
+ assert pre_decimal <= self.max_pre_decimal
218
+ assert post_decimal <= self.max_post_decimal
219
+
220
+ keys, multiplers = get_keys_and_multipliers(pre_decimal, post_decimal)
221
+
222
+ symbol = int(np.sign(val)) # 返回一个浮点数(1.0, -1.0 或 0.0),代表正负和0
223
+ if symbol == 0: # 两类:>= 0 & < 0
224
+ symbol = 1
225
+ val = round(abs(val), post_decimal) # 将 val 的绝对值四舍五入到 post_decimal 位小数
226
+
227
+ tokens = []
228
+ if need_symbol: # self.symbols = {-1: '<symbol-1>', 1: '<symbol-0>',}
229
+ symbol_type = self.symbols[symbol]
230
+ tokens.append(symbol_type)
231
+ else:
232
+ assert symbol >= 0
233
+
234
+ for key, multipler in zip(keys, multiplers):
235
+ # 用于获取对于给定数值 val,每一位的数字,并且生成为'<one-7>'这样的token
236
+ v = math.floor(val / multipler)
237
+ if v > 9:
238
+ raise ValueError(f'Invalid value {val} for {pre_decimal} pre_decimal and {post_decimal} post_decimal')
239
+ val = val - v * multipler
240
+ tokens.append(f'<{key}-{v}>')
241
+
242
+ # 对于val,生成每一位数字对应的token,如果need_symbol = True,还会在前面加上 标识 >= 0 和 < 0 的 symbol-1 和 symbol-0
243
+ return ''.join(tokens)
244
+
245
+ def parse_token(self, token):
246
+ # <hundred-1> -> hundred, 1
247
+ key, val = token[1:-1].split('-')
248
+ return key, int(val)
249
+
250
+ def decode(self, tokens_str): # 将token_str用 > 先拆开,再添上 > ,然后转化为 list
251
+ tokens = tokens_str.split('>')
252
+ tokens = [x+'>' for x in tokens if x != '']
253
+ if tokens[0].startswith('<symbol'):
254
+ symbol_type = tokens[0]
255
+ tokens = tokens[1:]
256
+ inv_map = {v: k for k, v in self.symbols.items()} # 和 原字典 键、值 对调
257
+ symbol = inv_map[symbol_type]
258
+ else:
259
+ symbol = 1
260
+
261
+ accumulater = 0
262
+ for token in tokens:
263
+ key, val = self.parse_token(token)
264
+ multipler_index = self.keys.index(key)
265
+ multipler = self.multiplers[multipler_index]
266
+ actual_val = val * multipler
267
+ # print(key, val, multipler, actual_val)
268
+ accumulater += actual_val
269
+ accumulater = accumulater * symbol
270
+
271
+ # 还原出原来的整数,带有符号,并且精度 由 pre/post_decimal位数控制
272
+ return accumulater
273
+
274
+ # min_max_bins = {
275
+ # 'size': (0,1,256),
276
+ # 'pos': (0,1,256),
277
+ # 'color': (0,137,138),
278
+ # 'font': (0,511,512)
279
+ # }
280
+ pre_post_decimals={
281
+ 'size': {
282
+ 'pre_decimal': 1,
283
+ 'post_decimal': 2,
284
+ 'need_symbol': False
285
+ },
286
+ 'pos': {
287
+ 'pre_decimal': 1,
288
+ 'post_decimal': 2,
289
+ 'need_symbol': True
290
+ },
291
+ 'opacity': {
292
+ 'pre_decimal': 1,
293
+ 'post_decimal': 1,
294
+ 'need_symbol': False
295
+ },
296
+ 'color':{
297
+ 'pre_decimal': 3,
298
+ 'post_decimal': 0,
299
+ 'need_symbol': False
300
+ },
301
+ 'angle':{
302
+ 'pre_decimal': 1,
303
+ 'post_decimal': 2,
304
+ 'need_symbol': False
305
+ },
306
+ 'font_size':{
307
+ 'pre_decimal': 3,
308
+ 'post_decimal': 0,
309
+ 'need_symbol': False
310
+ },
311
+ }
312
+
313
+ class QuantizerV4(BaseQuantizer):
314
+ def __init__(self, quant=True,
315
+ decimal_quantize_types = [],
316
+ decimal_quantize_kwargs = {'max_pre_decimal':3, 'max_post_decimal':2},
317
+ mask_values=False,
318
+ **kwargs):
319
+ super().__init__(**kwargs)
320
+ self.min = min
321
+ self.max = max
322
+ self.quant = quant
323
+ self.mask_values = mask_values
324
+ self.text_split_token = '<split-text>'
325
+ self.decimal_quantize_types = decimal_quantize_types
326
+ self.decimal_quantize = len(decimal_quantize_types) > 0
327
+ if len(decimal_quantize_types) > 0:
328
+ print('decimal quantize types', decimal_quantize_types)
329
+ self.decimal_quantizer = DecimalQuantizer(**decimal_quantize_kwargs)
330
+ else:
331
+ self.decimal_quantizer = None
332
+
333
+ self.set_min_max_bins(min_max_bins)
334
+ # min_max_bins = {
335
+ # 'size': (0,1,256),
336
+ # 'pos': (0,1,256),
337
+ # 'color': (0,137,138),
338
+ # 'font': (0,511,512)
339
+ # }
340
+ self.width = kwargs.get('width', 1456)
341
+ self.height = kwargs.get('height', 1457)
342
+ self.width = int(self.width)
343
+ self.height = int(self.height)
344
+
345
+ def set_min_max_bins(self, min_max_bins): # 检查 n_bins是否是偶数,然后将其 +1
346
+ min_max_bins = copy.deepcopy(min_max_bins)
347
+ # adjust the bins to plus one
348
+ for type_name, (min_val, max_val, n_bins) in min_max_bins.items():
349
+ assert n_bins % 2 == 0 # must be even
350
+ min_max_bins[type_name] = (min_val, max_val, n_bins+1)
351
+ self.min_max_bins = min_max_bins
352
+
353
+ def setup_tokenizer(self, tokenizer):
354
+ # 整个函数生成additional_special_tokens:1. '<split-text>' 2.<one-1> <symbol-1> : decimal quantizer 3. <size-255> quantizerV4 4.self.get_additional_mask_tokens()
355
+ # 然后tokenizer.add_special_tokens({'additional_special_tokens': additional_special_tokens})
356
+ additional_special_tokens = [self.text_split_token] # self.text_split_token = '<split-text>'
357
+ if self.decimal_quantize:
358
+ special_tokens = self.decimal_quantizer.get_vocab() # <one-1> <symbol-1>
359
+ self.io_ignore_replace_tokens += special_tokens # self.io_ignore_replace_tokens = ['<split-text>'] 在BaseQuantizer中声明
360
+ additional_special_tokens += special_tokens
361
+ # the order must be preserved, other wise the tokenizer will be wrong
362
+ rest_types = [key for key in self.min_max_bins.keys() if key not in self.decimal_quantize_types]
363
+ for type_name in rest_types:
364
+ min_val, max_val, n_bins = self.min_max_bins[type_name]
365
+ additional_special_tokens += [f'<{type_name}-{i}>' for i in range(n_bins)] # <size-256>
366
+
367
+ if self.num_mask_tokens > 0:
368
+ additional_special_tokens.extend(self.get_additional_mask_tokens())
369
+
370
+ print('additional_special_tokens', additional_special_tokens)
371
+
372
+ tokenizer.add_special_tokens({'additional_special_tokens': additional_special_tokens})
373
+ self.additional_special_tokens = set(additional_special_tokens)
374
+ return tokenizer
375
+
376
+ @lru_cache(maxsize=128) # 缓存函数的返回值,以提高性能。maxsize=128 表示缓存最多存储 128 个不同的输入结果
377
+ def get_bins(self, real_type): # real_type: size, pos, font, color
378
+ # 返回 最小值,最大值,等距数组
379
+ min_val, max_val, n_bins = self.min_max_bins[real_type]
380
+ return min_val, max_val, np.linspace(min_val, max_val, n_bins)
381
+
382
+ def quantize(self, x, type): # (0.25, 'y') -> (<size-50>)
383
+ if not self.quant:
384
+ return x
385
+ """Quantize a float array x into n_bins discrete values."""
386
+ real_type = specs[type] # x, y, width, height, color, font -> size, pos, font, color
387
+ min_val, max_val, bins = self.get_bins(real_type)
388
+ x = np.clip(float(x), min_val, max_val) # 确保 x 的值在 [min_val, max_val] 范围内,否则截断
389
+ if self.decimal_quantize and real_type in self.decimal_quantize_types:
390
+ return self.decimal_quantizer(x, **pre_post_decimals[real_type])
391
+ val = np.digitize(x, bins) - 1 # val是一个整数,取值范围在[0, len(bins)],换句话说就是bins数组的索引
392
+ n_bins = len(bins)
393
+ assert val >= 0 and val < n_bins
394
+ return f'<{real_type}-{val}>' # <size-255>
395
+
396
+ def dequantize(self, x): # (<size-255> -> 0.99?)
397
+ # <pos-1>->1
398
+ val = x.split('-')[1].strip('>')
399
+ # <pos-1>->pos
400
+ real_type = x.split('-')[0][1:]
401
+ if self.decimal_quantize and self.decimal_quantizer.check_valid(x):
402
+ return self.decimal_quantizer.decode(x)
403
+ min_val, max_val, bins = self.get_bins(real_type)
404
+ return bins[int(val)]
405
+
406
+ def construct_map_dict(self):
407
+ map_dict = {}
408
+ for i in range(self.min_max_bins['size'][2]): # 'size': (0, 1, 256),
409
+ name = "<size-%d>" % i
410
+ value = self.dequantize(name)
411
+ map_dict[name] = str(value) # 255 -> 0.99?
412
+ for i in range(self.min_max_bins['pos'][2]):
413
+ name = "<pos-%d>" % i
414
+ value = self.dequantize(name)
415
+ map_dict[name] = str(value)
416
+ return map_dict
417
+
418
+ def postprocess_colorandfont(self, json_example):
419
+ # 将其中的 正则 匹配部分 用双引号包裹
420
+ import re
421
+ json_example = re.sub(r'(<font-\d+>)', r'"\1"', json_example)
422
+ json_example = re.sub(r'(<color-\d+>)', r'"\1"', json_example)
423
+ return json_example
424
+
425
+ def to_str(self, x, type):
426
+ feature = self.get_feature(type)
427
+ return feature.int2str(x)
428
+
429
+ def convert2layout(self, example): # 将原始的数据转化为 <size-255> 的 token形式
430
+ new_example = OrderedDict()
431
+ new_example['wholecaption'] = example['wholecaption']
432
+ new_layout = []
433
+ for meta_layer in example['layout']:
434
+ new_layout.append({
435
+ "layer": meta_layer["layer"],
436
+ "x": self.quantize(meta_layer["x"]/self.width, 'x'),
437
+ "y": self.quantize(meta_layer["y"]/self.height, 'y'),
438
+ "width": self.quantize(meta_layer["width"]/self.width, 'width'),
439
+ "height": self.quantize(meta_layer["height"]/self.height, 'height')
440
+ })
441
+ new_example['layout'] = new_layout
442
+ return new_example
443
+
444
+ def apply_masking(self,
445
+ json_example,
446
+ mask_all=None,
447
+ return_meta=False,
448
+ # target_keys=['width', 'height', 'left', 'top'], # useless
449
+ # target_element_types=None, # useless
450
+ mask_values = True
451
+ ):
452
+ if mask_all is None:
453
+ mask_all = self.mask_all
454
+
455
+ json_example = copy.deepcopy(json_example)
456
+
457
+ # 这段内容对json中的一些 value 替换为 <mask-i>,并用self.num_mask_tokens限制mask的数量,根据参数还可能进行随机mask
458
+ # 并记录 <mask-i> & value & num_token = value.count('<') 的 三元tuple
459
+ target_tokens = []
460
+ if self.mask_values and mask_values:
461
+ target_tokens.append((-1,-1,'globalcaption', json_example['globalcaption']))
462
+ target_tokens.append((-1,-1,'canvas_width', json_example['canvas_width']))
463
+ target_tokens.append((-1,-1,'canvas_height', json_example['canvas_height']))
464
+ target_tokens.append((-1,-1,'category', json_example['category']))
465
+ target_tokens.append((-1,-1,'keywords', json_example['keywords']))
466
+ target_tokens.append((-1,-1,'bgcaption', json_example['layers']['bglayer']['bgcaption']))
467
+ target_tokens.append((-1,-1,'flag', json_example['layers']['objlayer']['flag']))
468
+ target_tokens.append((-1,-1,'objcaption', json_example['layers']['objlayer']['objcaption']))
469
+ for layer_i, textlayer in enumerate(json_example['layers']['textlayer']):
470
+ target_tokens.append((layer_i, -1, 'text', json_example['layers']['textlayer'][textlayer]))
471
+ if not mask_all: # 随机取值 target_num_mask_tokens, 上界是self.num_mask_tokens
472
+ target_num_mask_tokens = random.randint(1, self.num_mask_tokens)
473
+ if len(target_tokens) > target_num_mask_tokens:
474
+ random.shuffle(target_tokens)
475
+ target_tokens = target_tokens[:target_num_mask_tokens]
476
+ # sort by shape_i and key_i
477
+ target_tokens = sorted(target_tokens, key=lambda x: x[0]*100+x[1])
478
+ else: # 取定值 num_mask_tokens
479
+ if len(target_tokens) > self.num_mask_tokens:
480
+ # 取最后面几个
481
+ target_tokens = target_tokens[-self.num_mask_tokens:]
482
+
483
+ tuples = []
484
+ meta_infos = []
485
+ layer_list = ['heading', 'subheading', 'body']
486
+ for mask_i, (shape_i, key_i, key, value) in enumerate(target_tokens):
487
+ if self.mask_type == 'cm3':
488
+ mask_token = self.mask_tokens[mask_i]
489
+ elif self.mask_type == 'mask_aug':
490
+ mask_token = self.mask_aug_token
491
+ else:
492
+ raise ValueError(f'Invalid mask type {self.mask_type}')
493
+ # <one-1><decimal0-1><decimal1-2>
494
+ if '<' in value:
495
+ num_token = value.count('<')
496
+ else:
497
+ num_token = value.count(' ') + 1
498
+ if shape_i == -1:
499
+ if key in ['bgcaption']:
500
+ json_example['layers']['bglayer']['bgcaption'] = mask_token
501
+ elif key in ['objcaption']:
502
+ json_example['layers']['objlayer']['objcaption'] = mask_token
503
+ elif key in ['flag']:
504
+ json_example['layers']['objlayer']['flag'] = mask_token
505
+ else:
506
+ json_example[key] = mask_token
507
+ else:
508
+ curlayer = layer_list[shape_i]
509
+ json_example['layers']['textlayer'][curlayer] = mask_token
510
+ tuples.append((mask_token, value, num_token))
511
+ meta_infos.append((shape_i,key))
512
+ if return_meta:
513
+ return json_example, tuples, meta_infos
514
+ else:
515
+ return json_example, tuples
516
+
517
+
518
+ # useless orginally used for render
519
+ def is_font_exists(font_name):
520
+ font_list = font_manager.findSystemFonts()
521
+ # print("\nfont_list: ",font_list)
522
+ for font in font_list:
523
+ if font_name.lower() in font.lower():
524
+ return True
525
+ return False
526
+
527
+ def print_info(msg):
528
+ print(Fore.GREEN + "[INFO] " + msg)
529
+
530
+ def print_warning(msg):
531
+ print(Fore.YELLOW + "[WARNING] " + msg)
532
+
533
+ def print_error(msg):
534
+ print(Fore.RED + "[ERROR] " + msg)
535
+
536
+ def load_feature(path):
537
+ with open(path) as f:
538
+ content = f.read()
539
+ content = json.loads(content)
540
+ names = [content[str(i)] for i in range(len(content))]
541
+ return ClassLabel(num_classes= len(names), names=names)
542
+
543
+ def get_quantizer(version='v1', update_vocab=False, **kwargs):
544
+ """ if kwargs.pop('separate_alpha', False): # useless
545
+ kwargs['n_visual_tokens'] *= 2 """
546
+ if version == 'v4':
547
+ quantizer = QuantizerV4(**kwargs)
548
+ else:
549
+ raise NotImplementedError
550
+
551
+ return quantizer
552
+
requirements.txt ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 核心框架
2
+ torch==2.4.0 # 保持与conda环境一致(pypi安装)
3
+ torchvision==0.19.0 # 保持与conda环境一致(pypi安装)
4
+
5
+ # Hugging Face 生态
6
+ transformers==4.44.0 # 原4.39.1 → conda实际安装4.44.0
7
+ diffusers==0.31.0 # 保持与conda环境一致
8
+ accelerate==0.34.2 # 原0.27.2 → conda实际安装0.34.2
9
+ peft==0.12.0 # 原git提交 → conda实际安装0.12.0(pypi)
10
+ datasets==2.20.0 # 保持与conda环境一致
11
+
12
+ # 工具链
13
+ deepspeed==0.15.4 # 原0.14.2 → conda实际安装0.15.4
14
+ # bitsandbytes==0.44.1 # 原0.43.0 → conda实际安装0.44.1
15
+ protobuf==3.20.0 # 原3.20.3 → conda实际安装3.20.0(需验证tensorboard兼容性)
16
+ tensorboard==2.18.0 # 新增明确版本(conda实际安装2.18.0)
17
+ tensorboardx==2.6.2.2 # 新增明确版本(conda实际安装2.6.2.2)
18
+ webdataset==0.2.100 # 新增明确版本(conda实际安装0.2.100)
19
+
20
+ # 训练辅助
21
+ warmup_scheduler==0.3 # 新增明确版本(conda实际安装0.3)
22
+ torchmetrics==1.6.0 # 新增明确版本(conda实际安装1.6.0)
23
+ open_clip_torch==2.29.0 # 新增明确版本(conda实际安装2.29.0)
24
+ evaluate==0.4.3 # 新增明确版本(conda实际安装0.4.3)
25
+ bert_score==0.3.13 # 新增明确版本(conda实际安装0.3.13)
26
+ einops==0.8.0 # 保持与conda环境一致
27
+ wandb==0.17.7 # 保持与conda环境一致
28
+
29
+ # 图像处理
30
+ matplotlib==3.9.2 # 新增明确版本(conda实际安装3.9.2)
31
+ opencv-python==4.10.0.84 # 新增明确版本(conda实际安装4.10.0.84)
32
+ clean-fid==0.1.35 # 新增明确版本(conda实际安装0.1.35)
33
+ skia-python==87.6 # 新增明确版本(conda实际安装87.6)
34
+
35
+ # 部署与接口
36
+ # gradio==5.5.0 # 新增明确版本(conda实际安装5.5.0)
37
+ langchain>=0.0.139 # 保持约束(conda实际安装0.3.7符合要求)
38
+ tiktoken==0.8.0 # 新增明确版本(conda实际安装0.8.0)
39
+
40
+ # 系统工具
41
+ ninja==1.11.1.1 # 新增明确版本(conda实际安装1.11.1.1)
42
+ pynvml==11.5.3 # 新增明确版本(conda实际安装11.5.3)
43
+ colorama==0.4.6 # 新增明确版本(conda实际安装0.4.6)
44
+ click>=8.0.4,<9 # 保持约束(conda实际安装8.1.7符合要求)\
45
+
46
+ sentencepiece