finish with token
Browse files- README.md +2 -2
- app.py +684 -0
- app_test.py +684 -0
- config/base.py +31 -0
- config/v04sv03_lora_r64_upto50layers_bs1_lr1_prodigy_800k_wds_512_filtered_10ep_none_8gpu.py +111 -0
- custom_model_mmdit.py +334 -0
- custom_model_transp_vae.py +331 -0
- custom_pipeline.py +845 -0
- modeling_crello.py +235 -0
- quantizer.py +552 -0
- requirements.txt +46 -0
README.md
CHANGED
@@ -1,8 +1,8 @@
|
|
1 |
---
|
2 |
title: ART V1.0
|
3 |
-
emoji:
|
4 |
colorFrom: gray
|
5 |
-
colorTo:
|
6 |
sdk: gradio
|
7 |
sdk_version: 5.20.0
|
8 |
app_file: app.py
|
|
|
1 |
---
|
2 |
title: ART V1.0
|
3 |
+
emoji: 📊
|
4 |
colorFrom: gray
|
5 |
+
colorTo: indigo
|
6 |
sdk: gradio
|
7 |
sdk_version: 5.20.0
|
8 |
app_file: app.py
|
app.py
ADDED
@@ -0,0 +1,684 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import spaces
|
3 |
+
|
4 |
+
import ast
|
5 |
+
import numpy as np
|
6 |
+
from functools import partial
|
7 |
+
|
8 |
+
import torch
|
9 |
+
import torch.utils.checkpoint
|
10 |
+
|
11 |
+
from PIL import Image
|
12 |
+
import xml.etree.cElementTree as ET
|
13 |
+
from io import BytesIO
|
14 |
+
import base64
|
15 |
+
import json
|
16 |
+
|
17 |
+
import gradio as gr
|
18 |
+
from functools import partial
|
19 |
+
import requests
|
20 |
+
import base64
|
21 |
+
import os
|
22 |
+
import time
|
23 |
+
import re
|
24 |
+
|
25 |
+
from transformers import (
|
26 |
+
AutoTokenizer,
|
27 |
+
set_seed
|
28 |
+
)
|
29 |
+
from typing import List
|
30 |
+
|
31 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
32 |
+
from transformers.generation.stopping_criteria import StoppingCriteria, StoppingCriteriaList, \
|
33 |
+
STOPPING_CRITERIA_INPUTS_DOCSTRING, add_start_docstrings
|
34 |
+
class StopAtSpecificTokenCriteria(StoppingCriteria):
|
35 |
+
def __init__(self, token_id_list: List[int] = None):
|
36 |
+
self.token_id_list = token_id_list
|
37 |
+
@add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING)
|
38 |
+
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
|
39 |
+
return input_ids[0][-1].detach().cpu().numpy() in self.token_id_list
|
40 |
+
|
41 |
+
def ensure_space_after_period(input_string):
|
42 |
+
# 去除多余的空格
|
43 |
+
output_string = re.sub(r'\.\s*', '. ', input_string)
|
44 |
+
return output_string
|
45 |
+
|
46 |
+
def generate_unique_filename():
|
47 |
+
# 生成一个基于时间戳和随机数的唯一文件名
|
48 |
+
timestamp = int(time.time() * 1000) # 时间戳,毫秒级
|
49 |
+
# random_num = random.randint(1000, 9999) # 随机数
|
50 |
+
unique_filename = f"{timestamp}"
|
51 |
+
return unique_filename
|
52 |
+
|
53 |
+
def upload_to_github(file_path,
|
54 |
+
repo='WYBar/gradiodemo_svg',
|
55 |
+
branch='main',
|
56 |
+
token='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'):
|
57 |
+
if not os.path.isfile(file_path):
|
58 |
+
print(f"File not found: {file_path}")
|
59 |
+
return None
|
60 |
+
with open(file_path, 'rb') as file:
|
61 |
+
content = file.read()
|
62 |
+
encoded_content = base64.b64encode(content).decode('utf-8')
|
63 |
+
unique_filename = generate_unique_filename()
|
64 |
+
url = f"https://api.github.com/repos/{repo}/contents/{unique_filename}.svg"
|
65 |
+
headers = {
|
66 |
+
"Authorization": f"token {token}"
|
67 |
+
}
|
68 |
+
response = requests.get(url, headers=headers)
|
69 |
+
|
70 |
+
sha = None
|
71 |
+
if response.status_code == 200:
|
72 |
+
sha = response.json()['sha']
|
73 |
+
elif response.status_code == 404:
|
74 |
+
# 文件不存在,不需要SHA
|
75 |
+
pass
|
76 |
+
else:
|
77 |
+
print(f"Failed to get file status: {response.status_code}")
|
78 |
+
# print(response.text)
|
79 |
+
return None
|
80 |
+
|
81 |
+
headers = {
|
82 |
+
"Authorization": f"token {token}",
|
83 |
+
"Content-Type": "application/json"
|
84 |
+
}
|
85 |
+
data = {
|
86 |
+
"message": "upload svg file",
|
87 |
+
"content": encoded_content,
|
88 |
+
"branch": branch
|
89 |
+
}
|
90 |
+
|
91 |
+
if sha:
|
92 |
+
# 文件存在,更新文件
|
93 |
+
# print('sha exists, update the old one')
|
94 |
+
data["sha"] = sha
|
95 |
+
response = requests.put(url, headers=headers, json=data)
|
96 |
+
else:
|
97 |
+
# 文件不存在,创建新文件
|
98 |
+
print("sha not exist, need to create a new one")
|
99 |
+
response = requests.put(url, headers=headers, json=data)
|
100 |
+
|
101 |
+
# print(response.status_code)
|
102 |
+
# print(response.text)
|
103 |
+
if response.status_code in [200, 201]:
|
104 |
+
# print(response.json()['content']['download_url'])
|
105 |
+
return response.json()['content']['download_url'], unique_filename
|
106 |
+
else:
|
107 |
+
print("None")
|
108 |
+
return None
|
109 |
+
|
110 |
+
def calculate_iou(box1, box2):
|
111 |
+
# 计算两个框的交集
|
112 |
+
x1 = max(box1[0], box2[0])
|
113 |
+
y1 = max(box1[1], box2[1])
|
114 |
+
x2 = min(box1[2], box2[2])
|
115 |
+
y2 = min(box1[3], box2[3])
|
116 |
+
|
117 |
+
intersection_area = max(0, x2 - x1) * max(0, y2 - y1)
|
118 |
+
|
119 |
+
# 计算两个框的并集
|
120 |
+
box1_area = (box1[2] - box1[0]) * (box1[3] - box1[1])
|
121 |
+
box2_area = (box2[2] - box2[0]) * (box2[3] - box2[1])
|
122 |
+
|
123 |
+
union_area = box1_area + box2_area - intersection_area
|
124 |
+
|
125 |
+
# 计算IOU
|
126 |
+
iou = intersection_area / union_area
|
127 |
+
return iou
|
128 |
+
|
129 |
+
def adjust_coordinates(box):
|
130 |
+
size = 32
|
131 |
+
(x1, y1, x2, y2) = box
|
132 |
+
if x1 % size != 0:
|
133 |
+
x1 = (x1 // size) * size
|
134 |
+
if x2 % size != 0:
|
135 |
+
x2 = (x2 // size + 1) * size
|
136 |
+
|
137 |
+
if y1 % size != 0:
|
138 |
+
y1 = (y1 // size) * size
|
139 |
+
if y2 % size != 0:
|
140 |
+
y2 = (y2 // size + 1) * size
|
141 |
+
return (x1, y1, x2, y2)
|
142 |
+
|
143 |
+
def adjust_validation_box(validation_box):
|
144 |
+
return [adjust_coordinates(box) for box in validation_box]
|
145 |
+
|
146 |
+
def get_list_layer_box(list_png_images):
|
147 |
+
list_layer_box = []
|
148 |
+
for img in list_png_images:
|
149 |
+
img_np = np.array(img)
|
150 |
+
alpha_channel = img_np[:, :, -1]
|
151 |
+
|
152 |
+
# Step 1: Find the non-zero indices
|
153 |
+
rows, cols = np.nonzero(alpha_channel)
|
154 |
+
|
155 |
+
if (len(rows) == 0) or (len(cols) == 0):
|
156 |
+
# If there are no non-zero indices, we can skip this layer
|
157 |
+
list_layer_box.append((0, 0, 0, 0))
|
158 |
+
continue
|
159 |
+
|
160 |
+
# Step 2: Get the minimum and maximum indices for rows and columns
|
161 |
+
min_row, max_row = rows.min().item(), rows.max().item()
|
162 |
+
min_col, max_col = cols.min().item(), cols.max().item()
|
163 |
+
|
164 |
+
# Step 3: Quantize the minimum values down to the nearest multiple of 8
|
165 |
+
quantized_min_row = (min_row // 8) * 8
|
166 |
+
quantized_min_col = (min_col // 8) * 8
|
167 |
+
|
168 |
+
# Step 4: Quantize the maximum values up to the nearest multiple of 8 outside of the max
|
169 |
+
quantized_max_row = ((max_row // 8) + 1) * 8
|
170 |
+
quantized_max_col = ((max_col // 8) + 1) * 8
|
171 |
+
list_layer_box.append(
|
172 |
+
(quantized_min_col, quantized_min_row, quantized_max_col, quantized_max_row)
|
173 |
+
)
|
174 |
+
return list_layer_box
|
175 |
+
|
176 |
+
def pngs_to_svg(list_png_images):
|
177 |
+
list_layer_box = get_list_layer_box(list_png_images)
|
178 |
+
assert(len(list_png_images) == len(list_layer_box))
|
179 |
+
width, height = list_png_images[0].width, list_png_images[0].height
|
180 |
+
img_svg = ET.Element(
|
181 |
+
'svg',
|
182 |
+
{
|
183 |
+
"width": str(width),
|
184 |
+
"height": str(height),
|
185 |
+
"xmlns": "http://www.w3.org/2000/svg",
|
186 |
+
"xmlns:svg": "http://www.w3.org/2000/svg",
|
187 |
+
"xmlns:xlink":"http://www.w3.org/1999/xlink"
|
188 |
+
}
|
189 |
+
)
|
190 |
+
for img, box in zip(list_png_images, list_layer_box):
|
191 |
+
x, y, w, h = box[0], box[1], box[2]-box[0], box[3]-box[1]
|
192 |
+
if (w == 0 or h == 0):
|
193 |
+
continue
|
194 |
+
img = img.crop((x, y, x+w, y+h))
|
195 |
+
buffer = BytesIO()
|
196 |
+
img.save(buffer, format='PNG')
|
197 |
+
img_str = base64.b64encode(buffer.getvalue())
|
198 |
+
ET.SubElement(
|
199 |
+
img_svg,
|
200 |
+
"image",
|
201 |
+
{
|
202 |
+
"x": str(x),
|
203 |
+
"y": str(y),
|
204 |
+
"width": str(w),
|
205 |
+
"height": str(h),
|
206 |
+
"xlink:href": "data:image/png;base64,"+img_str.decode('utf-8')
|
207 |
+
}
|
208 |
+
)
|
209 |
+
return ET.tostring(img_svg, encoding='utf-8').decode('utf-8')
|
210 |
+
|
211 |
+
def calculate_iou(box1, box2):
|
212 |
+
# 计算两个框的交集
|
213 |
+
x1 = max(box1[0], box2[0])
|
214 |
+
y1 = max(box1[1], box2[1])
|
215 |
+
x2 = min(box1[2], box2[2])
|
216 |
+
y2 = min(box1[3], box2[3])
|
217 |
+
|
218 |
+
intersection_area = max(0, x2 - x1) * max(0, y2 - y1)
|
219 |
+
|
220 |
+
# 计算两个框的并集
|
221 |
+
box1_area = (box1[2] - box1[0]) * (box1[3] - box1[1])
|
222 |
+
box2_area = (box2[2] - box2[0]) * (box2[3] - box2[1])
|
223 |
+
|
224 |
+
union_area = box1_area + box2_area - intersection_area
|
225 |
+
|
226 |
+
# 计算IOU
|
227 |
+
iou = intersection_area / union_area
|
228 |
+
return iou
|
229 |
+
|
230 |
+
# @spaces.GPU(enable_queue=True, duration=60)
|
231 |
+
def buildmodel(**kwargs):
|
232 |
+
from modeling_crello import CrelloModel, CrelloModelConfig
|
233 |
+
from quantizer import get_quantizer
|
234 |
+
# seed / input model / resume
|
235 |
+
resume = kwargs.get('resume', None)
|
236 |
+
seed = kwargs.get('seed', None)
|
237 |
+
input_model = kwargs.get('input_model', None)
|
238 |
+
quantizer_version = kwargs.get('quantizer_version', 'v4')
|
239 |
+
device = "cuda"
|
240 |
+
|
241 |
+
set_seed(seed)
|
242 |
+
# old_tokenizer = AutoTokenizer.from_pretrained(input_model, trust_remote_code=True)
|
243 |
+
old_tokenizer = AutoTokenizer.from_pretrained(
|
244 |
+
"WYBar/LLM_For_Layout_Planning", # 仓库路径
|
245 |
+
subfolder="Meta-Llama-3-8B", # 子目录对应模型文件夹
|
246 |
+
trust_remote_code=True,
|
247 |
+
# cache_dir="/openseg_blob/v-yanbin/GradioDemo/cache_dir",
|
248 |
+
)
|
249 |
+
old_vocab_size = len(old_tokenizer)
|
250 |
+
# tokenizer = AutoTokenizer.from_pretrained(resume, trust_remote_code=True)
|
251 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
252 |
+
"WYBar/LLM_For_Layout_Planning",
|
253 |
+
subfolder="checkpoint-26000", # 检查点所在子目录
|
254 |
+
trust_remote_code=True,
|
255 |
+
# cache_dir="/openseg_blob/v-yanbin/GradioDemo/cache_dir",
|
256 |
+
)
|
257 |
+
|
258 |
+
quantizer = get_quantizer(
|
259 |
+
quantizer_version,
|
260 |
+
update_vocab = False,
|
261 |
+
decimal_quantize_types = kwargs.get('decimal_quantize_types'),
|
262 |
+
mask_values = kwargs['mask_values'],
|
263 |
+
width = kwargs['width'],
|
264 |
+
height = kwargs['height'],
|
265 |
+
simplify_json = False,
|
266 |
+
num_mask_tokens = 0,
|
267 |
+
mask_type = kwargs.get('mask_type'),
|
268 |
+
)
|
269 |
+
quantizer.setup_tokenizer(tokenizer)
|
270 |
+
|
271 |
+
model_args = CrelloModelConfig(
|
272 |
+
old_vocab_size = old_vocab_size,
|
273 |
+
vocab_size=len(tokenizer),
|
274 |
+
pad_token_id=tokenizer.pad_token_id,
|
275 |
+
ignore_ids=tokenizer.convert_tokens_to_ids(quantizer.ignore_tokens),
|
276 |
+
)
|
277 |
+
model_args.freeze_lm = True
|
278 |
+
model_args.opt_version = "WYBar/LLM_For_Layout_Planning"
|
279 |
+
model_args.use_lora = False
|
280 |
+
model_args.load_in_4bit = kwargs.get('load_in_4bit', False)
|
281 |
+
# model = CrelloModel.from_pretrained(
|
282 |
+
# resume,
|
283 |
+
# config=model_args
|
284 |
+
# ).to(device)
|
285 |
+
# model = CrelloModel.from_pretrained(
|
286 |
+
# "WYBar/LLM_For_Layout_Planning",
|
287 |
+
# subfolder="checkpoint-26000", # 加载检查点目录
|
288 |
+
# config=model_args,
|
289 |
+
# # cache_dir="/openseg_blob/v-yanbin/GradioDemo/cache_dir",
|
290 |
+
# )
|
291 |
+
model = CrelloModel(config=model_args)
|
292 |
+
print("before .to(device)")
|
293 |
+
model = model.to(device)
|
294 |
+
print("after .to(device)")
|
295 |
+
model = model.bfloat16()
|
296 |
+
model.eval()
|
297 |
+
|
298 |
+
tokenizer.add_special_tokens({"mask_token": "<mask>"})
|
299 |
+
quantizer.additional_special_tokens.add("<mask>")
|
300 |
+
added_special_tokens_list = ["<layout>", "<position>", "<wholecaption>"]
|
301 |
+
tokenizer.add_special_tokens({"additional_special_tokens": added_special_tokens_list}, replace_additional_special_tokens=False)
|
302 |
+
for token in added_special_tokens_list:
|
303 |
+
quantizer.additional_special_tokens.add(token)
|
304 |
+
|
305 |
+
return model, quantizer, tokenizer
|
306 |
+
|
307 |
+
def construction_layout():
|
308 |
+
params_dict = {
|
309 |
+
# 需要修改
|
310 |
+
"input_model": "WYBar/LLM_For_Layout_Planning",
|
311 |
+
"resume": "WYBar/LLM_For_Layout_Planning",
|
312 |
+
|
313 |
+
"seed": 0,
|
314 |
+
"mask_values": False,
|
315 |
+
"quantizer_version": 'v4',
|
316 |
+
"mask_type": 'cm3',
|
317 |
+
"decimal_quantize_types": [],
|
318 |
+
"num_mask_tokens": 0,
|
319 |
+
"width": 512,
|
320 |
+
"height": 512,
|
321 |
+
"device": 0,
|
322 |
+
}
|
323 |
+
device = "cuda"
|
324 |
+
# Init model
|
325 |
+
model, quantizer, tokenizer = buildmodel(**params_dict)
|
326 |
+
|
327 |
+
print('resize token embeddings to match the tokenizer', 129423)
|
328 |
+
model.lm.resize_token_embeddings(129423)
|
329 |
+
model.input_embeddings = model.lm.get_input_embeddings()
|
330 |
+
print('after token embeddings to match the tokenizer', 129423)
|
331 |
+
return model, quantizer, tokenizer, params_dict["width"], params_dict["height"], device
|
332 |
+
|
333 |
+
@torch.no_grad()
|
334 |
+
@spaces.GPU(enable_queue=True, duration=60)
|
335 |
+
def evaluate_v1(inputs, model, quantizer, tokenizer, width, height, device, do_sample=False, temperature=1.0, top_p=1.0, top_k=50):
|
336 |
+
json_example = inputs
|
337 |
+
input_intension = '{"wholecaption":"' + json_example["wholecaption"] + '","layout":[{"layer":'
|
338 |
+
inputs = tokenizer(
|
339 |
+
input_intension, return_tensors="pt"
|
340 |
+
).to(device)
|
341 |
+
|
342 |
+
stopping_criteria = StoppingCriteriaList()
|
343 |
+
stopping_criteria.append(StopAtSpecificTokenCriteria(token_id_list=[128000]))
|
344 |
+
|
345 |
+
outputs = model.lm.generate(**inputs, use_cache=True, max_length=8000, stopping_criteria=stopping_criteria, do_sample=do_sample, temperature=temperature, top_p=top_p, top_k=top_k)
|
346 |
+
inputs_length = inputs['input_ids'].shape[1]
|
347 |
+
outputs = outputs[:, inputs_length:]
|
348 |
+
|
349 |
+
outputs_word = tokenizer.batch_decode(outputs)[0]
|
350 |
+
split_word = outputs_word.split('}]}')[0]+"}]}"
|
351 |
+
split_word = '{"wholecaption":"' + json_example["wholecaption"].replace('\n', '\\n').replace('"', '\\"') + '","layout":[{"layer":' + split_word
|
352 |
+
map_dict = quantizer.construct_map_dict()
|
353 |
+
|
354 |
+
for key ,value in map_dict.items():
|
355 |
+
split_word = split_word.replace(key, value)
|
356 |
+
try:
|
357 |
+
pred_json_example = json.loads(split_word)
|
358 |
+
for layer in pred_json_example["layout"]:
|
359 |
+
layer['x'] = round(int(width)*layer['x'])
|
360 |
+
layer['y'] = round(int(height)*layer['y'])
|
361 |
+
layer['width'] = round(int(width)*layer['width'])
|
362 |
+
layer['height'] = round(int(height)*layer['height'])
|
363 |
+
except Exception as e:
|
364 |
+
print(e)
|
365 |
+
pred_json_example = None
|
366 |
+
return pred_json_example
|
367 |
+
|
368 |
+
def inference(generate_method, intention, model, quantizer, tokenizer, width, height, device, do_sample=True, temperature=1.0, top_p=1.0, top_k=50):
|
369 |
+
def FormulateInput(intension: str):
|
370 |
+
resdict = {}
|
371 |
+
resdict["wholecaption"] = intension
|
372 |
+
resdict["layout"] = []
|
373 |
+
return resdict
|
374 |
+
|
375 |
+
rawdata = FormulateInput(intention)
|
376 |
+
|
377 |
+
if generate_method == 'v1':
|
378 |
+
max_try_time = 5
|
379 |
+
preddata = None
|
380 |
+
while preddata is None and max_try_time > 0:
|
381 |
+
preddata = evaluate_v1(rawdata, model, quantizer, tokenizer, width, height, device, do_sample=do_sample, temperature=temperature, top_p=top_p, top_k=top_k)
|
382 |
+
max_try_time -= 1
|
383 |
+
else:
|
384 |
+
print("Please input correct generate method")
|
385 |
+
preddata = None
|
386 |
+
|
387 |
+
return preddata
|
388 |
+
|
389 |
+
# @spaces.GPU(enable_queue=True, duration=60)
|
390 |
+
def construction():
|
391 |
+
from custom_model_mmdit import CustomFluxTransformer2DModel
|
392 |
+
from custom_model_transp_vae import AutoencoderKLTransformerTraining as CustomVAE
|
393 |
+
from custom_pipeline import CustomFluxPipelineCfg
|
394 |
+
|
395 |
+
transformer = CustomFluxTransformer2DModel.from_pretrained(
|
396 |
+
"WYBar/ART_test_weights",
|
397 |
+
subfolder="fused_transformer",
|
398 |
+
torch_dtype=torch.bfloat16,
|
399 |
+
# cache_dir="/openseg_blob/v-yanbin/GradioDemo/cache_dir"
|
400 |
+
)
|
401 |
+
|
402 |
+
transp_vae = CustomVAE.from_pretrained(
|
403 |
+
"WYBar/ART_test_weights",
|
404 |
+
subfolder="custom_vae",
|
405 |
+
torch_dtype=torch.float32,
|
406 |
+
# cache_dir="/openseg_blob/v-yanbin/GradioDemo/cache_dir"
|
407 |
+
)
|
408 |
+
|
409 |
+
token = os.environ.get("HF_TOKEN")
|
410 |
+
pipeline = CustomFluxPipelineCfg.from_pretrained(
|
411 |
+
"black-forest-labs/FLUX.1-dev",
|
412 |
+
transformer=transformer,
|
413 |
+
torch_dtype=torch.bfloat16,
|
414 |
+
token=token,
|
415 |
+
# cache_dir="/openseg_blob/v-yanbin/GradioDemo/cache_dir"
|
416 |
+
).to("cuda")
|
417 |
+
pipeline.enable_model_cpu_offload(gpu_id=0) # Save GPU memory
|
418 |
+
|
419 |
+
return pipeline, transp_vae
|
420 |
+
|
421 |
+
@spaces.GPU(enable_queue=True, duration=60)
|
422 |
+
def test_one_sample(validation_box, validation_prompt, true_gs, inference_steps, pipeline, generator, transp_vae):
|
423 |
+
print(validation_box)
|
424 |
+
output, rgba_output, _, _ = pipeline(
|
425 |
+
prompt=validation_prompt,
|
426 |
+
validation_box=validation_box,
|
427 |
+
generator=generator,
|
428 |
+
height=512,
|
429 |
+
width=512,
|
430 |
+
num_layers=len(validation_box),
|
431 |
+
guidance_scale=4.0,
|
432 |
+
num_inference_steps=inference_steps,
|
433 |
+
transparent_decoder=transp_vae,
|
434 |
+
true_gs=true_gs
|
435 |
+
)
|
436 |
+
images = output.images # list of PIL, len=layers
|
437 |
+
rgba_images = [Image.fromarray(arr, 'RGBA') for arr in rgba_output]
|
438 |
+
|
439 |
+
output_gradio = []
|
440 |
+
merged_pil = images[1].convert('RGBA')
|
441 |
+
for frame_idx, frame_pil in enumerate(rgba_images):
|
442 |
+
if frame_idx < 2:
|
443 |
+
frame_pil = images[frame_idx].convert('RGBA') # merged and background
|
444 |
+
else:
|
445 |
+
merged_pil = Image.alpha_composite(merged_pil, frame_pil)
|
446 |
+
output_gradio.append(frame_pil)
|
447 |
+
|
448 |
+
return output_gradio
|
449 |
+
|
450 |
+
def svg_test_one_sample(validation_prompt, validation_box_str, seed, true_gs, inference_steps, pipeline, transp_vae):
|
451 |
+
generator = torch.Generator().manual_seed(seed)
|
452 |
+
try:
|
453 |
+
validation_box = ast.literal_eval(validation_box_str)
|
454 |
+
except Exception as e:
|
455 |
+
return [f"Error parsing validation_box: {e}"]
|
456 |
+
if not isinstance(validation_box, list) or not all(isinstance(t, tuple) and len(t) == 4 for t in validation_box):
|
457 |
+
return ["validation_box must be a list of tuples, each of length 4."]
|
458 |
+
|
459 |
+
validation_box = adjust_validation_box(validation_box)
|
460 |
+
|
461 |
+
result_images = test_one_sample(validation_box, validation_prompt, true_gs, inference_steps, pipeline, generator, transp_vae)
|
462 |
+
|
463 |
+
svg_img = pngs_to_svg(result_images[1:])
|
464 |
+
|
465 |
+
svg_file_path = './image.svg'
|
466 |
+
os.makedirs(os.path.dirname(svg_file_path), exist_ok=True)
|
467 |
+
with open(svg_file_path, 'w', encoding='utf-8') as f:
|
468 |
+
f.write(svg_img)
|
469 |
+
|
470 |
+
return result_images, svg_file_path
|
471 |
+
|
472 |
+
def main():
|
473 |
+
model, quantizer, tokenizer, width, height, device = construction_layout()
|
474 |
+
|
475 |
+
inference_partial = partial(
|
476 |
+
inference,
|
477 |
+
model=model,
|
478 |
+
quantizer=quantizer,
|
479 |
+
tokenizer=tokenizer,
|
480 |
+
width=width,
|
481 |
+
height=height,
|
482 |
+
device=device
|
483 |
+
)
|
484 |
+
|
485 |
+
def process_preddate(intention, temperature, top_p, generate_method='v1'):
|
486 |
+
intention = intention.replace('\n', '').replace('\r', '').replace('\\', '')
|
487 |
+
intention = ensure_space_after_period(intention)
|
488 |
+
if temperature == 0.0:
|
489 |
+
# print("looking for greedy decoding strategies, set `do_sample=False`.")
|
490 |
+
preddata = inference_partial(generate_method, intention, do_sample=False)
|
491 |
+
else:
|
492 |
+
preddata = inference_partial(generate_method, intention, temperature=temperature, top_p=top_p)
|
493 |
+
# wholecaption = preddata["wholecaption"]
|
494 |
+
layouts = preddata["layout"]
|
495 |
+
list_box = []
|
496 |
+
for i, layout in enumerate(layouts):
|
497 |
+
x, y = layout["x"], layout["y"]
|
498 |
+
width, height = layout["width"], layout["height"]
|
499 |
+
if i == 0:
|
500 |
+
list_box.append((0, 0, width, height))
|
501 |
+
list_box.append((0, 0, width, height))
|
502 |
+
else:
|
503 |
+
left = x - width // 2
|
504 |
+
top = y - height // 2
|
505 |
+
right = x + width // 2
|
506 |
+
bottom = y + height // 2
|
507 |
+
list_box.append((left, top, right, bottom))
|
508 |
+
|
509 |
+
# print(list_box)
|
510 |
+
filtered_boxes = list_box[:2]
|
511 |
+
for i in range(2, len(list_box)):
|
512 |
+
keep = True
|
513 |
+
for j in range(1, len(filtered_boxes)):
|
514 |
+
iou = calculate_iou(list_box[i], filtered_boxes[j])
|
515 |
+
if iou > 0.65:
|
516 |
+
print(list_box[i], filtered_boxes[j])
|
517 |
+
keep = False
|
518 |
+
break
|
519 |
+
if keep:
|
520 |
+
filtered_boxes.append(list_box[i])
|
521 |
+
|
522 |
+
return str(filtered_boxes), intention, str(filtered_boxes)
|
523 |
+
|
524 |
+
# def process_preddate(intention, generate_method='v1'):
|
525 |
+
# list_box = [(0, 0, 512, 512), (0, 0, 512, 512), (136, 184, 512, 512), (144, 0, 512, 512), (0, 0, 328, 136), (160, 112, 512, 360), (168, 112, 512, 360), (40, 232, 112, 296), (32, 88, 248, 176), (48, 424, 144, 448), (48, 464, 144, 488), (240, 464, 352, 488), (384, 464, 488, 488), (48, 480, 144, 504), (240, 480, 360, 504), (456, 0, 512, 56), (0, 0, 56, 40), (440, 0, 512, 40), (0, 24, 48, 88), (48, 168, 168, 240)]
|
526 |
+
# wholecaption = "Design an engaging and vibrant recruitment advertisement for our company. The image should feature three animated characters in a modern cityscape, depicting a dynamic and collaborative work environment. Incorporate a light bulb graphic with a question mark, symbolizing innovation, creativity, and problem-solving. Use bold text to announce \"WE ARE RECRUITING\" and provide the company's social media handle \"@reallygreatsite\" and a contact phone number \"+123-456-7890\" for interested individuals. The overall design should be playful and youthful, attracting potential recruits who are innovative and eager to contribute to a lively team."
|
527 |
+
# json_file = "/home/wyb/openseg_blob/v-yanbin/GradioDemo/LLM-For-Layout-Planning/inference_test.json"
|
528 |
+
# return wholecaption, str(list_box), json_file
|
529 |
+
|
530 |
+
pipeline, transp_vae = construction()
|
531 |
+
|
532 |
+
gradio_test_one_sample_partial = partial(
|
533 |
+
svg_test_one_sample,
|
534 |
+
pipeline=pipeline,
|
535 |
+
transp_vae=transp_vae,
|
536 |
+
)
|
537 |
+
|
538 |
+
def process_svg(text_input, tuple_input, seed, true_gs, inference_steps):
|
539 |
+
result_images = []
|
540 |
+
result_images, svg_file_path = gradio_test_one_sample_partial(text_input, tuple_input, seed, true_gs, inference_steps)
|
541 |
+
|
542 |
+
url, unique_filename = upload_to_github(file_path=svg_file_path)
|
543 |
+
unique_filename = f'{unique_filename}'
|
544 |
+
|
545 |
+
if url != None:
|
546 |
+
print(f"File uploaded to: {url}")
|
547 |
+
svg_editor = f"""
|
548 |
+
<iframe src="https://svgedit.netlify.app/editor/index.html?\
|
549 |
+
storagePrompt=false&url={url}" \
|
550 |
+
width="100%", height="800px"></iframe>
|
551 |
+
"""
|
552 |
+
else:
|
553 |
+
print('upload_to_github FAILED!')
|
554 |
+
svg_editor = f"""
|
555 |
+
<iframe src="https://svgedit.netlify.app/editor/index.html" \
|
556 |
+
width="100%", height="800px"></iframe>
|
557 |
+
"""
|
558 |
+
|
559 |
+
return result_images, svg_file_path, svg_editor
|
560 |
+
|
561 |
+
def one_click_generate(intention_input, temperature, top_p, seed, true_gs, inference_steps):
|
562 |
+
# 首先调用process_preddate
|
563 |
+
list_box_output, intention_input, list_box_output = process_preddate(intention_input, temperature, top_p)
|
564 |
+
|
565 |
+
# 然后将process_preddate的输出作为process_svg的输入
|
566 |
+
result_images, svg_file, svg_editor = process_svg(intention_input, list_box_output, seed, true_gs, inference_steps)
|
567 |
+
|
568 |
+
# 返回两个函数的输出
|
569 |
+
return list_box_output, result_images, svg_file, svg_editor, intention_input, list_box_output
|
570 |
+
|
571 |
+
def clear_inputs1():
|
572 |
+
return "", ""
|
573 |
+
|
574 |
+
def clear_inputs2():
|
575 |
+
return "", ""
|
576 |
+
|
577 |
+
def transfer_inputs(intention, list_box):
|
578 |
+
return intention, list_box
|
579 |
+
|
580 |
+
theme = gr.themes.Soft(
|
581 |
+
radius_size="lg",
|
582 |
+
).set(
|
583 |
+
block_background_fill='*primary_50',
|
584 |
+
block_border_color='*primary_200',
|
585 |
+
block_border_width='1px',
|
586 |
+
block_border_width_dark='100px',
|
587 |
+
block_info_text_color='*primary_950',
|
588 |
+
block_label_border_color='*primary_200',
|
589 |
+
block_radius='*radius_lg'
|
590 |
+
)
|
591 |
+
|
592 |
+
with gr.Blocks(theme=theme) as demo:
|
593 |
+
gr.HTML("<h1 style='text-align: center;'>ART: Anonymous Region Transformer for Variable Multi-Layer Transparent Image Generation</h1>")
|
594 |
+
gr.HTML("<h2>Anonymous Region Layout Planner</h2>")
|
595 |
+
|
596 |
+
with gr.Row():
|
597 |
+
with gr.Column():
|
598 |
+
intention_input = gr.Textbox(lines=15, placeholder="Enter intention", label="Prompt")
|
599 |
+
with gr.Row():
|
600 |
+
temperature_input=gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label="Temperature", value=0.0)
|
601 |
+
top_p_input=gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label="Top P", value=0.0)
|
602 |
+
with gr.Row():
|
603 |
+
clear_btn1 = gr.Button("Clear")
|
604 |
+
model_btn1 = gr.Button("Commit", variant='primary')
|
605 |
+
transfer_btn1 = gr.Button("Export to below")
|
606 |
+
|
607 |
+
one_click_btn = gr.Button("One Click Generate ALL", variant='primary')
|
608 |
+
|
609 |
+
with gr.Column():
|
610 |
+
list_box_output = gr.Textbox(lines=10, placeholder="Validation Box", label="Validation Box")
|
611 |
+
|
612 |
+
examples = gr.Examples(
|
613 |
+
examples=[
|
614 |
+
['The image is a graphic design with a celebratory theme. At the top, there is a banner with the text \"Happy Anniversary\" in a bold, sans-serif font. Below this banner, there is a circular frame containing a photograph of a couple. The man has short, dark hair and is wearing a light-colored sweater, while the woman has long blonde hair and is also wearing a light-colored sweater. They are both smiling and appear to be embracing each other.Surrounding the circular frame are decorative elements such as pink flowers and green leaves, which add a festive touch to the design. Below the circular frame, there is a text that reads "Isabel & Morgan" in a cursive, elegant font, suggesting that the couple\'s names are Isabel and Morgan.At the bottom of the image, there is a banner with a message that says "Happy Anniversary! Cheers to another year of love, laughter, and cherished memories together.\" This text is in a smaller, sans-serif font and is placed against a solid background, providing a clear message of celebration and well-wishes for the couple.The overall style of the image is warm and celebratory, with a color scheme that includes shades of pink, green, and white, which contribute to a joyful and romantic atmosphere.'],
|
615 |
+
['The image is a digital illustration with a light blue background. At the top, there is a logo consisting of a snake wrapped around a staff, which is a common symbol in healthcare. Below the logo, the text "International Nurses Day" is prominently displayed in white, with the date "12 May 20xx" in smaller font size.The central part of the image features two stylized characters. On the left, there is a female character with dark hair, wearing a white nurse\'s uniform with a cap. She is holding a clipboard and appears to be speaking or gesturing, as indicated by a speech bubble with the word "OK" in it. On the right, there is a male character with light brown hair, wearing a light blue shirt with a white collar and a white apron. He is holding a stethoscope to his ear, suggesting he is a doctor or a healthcare professional.The characters are depicted in a friendly and approachable manner, with smiles on their faces. Around them, there are small blue plus signs, which are often associated with healthcare and medical services. The overall style of the image is clean, modern, and appears to be designed to celebrate International Nurses Day.'],
|
616 |
+
['The image features a graphic design with a festive theme. At the top, there is a decorative border with a wavy pattern. Below this border, the text "WINTER SEASON SPECIAL COOKIES" is prominently displayed in a bold, sans-serif font. The text is black with a slight shadow effect, giving it a three-dimensional appearance.In the center of the image, there are three illustrated gingerbread cookies. Each cookie has a smiling face with eyes, a nose, and a mouth, and they are colored in a warm, brown hue. The cookies are arranged in a staggered formation, with the middle cookie slightly higher than the others, creating a sense of depth.At the bottom of the image, there is a call to action that reads "ORDER.NOW" in a large, bold, sans-serif font. The text is colored in a darker shade of brown, contrasting with the lighter background. The overall style of the image suggests it is an advertisement or promotional graphic for a winter-themed cookie special.']
|
617 |
+
],
|
618 |
+
inputs=[intention_input]
|
619 |
+
)
|
620 |
+
|
621 |
+
gr.HTML("<h2>Anonymous Region Transformer</h2>")
|
622 |
+
with gr.Row():
|
623 |
+
with gr.Column():
|
624 |
+
text_input = gr.Textbox(lines=10, placeholder="Enter prompt text", label="Prompt")
|
625 |
+
tuple_input = gr.Textbox(lines=5, placeholder="Enter list of tuples, e.g., [(1, 2, 3, 4), (5, 6, 7, 8)]", label="Validation Box")
|
626 |
+
with gr.Row():
|
627 |
+
true_gs_input=gr.Slider(minimum=3.0, maximum=5.0, step=0.1, label="true_gs", value=3.5)
|
628 |
+
inference_steps_input=gr.Slider(minimum=5, maximum=50, step=1, label="inference_steps", value=28)
|
629 |
+
with gr.Row():
|
630 |
+
seed_input = gr.Number(label="Seed", value=42)
|
631 |
+
with gr.Row():
|
632 |
+
transfer_btn2 = gr.Button("Import from above")
|
633 |
+
with gr.Row():
|
634 |
+
clear_btn2 = gr.Button("Clear")
|
635 |
+
model_btn2 = gr.Button("Commit", variant='primary')
|
636 |
+
|
637 |
+
with gr.Column():
|
638 |
+
result_images = gr.Gallery(label="Result Images", columns=5, height='auto')
|
639 |
+
|
640 |
+
gr.HTML("<h1>SVG Image</h1>")
|
641 |
+
svg_file = gr.File(label="Download SVG Image")
|
642 |
+
svg_editor = gr.HTML(label="Editable SVG Editor")
|
643 |
+
|
644 |
+
model_btn1.click(
|
645 |
+
fn=process_preddate,
|
646 |
+
inputs=[intention_input, temperature_input, top_p_input],
|
647 |
+
outputs=[list_box_output, text_input, tuple_input],
|
648 |
+
api_name="process_preddate"
|
649 |
+
)
|
650 |
+
clear_btn1.click(
|
651 |
+
fn=clear_inputs1,
|
652 |
+
inputs=[],
|
653 |
+
outputs=[intention_input, list_box_output]
|
654 |
+
)
|
655 |
+
model_btn2.click(
|
656 |
+
fn=process_svg,
|
657 |
+
inputs=[text_input, tuple_input, seed_input, true_gs_input, inference_steps_input],
|
658 |
+
outputs=[result_images, svg_file, svg_editor],
|
659 |
+
api_name="process_svg"
|
660 |
+
)
|
661 |
+
clear_btn2.click(
|
662 |
+
fn=clear_inputs2,
|
663 |
+
inputs=[],
|
664 |
+
outputs=[text_input, tuple_input]
|
665 |
+
)
|
666 |
+
transfer_btn1.click(
|
667 |
+
fn=transfer_inputs,
|
668 |
+
inputs=[intention_input, list_box_output],
|
669 |
+
outputs=[text_input, tuple_input]
|
670 |
+
)
|
671 |
+
transfer_btn2.click(
|
672 |
+
fn=transfer_inputs,
|
673 |
+
inputs=[intention_input, list_box_output],
|
674 |
+
outputs=[text_input, tuple_input]
|
675 |
+
)
|
676 |
+
one_click_btn.click(
|
677 |
+
fn=one_click_generate,
|
678 |
+
inputs=[intention_input, temperature_input, top_p_input, seed_input, true_gs_input, inference_steps_input],
|
679 |
+
outputs=[list_box_output, result_images, svg_file, svg_editor, text_input, tuple_input]
|
680 |
+
)
|
681 |
+
demo.launch()
|
682 |
+
|
683 |
+
if __name__ == "__main__":
|
684 |
+
main()
|
app_test.py
ADDED
@@ -0,0 +1,684 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
# import spaces
|
3 |
+
|
4 |
+
import ast
|
5 |
+
import numpy as np
|
6 |
+
from functools import partial
|
7 |
+
|
8 |
+
import torch
|
9 |
+
import torch.utils.checkpoint
|
10 |
+
|
11 |
+
from PIL import Image
|
12 |
+
import xml.etree.cElementTree as ET
|
13 |
+
from io import BytesIO
|
14 |
+
import base64
|
15 |
+
import json
|
16 |
+
|
17 |
+
import gradio as gr
|
18 |
+
from functools import partial
|
19 |
+
import requests
|
20 |
+
import base64
|
21 |
+
import os
|
22 |
+
import time
|
23 |
+
import re
|
24 |
+
|
25 |
+
from transformers import (
|
26 |
+
AutoTokenizer,
|
27 |
+
set_seed
|
28 |
+
)
|
29 |
+
from typing import List
|
30 |
+
|
31 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
32 |
+
from transformers.generation.stopping_criteria import StoppingCriteria, StoppingCriteriaList, \
|
33 |
+
STOPPING_CRITERIA_INPUTS_DOCSTRING, add_start_docstrings
|
34 |
+
class StopAtSpecificTokenCriteria(StoppingCriteria):
|
35 |
+
def __init__(self, token_id_list: List[int] = None):
|
36 |
+
self.token_id_list = token_id_list
|
37 |
+
@add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING)
|
38 |
+
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
|
39 |
+
return input_ids[0][-1].detach().cpu().numpy() in self.token_id_list
|
40 |
+
|
41 |
+
def ensure_space_after_period(input_string):
|
42 |
+
# 去除多余的空格
|
43 |
+
output_string = re.sub(r'\.\s*', '. ', input_string)
|
44 |
+
return output_string
|
45 |
+
|
46 |
+
def generate_unique_filename():
|
47 |
+
# 生成一个基于时间戳和随机数的唯一文件名
|
48 |
+
timestamp = int(time.time() * 1000) # 时间戳,毫秒级
|
49 |
+
# random_num = random.randint(1000, 9999) # 随机数
|
50 |
+
unique_filename = f"{timestamp}"
|
51 |
+
return unique_filename
|
52 |
+
|
53 |
+
def upload_to_github(file_path,
|
54 |
+
repo='WYBar/gradiodemo_svg',
|
55 |
+
branch='main',
|
56 |
+
token='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'):
|
57 |
+
if not os.path.isfile(file_path):
|
58 |
+
print(f"File not found: {file_path}")
|
59 |
+
return None
|
60 |
+
with open(file_path, 'rb') as file:
|
61 |
+
content = file.read()
|
62 |
+
encoded_content = base64.b64encode(content).decode('utf-8')
|
63 |
+
unique_filename = generate_unique_filename()
|
64 |
+
url = f"https://api.github.com/repos/{repo}/contents/{unique_filename}.svg"
|
65 |
+
headers = {
|
66 |
+
"Authorization": f"token {token}"
|
67 |
+
}
|
68 |
+
response = requests.get(url, headers=headers)
|
69 |
+
|
70 |
+
sha = None
|
71 |
+
if response.status_code == 200:
|
72 |
+
sha = response.json()['sha']
|
73 |
+
elif response.status_code == 404:
|
74 |
+
# 文件不存在,不需要SHA
|
75 |
+
pass
|
76 |
+
else:
|
77 |
+
print(f"Failed to get file status: {response.status_code}")
|
78 |
+
# print(response.text)
|
79 |
+
return None
|
80 |
+
|
81 |
+
headers = {
|
82 |
+
"Authorization": f"token {token}",
|
83 |
+
"Content-Type": "application/json"
|
84 |
+
}
|
85 |
+
data = {
|
86 |
+
"message": "upload svg file",
|
87 |
+
"content": encoded_content,
|
88 |
+
"branch": branch
|
89 |
+
}
|
90 |
+
|
91 |
+
if sha:
|
92 |
+
# 文件存在,更新文件
|
93 |
+
# print('sha exists, update the old one')
|
94 |
+
data["sha"] = sha
|
95 |
+
response = requests.put(url, headers=headers, json=data)
|
96 |
+
else:
|
97 |
+
# 文件不存在,创建新文件
|
98 |
+
print("sha not exist, need to create a new one")
|
99 |
+
response = requests.put(url, headers=headers, json=data)
|
100 |
+
|
101 |
+
# print(response.status_code)
|
102 |
+
# print(response.text)
|
103 |
+
if response.status_code in [200, 201]:
|
104 |
+
# print(response.json()['content']['download_url'])
|
105 |
+
return response.json()['content']['download_url'], unique_filename
|
106 |
+
else:
|
107 |
+
print("None")
|
108 |
+
return None
|
109 |
+
|
110 |
+
def calculate_iou(box1, box2):
|
111 |
+
# 计算两个框的交集
|
112 |
+
x1 = max(box1[0], box2[0])
|
113 |
+
y1 = max(box1[1], box2[1])
|
114 |
+
x2 = min(box1[2], box2[2])
|
115 |
+
y2 = min(box1[3], box2[3])
|
116 |
+
|
117 |
+
intersection_area = max(0, x2 - x1) * max(0, y2 - y1)
|
118 |
+
|
119 |
+
# 计算两个框的并集
|
120 |
+
box1_area = (box1[2] - box1[0]) * (box1[3] - box1[1])
|
121 |
+
box2_area = (box2[2] - box2[0]) * (box2[3] - box2[1])
|
122 |
+
|
123 |
+
union_area = box1_area + box2_area - intersection_area
|
124 |
+
|
125 |
+
# 计算IOU
|
126 |
+
iou = intersection_area / union_area
|
127 |
+
return iou
|
128 |
+
|
129 |
+
def adjust_coordinates(box):
|
130 |
+
size = 32
|
131 |
+
(x1, y1, x2, y2) = box
|
132 |
+
if x1 % size != 0:
|
133 |
+
x1 = (x1 // size) * size
|
134 |
+
if x2 % size != 0:
|
135 |
+
x2 = (x2 // size + 1) * size
|
136 |
+
|
137 |
+
if y1 % size != 0:
|
138 |
+
y1 = (y1 // size) * size
|
139 |
+
if y2 % size != 0:
|
140 |
+
y2 = (y2 // size + 1) * size
|
141 |
+
return (x1, y1, x2, y2)
|
142 |
+
|
143 |
+
def adjust_validation_box(validation_box):
|
144 |
+
return [adjust_coordinates(box) for box in validation_box]
|
145 |
+
|
146 |
+
def get_list_layer_box(list_png_images):
|
147 |
+
list_layer_box = []
|
148 |
+
for img in list_png_images:
|
149 |
+
img_np = np.array(img)
|
150 |
+
alpha_channel = img_np[:, :, -1]
|
151 |
+
|
152 |
+
# Step 1: Find the non-zero indices
|
153 |
+
rows, cols = np.nonzero(alpha_channel)
|
154 |
+
|
155 |
+
if (len(rows) == 0) or (len(cols) == 0):
|
156 |
+
# If there are no non-zero indices, we can skip this layer
|
157 |
+
list_layer_box.append((0, 0, 0, 0))
|
158 |
+
continue
|
159 |
+
|
160 |
+
# Step 2: Get the minimum and maximum indices for rows and columns
|
161 |
+
min_row, max_row = rows.min().item(), rows.max().item()
|
162 |
+
min_col, max_col = cols.min().item(), cols.max().item()
|
163 |
+
|
164 |
+
# Step 3: Quantize the minimum values down to the nearest multiple of 8
|
165 |
+
quantized_min_row = (min_row // 8) * 8
|
166 |
+
quantized_min_col = (min_col // 8) * 8
|
167 |
+
|
168 |
+
# Step 4: Quantize the maximum values up to the nearest multiple of 8 outside of the max
|
169 |
+
quantized_max_row = ((max_row // 8) + 1) * 8
|
170 |
+
quantized_max_col = ((max_col // 8) + 1) * 8
|
171 |
+
list_layer_box.append(
|
172 |
+
(quantized_min_col, quantized_min_row, quantized_max_col, quantized_max_row)
|
173 |
+
)
|
174 |
+
return list_layer_box
|
175 |
+
|
176 |
+
def pngs_to_svg(list_png_images):
|
177 |
+
list_layer_box = get_list_layer_box(list_png_images)
|
178 |
+
assert(len(list_png_images) == len(list_layer_box))
|
179 |
+
width, height = list_png_images[0].width, list_png_images[0].height
|
180 |
+
img_svg = ET.Element(
|
181 |
+
'svg',
|
182 |
+
{
|
183 |
+
"width": str(width),
|
184 |
+
"height": str(height),
|
185 |
+
"xmlns": "http://www.w3.org/2000/svg",
|
186 |
+
"xmlns:svg": "http://www.w3.org/2000/svg",
|
187 |
+
"xmlns:xlink":"http://www.w3.org/1999/xlink"
|
188 |
+
}
|
189 |
+
)
|
190 |
+
for img, box in zip(list_png_images, list_layer_box):
|
191 |
+
x, y, w, h = box[0], box[1], box[2]-box[0], box[3]-box[1]
|
192 |
+
if (w == 0 or h == 0):
|
193 |
+
continue
|
194 |
+
img = img.crop((x, y, x+w, y+h))
|
195 |
+
buffer = BytesIO()
|
196 |
+
img.save(buffer, format='PNG')
|
197 |
+
img_str = base64.b64encode(buffer.getvalue())
|
198 |
+
ET.SubElement(
|
199 |
+
img_svg,
|
200 |
+
"image",
|
201 |
+
{
|
202 |
+
"x": str(x),
|
203 |
+
"y": str(y),
|
204 |
+
"width": str(w),
|
205 |
+
"height": str(h),
|
206 |
+
"xlink:href": "data:image/png;base64,"+img_str.decode('utf-8')
|
207 |
+
}
|
208 |
+
)
|
209 |
+
return ET.tostring(img_svg, encoding='utf-8').decode('utf-8')
|
210 |
+
|
211 |
+
def calculate_iou(box1, box2):
|
212 |
+
# 计算两个框的交集
|
213 |
+
x1 = max(box1[0], box2[0])
|
214 |
+
y1 = max(box1[1], box2[1])
|
215 |
+
x2 = min(box1[2], box2[2])
|
216 |
+
y2 = min(box1[3], box2[3])
|
217 |
+
|
218 |
+
intersection_area = max(0, x2 - x1) * max(0, y2 - y1)
|
219 |
+
|
220 |
+
# 计算两个框的并集
|
221 |
+
box1_area = (box1[2] - box1[0]) * (box1[3] - box1[1])
|
222 |
+
box2_area = (box2[2] - box2[0]) * (box2[3] - box2[1])
|
223 |
+
|
224 |
+
union_area = box1_area + box2_area - intersection_area
|
225 |
+
|
226 |
+
# 计算IOU
|
227 |
+
iou = intersection_area / union_area
|
228 |
+
return iou
|
229 |
+
|
230 |
+
# @spaces.GPU(enable_queue=True, duration=60)
|
231 |
+
def buildmodel(**kwargs):
|
232 |
+
from modeling_crello import CrelloModel, CrelloModelConfig
|
233 |
+
from quantizer import get_quantizer
|
234 |
+
# seed / input model / resume
|
235 |
+
resume = kwargs.get('resume', None)
|
236 |
+
seed = kwargs.get('seed', None)
|
237 |
+
input_model = kwargs.get('input_model', None)
|
238 |
+
quantizer_version = kwargs.get('quantizer_version', 'v4')
|
239 |
+
device = "cuda"
|
240 |
+
|
241 |
+
set_seed(seed)
|
242 |
+
# old_tokenizer = AutoTokenizer.from_pretrained(input_model, trust_remote_code=True)
|
243 |
+
old_tokenizer = AutoTokenizer.from_pretrained(
|
244 |
+
"WYBar/LLM_For_Layout_Planning", # 仓库路径
|
245 |
+
subfolder="Meta-Llama-3-8B", # 子目录对应模型文件夹
|
246 |
+
trust_remote_code=True,
|
247 |
+
cache_dir="/openseg_blob/v-yanbin/GradioDemo/cache_dir",
|
248 |
+
)
|
249 |
+
old_vocab_size = len(old_tokenizer)
|
250 |
+
# tokenizer = AutoTokenizer.from_pretrained(resume, trust_remote_code=True)
|
251 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
252 |
+
"WYBar/LLM_For_Layout_Planning",
|
253 |
+
subfolder="checkpoint-26000", # 检查点所在子目录
|
254 |
+
trust_remote_code=True,
|
255 |
+
cache_dir="/openseg_blob/v-yanbin/GradioDemo/cache_dir",
|
256 |
+
)
|
257 |
+
|
258 |
+
quantizer = get_quantizer(
|
259 |
+
quantizer_version,
|
260 |
+
update_vocab = False,
|
261 |
+
decimal_quantize_types = kwargs.get('decimal_quantize_types'),
|
262 |
+
mask_values = kwargs['mask_values'],
|
263 |
+
width = kwargs['width'],
|
264 |
+
height = kwargs['height'],
|
265 |
+
simplify_json = False,
|
266 |
+
num_mask_tokens = 0,
|
267 |
+
mask_type = kwargs.get('mask_type'),
|
268 |
+
)
|
269 |
+
quantizer.setup_tokenizer(tokenizer)
|
270 |
+
|
271 |
+
model_args = CrelloModelConfig(
|
272 |
+
old_vocab_size = old_vocab_size,
|
273 |
+
vocab_size=len(tokenizer),
|
274 |
+
pad_token_id=tokenizer.pad_token_id,
|
275 |
+
ignore_ids=tokenizer.convert_tokens_to_ids(quantizer.ignore_tokens),
|
276 |
+
)
|
277 |
+
model_args.freeze_lm = True
|
278 |
+
model_args.opt_version = "WYBar/LLM_For_Layout_Planning"
|
279 |
+
model_args.use_lora = False
|
280 |
+
model_args.load_in_4bit = kwargs.get('load_in_4bit', False)
|
281 |
+
# model = CrelloModel.from_pretrained(
|
282 |
+
# resume,
|
283 |
+
# config=model_args
|
284 |
+
# ).to(device)
|
285 |
+
# model = CrelloModel.from_pretrained(
|
286 |
+
# "WYBar/LLM_For_Layout_Planning",
|
287 |
+
# subfolder="checkpoint-26000", # 加载检查点目录
|
288 |
+
# config=model_args,
|
289 |
+
# # cache_dir="/openseg_blob/v-yanbin/GradioDemo/cache_dir",
|
290 |
+
# )
|
291 |
+
model = CrelloModel(config=model_args)
|
292 |
+
print("before .to(device)")
|
293 |
+
model = model.to(device)
|
294 |
+
print("after .to(device)")
|
295 |
+
model = model.bfloat16()
|
296 |
+
model.eval()
|
297 |
+
|
298 |
+
tokenizer.add_special_tokens({"mask_token": "<mask>"})
|
299 |
+
quantizer.additional_special_tokens.add("<mask>")
|
300 |
+
added_special_tokens_list = ["<layout>", "<position>", "<wholecaption>"]
|
301 |
+
tokenizer.add_special_tokens({"additional_special_tokens": added_special_tokens_list}, replace_additional_special_tokens=False)
|
302 |
+
for token in added_special_tokens_list:
|
303 |
+
quantizer.additional_special_tokens.add(token)
|
304 |
+
|
305 |
+
return model, quantizer, tokenizer
|
306 |
+
|
307 |
+
def construction_layout():
|
308 |
+
params_dict = {
|
309 |
+
# 需要修改
|
310 |
+
"input_model": "WYBar/LLM_For_Layout_Planning",
|
311 |
+
"resume": "WYBar/LLM_For_Layout_Planning",
|
312 |
+
|
313 |
+
"seed": 0,
|
314 |
+
"mask_values": False,
|
315 |
+
"quantizer_version": 'v4',
|
316 |
+
"mask_type": 'cm3',
|
317 |
+
"decimal_quantize_types": [],
|
318 |
+
"num_mask_tokens": 0,
|
319 |
+
"width": 512,
|
320 |
+
"height": 512,
|
321 |
+
"device": 0,
|
322 |
+
}
|
323 |
+
device = "cuda"
|
324 |
+
# Init model
|
325 |
+
model, quantizer, tokenizer = buildmodel(**params_dict)
|
326 |
+
|
327 |
+
print('resize token embeddings to match the tokenizer', 129423)
|
328 |
+
model.lm.resize_token_embeddings(129423)
|
329 |
+
model.input_embeddings = model.lm.get_input_embeddings()
|
330 |
+
print('after token embeddings to match the tokenizer', 129423)
|
331 |
+
return model, quantizer, tokenizer, params_dict["width"], params_dict["height"], device
|
332 |
+
|
333 |
+
@torch.no_grad()
|
334 |
+
# @spaces.GPU(enable_queue=True, duration=60)
|
335 |
+
def evaluate_v1(inputs, model, quantizer, tokenizer, width, height, device, do_sample=False, temperature=1.0, top_p=1.0, top_k=50):
|
336 |
+
json_example = inputs
|
337 |
+
input_intension = '{"wholecaption":"' + json_example["wholecaption"] + '","layout":[{"layer":'
|
338 |
+
inputs = tokenizer(
|
339 |
+
input_intension, return_tensors="pt"
|
340 |
+
).to(device)
|
341 |
+
|
342 |
+
stopping_criteria = StoppingCriteriaList()
|
343 |
+
stopping_criteria.append(StopAtSpecificTokenCriteria(token_id_list=[128000]))
|
344 |
+
|
345 |
+
outputs = model.lm.generate(**inputs, use_cache=True, max_length=8000, stopping_criteria=stopping_criteria, do_sample=do_sample, temperature=temperature, top_p=top_p, top_k=top_k)
|
346 |
+
inputs_length = inputs['input_ids'].shape[1]
|
347 |
+
outputs = outputs[:, inputs_length:]
|
348 |
+
|
349 |
+
outputs_word = tokenizer.batch_decode(outputs)[0]
|
350 |
+
split_word = outputs_word.split('}]}')[0]+"}]}"
|
351 |
+
split_word = '{"wholecaption":"' + json_example["wholecaption"].replace('\n', '\\n').replace('"', '\\"') + '","layout":[{"layer":' + split_word
|
352 |
+
map_dict = quantizer.construct_map_dict()
|
353 |
+
|
354 |
+
for key ,value in map_dict.items():
|
355 |
+
split_word = split_word.replace(key, value)
|
356 |
+
try:
|
357 |
+
pred_json_example = json.loads(split_word)
|
358 |
+
for layer in pred_json_example["layout"]:
|
359 |
+
layer['x'] = round(int(width)*layer['x'])
|
360 |
+
layer['y'] = round(int(height)*layer['y'])
|
361 |
+
layer['width'] = round(int(width)*layer['width'])
|
362 |
+
layer['height'] = round(int(height)*layer['height'])
|
363 |
+
except Exception as e:
|
364 |
+
print(e)
|
365 |
+
pred_json_example = None
|
366 |
+
return pred_json_example
|
367 |
+
|
368 |
+
def inference(generate_method, intention, model, quantizer, tokenizer, width, height, device, do_sample=True, temperature=1.0, top_p=1.0, top_k=50):
|
369 |
+
def FormulateInput(intension: str):
|
370 |
+
resdict = {}
|
371 |
+
resdict["wholecaption"] = intension
|
372 |
+
resdict["layout"] = []
|
373 |
+
return resdict
|
374 |
+
|
375 |
+
rawdata = FormulateInput(intention)
|
376 |
+
|
377 |
+
if generate_method == 'v1':
|
378 |
+
max_try_time = 5
|
379 |
+
preddata = None
|
380 |
+
while preddata is None and max_try_time > 0:
|
381 |
+
preddata = evaluate_v1(rawdata, model, quantizer, tokenizer, width, height, device, do_sample=do_sample, temperature=temperature, top_p=top_p, top_k=top_k)
|
382 |
+
max_try_time -= 1
|
383 |
+
else:
|
384 |
+
print("Please input correct generate method")
|
385 |
+
preddata = None
|
386 |
+
|
387 |
+
return preddata
|
388 |
+
|
389 |
+
# @spaces.GPU(enable_queue=True, duration=60)
|
390 |
+
def construction():
|
391 |
+
from custom_model_mmdit import CustomFluxTransformer2DModel
|
392 |
+
from custom_model_transp_vae import AutoencoderKLTransformerTraining as CustomVAE
|
393 |
+
from custom_pipeline import CustomFluxPipelineCfg
|
394 |
+
|
395 |
+
transformer = CustomFluxTransformer2DModel.from_pretrained(
|
396 |
+
"WYBar/ART_test_weights",
|
397 |
+
subfolder="fused_transformer",
|
398 |
+
torch_dtype=torch.bfloat16,
|
399 |
+
cache_dir="/openseg_blob/v-yanbin/GradioDemo/cache_dir"
|
400 |
+
)
|
401 |
+
|
402 |
+
transp_vae = CustomVAE.from_pretrained(
|
403 |
+
"WYBar/ART_test_weights",
|
404 |
+
subfolder="custom_vae",
|
405 |
+
torch_dtype=torch.float32,
|
406 |
+
cache_dir="/openseg_blob/v-yanbin/GradioDemo/cache_dir"
|
407 |
+
)
|
408 |
+
|
409 |
+
token = os.environ.get("HF_TOKEN")
|
410 |
+
pipeline = CustomFluxPipelineCfg.from_pretrained(
|
411 |
+
"black-forest-labs/FLUX.1-dev",
|
412 |
+
transformer=transformer,
|
413 |
+
torch_dtype=torch.bfloat16,
|
414 |
+
token=token,
|
415 |
+
cache_dir="/openseg_blob/v-yanbin/GradioDemo/cache_dir"
|
416 |
+
).to("cuda")
|
417 |
+
pipeline.enable_model_cpu_offload(gpu_id=0) # Save GPU memory
|
418 |
+
|
419 |
+
return pipeline, transp_vae
|
420 |
+
|
421 |
+
# @spaces.GPU(enable_queue=True, duration=60)
|
422 |
+
def test_one_sample(validation_box, validation_prompt, true_gs, inference_steps, pipeline, generator, transp_vae):
|
423 |
+
print(validation_box)
|
424 |
+
output, rgba_output, _, _ = pipeline(
|
425 |
+
prompt=validation_prompt,
|
426 |
+
validation_box=validation_box,
|
427 |
+
generator=generator,
|
428 |
+
height=512,
|
429 |
+
width=512,
|
430 |
+
num_layers=len(validation_box),
|
431 |
+
guidance_scale=4.0,
|
432 |
+
num_inference_steps=inference_steps,
|
433 |
+
transparent_decoder=transp_vae,
|
434 |
+
true_gs=true_gs
|
435 |
+
)
|
436 |
+
images = output.images # list of PIL, len=layers
|
437 |
+
rgba_images = [Image.fromarray(arr, 'RGBA') for arr in rgba_output]
|
438 |
+
|
439 |
+
output_gradio = []
|
440 |
+
merged_pil = images[1].convert('RGBA')
|
441 |
+
for frame_idx, frame_pil in enumerate(rgba_images):
|
442 |
+
if frame_idx < 2:
|
443 |
+
frame_pil = images[frame_idx].convert('RGBA') # merged and background
|
444 |
+
else:
|
445 |
+
merged_pil = Image.alpha_composite(merged_pil, frame_pil)
|
446 |
+
output_gradio.append(frame_pil)
|
447 |
+
|
448 |
+
return output_gradio
|
449 |
+
|
450 |
+
def svg_test_one_sample(validation_prompt, validation_box_str, seed, true_gs, inference_steps, pipeline, transp_vae):
|
451 |
+
generator = torch.Generator().manual_seed(seed)
|
452 |
+
try:
|
453 |
+
validation_box = ast.literal_eval(validation_box_str)
|
454 |
+
except Exception as e:
|
455 |
+
return [f"Error parsing validation_box: {e}"]
|
456 |
+
if not isinstance(validation_box, list) or not all(isinstance(t, tuple) and len(t) == 4 for t in validation_box):
|
457 |
+
return ["validation_box must be a list of tuples, each of length 4."]
|
458 |
+
|
459 |
+
validation_box = adjust_validation_box(validation_box)
|
460 |
+
|
461 |
+
result_images = test_one_sample(validation_box, validation_prompt, true_gs, inference_steps, pipeline, generator, transp_vae)
|
462 |
+
|
463 |
+
svg_img = pngs_to_svg(result_images[1:])
|
464 |
+
|
465 |
+
svg_file_path = './image.svg'
|
466 |
+
os.makedirs(os.path.dirname(svg_file_path), exist_ok=True)
|
467 |
+
with open(svg_file_path, 'w', encoding='utf-8') as f:
|
468 |
+
f.write(svg_img)
|
469 |
+
|
470 |
+
return result_images, svg_file_path
|
471 |
+
|
472 |
+
def main():
|
473 |
+
model, quantizer, tokenizer, width, height, device = construction_layout()
|
474 |
+
|
475 |
+
inference_partial = partial(
|
476 |
+
inference,
|
477 |
+
model=model,
|
478 |
+
quantizer=quantizer,
|
479 |
+
tokenizer=tokenizer,
|
480 |
+
width=width,
|
481 |
+
height=height,
|
482 |
+
device=device
|
483 |
+
)
|
484 |
+
|
485 |
+
def process_preddate(intention, temperature, top_p, generate_method='v1'):
|
486 |
+
intention = intention.replace('\n', '').replace('\r', '').replace('\\', '')
|
487 |
+
intention = ensure_space_after_period(intention)
|
488 |
+
if temperature == 0.0:
|
489 |
+
# print("looking for greedy decoding strategies, set `do_sample=False`.")
|
490 |
+
preddata = inference_partial(generate_method, intention, do_sample=False)
|
491 |
+
else:
|
492 |
+
preddata = inference_partial(generate_method, intention, temperature=temperature, top_p=top_p)
|
493 |
+
# wholecaption = preddata["wholecaption"]
|
494 |
+
layouts = preddata["layout"]
|
495 |
+
list_box = []
|
496 |
+
for i, layout in enumerate(layouts):
|
497 |
+
x, y = layout["x"], layout["y"]
|
498 |
+
width, height = layout["width"], layout["height"]
|
499 |
+
if i == 0:
|
500 |
+
list_box.append((0, 0, width, height))
|
501 |
+
list_box.append((0, 0, width, height))
|
502 |
+
else:
|
503 |
+
left = x - width // 2
|
504 |
+
top = y - height // 2
|
505 |
+
right = x + width // 2
|
506 |
+
bottom = y + height // 2
|
507 |
+
list_box.append((left, top, right, bottom))
|
508 |
+
|
509 |
+
# print(list_box)
|
510 |
+
filtered_boxes = list_box[:2]
|
511 |
+
for i in range(2, len(list_box)):
|
512 |
+
keep = True
|
513 |
+
for j in range(1, len(filtered_boxes)):
|
514 |
+
iou = calculate_iou(list_box[i], filtered_boxes[j])
|
515 |
+
if iou > 0.65:
|
516 |
+
print(list_box[i], filtered_boxes[j])
|
517 |
+
keep = False
|
518 |
+
break
|
519 |
+
if keep:
|
520 |
+
filtered_boxes.append(list_box[i])
|
521 |
+
|
522 |
+
return str(filtered_boxes), intention, str(filtered_boxes)
|
523 |
+
|
524 |
+
# def process_preddate(intention, generate_method='v1'):
|
525 |
+
# list_box = [(0, 0, 512, 512), (0, 0, 512, 512), (136, 184, 512, 512), (144, 0, 512, 512), (0, 0, 328, 136), (160, 112, 512, 360), (168, 112, 512, 360), (40, 232, 112, 296), (32, 88, 248, 176), (48, 424, 144, 448), (48, 464, 144, 488), (240, 464, 352, 488), (384, 464, 488, 488), (48, 480, 144, 504), (240, 480, 360, 504), (456, 0, 512, 56), (0, 0, 56, 40), (440, 0, 512, 40), (0, 24, 48, 88), (48, 168, 168, 240)]
|
526 |
+
# wholecaption = "Design an engaging and vibrant recruitment advertisement for our company. The image should feature three animated characters in a modern cityscape, depicting a dynamic and collaborative work environment. Incorporate a light bulb graphic with a question mark, symbolizing innovation, creativity, and problem-solving. Use bold text to announce \"WE ARE RECRUITING\" and provide the company's social media handle \"@reallygreatsite\" and a contact phone number \"+123-456-7890\" for interested individuals. The overall design should be playful and youthful, attracting potential recruits who are innovative and eager to contribute to a lively team."
|
527 |
+
# json_file = "/home/wyb/openseg_blob/v-yanbin/GradioDemo/LLM-For-Layout-Planning/inference_test.json"
|
528 |
+
# return wholecaption, str(list_box), json_file
|
529 |
+
|
530 |
+
pipeline, transp_vae = construction()
|
531 |
+
|
532 |
+
gradio_test_one_sample_partial = partial(
|
533 |
+
svg_test_one_sample,
|
534 |
+
pipeline=pipeline,
|
535 |
+
transp_vae=transp_vae,
|
536 |
+
)
|
537 |
+
|
538 |
+
def process_svg(text_input, tuple_input, seed, true_gs, inference_steps):
|
539 |
+
result_images = []
|
540 |
+
result_images, svg_file_path = gradio_test_one_sample_partial(text_input, tuple_input, seed, true_gs, inference_steps)
|
541 |
+
|
542 |
+
url, unique_filename = upload_to_github(file_path=svg_file_path)
|
543 |
+
unique_filename = f'{unique_filename}'
|
544 |
+
|
545 |
+
if url != None:
|
546 |
+
print(f"File uploaded to: {url}")
|
547 |
+
svg_editor = f"""
|
548 |
+
<iframe src="https://svgedit.netlify.app/editor/index.html?\
|
549 |
+
storagePrompt=false&url={url}" \
|
550 |
+
width="100%", height="800px"></iframe>
|
551 |
+
"""
|
552 |
+
else:
|
553 |
+
print('upload_to_github FAILED!')
|
554 |
+
svg_editor = f"""
|
555 |
+
<iframe src="https://svgedit.netlify.app/editor/index.html" \
|
556 |
+
width="100%", height="800px"></iframe>
|
557 |
+
"""
|
558 |
+
|
559 |
+
return result_images, svg_file_path, svg_editor
|
560 |
+
|
561 |
+
def one_click_generate(intention_input, temperature, top_p, seed, true_gs, inference_steps):
|
562 |
+
# 首先调用process_preddate
|
563 |
+
list_box_output, intention_input, list_box_output = process_preddate(intention_input, temperature, top_p)
|
564 |
+
|
565 |
+
# 然后将process_preddate的输出作为process_svg的输入
|
566 |
+
result_images, svg_file, svg_editor = process_svg(intention_input, list_box_output, seed, true_gs, inference_steps)
|
567 |
+
|
568 |
+
# 返回两个函数的输出
|
569 |
+
return list_box_output, result_images, svg_file, svg_editor, intention_input, list_box_output
|
570 |
+
|
571 |
+
def clear_inputs1():
|
572 |
+
return "", ""
|
573 |
+
|
574 |
+
def clear_inputs2():
|
575 |
+
return "", ""
|
576 |
+
|
577 |
+
def transfer_inputs(intention, list_box):
|
578 |
+
return intention, list_box
|
579 |
+
|
580 |
+
theme = gr.themes.Soft(
|
581 |
+
radius_size="lg",
|
582 |
+
).set(
|
583 |
+
block_background_fill='*primary_50',
|
584 |
+
block_border_color='*primary_200',
|
585 |
+
block_border_width='1px',
|
586 |
+
block_border_width_dark='100px',
|
587 |
+
block_info_text_color='*primary_950',
|
588 |
+
block_label_border_color='*primary_200',
|
589 |
+
block_radius='*radius_lg'
|
590 |
+
)
|
591 |
+
|
592 |
+
with gr.Blocks(theme=theme) as demo:
|
593 |
+
gr.HTML("<h1 style='text-align: center;'>ART: Anonymous Region Transformer for Variable Multi-Layer Transparent Image Generation</h1>")
|
594 |
+
gr.HTML("<h2>Anonymous Region Layout Planner</h2>")
|
595 |
+
|
596 |
+
with gr.Row():
|
597 |
+
with gr.Column():
|
598 |
+
intention_input = gr.Textbox(lines=15, placeholder="Enter intention", label="Prompt")
|
599 |
+
with gr.Row():
|
600 |
+
temperature_input=gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label="Temperature", value=0.0)
|
601 |
+
top_p_input=gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label="Top P", value=0.0)
|
602 |
+
with gr.Row():
|
603 |
+
clear_btn1 = gr.Button("Clear")
|
604 |
+
model_btn1 = gr.Button("Commit", variant='primary')
|
605 |
+
transfer_btn1 = gr.Button("Export to below")
|
606 |
+
|
607 |
+
one_click_btn = gr.Button("One Click Generate ALL", variant='primary')
|
608 |
+
|
609 |
+
with gr.Column():
|
610 |
+
list_box_output = gr.Textbox(lines=10, placeholder="Validation Box", label="Validation Box")
|
611 |
+
|
612 |
+
examples = gr.Examples(
|
613 |
+
examples=[
|
614 |
+
['The image is a graphic design with a celebratory theme. At the top, there is a banner with the text \"Happy Anniversary\" in a bold, sans-serif font. Below this banner, there is a circular frame containing a photograph of a couple. The man has short, dark hair and is wearing a light-colored sweater, while the woman has long blonde hair and is also wearing a light-colored sweater. They are both smiling and appear to be embracing each other.Surrounding the circular frame are decorative elements such as pink flowers and green leaves, which add a festive touch to the design. Below the circular frame, there is a text that reads "Isabel & Morgan" in a cursive, elegant font, suggesting that the couple\'s names are Isabel and Morgan.At the bottom of the image, there is a banner with a message that says "Happy Anniversary! Cheers to another year of love, laughter, and cherished memories together.\" This text is in a smaller, sans-serif font and is placed against a solid background, providing a clear message of celebration and well-wishes for the couple.The overall style of the image is warm and celebratory, with a color scheme that includes shades of pink, green, and white, which contribute to a joyful and romantic atmosphere.'],
|
615 |
+
['The image is a digital illustration with a light blue background. At the top, there is a logo consisting of a snake wrapped around a staff, which is a common symbol in healthcare. Below the logo, the text "International Nurses Day" is prominently displayed in white, with the date "12 May 20xx" in smaller font size.The central part of the image features two stylized characters. On the left, there is a female character with dark hair, wearing a white nurse\'s uniform with a cap. She is holding a clipboard and appears to be speaking or gesturing, as indicated by a speech bubble with the word "OK" in it. On the right, there is a male character with light brown hair, wearing a light blue shirt with a white collar and a white apron. He is holding a stethoscope to his ear, suggesting he is a doctor or a healthcare professional.The characters are depicted in a friendly and approachable manner, with smiles on their faces. Around them, there are small blue plus signs, which are often associated with healthcare and medical services. The overall style of the image is clean, modern, and appears to be designed to celebrate International Nurses Day.'],
|
616 |
+
['The image features a graphic design with a festive theme. At the top, there is a decorative border with a wavy pattern. Below this border, the text "WINTER SEASON SPECIAL COOKIES" is prominently displayed in a bold, sans-serif font. The text is black with a slight shadow effect, giving it a three-dimensional appearance.In the center of the image, there are three illustrated gingerbread cookies. Each cookie has a smiling face with eyes, a nose, and a mouth, and they are colored in a warm, brown hue. The cookies are arranged in a staggered formation, with the middle cookie slightly higher than the others, creating a sense of depth.At the bottom of the image, there is a call to action that reads "ORDER.NOW" in a large, bold, sans-serif font. The text is colored in a darker shade of brown, contrasting with the lighter background. The overall style of the image suggests it is an advertisement or promotional graphic for a winter-themed cookie special.']
|
617 |
+
],
|
618 |
+
inputs=[intention_input]
|
619 |
+
)
|
620 |
+
|
621 |
+
gr.HTML("<h2>Anonymous Region Transformer</h2>")
|
622 |
+
with gr.Row():
|
623 |
+
with gr.Column():
|
624 |
+
text_input = gr.Textbox(lines=10, placeholder="Enter prompt text", label="Prompt")
|
625 |
+
tuple_input = gr.Textbox(lines=5, placeholder="Enter list of tuples, e.g., [(1, 2, 3, 4), (5, 6, 7, 8)]", label="Validation Box")
|
626 |
+
with gr.Row():
|
627 |
+
true_gs_input=gr.Slider(minimum=3.0, maximum=5.0, step=0.1, label="true_gs", value=3.5)
|
628 |
+
inference_steps_input=gr.Slider(minimum=5, maximum=50, step=1, label="inference_steps", value=28)
|
629 |
+
with gr.Row():
|
630 |
+
seed_input = gr.Number(label="Seed", value=42)
|
631 |
+
with gr.Row():
|
632 |
+
transfer_btn2 = gr.Button("Import from above")
|
633 |
+
with gr.Row():
|
634 |
+
clear_btn2 = gr.Button("Clear")
|
635 |
+
model_btn2 = gr.Button("Commit", variant='primary')
|
636 |
+
|
637 |
+
with gr.Column():
|
638 |
+
result_images = gr.Gallery(label="Result Images", columns=5, height='auto')
|
639 |
+
|
640 |
+
gr.HTML("<h1>SVG Image</h1>")
|
641 |
+
svg_file = gr.File(label="Download SVG Image")
|
642 |
+
svg_editor = gr.HTML(label="Editable SVG Editor")
|
643 |
+
|
644 |
+
model_btn1.click(
|
645 |
+
fn=process_preddate,
|
646 |
+
inputs=[intention_input, temperature_input, top_p_input],
|
647 |
+
outputs=[list_box_output, text_input, tuple_input],
|
648 |
+
api_name="process_preddate"
|
649 |
+
)
|
650 |
+
clear_btn1.click(
|
651 |
+
fn=clear_inputs1,
|
652 |
+
inputs=[],
|
653 |
+
outputs=[intention_input, list_box_output]
|
654 |
+
)
|
655 |
+
model_btn2.click(
|
656 |
+
fn=process_svg,
|
657 |
+
inputs=[text_input, tuple_input, seed_input, true_gs_input, inference_steps_input],
|
658 |
+
outputs=[result_images, svg_file, svg_editor],
|
659 |
+
api_name="process_svg"
|
660 |
+
)
|
661 |
+
clear_btn2.click(
|
662 |
+
fn=clear_inputs2,
|
663 |
+
inputs=[],
|
664 |
+
outputs=[text_input, tuple_input]
|
665 |
+
)
|
666 |
+
transfer_btn1.click(
|
667 |
+
fn=transfer_inputs,
|
668 |
+
inputs=[intention_input, list_box_output],
|
669 |
+
outputs=[text_input, tuple_input]
|
670 |
+
)
|
671 |
+
transfer_btn2.click(
|
672 |
+
fn=transfer_inputs,
|
673 |
+
inputs=[intention_input, list_box_output],
|
674 |
+
outputs=[text_input, tuple_input]
|
675 |
+
)
|
676 |
+
one_click_btn.click(
|
677 |
+
fn=one_click_generate,
|
678 |
+
inputs=[intention_input, temperature_input, top_p_input, seed_input, true_gs_input, inference_steps_input],
|
679 |
+
outputs=[list_box_output, result_images, svg_file, svg_editor, text_input, tuple_input]
|
680 |
+
)
|
681 |
+
demo.launch()
|
682 |
+
|
683 |
+
if __name__ == "__main__":
|
684 |
+
main()
|
config/base.py
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
### Model Settings
|
2 |
+
pretrained_model_name_or_path = "black-forest-labs/FLUX.1-dev"
|
3 |
+
revision = None
|
4 |
+
variant = None
|
5 |
+
cache_dir = None
|
6 |
+
|
7 |
+
### Training Settings
|
8 |
+
seed = 42
|
9 |
+
report_to = "wandb"
|
10 |
+
tracker_project_name = "multilayer"
|
11 |
+
wandb_job_name = "YOU_FORGET_TO_SET"
|
12 |
+
logging_dir = "logs"
|
13 |
+
max_train_steps = None
|
14 |
+
checkpoints_total_limit = None
|
15 |
+
|
16 |
+
# gpu
|
17 |
+
allow_tf32 = True
|
18 |
+
gradient_checkpointing = True
|
19 |
+
mixed_precision = "bf16"
|
20 |
+
|
21 |
+
### Validation Settings
|
22 |
+
num_validation_images = 1
|
23 |
+
validation_steps = 5
|
24 |
+
validation_prompts = [
|
25 |
+
"The image features a simple, flat design with a solid pink background. On the left side, there is a stylized depiction of a decorated egg with a pattern of alternating white and light blue stripes. The egg has a smooth, oval shape and is outlined with a thin line. In the center of the image, there is a floral arrangement consisting of a large, white flower with a green center and several smaller white flowers with green centers. The flowers are connected by thin green stems and leaves, creating a small bouquet. On the right side of the image, there is another egg similar to the one on the left, with the same pattern of stripes. This egg is also outlined with a thin line and has a smooth, oval shape. The overall style of the image is clean and modern, with a limited color palette and a focus on geometric shapes and simple patterns. There are no texts or additional elements in the image.",
|
26 |
+
"The image features a cartoon-style illustration with three characters against a blue background. On the left side, there is a green, goblin-like creature with large, expressive eyes and a wide grin. It has a small body and is standing upright with its arms raised in a welcoming or excited gesture. In the center, there is a large, white, egg-shaped object that appears to be floating or resting on the surface. It has a smooth, rounded shape and is the largest object in the image. On the right side, there is a purple dinosaur with a friendly expression. It has a small head, large eyes, and a wide mouth that seems to be smiling. The dinosaur is standing on all fours and appears to be looking towards the viewer. The overall style of the image is playful and whimsical, with a clear emphasis on the characters rather than any specific background details.",
|
27 |
+
"The image features a collection of Christmas-themed objects against a solid green background. On the left side, there is a red Christmas ornament with a white pattern, resembling a traditional Christmas ball. Next to it, there is a red and white striped stocking with a small white cuff at the top. On the right side, there is a cartoon-style depiction of Santa Claus' face, with a white beard, red cheeks, and a smiling expression. The Santa face is stylized with simple lines and shapes, giving it a friendly and festive appearance. The overall style of the image is flat and graphic, with a clear focus on holiday-related items.",
|
28 |
+
"The image depicts a stylized illustration of a rocket launch. The rocket, which is the central focus of the image, is depicted in a simplified, cartoon-like style with a white body and a pointed nose cone. It is shown ascending into a dark background, which is likely meant to represent the night sky. Above the rocket, there are several small, golden stars scattered across the sky, adding a sense of motion and direction to the rocket's ascent. The stars are of varying sizes and are positioned at different heights, creating a sense of depth and distance. The overall style of the image is minimalist and modern, with a limited color palette that emphasizes the rocket and the stars against the dark background. The image does not contain any text or additional elements that would provide context or narrative beyond the depiction of the rocket launch.",
|
29 |
+
"The image features a stylized, cartoon-like depiction of a bear. The bear is predominantly pink with a lighter pink nose and a small black dot for an eye. It has two small ears and a small black line for a mouth. The bear is standing upright and appears to be holding a yellow object, possibly a piece of paper or a card, in its right paw. To the right of the bear, there is a purple background with a large, heart-shaped doodle. The overall style of the image is simplistic and child-friendly, with a limited color palette and a clear, uncluttered composition.",
|
30 |
+
"The image features three ice cream cones against a pink background. Each cone is filled with a different flavor of ice cream: the leftmost cone has chocolate ice cream, the middle cone has vanilla ice cream, and the rightmost cone has strawberry ice cream. The ice cream is topped with a drizzle of the respective flavor's syrup, and each cone is adorned with a small, round, chocolate-covered piece of candy. The image also contains text that reads 'Sprinkle Sunday Ice Cream Factory East Avenue, CA 13154' and a phone number '+799-2324-9890'. Additionally, there is a website address 'www.sprinklesunday.com'. The style of the image is illustrative and appears to be designed for advertising or promotional purposes.",
|
31 |
+
]
|
config/v04sv03_lora_r64_upto50layers_bs1_lr1_prodigy_800k_wds_512_filtered_10ep_none_8gpu.py
ADDED
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
_base_ = "./base.py"
|
2 |
+
|
3 |
+
### path & device settings
|
4 |
+
img_tar_path = "/openseg_blob/puyifan/shared_data/CANVA_802000_resolution512max21760tokens/"
|
5 |
+
output_path_base = "/openseg_blob/zhaoym/multi_layer_sd3/work_dirs/"
|
6 |
+
cache_dir = "/openseg_blob/zhaoym/pretrained/flux"
|
7 |
+
# transformer_varient = "ashen0209/Flux-Dev2Pro"
|
8 |
+
pretrained_lora_dir = "/openseg_blob/zhaoym/sd3/work_dirs/canva500k_mix100k_sft_flux"
|
9 |
+
total_gpu_num = 8
|
10 |
+
|
11 |
+
### wandb settings
|
12 |
+
wandb_job_name = "flux_" + '{{fileBasenameNoExtension}}'
|
13 |
+
|
14 |
+
### Dataset Settings
|
15 |
+
resolution = 512
|
16 |
+
dataloader_pin_memory = True
|
17 |
+
dataloader_num_workers = 16
|
18 |
+
train_batch_size = 1
|
19 |
+
dataset_cfg = dict(
|
20 |
+
img_tar_path=img_tar_path,
|
21 |
+
num_train_examples=802000,
|
22 |
+
per_gpu_batch_size=train_batch_size,
|
23 |
+
global_batch_size=(train_batch_size * total_gpu_num),
|
24 |
+
num_workers=dataloader_num_workers,
|
25 |
+
resolution=resolution,
|
26 |
+
center_crop=True,
|
27 |
+
random_flip=False,
|
28 |
+
shuffle_buffer_size=1000,
|
29 |
+
pin_memory=dataloader_pin_memory,
|
30 |
+
persistent_workers=True,
|
31 |
+
)
|
32 |
+
|
33 |
+
### Model Settings
|
34 |
+
rank = 64
|
35 |
+
text_encoder_rank = 64
|
36 |
+
train_text_encoder = False
|
37 |
+
max_layer_num = 50 + 2
|
38 |
+
learnable_proj = True
|
39 |
+
|
40 |
+
### Training Settings
|
41 |
+
weighting_scheme = "none"
|
42 |
+
logit_mean = 0.0
|
43 |
+
logit_std = 1.0
|
44 |
+
mode_scale = 1.29
|
45 |
+
guidance_scale = 1.0 ###IMPORTANT
|
46 |
+
layer_weighting = 5.0
|
47 |
+
|
48 |
+
# steps
|
49 |
+
# train_batch_size = 1
|
50 |
+
num_train_epochs = 1
|
51 |
+
max_train_steps = None
|
52 |
+
checkpointing_steps = 2000
|
53 |
+
resume_from_checkpoint = "latest"
|
54 |
+
gradient_accumulation_steps = 1
|
55 |
+
|
56 |
+
# lr
|
57 |
+
optimizer = "prodigy"
|
58 |
+
learning_rate = 1.0
|
59 |
+
scale_lr = False
|
60 |
+
lr_scheduler = "constant"
|
61 |
+
lr_warmup_steps = 0
|
62 |
+
lr_num_cycles = 1
|
63 |
+
lr_power = 1.0
|
64 |
+
|
65 |
+
# optim
|
66 |
+
adam_beta1 = 0.9
|
67 |
+
adam_beta2 = 0.999
|
68 |
+
adam_weight_decay = 1e-3
|
69 |
+
adam_epsilon = 1e-8
|
70 |
+
prodigy_beta3 = None
|
71 |
+
prodigy_decouple = True
|
72 |
+
prodigy_use_bias_correction = True
|
73 |
+
prodigy_safeguard_warmup = True
|
74 |
+
max_grad_norm = 1.0
|
75 |
+
|
76 |
+
# logging
|
77 |
+
tracker_task_name = '{{fileBasenameNoExtension}}'
|
78 |
+
output_dir = output_path_base + "{{fileBasenameNoExtension}}"
|
79 |
+
|
80 |
+
### Validation Settings
|
81 |
+
num_validation_images = 1
|
82 |
+
validation_steps = 2000
|
83 |
+
validation_prompts = [
|
84 |
+
'The image features a background with a soft, pastel color gradient that transitions from pink to purple. There are abstract floral elements scattered throughout the background, with some appearing to be in full bloom and others in a more delicate, bud-like state. The flowers have a watercolor effect, with soft edges that blend into the background.\n\nCentered in the image is a quote in a serif font that reads, "You\'re free to be different." The text is black, which stands out against the lighter background. The overall style of the image is artistic and inspirational, with a motivational message that encourages individuality and self-expression. The image could be used for motivational purposes, as a background for a blog or social media post, or as part of a personal development or self-help theme.',
|
85 |
+
'The image features a logo for a company named "Bull Head Party Adventure." The logo is stylized with a cartoon-like depiction of a bull\'s head, which is the central element of the design. The bull has prominent horns and a fierce expression, with its mouth slightly open as if it\'s snarling or roaring. The color scheme of the bull is a mix of brown and beige tones, with the horns highlighted in a lighter shade.\n\nBelow the bull\'s head, the company name is written in a bold, sans-serif font. The text is arranged in two lines, with "Bull Head" on the top line and "Party Adventure" on the bottom line. The font color matches the color of the bull, creating a cohesive look. The overall style of the image is playful and energetic, suggesting that the company may offer exciting or adventurous party experiences.',
|
86 |
+
'The image features a festive and colorful illustration with a theme related to the Islamic holiday of Eid al-Fitr. At the center of the image is a large, ornate crescent moon with intricate patterns and decorations. Surrounding the moon are several smaller stars and crescents, also adorned with decorative elements. These smaller celestial motifs are suspended from the moon, creating a sense of depth and dimension.\n\nBelow the central moon, there is a banner with the text "Eid Mubarak" in a stylized, elegant font. The text is in a bold, dark color that stands out against the lighter background. The background itself is a gradient of light to dark green, which complements the golden and white hues of the celestial motifs.\n\nThe overall style of the image is celebratory and decorative, with a focus on the traditional symbols associated with Eid al-Fitr. The use of gold and white gives the image a luxurious and festive feel, while the green background is a color often associated with Islam. The image appears to be a digital artwork or graphic design, possibly intended for use as a greeting card or a festive decoration.',
|
87 |
+
'The image is a festive graphic with a dark background. At the center, there is a large, bold text that reads "Happy New Year 2023" in a combination of white and gold colors. The text is surrounded by numerous white balloons with gold ribbons, giving the impression of a celebratory atmosphere. The balloons are scattered around the text, creating a sense of depth and movement. Additionally, there are small gold sparkles and confetti-like elements that add to the celebratory theme. The overall design suggests a New Year\'s celebration, with the year 2023 being the focal point.',
|
88 |
+
'The image is a stylized illustration with a flat design aesthetic. It depicts a scene related to healthcare or medical care. In the center, there is a hospital bed with a patient lying down, appearing to be resting or possibly receiving treatment. The patient is surrounded by three individuals who seem to be healthcare professionals or caregivers. They are standing around the bed, with one on each side and one at the foot of the bed. The person at the foot of the bed is holding a clipboard, suggesting they might be taking notes or reviewing medical records.\n\nThe room has a window with curtains partially drawn, allowing some light to enter. The color palette is soft, with pastel tones dominating the scene. The text "INTERNATIONAL CANCER DAY" is prominently displayed at the top of the image, indicating that the illustration is related to this event. The overall impression is one of care and support, with a focus on the patient\'s well-being.',
|
89 |
+
'The image features a stylized illustration of a man with a beard and a tank top, drinking from a can. The man is depicted in a simplified, cartoon-like style with a limited color palette. Above him, there is a text that reads "Happy Eating, Friends" in a bold, friendly font. Below the illustration, there is another line of text that states "Food is a Necessity That is Not Prioritized," which is also in a bold, sans-serif font. The background of the image is a gradient of light to dark blue, giving the impression of a sky or a calm, serene environment. The overall style of the image is casual and approachable, with a focus on the message conveyed by the text.',
|
90 |
+
'The image is a digital illustration with a pastel pink background. At the top, there is a text that reads "Sending you my Easter wishes" in a simple, sans-serif font. Below this, a larger text states "May Your Heart be Happy!" in a more decorative, serif font. Underneath this main message, there is a smaller text that says "Let the miracle of the season fill you with hope and love."\n\nThe illustration features three stylized flowers with smiling faces. On the left, there is a purple flower with a yellow center. In the center, there is a blue flower with a green center. On the right, there is a pink flower with a yellow center. Each flower has a pair of eyes and a mouth, giving them a friendly appearance. The flowers are drawn with a cartoon-like style, using solid colors and simple shapes.\n\nThe overall style of the image is cheerful and whimsical, with a clear Easter theme suggested by the text and the presence of flowers, which are often associated with spring and new beginnings.',
|
91 |
+
'The image is a vibrant and colorful graphic with a pink background. In the center, there is a photograph of a man and a woman embracing each other. The man is wearing a white shirt, and the woman is wearing a patterned top. They are both smiling and appear to be in a joyful mood.\n\nSurrounding the photograph are various elements that suggest a festive or celebratory theme. There are three hot air balloons in the background, each with a different design: one with a heart, one with a gift box, and one with a basket. These balloons are floating against a clear sky.\n\nAdditionally, there are two gift boxes with ribbons, one on the left and one on the right side of the image. These gift boxes are stylized with a glossy finish and are placed at different heights, creating a sense of depth.\n\nAt the bottom of the image, there is a large red heart, which is a common symbol associated with love and Valentine\'s Day.\n\nFinally, at the very bottom of the image, there is a text that reads "Happy Valentine\'s Day," which confirms the theme of the image as a Valentine\'s Day greeting. The text is in a playful, cursive font that matches the overall cheerful and romantic tone of the image.',
|
92 |
+
'The image depicts a stylized illustration of two women sitting on stools, engaged in conversation. They are wearing traditional attire, with headscarves and patterned dresses. The woman on the left is wearing a brown dress with a purple pattern, while the woman on the right is wearing a purple dress with a brown pattern. Between them is a purple flower. Above the women, the text "INTERNATIONAL WOMEN\'S DAY" is written in bold, uppercase letters. The background is a soft, pastel pink, and there are abstract, swirling lines in a darker shade of pink above the women. The overall style of the image is simplistic and cartoonish, with a warm and friendly tone.',
|
93 |
+
'The image is a digital graphic with a clean, minimalist design. It features a light blue background with a subtle floral pattern at the bottom. On the left side, there is a large, bold text that reads "Our Global Idea." The text is in a serif font and is colored in a darker shade of blue, creating a contrast against the lighter background.\n\nOn the right side, there is a smaller text in a sans-serif font that provides information about utilizing the Live Q&A feature of Canva. The text suggests using this feature to engage an audience more effectively, such as asking about their opinions on certain topics and themes. The text is in a lighter shade of blue, which matches the background, and it is enclosed within a decorative border that includes a floral motif, mirroring the design at the bottom of the image.\n\nThe overall style of the image is professional and modern, with a focus on typography and a simple color scheme. The design elements are well-balanced, with the text and decorative elements complementing each other without overwhelming the viewer.',
|
94 |
+
'The image is a stylized illustration with a warm, peach-colored background. At the center, there is a vintage-style radio with a prominent dial and antenna. The radio is emitting a blue, star-like burst of light or energy from its top. Surrounding the radio are various objects and elements that seem to be floating or suspended in the air. These include a brown, cone-shaped object, a blue, star-like shape, and a brown, wavy, abstract shape that could be interpreted as a flower or a wave.\n\nAt the top of the image, there is text that reads "World Radio Day" in a bold, serif font. Below this, in a smaller, sans-serif font, is the date "13 February 2022." The overall style of the image is playful and cartoonish, with a clear focus on celebrating World Radio Day.',
|
95 |
+
'The image is a graphic design of a baby shower invitation. The central focus is a cute, cartoon-style teddy bear with a friendly expression, sitting upright. The bear is colored in a soft, light brown hue. Above the bear, there is a bold text that reads "YOU\'RE INVITED" in a playful, sans-serif font. Below this, the words "BABY SHOWER" are prominently displayed in a larger, more decorative font, suggesting the theme of the event.\n\nThe background of the invitation is a soft, light pink color, which adds to the gentle and welcoming atmosphere of the design. At the bottom of the image, there is additional text providing specific details about the event. It reads "27 January, 2022 - 8:00 PM" followed by "FAUGET INDUSTRIES CAFE," indicating the date, time, and location of the baby shower.\n\nThe overall style of the image is warm, inviting, and child-friendly, with a clear focus on the theme of a baby shower celebration. The use of a teddy bear as the central image reinforces the baby-related theme. The design is simple yet effective, with a clear hierarchy of information that guides the viewer\'s attention from the top to the bottom of the invitation.',
|
96 |
+
]
|
97 |
+
|
98 |
+
validation_boxes = [
|
99 |
+
[(0, 0, 512, 512), (0, 0, 512, 512), (368, 0, 512, 272), (0, 272, 112, 512), (160, 208, 352, 304)],
|
100 |
+
[(0, 0, 512, 512), (0, 0, 512, 512), (128, 128, 384, 304), (96, 288, 416, 336), (128, 336, 384, 368)],
|
101 |
+
[(0, 0, 512, 512), (0, 0, 512, 512), (112, 48, 400, 368), (0, 48, 96, 176), (128, 336, 384, 384), (240, 384, 384, 432)],
|
102 |
+
[(0, 0, 512, 512), (0, 0, 512, 512), (32, 32, 480, 480), (80, 176, 432, 368), (64, 176, 448, 224), (144, 96, 368, 224)],
|
103 |
+
[(0, 0, 512, 512), (0, 0, 512, 512), (0, 64, 176, 272), (0, 400, 512, 512), (16, 160, 496, 512), (224, 48, 464, 112), (208, 96, 464, 160)],
|
104 |
+
[(0, 0, 512, 512), (0, 0, 512, 512), (112, 224, 512, 512), (0, 0, 240, 160), (144, 144, 512, 512), (48, 64, 432, 208), (48, 400, 256, 448)],
|
105 |
+
[(0, 0, 512, 512), (0, 0, 512, 512), (160, 48, 352, 80), (64, 80, 448, 192), (128, 208, 384, 240), (320, 240, 512, 512), (80, 272, 368, 512), (0, 224, 192, 512)],
|
106 |
+
[(0, 0, 512, 512), (0, 0, 512, 512), (48, 0, 464, 304), (128, 144, 384, 400), (288, 288, 384, 368), (336, 304, 400, 368), (176, 432, 336, 480), (224, 400, 288, 432)],
|
107 |
+
[(0, 0, 512, 512), (0, 0, 512, 512), (32, 288, 448, 512), (144, 176, 336, 400), (224, 208, 272, 256), (160, 128, 336, 192), (192, 368, 304, 400), (368, 80, 448, 224), (48, 160, 128, 256)],
|
108 |
+
[(0, 0, 512, 512), (0, 0, 512, 512), (0, 112, 112, 240), (400, 272, 512, 416), (400, 112, 512, 240), (0, 272, 112, 400), (64, 192, 176, 320), (224, 192, 432, 320), (224, 304, 448, 368)],
|
109 |
+
[(0, 0, 512, 512), (0, 0, 512, 512), (0, 352, 512, 512), (112, 176, 368, 432), (48, 176, 128, 256), (48, 368, 128, 448), (384, 192, 480, 272), (384, 336, 432, 384), (80, 80, 432, 128), (176, 128, 336, 160)],
|
110 |
+
[(0, 0, 512, 512), (0, 0, 512, 512), (0, 0, 512, 352), (144, 384, 368, 448), (160, 192, 352, 432), (368, 0, 512, 144), (0, 0, 144, 144), (128, 80, 384, 208), (128, 448, 384, 496), (176, 48, 336, 80)],
|
111 |
+
]
|
custom_model_mmdit.py
ADDED
@@ -0,0 +1,334 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from typing import Any, Dict, List, Optional, Union, Tuple
|
4 |
+
|
5 |
+
from accelerate.utils import set_module_tensor_to_device
|
6 |
+
from diffusers.models.modeling_outputs import Transformer2DModelOutput
|
7 |
+
from diffusers.models.normalization import AdaLayerNormContinuous
|
8 |
+
from diffusers.models.embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, FluxPosEmbed
|
9 |
+
from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel, FluxTransformerBlock, FluxSingleTransformerBlock
|
10 |
+
|
11 |
+
from diffusers.configuration_utils import register_to_config
|
12 |
+
from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
|
13 |
+
|
14 |
+
|
15 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
16 |
+
|
17 |
+
|
18 |
+
class CustomFluxTransformer2DModel(FluxTransformer2DModel):
|
19 |
+
"""
|
20 |
+
The Transformer model introduced in Flux.
|
21 |
+
|
22 |
+
Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
|
23 |
+
|
24 |
+
Parameters:
|
25 |
+
patch_size (`int`): Patch size to turn the input data into small patches.
|
26 |
+
in_channels (`int`, *optional*, defaults to 16): The number of channels in the input.
|
27 |
+
num_layers (`int`, *optional*, defaults to 18): The number of layers of MMDiT blocks to use.
|
28 |
+
num_single_layers (`int`, *optional*, defaults to 18): The number of layers of single DiT blocks to use.
|
29 |
+
attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head.
|
30 |
+
num_attention_heads (`int`, *optional*, defaults to 18): The number of heads to use for multi-head attention.
|
31 |
+
joint_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
|
32 |
+
pooled_projection_dim (`int`): Number of dimensions to use when projecting the `pooled_projections`.
|
33 |
+
guidance_embeds (`bool`, defaults to False): Whether to use guidance embeddings.
|
34 |
+
"""
|
35 |
+
|
36 |
+
@register_to_config
|
37 |
+
def __init__(
|
38 |
+
self,
|
39 |
+
patch_size: int = 1,
|
40 |
+
in_channels: int = 64,
|
41 |
+
num_layers: int = 19,
|
42 |
+
num_single_layers: int = 38,
|
43 |
+
attention_head_dim: int = 128,
|
44 |
+
num_attention_heads: int = 24,
|
45 |
+
joint_attention_dim: int = 4096,
|
46 |
+
pooled_projection_dim: int = 768,
|
47 |
+
guidance_embeds: bool = False,
|
48 |
+
axes_dims_rope: Tuple[int] = (16, 56, 56),
|
49 |
+
max_layer_num: int = 10,
|
50 |
+
):
|
51 |
+
super(FluxTransformer2DModel, self).__init__()
|
52 |
+
self.out_channels = in_channels
|
53 |
+
self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim
|
54 |
+
|
55 |
+
self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope)
|
56 |
+
|
57 |
+
text_time_guidance_cls = (
|
58 |
+
CombinedTimestepGuidanceTextProjEmbeddings if guidance_embeds else CombinedTimestepTextProjEmbeddings
|
59 |
+
)
|
60 |
+
self.time_text_embed = text_time_guidance_cls(
|
61 |
+
embedding_dim=self.inner_dim, pooled_projection_dim=self.config.pooled_projection_dim
|
62 |
+
)
|
63 |
+
|
64 |
+
self.context_embedder = nn.Linear(self.config.joint_attention_dim, self.inner_dim)
|
65 |
+
self.x_embedder = torch.nn.Linear(self.config.in_channels, self.inner_dim)
|
66 |
+
|
67 |
+
self.transformer_blocks = nn.ModuleList(
|
68 |
+
[
|
69 |
+
FluxTransformerBlock(
|
70 |
+
dim=self.inner_dim,
|
71 |
+
num_attention_heads=self.config.num_attention_heads,
|
72 |
+
attention_head_dim=self.config.attention_head_dim,
|
73 |
+
)
|
74 |
+
for i in range(self.config.num_layers)
|
75 |
+
]
|
76 |
+
)
|
77 |
+
|
78 |
+
self.single_transformer_blocks = nn.ModuleList(
|
79 |
+
[
|
80 |
+
FluxSingleTransformerBlock(
|
81 |
+
dim=self.inner_dim,
|
82 |
+
num_attention_heads=self.config.num_attention_heads,
|
83 |
+
attention_head_dim=self.config.attention_head_dim,
|
84 |
+
)
|
85 |
+
for i in range(self.config.num_single_layers)
|
86 |
+
]
|
87 |
+
)
|
88 |
+
|
89 |
+
self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6)
|
90 |
+
self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)
|
91 |
+
|
92 |
+
self.gradient_checkpointing = False
|
93 |
+
|
94 |
+
self.max_layer_num = max_layer_num
|
95 |
+
|
96 |
+
# the following process ensures self.layer_pe is not created as a meta tensor
|
97 |
+
self.layer_pe = nn.Parameter(torch.empty(1, self.max_layer_num, 1, 1, self.inner_dim))
|
98 |
+
nn.init.trunc_normal_(self.layer_pe, mean=0.0, std=0.02, a=-2.0, b=2.0)
|
99 |
+
# layer_pe_value = nn.init.trunc_normal_(
|
100 |
+
# nn.Parameter(torch.zeros(
|
101 |
+
# 1, self.max_layer_num, 1, 1, self.inner_dim,
|
102 |
+
# )),
|
103 |
+
# mean=0.0, std=0.02, a=-2.0, b=2.0,
|
104 |
+
# ).data.detach()
|
105 |
+
# self.layer_pe = nn.Parameter(layer_pe_value)
|
106 |
+
# set_module_tensor_to_device(
|
107 |
+
# self,
|
108 |
+
# 'layer_pe',
|
109 |
+
# device='cpu',
|
110 |
+
# value=layer_pe_value,
|
111 |
+
# dtype=layer_pe_value.dtype,
|
112 |
+
# )
|
113 |
+
|
114 |
+
@classmethod
|
115 |
+
def from_pretrained(cls, *args, **kwarg):
|
116 |
+
model = super().from_pretrained(*args, **kwarg)
|
117 |
+
for name, para in model.named_parameters():
|
118 |
+
if name != 'layer_pe':
|
119 |
+
device = para.device
|
120 |
+
break
|
121 |
+
model.layer_pe.to(device)
|
122 |
+
return model
|
123 |
+
|
124 |
+
def crop_each_layer(self, hidden_states, list_layer_box):
|
125 |
+
"""
|
126 |
+
hidden_states: [1, n_layers, h, w, inner_dim]
|
127 |
+
list_layer_box: List, length=n_layers, each element is a Tuple of 4 elements (x1, y1, x2, y2)
|
128 |
+
"""
|
129 |
+
token_list = []
|
130 |
+
for layer_idx in range(hidden_states.shape[1]):
|
131 |
+
if list_layer_box[layer_idx] == None:
|
132 |
+
continue
|
133 |
+
else:
|
134 |
+
x1, y1, x2, y2 = list_layer_box[layer_idx]
|
135 |
+
x1, y1, x2, y2 = x1 // 16, y1 // 16, x2 // 16, y2 // 16
|
136 |
+
layer_token = hidden_states[:, layer_idx, y1:y2, x1:x2, :]
|
137 |
+
bs, h, w, c = layer_token.shape
|
138 |
+
layer_token = layer_token.reshape(bs, -1, c)
|
139 |
+
token_list.append(layer_token)
|
140 |
+
result = torch.cat(token_list, dim=1)
|
141 |
+
return result
|
142 |
+
|
143 |
+
def fill_in_processed_tokens(self, hidden_states, full_hidden_states, list_layer_box):
|
144 |
+
"""
|
145 |
+
hidden_states: [1, h1xw1 + h2xw2 + ... + hlxwl , inner_dim]
|
146 |
+
full_hidden_states: [1, n_layers, h, w, inner_dim]
|
147 |
+
list_layer_box: List, length=n_layers, each element is a Tuple of 4 elements (x1, y1, x2, y2)
|
148 |
+
"""
|
149 |
+
used_token_len = 0
|
150 |
+
bs = hidden_states.shape[0]
|
151 |
+
for layer_idx in range(full_hidden_states.shape[1]):
|
152 |
+
if list_layer_box[layer_idx] == None:
|
153 |
+
continue
|
154 |
+
else:
|
155 |
+
x1, y1, x2, y2 = list_layer_box[layer_idx]
|
156 |
+
x1, y1, x2, y2 = x1 // 16, y1 // 16, x2 // 16, y2 // 16
|
157 |
+
full_hidden_states[:, layer_idx, y1:y2, x1:x2, :] = hidden_states[:, used_token_len: used_token_len + (y2-y1) * (x2-x1), :].reshape(bs, y2-y1, x2-x1, -1)
|
158 |
+
used_token_len = used_token_len + (y2-y1) * (x2-x1)
|
159 |
+
return full_hidden_states
|
160 |
+
|
161 |
+
def forward(
|
162 |
+
self,
|
163 |
+
hidden_states: torch.Tensor,
|
164 |
+
list_layer_box: List[Tuple] = None,
|
165 |
+
encoder_hidden_states: torch.Tensor = None,
|
166 |
+
pooled_projections: torch.Tensor = None,
|
167 |
+
timestep: torch.LongTensor = None,
|
168 |
+
img_ids: torch.Tensor = None,
|
169 |
+
txt_ids: torch.Tensor = None,
|
170 |
+
guidance: torch.Tensor = None,
|
171 |
+
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
172 |
+
return_dict: bool = True,
|
173 |
+
) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
|
174 |
+
"""
|
175 |
+
The [`FluxTransformer2DModel`] forward method.
|
176 |
+
|
177 |
+
Args:
|
178 |
+
hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
|
179 |
+
Input `hidden_states`.
|
180 |
+
encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
|
181 |
+
Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
|
182 |
+
pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
|
183 |
+
from the embeddings of input conditions.
|
184 |
+
timestep ( `torch.LongTensor`):
|
185 |
+
Used to indicate denoising step.
|
186 |
+
block_controlnet_hidden_states: (`list` of `torch.Tensor`):
|
187 |
+
A list of tensors that if specified are added to the residuals of transformer blocks.
|
188 |
+
joint_attention_kwargs (`dict`, *optional*):
|
189 |
+
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
190 |
+
`self.processor` in
|
191 |
+
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
192 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
193 |
+
Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
|
194 |
+
tuple.
|
195 |
+
|
196 |
+
Returns:
|
197 |
+
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
|
198 |
+
`tuple` where the first element is the sample tensor.
|
199 |
+
"""
|
200 |
+
if joint_attention_kwargs is not None:
|
201 |
+
joint_attention_kwargs = joint_attention_kwargs.copy()
|
202 |
+
lora_scale = joint_attention_kwargs.pop("scale", 1.0)
|
203 |
+
else:
|
204 |
+
lora_scale = 1.0
|
205 |
+
|
206 |
+
if USE_PEFT_BACKEND:
|
207 |
+
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
208 |
+
scale_lora_layers(self, lora_scale)
|
209 |
+
else:
|
210 |
+
if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
|
211 |
+
logger.warning(
|
212 |
+
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
|
213 |
+
)
|
214 |
+
|
215 |
+
bs, n_layers, channel_latent, height, width = hidden_states.shape # [bs, n_layers, c_latent, h, w]
|
216 |
+
|
217 |
+
hidden_states = hidden_states.view(bs, n_layers, channel_latent, height // 2, 2, width // 2, 2) # [bs, n_layers, c_latent, h/2, 2, w/2, 2]
|
218 |
+
hidden_states = hidden_states.permute(0, 1, 3, 5, 2, 4, 6) # [bs, n_layers, h/2, w/2, c_latent, 2, 2]
|
219 |
+
hidden_states = hidden_states.reshape(bs, n_layers, height // 2, width // 2, channel_latent * 4) # [bs, n_layers, h/2, w/2, c_latent*4]
|
220 |
+
hidden_states = self.x_embedder(hidden_states) # [bs, n_layers, h/2, w/2, inner_dim]
|
221 |
+
|
222 |
+
full_hidden_states = torch.zeros_like(hidden_states) # [bs, n_layers, h/2, w/2, inner_dim]
|
223 |
+
layer_pe = self.layer_pe.view(1, self.max_layer_num, 1, 1, self.inner_dim) # [1, max_n_layers, 1, 1, inner_dim]
|
224 |
+
hidden_states = hidden_states + layer_pe[:, :n_layers] # [bs, n_layers, h/2, w/2, inner_dim] + [1, n_layers, 1, 1, inner_dim] --> [bs, f, h/2, w/2, inner_dim]
|
225 |
+
hidden_states = self.crop_each_layer(hidden_states, list_layer_box) # [bs, token_len, inner_dim]
|
226 |
+
|
227 |
+
timestep = timestep.to(hidden_states.dtype) * 1000
|
228 |
+
if guidance is not None:
|
229 |
+
guidance = guidance.to(hidden_states.dtype) * 1000
|
230 |
+
else:
|
231 |
+
guidance = None
|
232 |
+
temb = (
|
233 |
+
self.time_text_embed(timestep, pooled_projections)
|
234 |
+
if guidance is None
|
235 |
+
else self.time_text_embed(timestep, guidance, pooled_projections)
|
236 |
+
)
|
237 |
+
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
|
238 |
+
|
239 |
+
if txt_ids.ndim == 3:
|
240 |
+
logger.warning(
|
241 |
+
"Passing `txt_ids` 3d torch.Tensor is deprecated."
|
242 |
+
"Please remove the batch dimension and pass it as a 2d torch Tensor"
|
243 |
+
)
|
244 |
+
txt_ids = txt_ids[0]
|
245 |
+
if img_ids.ndim == 3:
|
246 |
+
logger.warning(
|
247 |
+
"Passing `img_ids` 3d torch.Tensor is deprecated."
|
248 |
+
"Please remove the batch dimension and pass it as a 2d torch Tensor"
|
249 |
+
)
|
250 |
+
img_ids = img_ids[0]
|
251 |
+
ids = torch.cat((txt_ids, img_ids), dim=0)
|
252 |
+
image_rotary_emb = self.pos_embed(ids)
|
253 |
+
|
254 |
+
for index_block, block in enumerate(self.transformer_blocks):
|
255 |
+
if self.training and self.gradient_checkpointing:
|
256 |
+
|
257 |
+
def create_custom_forward(module, return_dict=None):
|
258 |
+
def custom_forward(*inputs):
|
259 |
+
if return_dict is not None:
|
260 |
+
return module(*inputs, return_dict=return_dict)
|
261 |
+
else:
|
262 |
+
return module(*inputs)
|
263 |
+
|
264 |
+
return custom_forward
|
265 |
+
|
266 |
+
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
267 |
+
encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
|
268 |
+
create_custom_forward(block),
|
269 |
+
hidden_states,
|
270 |
+
encoder_hidden_states,
|
271 |
+
temb,
|
272 |
+
image_rotary_emb,
|
273 |
+
**ckpt_kwargs,
|
274 |
+
)
|
275 |
+
|
276 |
+
else:
|
277 |
+
encoder_hidden_states, hidden_states = block(
|
278 |
+
hidden_states=hidden_states,
|
279 |
+
encoder_hidden_states=encoder_hidden_states,
|
280 |
+
temb=temb,
|
281 |
+
image_rotary_emb=image_rotary_emb,
|
282 |
+
)
|
283 |
+
|
284 |
+
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
285 |
+
|
286 |
+
for index_block, block in enumerate(self.single_transformer_blocks):
|
287 |
+
if self.training and self.gradient_checkpointing:
|
288 |
+
|
289 |
+
def create_custom_forward(module, return_dict=None):
|
290 |
+
def custom_forward(*inputs):
|
291 |
+
if return_dict is not None:
|
292 |
+
return module(*inputs, return_dict=return_dict)
|
293 |
+
else:
|
294 |
+
return module(*inputs)
|
295 |
+
|
296 |
+
return custom_forward
|
297 |
+
|
298 |
+
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
299 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
300 |
+
create_custom_forward(block),
|
301 |
+
hidden_states,
|
302 |
+
temb,
|
303 |
+
image_rotary_emb,
|
304 |
+
**ckpt_kwargs,
|
305 |
+
)
|
306 |
+
|
307 |
+
else:
|
308 |
+
hidden_states = block(
|
309 |
+
hidden_states=hidden_states,
|
310 |
+
temb=temb,
|
311 |
+
image_rotary_emb=image_rotary_emb,
|
312 |
+
)
|
313 |
+
|
314 |
+
hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]
|
315 |
+
|
316 |
+
hidden_states = self.fill_in_processed_tokens(hidden_states, full_hidden_states, list_layer_box) # [bs, n_layers, h/2, w/2, inner_dim]
|
317 |
+
hidden_states = hidden_states.view(bs, -1, self.inner_dim) # [bs, n_layers * full_len, inner_dim]
|
318 |
+
|
319 |
+
hidden_states = self.norm_out(hidden_states, temb) # [bs, n_layers * full_len, inner_dim]
|
320 |
+
hidden_states = self.proj_out(hidden_states) # [bs, n_layers * full_len, c_latent*4]
|
321 |
+
|
322 |
+
# unpatchify
|
323 |
+
hidden_states = hidden_states.view(bs, n_layers, height//2, width//2, channel_latent, 2, 2) # [bs, n_layers, h/2, w/2, c_latent, 2, 2]
|
324 |
+
hidden_states = hidden_states.permute(0, 1, 4, 2, 5, 3, 6)
|
325 |
+
output = hidden_states.reshape(bs, n_layers, channel_latent, height, width) # [bs, n_layers, c_latent, h, w]
|
326 |
+
|
327 |
+
if USE_PEFT_BACKEND:
|
328 |
+
# remove `lora_scale` from each PEFT layer
|
329 |
+
unscale_lora_layers(self, lora_scale)
|
330 |
+
|
331 |
+
if not return_dict:
|
332 |
+
return (output,)
|
333 |
+
|
334 |
+
return Transformer2DModelOutput(sample=output)
|
custom_model_transp_vae.py
ADDED
@@ -0,0 +1,331 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import einops
|
2 |
+
from collections import OrderedDict
|
3 |
+
from functools import partial
|
4 |
+
from typing import Callable
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
import torchvision
|
9 |
+
from torch.utils.checkpoint import checkpoint
|
10 |
+
|
11 |
+
from accelerate.utils import set_module_tensor_to_device
|
12 |
+
from diffusers.models.embeddings import apply_rotary_emb, FluxPosEmbed
|
13 |
+
from diffusers.models.modeling_utils import ModelMixin
|
14 |
+
from diffusers.configuration_utils import ConfigMixin
|
15 |
+
from diffusers.loaders import FromOriginalModelMixin
|
16 |
+
|
17 |
+
|
18 |
+
class MLPBlock(torchvision.ops.misc.MLP):
|
19 |
+
"""Transformer MLP block."""
|
20 |
+
|
21 |
+
_version = 2
|
22 |
+
|
23 |
+
def __init__(self, in_dim: int, mlp_dim: int, dropout: float):
|
24 |
+
super().__init__(in_dim, [mlp_dim, in_dim], activation_layer=nn.GELU, inplace=None, dropout=dropout)
|
25 |
+
|
26 |
+
for m in self.modules():
|
27 |
+
if isinstance(m, nn.Linear):
|
28 |
+
nn.init.xavier_uniform_(m.weight)
|
29 |
+
if m.bias is not None:
|
30 |
+
nn.init.normal_(m.bias, std=1e-6)
|
31 |
+
|
32 |
+
def _load_from_state_dict(
|
33 |
+
self,
|
34 |
+
state_dict,
|
35 |
+
prefix,
|
36 |
+
local_metadata,
|
37 |
+
strict,
|
38 |
+
missing_keys,
|
39 |
+
unexpected_keys,
|
40 |
+
error_msgs,
|
41 |
+
):
|
42 |
+
version = local_metadata.get("version", None)
|
43 |
+
|
44 |
+
if version is None or version < 2:
|
45 |
+
# Replacing legacy MLPBlock with MLP. See https://github.com/pytorch/vision/pull/6053
|
46 |
+
for i in range(2):
|
47 |
+
for type in ["weight", "bias"]:
|
48 |
+
old_key = f"{prefix}linear_{i+1}.{type}"
|
49 |
+
new_key = f"{prefix}{3*i}.{type}"
|
50 |
+
if old_key in state_dict:
|
51 |
+
state_dict[new_key] = state_dict.pop(old_key)
|
52 |
+
|
53 |
+
super()._load_from_state_dict(
|
54 |
+
state_dict,
|
55 |
+
prefix,
|
56 |
+
local_metadata,
|
57 |
+
strict,
|
58 |
+
missing_keys,
|
59 |
+
unexpected_keys,
|
60 |
+
error_msgs,
|
61 |
+
)
|
62 |
+
|
63 |
+
|
64 |
+
class EncoderBlock(nn.Module):
|
65 |
+
"""Transformer encoder block."""
|
66 |
+
|
67 |
+
def __init__(
|
68 |
+
self,
|
69 |
+
num_heads: int,
|
70 |
+
hidden_dim: int,
|
71 |
+
mlp_dim: int,
|
72 |
+
dropout: float,
|
73 |
+
attention_dropout: float,
|
74 |
+
norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6),
|
75 |
+
):
|
76 |
+
super().__init__()
|
77 |
+
self.num_heads = num_heads
|
78 |
+
self.hidden_dim = hidden_dim
|
79 |
+
self.num_heads = num_heads
|
80 |
+
|
81 |
+
# Attention block
|
82 |
+
self.ln_1 = norm_layer(hidden_dim)
|
83 |
+
self.self_attention = nn.MultiheadAttention(hidden_dim, num_heads, dropout=attention_dropout, batch_first=True)
|
84 |
+
self.dropout = nn.Dropout(dropout)
|
85 |
+
|
86 |
+
# MLP block
|
87 |
+
self.ln_2 = norm_layer(hidden_dim)
|
88 |
+
self.mlp = MLPBlock(hidden_dim, mlp_dim, dropout)
|
89 |
+
|
90 |
+
def forward(self, input: torch.Tensor, freqs_cis):
|
91 |
+
torch._assert(input.dim() == 3, f"Expected (batch_size, seq_length, hidden_dim) got {input.shape}")
|
92 |
+
B, L, C = input.shape
|
93 |
+
x = self.ln_1(input)
|
94 |
+
if freqs_cis is not None:
|
95 |
+
query = x.view(B, L, self.num_heads, self.hidden_dim // self.num_heads).transpose(1, 2)
|
96 |
+
query = apply_rotary_emb(query, freqs_cis)
|
97 |
+
query = query.transpose(1, 2).reshape(B, L, self.hidden_dim)
|
98 |
+
x, _ = self.self_attention(query, query, x, need_weights=False)
|
99 |
+
x = self.dropout(x)
|
100 |
+
x = x + input
|
101 |
+
|
102 |
+
y = self.ln_2(x)
|
103 |
+
y = self.mlp(y)
|
104 |
+
return x + y
|
105 |
+
|
106 |
+
|
107 |
+
class Encoder(nn.Module):
|
108 |
+
"""Transformer Model Encoder for sequence to sequence translation."""
|
109 |
+
|
110 |
+
def __init__(
|
111 |
+
self,
|
112 |
+
seq_length: int,
|
113 |
+
num_layers: int,
|
114 |
+
num_heads: int,
|
115 |
+
hidden_dim: int,
|
116 |
+
mlp_dim: int,
|
117 |
+
dropout: float,
|
118 |
+
attention_dropout: float,
|
119 |
+
norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6),
|
120 |
+
):
|
121 |
+
super().__init__()
|
122 |
+
# Note that batch_size is on the first dim because
|
123 |
+
# we have batch_first=True in nn.MultiAttention() by default
|
124 |
+
# self.pos_embedding = nn.Parameter(torch.empty(1, seq_length, hidden_dim).normal_(std=0.02)) # from BERT
|
125 |
+
self.dropout = nn.Dropout(dropout)
|
126 |
+
layers: OrderedDict[str, nn.Module] = OrderedDict()
|
127 |
+
for i in range(num_layers):
|
128 |
+
layers[f"encoder_layer_{i}"] = EncoderBlock(
|
129 |
+
num_heads,
|
130 |
+
hidden_dim,
|
131 |
+
mlp_dim,
|
132 |
+
dropout,
|
133 |
+
attention_dropout,
|
134 |
+
norm_layer,
|
135 |
+
)
|
136 |
+
self.layers = nn.Sequential(layers)
|
137 |
+
self.ln = norm_layer(hidden_dim)
|
138 |
+
|
139 |
+
def forward(self, input: torch.Tensor, freqs_cis):
|
140 |
+
torch._assert(input.dim() == 3, f"Expected (batch_size, seq_length, hidden_dim) got {input.shape}")
|
141 |
+
input = input # + self.pos_embedding
|
142 |
+
x = self.dropout(input)
|
143 |
+
for l in self.layers:
|
144 |
+
x = checkpoint(l, x, freqs_cis)
|
145 |
+
x = self.ln(x)
|
146 |
+
return x
|
147 |
+
|
148 |
+
|
149 |
+
class ViTEncoder(nn.Module):
|
150 |
+
def __init__(self, arch='vit-b/32'):
|
151 |
+
super().__init__()
|
152 |
+
self.arch = arch
|
153 |
+
|
154 |
+
if self.arch == 'vit-b/32':
|
155 |
+
ch = 768
|
156 |
+
layers = 12
|
157 |
+
heads = 12
|
158 |
+
elif self.arch == 'vit-h/14':
|
159 |
+
ch = 1280
|
160 |
+
layers = 32
|
161 |
+
heads = 16
|
162 |
+
|
163 |
+
self.encoder = Encoder(
|
164 |
+
seq_length=-1,
|
165 |
+
num_layers=layers,
|
166 |
+
num_heads=heads,
|
167 |
+
hidden_dim=ch,
|
168 |
+
mlp_dim=ch*4,
|
169 |
+
dropout=0.0,
|
170 |
+
attention_dropout=0.0,
|
171 |
+
)
|
172 |
+
self.fc_in = nn.Linear(16, ch)
|
173 |
+
self.fc_out = nn.Linear(ch, 256)
|
174 |
+
|
175 |
+
if self.arch == 'vit-b/32':
|
176 |
+
from torchvision.models.vision_transformer import vit_b_32, ViT_B_32_Weights
|
177 |
+
vit = vit_b_32(weights=ViT_B_32_Weights.DEFAULT)
|
178 |
+
elif self.arch == 'vit-h/14':
|
179 |
+
from torchvision.models.vision_transformer import vit_h_14, ViT_H_14_Weights
|
180 |
+
vit = vit_h_14(weights=ViT_H_14_Weights.IMAGENET1K_SWAG_E2E_V1)
|
181 |
+
|
182 |
+
missing_keys, unexpected_keys = self.encoder.load_state_dict(vit.encoder.state_dict(), strict=False)
|
183 |
+
if len(missing_keys) > 0 or len(unexpected_keys) > 0:
|
184 |
+
print(f"ViT Encoder Missing keys: {missing_keys}")
|
185 |
+
print(f"ViT Encoder Unexpected keys: {unexpected_keys}")
|
186 |
+
del vit
|
187 |
+
|
188 |
+
def forward(self, x, freqs_cis):
|
189 |
+
out = self.fc_in(x)
|
190 |
+
out = self.encoder(out, freqs_cis)
|
191 |
+
out = checkpoint(self.fc_out, out)
|
192 |
+
return out
|
193 |
+
|
194 |
+
|
195 |
+
def patchify(x, patch_size=8):
|
196 |
+
if len(x.shape) == 4:
|
197 |
+
bs, c, h, w = x.shape
|
198 |
+
x = einops.rearrange(x, "b c (h p1) (w p2) -> b (c p1 p2) h w", p1=patch_size, p2=patch_size)
|
199 |
+
elif len(x.shape) == 3:
|
200 |
+
c, h, w = x.shape
|
201 |
+
x = einops.rearrange(x, "c (h p1) (w p2) -> (c p1 p2) h w", p1=patch_size, p2=patch_size)
|
202 |
+
return x
|
203 |
+
|
204 |
+
|
205 |
+
def unpatchify(x, patch_size=8):
|
206 |
+
if len(x.shape) == 4:
|
207 |
+
bs, c, h, w = x.shape
|
208 |
+
x = einops.rearrange(x, "b (c p1 p2) h w -> b c (h p1) (w p2)", p1=patch_size, p2=patch_size)
|
209 |
+
elif len(x.shape) == 3:
|
210 |
+
c, h, w = x.shape
|
211 |
+
x = einops.rearrange(x, "(c p1 p2) h w -> c (h p1) (w p2)", p1=patch_size, p2=patch_size)
|
212 |
+
return x
|
213 |
+
|
214 |
+
|
215 |
+
def crop_each_layer(hidden_states, use_layers, list_layer_box, H, W, pos_embedding):
|
216 |
+
token_list = []
|
217 |
+
cos_list, sin_list = [], []
|
218 |
+
for layer_idx in range(hidden_states.shape[1]):
|
219 |
+
if list_layer_box[layer_idx] is None:
|
220 |
+
continue
|
221 |
+
else:
|
222 |
+
x1, y1, x2, y2 = list_layer_box[layer_idx]
|
223 |
+
x1, y1, x2, y2 = x1 // 8, y1 // 8, x2 // 8, y2 // 8
|
224 |
+
layer_token = hidden_states[:, layer_idx, y1:y2, x1:x2]
|
225 |
+
c, h, w = layer_token.shape
|
226 |
+
layer_token = layer_token.reshape(c, -1)
|
227 |
+
token_list.append(layer_token)
|
228 |
+
ids = prepare_latent_image_ids(-1, H * 2, W * 2, hidden_states.device, hidden_states.dtype)
|
229 |
+
ids[:, 0] = use_layers[layer_idx]
|
230 |
+
image_rotary_emb = pos_embedding(ids)
|
231 |
+
pos_cos, pos_sin = image_rotary_emb[0].reshape(H, W, -1), image_rotary_emb[1].reshape(H, W, -1)
|
232 |
+
cos_list.append(pos_cos[y1:y2, x1:x2].reshape(-1, 64))
|
233 |
+
sin_list.append(pos_sin[y1:y2, x1:x2].reshape(-1, 64))
|
234 |
+
token_list = torch.cat(token_list, dim=1).permute(1, 0)
|
235 |
+
cos_list = torch.cat(cos_list, dim=0)
|
236 |
+
sin_list = torch.cat(sin_list, dim=0)
|
237 |
+
return token_list, (cos_list, sin_list)
|
238 |
+
|
239 |
+
|
240 |
+
def prepare_latent_image_ids(batch_size, height, width, device, dtype):
|
241 |
+
latent_image_ids = torch.zeros(height // 2, width // 2, 3)
|
242 |
+
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
|
243 |
+
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]
|
244 |
+
|
245 |
+
latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
|
246 |
+
|
247 |
+
latent_image_ids = latent_image_ids.reshape(
|
248 |
+
latent_image_id_height * latent_image_id_width, latent_image_id_channels
|
249 |
+
)
|
250 |
+
|
251 |
+
return latent_image_ids.to(device=device, dtype=dtype)
|
252 |
+
|
253 |
+
|
254 |
+
class AutoencoderKLTransformerTraining(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
255 |
+
def __init__(self):
|
256 |
+
super().__init__()
|
257 |
+
|
258 |
+
self.decoder_arch = 'vit'
|
259 |
+
self.layer_embedding = 'rope'
|
260 |
+
|
261 |
+
self.decoder = ViTEncoder()
|
262 |
+
self.pos_embedding = FluxPosEmbed(theta=10000, axes_dim=(8, 28, 28))
|
263 |
+
if 'rel' in self.layer_embedding or 'abs' in self.layer_embedding:
|
264 |
+
self.layer_embedding = nn.Parameter(torch.empty(16, 2 + self.max_layers, 1, 1).normal_(std=0.02), requires_grad=True)
|
265 |
+
|
266 |
+
def zero_module(module):
|
267 |
+
"""
|
268 |
+
Zero out the parameters of a module and return it.
|
269 |
+
"""
|
270 |
+
for p in module.parameters():
|
271 |
+
p.detach().zero_()
|
272 |
+
return module
|
273 |
+
|
274 |
+
def encode(self, z_2d, box, use_layers):
|
275 |
+
B, C, T, H, W = z_2d.shape
|
276 |
+
|
277 |
+
z, freqs_cis = [], []
|
278 |
+
for b in range(B):
|
279 |
+
_z = z_2d[b]
|
280 |
+
if 'vit' in self.decoder_arch:
|
281 |
+
_use_layers = torch.tensor(use_layers[b], device=z_2d.device)
|
282 |
+
if 'rel' in self.layer_embedding:
|
283 |
+
_use_layers[_use_layers > 2] = 2
|
284 |
+
if 'rel' in self.layer_embedding or 'abs' in self.layer_embedding:
|
285 |
+
_z = _z + self.layer_embedding[:, _use_layers] # + self.pos_embedding
|
286 |
+
if 'rope' not in self.layer_embedding:
|
287 |
+
use_layers[b] = [0] * len(use_layers[b])
|
288 |
+
_z, cis = crop_each_layer(_z, use_layers[b], box[b], H, W, self.pos_embedding) ### modified
|
289 |
+
z.append(_z)
|
290 |
+
freqs_cis.append(cis)
|
291 |
+
|
292 |
+
return z, freqs_cis
|
293 |
+
|
294 |
+
def decode(self, z, freqs_cis, box, H, W):
|
295 |
+
B = len(z)
|
296 |
+
pad = torch.zeros(4, H, W, device=z[0].device, dtype=z[0].dtype)
|
297 |
+
pad[3, :, :] = -1
|
298 |
+
x = []
|
299 |
+
for b in range(B):
|
300 |
+
_x = []
|
301 |
+
_z = self.decoder(z[b].unsqueeze(0), freqs_cis[b]).squeeze(0)
|
302 |
+
current_index = 0
|
303 |
+
for layer_idx in range(len(box[b])):
|
304 |
+
if box[b][layer_idx] == None:
|
305 |
+
_x.append(pad.clone())
|
306 |
+
else:
|
307 |
+
x1, y1, x2, y2 = box[b][layer_idx]
|
308 |
+
x1_tok, y1_tok, x2_tok, y2_tok = x1 // 8, y1 // 8, x2 // 8, y2 // 8
|
309 |
+
token_length = (x2_tok - x1_tok) * (y2_tok - y1_tok)
|
310 |
+
tokens = _z[current_index:current_index + token_length]
|
311 |
+
pixels = einops.rearrange(tokens, "(h w) c -> c h w", h=y2_tok - y1_tok, w=x2_tok - x1_tok)
|
312 |
+
unpatched = unpatchify(pixels)
|
313 |
+
pixels = pad.clone()
|
314 |
+
pixels[:, y1:y2, x1:x2] = unpatched
|
315 |
+
_x.append(pixels)
|
316 |
+
current_index += token_length
|
317 |
+
_x = torch.stack(_x, dim=1)
|
318 |
+
x.append(_x)
|
319 |
+
x = torch.stack(x, dim=0)
|
320 |
+
return x
|
321 |
+
|
322 |
+
def forward(self, z_2d, box, use_layers=None):
|
323 |
+
z_2d = z_2d.transpose(0, 1).unsqueeze(0)
|
324 |
+
use_layers = use_layers or [list(range(z_2d.shape[2]))]
|
325 |
+
z, freqs_cis = self.encode(z_2d, box, use_layers)
|
326 |
+
H, W = z_2d.shape[-2:]
|
327 |
+
x_hat = self.decode(z, freqs_cis, box, H * 8, W * 8)
|
328 |
+
assert x_hat.shape[0] == 1, x_hat.shape
|
329 |
+
x_hat = einops.rearrange(x_hat[0], "c t h w -> t c h w")
|
330 |
+
x_hat_rgb, x_hat_alpha = x_hat[:, :3], x_hat[:, 3:]
|
331 |
+
return x_hat_rgb, x_hat_alpha
|
custom_pipeline.py
ADDED
@@ -0,0 +1,845 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from typing import Any, Callable, Dict, List, Optional, Union
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
|
7 |
+
from diffusers.utils.torch_utils import randn_tensor
|
8 |
+
from diffusers.utils import is_torch_xla_available, logging
|
9 |
+
from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput
|
10 |
+
from diffusers.pipelines.flux.pipeline_flux import calculate_shift, retrieve_timesteps, FluxPipeline
|
11 |
+
|
12 |
+
if is_torch_xla_available():
|
13 |
+
import torch_xla.core.xla_model as xm # type: ignore
|
14 |
+
XLA_AVAILABLE = True
|
15 |
+
else:
|
16 |
+
XLA_AVAILABLE = False
|
17 |
+
|
18 |
+
|
19 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
20 |
+
|
21 |
+
|
22 |
+
def _get_clip_prompt_embeds(
|
23 |
+
tokenizer,
|
24 |
+
text_encoder,
|
25 |
+
prompt: Union[str, List[str]],
|
26 |
+
num_images_per_prompt: int = 1,
|
27 |
+
device: Optional[torch.device] = None,
|
28 |
+
):
|
29 |
+
device = device or text_encoder.device
|
30 |
+
dtype = text_encoder.dtype
|
31 |
+
|
32 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
33 |
+
batch_size = len(prompt)
|
34 |
+
|
35 |
+
text_inputs = tokenizer(
|
36 |
+
prompt,
|
37 |
+
padding="max_length",
|
38 |
+
max_length=text_encoder.config.max_position_embeddings,
|
39 |
+
truncation=True,
|
40 |
+
return_overflowing_tokens=False,
|
41 |
+
return_length=False,
|
42 |
+
return_tensors="pt",
|
43 |
+
)
|
44 |
+
|
45 |
+
text_input_ids = text_inputs.input_ids
|
46 |
+
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=False)
|
47 |
+
|
48 |
+
# Use pooled output of CLIPTextModel
|
49 |
+
prompt_embeds = prompt_embeds.pooler_output
|
50 |
+
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
51 |
+
|
52 |
+
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
53 |
+
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
54 |
+
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
|
55 |
+
|
56 |
+
return prompt_embeds
|
57 |
+
|
58 |
+
|
59 |
+
def _get_t5_prompt_embeds(
|
60 |
+
tokenizer,
|
61 |
+
text_encoder,
|
62 |
+
prompt: Union[str, List[str]] = None,
|
63 |
+
num_images_per_prompt: int = 1,
|
64 |
+
max_sequence_length: int = 512,
|
65 |
+
device: Optional[torch.device] = None,
|
66 |
+
dtype: Optional[torch.dtype] = None,
|
67 |
+
):
|
68 |
+
device = device or text_encoder.device
|
69 |
+
dtype = dtype or text_encoder.dtype
|
70 |
+
|
71 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
72 |
+
batch_size = len(prompt)
|
73 |
+
|
74 |
+
text_inputs = tokenizer(
|
75 |
+
prompt,
|
76 |
+
padding="max_length",
|
77 |
+
max_length=max_sequence_length,
|
78 |
+
truncation=True,
|
79 |
+
return_length=False,
|
80 |
+
return_overflowing_tokens=False,
|
81 |
+
return_tensors="pt",
|
82 |
+
)
|
83 |
+
text_input_ids = text_inputs.input_ids
|
84 |
+
|
85 |
+
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=False)[0]
|
86 |
+
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
87 |
+
|
88 |
+
_, seq_len, _ = prompt_embeds.shape
|
89 |
+
|
90 |
+
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
|
91 |
+
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
92 |
+
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
93 |
+
|
94 |
+
return prompt_embeds
|
95 |
+
|
96 |
+
|
97 |
+
def encode_prompt(
|
98 |
+
tokenizers,
|
99 |
+
text_encoders,
|
100 |
+
prompt: Union[str, List[str]],
|
101 |
+
prompt_2: Union[str, List[str]] = None,
|
102 |
+
num_images_per_prompt: int = 1,
|
103 |
+
max_sequence_length: int = 512,
|
104 |
+
):
|
105 |
+
|
106 |
+
tokenizer_1, tokenizer_2 = tokenizers
|
107 |
+
text_encoder_1, text_encoder_2 = text_encoders
|
108 |
+
device = text_encoder_1.device
|
109 |
+
dtype = text_encoder_1.dtype
|
110 |
+
|
111 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
112 |
+
prompt_2 = prompt_2 or prompt
|
113 |
+
prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
|
114 |
+
|
115 |
+
# We only use the pooled prompt output from the CLIPTextModel
|
116 |
+
pooled_prompt_embeds = _get_clip_prompt_embeds(
|
117 |
+
tokenizer=tokenizer_1,
|
118 |
+
text_encoder=text_encoder_1,
|
119 |
+
prompt=prompt,
|
120 |
+
num_images_per_prompt=num_images_per_prompt,
|
121 |
+
)
|
122 |
+
prompt_embeds = _get_t5_prompt_embeds(
|
123 |
+
tokenizer=tokenizer_2,
|
124 |
+
text_encoder=text_encoder_2,
|
125 |
+
prompt=prompt_2,
|
126 |
+
num_images_per_prompt=num_images_per_prompt,
|
127 |
+
max_sequence_length=max_sequence_length,
|
128 |
+
)
|
129 |
+
|
130 |
+
text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
|
131 |
+
|
132 |
+
return prompt_embeds, pooled_prompt_embeds, text_ids
|
133 |
+
|
134 |
+
|
135 |
+
class CustomFluxPipeline(FluxPipeline):
|
136 |
+
|
137 |
+
@staticmethod
|
138 |
+
def _prepare_latent_image_ids(height, width, list_layer_box, device, dtype):
|
139 |
+
|
140 |
+
latent_image_ids_list = []
|
141 |
+
for layer_idx in range(len(list_layer_box)):
|
142 |
+
if list_layer_box[layer_idx] == None:
|
143 |
+
continue
|
144 |
+
else:
|
145 |
+
latent_image_ids = torch.zeros(height // 2, width // 2, 3) # [h/2, w/2, 3]
|
146 |
+
latent_image_ids[..., 0] = layer_idx # use the first dimension for layer representation
|
147 |
+
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
|
148 |
+
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]
|
149 |
+
|
150 |
+
x1, y1, x2, y2 = list_layer_box[layer_idx]
|
151 |
+
x1, y1, x2, y2 = x1 // 16, y1 // 16, x2 // 16, y2 // 16
|
152 |
+
latent_image_ids = latent_image_ids[y1:y2, x1:x2, :]
|
153 |
+
|
154 |
+
latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
|
155 |
+
latent_image_ids = latent_image_ids.reshape(
|
156 |
+
latent_image_id_height * latent_image_id_width, latent_image_id_channels
|
157 |
+
)
|
158 |
+
|
159 |
+
latent_image_ids_list.append(latent_image_ids)
|
160 |
+
|
161 |
+
full_latent_image_ids = torch.cat(latent_image_ids_list, dim=0)
|
162 |
+
|
163 |
+
return full_latent_image_ids.to(device=device, dtype=dtype)
|
164 |
+
|
165 |
+
def prepare_latents(
|
166 |
+
self,
|
167 |
+
batch_size,
|
168 |
+
num_layers,
|
169 |
+
num_channels_latents,
|
170 |
+
height,
|
171 |
+
width,
|
172 |
+
list_layer_box,
|
173 |
+
dtype,
|
174 |
+
device,
|
175 |
+
generator,
|
176 |
+
latents=None,
|
177 |
+
):
|
178 |
+
height = 2 * (int(height) // self.vae_scale_factor)
|
179 |
+
width = 2 * (int(width) // self.vae_scale_factor)
|
180 |
+
|
181 |
+
shape = (batch_size, num_layers, num_channels_latents, height, width)
|
182 |
+
|
183 |
+
if latents is not None:
|
184 |
+
latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype)
|
185 |
+
return latents.to(device=device, dtype=dtype), latent_image_ids
|
186 |
+
|
187 |
+
if isinstance(generator, list) and len(generator) != batch_size:
|
188 |
+
raise ValueError(
|
189 |
+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
190 |
+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
191 |
+
)
|
192 |
+
|
193 |
+
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) # [bs, f, c_latent, h, w]
|
194 |
+
|
195 |
+
latent_image_ids = self._prepare_latent_image_ids(height, width, list_layer_box, device, dtype)
|
196 |
+
|
197 |
+
return latents, latent_image_ids
|
198 |
+
|
199 |
+
@torch.no_grad()
|
200 |
+
def __call__(
|
201 |
+
self,
|
202 |
+
prompt: Union[str, List[str]] = None,
|
203 |
+
prompt_2: Optional[Union[str, List[str]]] = None,
|
204 |
+
validation_box: List[tuple] = None,
|
205 |
+
height: Optional[int] = None,
|
206 |
+
width: Optional[int] = None,
|
207 |
+
num_inference_steps: int = 28,
|
208 |
+
timesteps: List[int] = None,
|
209 |
+
guidance_scale: float = 3.5,
|
210 |
+
num_images_per_prompt: Optional[int] = 1,
|
211 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
212 |
+
latents: Optional[torch.FloatTensor] = None,
|
213 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
214 |
+
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
215 |
+
output_type: Optional[str] = "pil",
|
216 |
+
return_dict: bool = True,
|
217 |
+
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
218 |
+
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
219 |
+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
220 |
+
max_sequence_length: int = 512,
|
221 |
+
num_layers: int = 5,
|
222 |
+
sdxl_vae: nn.Module = None,
|
223 |
+
transparent_decoder: nn.Module = None,
|
224 |
+
):
|
225 |
+
r"""
|
226 |
+
Function invoked when calling the pipeline for generation.
|
227 |
+
|
228 |
+
Args:
|
229 |
+
prompt (`str` or `List[str]`, *optional*):
|
230 |
+
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
231 |
+
instead.
|
232 |
+
prompt_2 (`str` or `List[str]`, *optional*):
|
233 |
+
The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
|
234 |
+
will be used instead
|
235 |
+
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
236 |
+
The height in pixels of the generated image. This is set to 1024 by default for the best results.
|
237 |
+
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
238 |
+
The width in pixels of the generated image. This is set to 1024 by default for the best results.
|
239 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
240 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
241 |
+
expense of slower inference.
|
242 |
+
timesteps (`List[int]`, *optional*):
|
243 |
+
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
|
244 |
+
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
|
245 |
+
passed will be used. Must be in descending order.
|
246 |
+
guidance_scale (`float`, *optional*, defaults to 7.0):
|
247 |
+
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
248 |
+
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
249 |
+
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
250 |
+
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
251 |
+
usually at the expense of lower image quality.
|
252 |
+
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
253 |
+
The number of images to generate per prompt.
|
254 |
+
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
255 |
+
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
256 |
+
to make generation deterministic.
|
257 |
+
latents (`torch.FloatTensor`, *optional*):
|
258 |
+
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
259 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
260 |
+
tensor will ge generated by sampling using the supplied random `generator`.
|
261 |
+
prompt_embeds (`torch.FloatTensor`, *optional*):
|
262 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
263 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
264 |
+
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
|
265 |
+
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
|
266 |
+
If not provided, pooled text embeddings will be generated from `prompt` input argument.
|
267 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
268 |
+
The output format of the generate image. Choose between
|
269 |
+
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
270 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
271 |
+
Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple.
|
272 |
+
joint_attention_kwargs (`dict`, *optional*):
|
273 |
+
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
274 |
+
`self.processor` in
|
275 |
+
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
276 |
+
callback_on_step_end (`Callable`, *optional*):
|
277 |
+
A function that calls at the end of each denoising steps during the inference. The function is called
|
278 |
+
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
279 |
+
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
|
280 |
+
`callback_on_step_end_tensor_inputs`.
|
281 |
+
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
282 |
+
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
283 |
+
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
284 |
+
`._callback_tensor_inputs` attribute of your pipeline class.
|
285 |
+
max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
|
286 |
+
|
287 |
+
Examples:
|
288 |
+
|
289 |
+
Returns:
|
290 |
+
[`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict`
|
291 |
+
is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated
|
292 |
+
images.
|
293 |
+
"""
|
294 |
+
|
295 |
+
height = height or self.default_sample_size * self.vae_scale_factor
|
296 |
+
width = width or self.default_sample_size * self.vae_scale_factor
|
297 |
+
|
298 |
+
# 1. Check inputs. Raise error if not correct
|
299 |
+
self.check_inputs(
|
300 |
+
prompt,
|
301 |
+
prompt_2,
|
302 |
+
height,
|
303 |
+
width,
|
304 |
+
prompt_embeds=prompt_embeds,
|
305 |
+
pooled_prompt_embeds=pooled_prompt_embeds,
|
306 |
+
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
|
307 |
+
max_sequence_length=max_sequence_length,
|
308 |
+
)
|
309 |
+
|
310 |
+
self._guidance_scale = guidance_scale
|
311 |
+
self._joint_attention_kwargs = joint_attention_kwargs
|
312 |
+
self._interrupt = False
|
313 |
+
|
314 |
+
# 2. Define call parameters
|
315 |
+
if prompt is not None and isinstance(prompt, str):
|
316 |
+
batch_size = 1
|
317 |
+
elif prompt is not None and isinstance(prompt, list):
|
318 |
+
batch_size = len(prompt)
|
319 |
+
else:
|
320 |
+
batch_size = prompt_embeds.shape[0]
|
321 |
+
|
322 |
+
device = self._execution_device
|
323 |
+
|
324 |
+
lora_scale = (
|
325 |
+
self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
|
326 |
+
)
|
327 |
+
(
|
328 |
+
prompt_embeds,
|
329 |
+
pooled_prompt_embeds,
|
330 |
+
text_ids,
|
331 |
+
) = self.encode_prompt(
|
332 |
+
prompt=prompt,
|
333 |
+
prompt_2=prompt_2,
|
334 |
+
prompt_embeds=prompt_embeds,
|
335 |
+
pooled_prompt_embeds=pooled_prompt_embeds,
|
336 |
+
device=device,
|
337 |
+
num_images_per_prompt=num_images_per_prompt,
|
338 |
+
max_sequence_length=max_sequence_length,
|
339 |
+
lora_scale=lora_scale,
|
340 |
+
)
|
341 |
+
|
342 |
+
# 4. Prepare latent variables
|
343 |
+
num_channels_latents = self.transformer.config.in_channels // 4
|
344 |
+
latents, latent_image_ids = self.prepare_latents(
|
345 |
+
batch_size * num_images_per_prompt,
|
346 |
+
num_layers,
|
347 |
+
num_channels_latents,
|
348 |
+
height,
|
349 |
+
width,
|
350 |
+
validation_box,
|
351 |
+
prompt_embeds.dtype,
|
352 |
+
device,
|
353 |
+
generator,
|
354 |
+
latents,
|
355 |
+
)
|
356 |
+
|
357 |
+
# 5. Prepare timesteps
|
358 |
+
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
|
359 |
+
image_seq_len = latent_image_ids.shape[0] # ???
|
360 |
+
mu = calculate_shift(
|
361 |
+
image_seq_len,
|
362 |
+
self.scheduler.config.base_image_seq_len,
|
363 |
+
self.scheduler.config.max_image_seq_len,
|
364 |
+
self.scheduler.config.base_shift,
|
365 |
+
self.scheduler.config.max_shift,
|
366 |
+
)
|
367 |
+
timesteps, num_inference_steps = retrieve_timesteps(
|
368 |
+
self.scheduler,
|
369 |
+
num_inference_steps,
|
370 |
+
device,
|
371 |
+
timesteps,
|
372 |
+
sigmas,
|
373 |
+
mu=mu,
|
374 |
+
)
|
375 |
+
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
376 |
+
self._num_timesteps = len(timesteps)
|
377 |
+
|
378 |
+
# handle guidance
|
379 |
+
if self.transformer.config.guidance_embeds:
|
380 |
+
guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
|
381 |
+
guidance = guidance.expand(latents.shape[0])
|
382 |
+
else:
|
383 |
+
guidance = None
|
384 |
+
|
385 |
+
# 6. Denoising loop
|
386 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
387 |
+
for i, t in enumerate(timesteps):
|
388 |
+
if self.interrupt:
|
389 |
+
continue
|
390 |
+
|
391 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
392 |
+
timestep = t.expand(latents.shape[0]).to(latents.dtype)
|
393 |
+
|
394 |
+
noise_pred = self.transformer(
|
395 |
+
hidden_states=latents,
|
396 |
+
list_layer_box=validation_box,
|
397 |
+
timestep=timestep / 1000,
|
398 |
+
guidance=guidance,
|
399 |
+
pooled_projections=pooled_prompt_embeds,
|
400 |
+
encoder_hidden_states=prompt_embeds,
|
401 |
+
txt_ids=text_ids,
|
402 |
+
img_ids=latent_image_ids,
|
403 |
+
joint_attention_kwargs=self.joint_attention_kwargs,
|
404 |
+
return_dict=False,
|
405 |
+
)[0]
|
406 |
+
|
407 |
+
# compute the previous noisy sample x_t -> x_t-1
|
408 |
+
latents_dtype = latents.dtype
|
409 |
+
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
410 |
+
|
411 |
+
if latents.dtype != latents_dtype:
|
412 |
+
if torch.backends.mps.is_available():
|
413 |
+
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
|
414 |
+
latents = latents.to(latents_dtype)
|
415 |
+
|
416 |
+
if callback_on_step_end is not None:
|
417 |
+
callback_kwargs = {}
|
418 |
+
for k in callback_on_step_end_tensor_inputs:
|
419 |
+
callback_kwargs[k] = locals()[k]
|
420 |
+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
421 |
+
|
422 |
+
latents = callback_outputs.pop("latents", latents)
|
423 |
+
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
424 |
+
|
425 |
+
# call the callback, if provided
|
426 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
427 |
+
progress_bar.update()
|
428 |
+
|
429 |
+
if XLA_AVAILABLE:
|
430 |
+
xm.mark_step()
|
431 |
+
|
432 |
+
# create a grey latent
|
433 |
+
bs, n_frames, channel_latent, height, width = latents.shape
|
434 |
+
|
435 |
+
pixel_grey = torch.zeros(size=(bs*n_frames, 3, height*8, width*8), device=latents.device, dtype=latents.dtype)
|
436 |
+
latent_grey = self.vae.encode(pixel_grey).latent_dist.sample()
|
437 |
+
latent_grey = (latent_grey - self.vae.config.shift_factor) * self.vae.config.scaling_factor
|
438 |
+
latent_grey = latent_grey.view(bs, n_frames, channel_latent, height, width) # [bs, f, c_latent, h, w]
|
439 |
+
|
440 |
+
# fill in the latents
|
441 |
+
for layer_idx in range(latent_grey.shape[1]):
|
442 |
+
x1, y1, x2, y2 = validation_box[layer_idx]
|
443 |
+
x1, y1, x2, y2 = x1 // 8, y1 // 8, x2 // 8, y2 // 8
|
444 |
+
latent_grey[:, layer_idx, :, y1:y2, x1:x2] = latents[:, layer_idx, :, y1:y2, x1:x2]
|
445 |
+
latents = latent_grey
|
446 |
+
|
447 |
+
if output_type == "latent":
|
448 |
+
image = latents
|
449 |
+
|
450 |
+
else:
|
451 |
+
latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
|
452 |
+
latents = latents.reshape(bs * n_frames, channel_latent, height, width)
|
453 |
+
image = self.vae.decode(latents, return_dict=False)[0]
|
454 |
+
if sdxl_vae is not None:
|
455 |
+
sdxl_vae = sdxl_vae.to(dtype=image.dtype, device=image.device)
|
456 |
+
sdxl_latents = sdxl_vae.encode(image).latent_dist.sample()
|
457 |
+
transparent_decoder = transparent_decoder.to(dtype=image.dtype, device=image.device)
|
458 |
+
result_list, vis_list = transparent_decoder(sdxl_vae, sdxl_latents)
|
459 |
+
else:
|
460 |
+
result_list, vis_list = None, None
|
461 |
+
image = self.image_processor.postprocess(image, output_type=output_type)
|
462 |
+
|
463 |
+
# Offload all models
|
464 |
+
self.maybe_free_model_hooks()
|
465 |
+
|
466 |
+
if not return_dict:
|
467 |
+
return (image, result_list, vis_list)
|
468 |
+
|
469 |
+
return FluxPipelineOutput(images=image), result_list, vis_list
|
470 |
+
|
471 |
+
|
472 |
+
class CustomFluxPipelineCfg(FluxPipeline):
|
473 |
+
|
474 |
+
@staticmethod
|
475 |
+
def _prepare_latent_image_ids(height, width, list_layer_box, device, dtype):
|
476 |
+
|
477 |
+
latent_image_ids_list = []
|
478 |
+
for layer_idx in range(len(list_layer_box)):
|
479 |
+
if list_layer_box[layer_idx] == None:
|
480 |
+
continue
|
481 |
+
else:
|
482 |
+
latent_image_ids = torch.zeros(height // 2, width // 2, 3) # [h/2, w/2, 3]
|
483 |
+
latent_image_ids[..., 0] = layer_idx # use the first dimension for layer representation
|
484 |
+
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
|
485 |
+
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]
|
486 |
+
|
487 |
+
x1, y1, x2, y2 = list_layer_box[layer_idx]
|
488 |
+
x1, y1, x2, y2 = x1 // 16, y1 // 16, x2 // 16, y2 // 16
|
489 |
+
latent_image_ids = latent_image_ids[y1:y2, x1:x2, :]
|
490 |
+
|
491 |
+
latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
|
492 |
+
latent_image_ids = latent_image_ids.reshape(
|
493 |
+
latent_image_id_height * latent_image_id_width, latent_image_id_channels
|
494 |
+
)
|
495 |
+
|
496 |
+
latent_image_ids_list.append(latent_image_ids)
|
497 |
+
|
498 |
+
full_latent_image_ids = torch.cat(latent_image_ids_list, dim=0)
|
499 |
+
|
500 |
+
return full_latent_image_ids.to(device=device, dtype=dtype)
|
501 |
+
|
502 |
+
def prepare_latents(
|
503 |
+
self,
|
504 |
+
batch_size,
|
505 |
+
num_layers,
|
506 |
+
num_channels_latents,
|
507 |
+
height,
|
508 |
+
width,
|
509 |
+
list_layer_box,
|
510 |
+
dtype,
|
511 |
+
device,
|
512 |
+
generator,
|
513 |
+
latents=None,
|
514 |
+
):
|
515 |
+
height = 2 * (int(height) // self.vae_scale_factor)
|
516 |
+
width = 2 * (int(width) // self.vae_scale_factor)
|
517 |
+
|
518 |
+
shape = (batch_size, num_layers, num_channels_latents, height, width)
|
519 |
+
|
520 |
+
if latents is not None:
|
521 |
+
latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype)
|
522 |
+
return latents.to(device=device, dtype=dtype), latent_image_ids
|
523 |
+
|
524 |
+
if isinstance(generator, list) and len(generator) != batch_size:
|
525 |
+
raise ValueError(
|
526 |
+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
527 |
+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
528 |
+
)
|
529 |
+
|
530 |
+
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) # [bs, n_layers, c_latent, h, w]
|
531 |
+
|
532 |
+
latent_image_ids = self._prepare_latent_image_ids(height, width, list_layer_box, device, dtype)
|
533 |
+
|
534 |
+
return latents, latent_image_ids
|
535 |
+
|
536 |
+
@torch.no_grad()
|
537 |
+
def __call__(
|
538 |
+
self,
|
539 |
+
prompt: Union[str, List[str]] = None,
|
540 |
+
prompt_2: Optional[Union[str, List[str]]] = None,
|
541 |
+
validation_box: List[tuple] = None,
|
542 |
+
height: Optional[int] = None,
|
543 |
+
width: Optional[int] = None,
|
544 |
+
num_inference_steps: int = 28,
|
545 |
+
timesteps: List[int] = None,
|
546 |
+
guidance_scale: float = 3.5,
|
547 |
+
true_gs: float = 3.5,
|
548 |
+
num_images_per_prompt: Optional[int] = 1,
|
549 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
550 |
+
latents: Optional[torch.FloatTensor] = None,
|
551 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
552 |
+
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
553 |
+
output_type: Optional[str] = "pil",
|
554 |
+
return_dict: bool = True,
|
555 |
+
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
556 |
+
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
557 |
+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
558 |
+
max_sequence_length: int = 512,
|
559 |
+
num_layers: int = 5,
|
560 |
+
transparent_decoder: nn.Module = None,
|
561 |
+
):
|
562 |
+
r"""
|
563 |
+
Function invoked when calling the pipeline for generation.
|
564 |
+
|
565 |
+
Args:
|
566 |
+
prompt (`str` or `List[str]`, *optional*):
|
567 |
+
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
568 |
+
instead.
|
569 |
+
prompt_2 (`str` or `List[str]`, *optional*):
|
570 |
+
The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
|
571 |
+
will be used instead
|
572 |
+
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
573 |
+
The height in pixels of the generated image. This is set to 1024 by default for the best results.
|
574 |
+
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
575 |
+
The width in pixels of the generated image. This is set to 1024 by default for the best results.
|
576 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
577 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
578 |
+
expense of slower inference.
|
579 |
+
timesteps (`List[int]`, *optional*):
|
580 |
+
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
|
581 |
+
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
|
582 |
+
passed will be used. Must be in descending order.
|
583 |
+
guidance_scale (`float`, *optional*, defaults to 7.0):
|
584 |
+
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
585 |
+
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
586 |
+
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
587 |
+
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
588 |
+
usually at the expense of lower image quality.
|
589 |
+
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
590 |
+
The number of images to generate per prompt.
|
591 |
+
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
592 |
+
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
593 |
+
to make generation deterministic.
|
594 |
+
latents (`torch.FloatTensor`, *optional*):
|
595 |
+
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
596 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
597 |
+
tensor will ge generated by sampling using the supplied random `generator`.
|
598 |
+
prompt_embeds (`torch.FloatTensor`, *optional*):
|
599 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
600 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
601 |
+
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
|
602 |
+
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
|
603 |
+
If not provided, pooled text embeddings will be generated from `prompt` input argument.
|
604 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
605 |
+
The output format of the generate image. Choose between
|
606 |
+
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
607 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
608 |
+
Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple.
|
609 |
+
joint_attention_kwargs (`dict`, *optional*):
|
610 |
+
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
611 |
+
`self.processor` in
|
612 |
+
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
613 |
+
callback_on_step_end (`Callable`, *optional*):
|
614 |
+
A function that calls at the end of each denoising steps during the inference. The function is called
|
615 |
+
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
616 |
+
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
|
617 |
+
`callback_on_step_end_tensor_inputs`.
|
618 |
+
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
619 |
+
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
620 |
+
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
621 |
+
`._callback_tensor_inputs` attribute of your pipeline class.
|
622 |
+
max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
|
623 |
+
|
624 |
+
Examples:
|
625 |
+
|
626 |
+
Returns:
|
627 |
+
[`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict`
|
628 |
+
is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated
|
629 |
+
images.
|
630 |
+
"""
|
631 |
+
|
632 |
+
height = height or self.default_sample_size * self.vae_scale_factor
|
633 |
+
width = width or self.default_sample_size * self.vae_scale_factor
|
634 |
+
|
635 |
+
# 1. Check inputs. Raise error if not correct
|
636 |
+
self.check_inputs(
|
637 |
+
prompt,
|
638 |
+
prompt_2,
|
639 |
+
height,
|
640 |
+
width,
|
641 |
+
prompt_embeds=prompt_embeds,
|
642 |
+
pooled_prompt_embeds=pooled_prompt_embeds,
|
643 |
+
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
|
644 |
+
max_sequence_length=max_sequence_length,
|
645 |
+
)
|
646 |
+
|
647 |
+
self._guidance_scale = guidance_scale
|
648 |
+
self._joint_attention_kwargs = joint_attention_kwargs
|
649 |
+
self._interrupt = False
|
650 |
+
|
651 |
+
# 2. Define call parameters
|
652 |
+
if prompt is not None and isinstance(prompt, str):
|
653 |
+
batch_size = 1
|
654 |
+
elif prompt is not None and isinstance(prompt, list):
|
655 |
+
batch_size = len(prompt)
|
656 |
+
else:
|
657 |
+
batch_size = prompt_embeds.shape[0]
|
658 |
+
|
659 |
+
device = self._execution_device
|
660 |
+
|
661 |
+
lora_scale = (
|
662 |
+
self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
|
663 |
+
)
|
664 |
+
(
|
665 |
+
prompt_embeds,
|
666 |
+
pooled_prompt_embeds,
|
667 |
+
text_ids,
|
668 |
+
) = self.encode_prompt(
|
669 |
+
prompt=prompt,
|
670 |
+
prompt_2=prompt_2,
|
671 |
+
prompt_embeds=prompt_embeds,
|
672 |
+
pooled_prompt_embeds=pooled_prompt_embeds,
|
673 |
+
device=device,
|
674 |
+
num_images_per_prompt=num_images_per_prompt,
|
675 |
+
max_sequence_length=max_sequence_length,
|
676 |
+
lora_scale=lora_scale,
|
677 |
+
)
|
678 |
+
(
|
679 |
+
neg_prompt_embeds,
|
680 |
+
neg_pooled_prompt_embeds,
|
681 |
+
neg_text_ids,
|
682 |
+
) = self.encode_prompt(
|
683 |
+
prompt="",
|
684 |
+
prompt_2=None,
|
685 |
+
device=device,
|
686 |
+
num_images_per_prompt=num_images_per_prompt,
|
687 |
+
max_sequence_length=max_sequence_length,
|
688 |
+
lora_scale=lora_scale,
|
689 |
+
)
|
690 |
+
|
691 |
+
# 4. Prepare latent variables
|
692 |
+
num_channels_latents = self.transformer.config.in_channels // 4
|
693 |
+
latents, latent_image_ids = self.prepare_latents(
|
694 |
+
batch_size * num_images_per_prompt,
|
695 |
+
num_layers,
|
696 |
+
num_channels_latents,
|
697 |
+
height,
|
698 |
+
width,
|
699 |
+
validation_box,
|
700 |
+
prompt_embeds.dtype,
|
701 |
+
device,
|
702 |
+
generator,
|
703 |
+
latents,
|
704 |
+
)
|
705 |
+
|
706 |
+
# 5. Prepare timesteps
|
707 |
+
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
|
708 |
+
image_seq_len = latent_image_ids.shape[0]
|
709 |
+
mu = calculate_shift(
|
710 |
+
image_seq_len,
|
711 |
+
self.scheduler.config.base_image_seq_len,
|
712 |
+
self.scheduler.config.max_image_seq_len,
|
713 |
+
self.scheduler.config.base_shift,
|
714 |
+
self.scheduler.config.max_shift,
|
715 |
+
)
|
716 |
+
timesteps, num_inference_steps = retrieve_timesteps(
|
717 |
+
self.scheduler,
|
718 |
+
num_inference_steps,
|
719 |
+
device,
|
720 |
+
timesteps,
|
721 |
+
sigmas,
|
722 |
+
mu=mu,
|
723 |
+
)
|
724 |
+
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
725 |
+
self._num_timesteps = len(timesteps)
|
726 |
+
|
727 |
+
# handle guidance
|
728 |
+
if self.transformer.config.guidance_embeds:
|
729 |
+
guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
|
730 |
+
guidance = guidance.expand(latents.shape[0])
|
731 |
+
else:
|
732 |
+
guidance = None
|
733 |
+
|
734 |
+
# 6. Denoising loop
|
735 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
736 |
+
for i, t in enumerate(timesteps):
|
737 |
+
if self.interrupt:
|
738 |
+
continue
|
739 |
+
|
740 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
741 |
+
timestep = t.expand(latents.shape[0]).to(latents.dtype)
|
742 |
+
|
743 |
+
noise_pred = self.transformer(
|
744 |
+
hidden_states=latents,
|
745 |
+
list_layer_box=validation_box,
|
746 |
+
timestep=timestep / 1000,
|
747 |
+
guidance=guidance,
|
748 |
+
pooled_projections=pooled_prompt_embeds,
|
749 |
+
encoder_hidden_states=prompt_embeds,
|
750 |
+
txt_ids=text_ids,
|
751 |
+
img_ids=latent_image_ids,
|
752 |
+
joint_attention_kwargs=self.joint_attention_kwargs,
|
753 |
+
return_dict=False,
|
754 |
+
)[0]
|
755 |
+
|
756 |
+
neg_noise_pred = self.transformer(
|
757 |
+
hidden_states=latents,
|
758 |
+
list_layer_box=validation_box,
|
759 |
+
timestep=timestep / 1000,
|
760 |
+
guidance=guidance,
|
761 |
+
pooled_projections=neg_pooled_prompt_embeds,
|
762 |
+
encoder_hidden_states=neg_prompt_embeds,
|
763 |
+
txt_ids=neg_text_ids,
|
764 |
+
img_ids=latent_image_ids,
|
765 |
+
joint_attention_kwargs=self.joint_attention_kwargs,
|
766 |
+
return_dict=False,
|
767 |
+
)[0]
|
768 |
+
|
769 |
+
noise_pred = neg_noise_pred + true_gs * (noise_pred - neg_noise_pred)
|
770 |
+
|
771 |
+
# compute the previous noisy sample x_t -> x_t-1
|
772 |
+
latents_dtype = latents.dtype
|
773 |
+
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
774 |
+
|
775 |
+
if latents.dtype != latents_dtype:
|
776 |
+
if torch.backends.mps.is_available():
|
777 |
+
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
|
778 |
+
latents = latents.to(latents_dtype)
|
779 |
+
|
780 |
+
if callback_on_step_end is not None:
|
781 |
+
callback_kwargs = {}
|
782 |
+
for k in callback_on_step_end_tensor_inputs:
|
783 |
+
callback_kwargs[k] = locals()[k]
|
784 |
+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
785 |
+
|
786 |
+
latents = callback_outputs.pop("latents", latents)
|
787 |
+
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
788 |
+
|
789 |
+
# call the callback, if provided
|
790 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
791 |
+
progress_bar.update()
|
792 |
+
|
793 |
+
if XLA_AVAILABLE:
|
794 |
+
xm.mark_step()
|
795 |
+
|
796 |
+
# create a grey latent
|
797 |
+
bs, n_layers, channel_latent, height, width = latents.shape
|
798 |
+
|
799 |
+
pixel_grey = torch.zeros(size=(bs*n_layers, 3, height*8, width*8), device=latents.device, dtype=latents.dtype)
|
800 |
+
latent_grey = self.vae.encode(pixel_grey).latent_dist.sample()
|
801 |
+
latent_grey = (latent_grey - self.vae.config.shift_factor) * self.vae.config.scaling_factor
|
802 |
+
latent_grey = latent_grey.view(bs, n_layers, channel_latent, height, width) # [bs, n_layers, c_latent, h, w]
|
803 |
+
|
804 |
+
# fill in the latents
|
805 |
+
for layer_idx in range(latent_grey.shape[1]):
|
806 |
+
if validation_box[layer_idx] == None:
|
807 |
+
continue
|
808 |
+
x1, y1, x2, y2 = validation_box[layer_idx]
|
809 |
+
x1, y1, x2, y2 = x1 // 8, y1 // 8, x2 // 8, y2 // 8
|
810 |
+
latent_grey[:, layer_idx, :, y1:y2, x1:x2] = latents[:, layer_idx, :, y1:y2, x1:x2]
|
811 |
+
latents = latent_grey
|
812 |
+
|
813 |
+
if output_type == "latent":
|
814 |
+
image = latents
|
815 |
+
|
816 |
+
else:
|
817 |
+
latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
|
818 |
+
latents = latents.reshape(bs * n_layers, channel_latent, height, width)
|
819 |
+
latents_segs = torch.split(latents, 16, dim=0) ### split latents by 16 to avoid odd purple output
|
820 |
+
image_segs = [self.vae.decode(latents_seg, return_dict=False)[0] for latents_seg in latents_segs]
|
821 |
+
image = torch.cat(image_segs, dim=0)
|
822 |
+
if transparent_decoder is not None:
|
823 |
+
transparent_decoder = transparent_decoder.to(dtype=image.dtype, device=image.device)
|
824 |
+
|
825 |
+
decoded_fg, decoded_alpha = transparent_decoder(latents, [validation_box])
|
826 |
+
decoded_alpha = (decoded_alpha + 1.0) / 2.0
|
827 |
+
decoded_alpha = torch.clamp(decoded_alpha, min=0.0, max=1.0).permute(0, 2, 3, 1)
|
828 |
+
|
829 |
+
decoded_fg = (decoded_fg + 1.0) / 2.0
|
830 |
+
decoded_fg = torch.clamp(decoded_fg, min=0.0, max=1.0).permute(0, 2, 3, 1)
|
831 |
+
|
832 |
+
vis_list = None
|
833 |
+
png = torch.cat([decoded_fg, decoded_alpha], dim=3)
|
834 |
+
result_list = (png * 255.0).detach().cpu().float().numpy().clip(0, 255).astype(np.uint8)
|
835 |
+
else:
|
836 |
+
result_list, vis_list = None, None
|
837 |
+
image = self.image_processor.postprocess(image, output_type=output_type)
|
838 |
+
|
839 |
+
# Offload all models
|
840 |
+
self.maybe_free_model_hooks()
|
841 |
+
|
842 |
+
if not return_dict:
|
843 |
+
return (image, result_list, vis_list, latents)
|
844 |
+
|
845 |
+
return FluxPipelineOutput(images=image), result_list, vis_list, latents
|
modeling_crello.py
ADDED
@@ -0,0 +1,235 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from transformers import PreTrainedModel, PretrainedConfig, AutoModel, AutoModelForCausalLM, OPTForCausalLM
|
3 |
+
# from transformers import BitsAndBytesConfig
|
4 |
+
from torch import nn
|
5 |
+
import os
|
6 |
+
from typing import Optional, List
|
7 |
+
import os
|
8 |
+
|
9 |
+
def kmp_preprocess(pattern):
|
10 |
+
pattern_len = len(pattern)
|
11 |
+
prefix_suffix = [0] * pattern_len
|
12 |
+
j = 0
|
13 |
+
|
14 |
+
for i in range(1, pattern_len):
|
15 |
+
while j > 0 and pattern[i] != pattern[j]:
|
16 |
+
j = prefix_suffix[j - 1]
|
17 |
+
|
18 |
+
if pattern[i] == pattern[j]:
|
19 |
+
j += 1
|
20 |
+
|
21 |
+
prefix_suffix[i] = j
|
22 |
+
|
23 |
+
return prefix_suffix
|
24 |
+
|
25 |
+
def kmp_search(text, pattern):
|
26 |
+
text_len = len(text)
|
27 |
+
pattern_len = len(pattern)
|
28 |
+
prefix_suffix = kmp_preprocess(pattern)
|
29 |
+
matches = []
|
30 |
+
|
31 |
+
j = 0
|
32 |
+
for i in range(text_len):
|
33 |
+
while j > 0 and text[i] != pattern[j]:
|
34 |
+
j = prefix_suffix[j - 1]
|
35 |
+
|
36 |
+
if text[i] == pattern[j]:
|
37 |
+
j += 1
|
38 |
+
|
39 |
+
if j == pattern_len:
|
40 |
+
matches.append(i - j + 1)
|
41 |
+
j = prefix_suffix[j - 1]
|
42 |
+
|
43 |
+
return matches
|
44 |
+
|
45 |
+
class ModelWrapper:
|
46 |
+
def __init__(self, model):
|
47 |
+
self.model = model
|
48 |
+
|
49 |
+
def __getattr__(self, name):
|
50 |
+
return getattr(self.model, name)
|
51 |
+
|
52 |
+
@torch.no_grad()
|
53 |
+
def __call__(self, pixel_values):
|
54 |
+
return self.model(pixel_values)
|
55 |
+
|
56 |
+
def eval(self):
|
57 |
+
pass
|
58 |
+
|
59 |
+
def train(self):
|
60 |
+
pass
|
61 |
+
|
62 |
+
|
63 |
+
def parameters(self):
|
64 |
+
return self.model.parameters()
|
65 |
+
|
66 |
+
|
67 |
+
class CrelloModelConfig(PretrainedConfig):
|
68 |
+
def __init__(
|
69 |
+
self,
|
70 |
+
old_vocab_size: int = 32000,
|
71 |
+
vocab_size: int = 32000,
|
72 |
+
pad_token_id: int = 2,
|
73 |
+
ignore_ids: List[int] = [],
|
74 |
+
|
75 |
+
freeze_lm: bool = True, # lm.eval()
|
76 |
+
opt_version: str = 'facebook/opt-6.7b',
|
77 |
+
|
78 |
+
task: str = 'captioning',
|
79 |
+
|
80 |
+
use_lora: bool = False,
|
81 |
+
lora_alpha: int = 32,
|
82 |
+
lora_r: int = 8,
|
83 |
+
lora_dropout: float = 0.05,
|
84 |
+
lora_target_modules: str = r'.*\.(q_proj|v_proj)',
|
85 |
+
|
86 |
+
hidden_size: int = -1,
|
87 |
+
load_in_4bit: Optional[bool] = False,
|
88 |
+
|
89 |
+
**kwargs,
|
90 |
+
):
|
91 |
+
super().__init__(**kwargs)
|
92 |
+
assert old_vocab_size > 0, 'old_vocab_size must be positive'
|
93 |
+
assert vocab_size > 0, 'vocab_size must be positive'
|
94 |
+
|
95 |
+
self.old_vocab_size = old_vocab_size
|
96 |
+
self.vocab_size = vocab_size
|
97 |
+
self.pad_token_id = pad_token_id
|
98 |
+
self.freeze_lm = freeze_lm
|
99 |
+
self.opt_version = opt_version
|
100 |
+
self.task = task
|
101 |
+
self.use_lora = use_lora
|
102 |
+
self.lora_alpha = lora_alpha
|
103 |
+
self.lora_r = lora_r
|
104 |
+
self.lora_dropout = lora_dropout
|
105 |
+
self.lora_target_modules = lora_target_modules
|
106 |
+
self.hidden_size = hidden_size
|
107 |
+
self.load_in_4bit = load_in_4bit
|
108 |
+
self.ignore_ids = ignore_ids
|
109 |
+
|
110 |
+
|
111 |
+
class CrelloModel(PreTrainedModel):
|
112 |
+
config_class = CrelloModelConfig
|
113 |
+
supports_gradient_checkpointing = True
|
114 |
+
|
115 |
+
def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None):
|
116 |
+
self.lm.gradient_checkpointing_enable()
|
117 |
+
|
118 |
+
def __init__(self, config: CrelloModelConfig): # 显示声明config类型
|
119 |
+
super().__init__(config)
|
120 |
+
|
121 |
+
self.pad_token_id = config.pad_token_id
|
122 |
+
|
123 |
+
self.args = config
|
124 |
+
|
125 |
+
opt_version = "WYBar/LLM_For_Layout_Planning"
|
126 |
+
|
127 |
+
print(f"Using {opt_version} for the language model.")
|
128 |
+
|
129 |
+
if 'facebook/opt' in opt_version:
|
130 |
+
self.lm = OPTForCausalLM.from_pretrained(opt_version)
|
131 |
+
word_embed_proj_dim = self.lm.config.word_embed_proj_dim
|
132 |
+
else:
|
133 |
+
if config.load_in_4bit:
|
134 |
+
print("\n would load_in_4bit")
|
135 |
+
quantization_config = None
|
136 |
+
# This means: fit the entire model on the GPU:0
|
137 |
+
local_rank = int(os.environ.get("LOCAL_RANK", 0))
|
138 |
+
device_map = {"": local_rank}
|
139 |
+
torch_dtype = torch.bfloat16
|
140 |
+
else:
|
141 |
+
print("\n wouldn't load_in_4bit")
|
142 |
+
quantization_config = None
|
143 |
+
device_map = None
|
144 |
+
torch_dtype = None
|
145 |
+
|
146 |
+
self.lm = AutoModelForCausalLM.from_pretrained(
|
147 |
+
"WYBar/LLM_For_Layout_Planning",
|
148 |
+
subfolder="Meta-Llama-3-8B",
|
149 |
+
# use_auth_token=use_auth_token,
|
150 |
+
# quantization_config=quantization_config,
|
151 |
+
# device_map=device_map,
|
152 |
+
trust_remote_code=True,
|
153 |
+
torch_dtype=torch.bfloat16,
|
154 |
+
# cache_dir="/openseg_blob/v-yanbin/GradioDemo/cache_dir",
|
155 |
+
)
|
156 |
+
word_embed_proj_dim = self.lm.config.hidden_size
|
157 |
+
self.config.hidden_size = self.lm.config.hidden_size
|
158 |
+
self.opt_version = opt_version
|
159 |
+
|
160 |
+
if self.args.freeze_lm:
|
161 |
+
self.lm.eval()
|
162 |
+
print("Freezing the LM.")
|
163 |
+
# for param in self.lm.parameters():
|
164 |
+
# param.requires_grad = False
|
165 |
+
else:
|
166 |
+
print("\n no freeze lm, so to train lm")
|
167 |
+
self.lm.train()
|
168 |
+
self.lm.config.gradient_checkpointing = True
|
169 |
+
|
170 |
+
# print('resize token embeddings to match the tokenizer', config.vocab_size)
|
171 |
+
# self.lm.resize_token_embeddings(config.vocab_size)
|
172 |
+
# self.input_embeddings = self.lm.get_input_embeddings()
|
173 |
+
# print('after token embeddings to match the tokenizer', config.vocab_size)
|
174 |
+
|
175 |
+
def train(self, mode=True):
|
176 |
+
super().train(mode=mode)
|
177 |
+
# Overwrite train() to ensure frozen models remain frozen.
|
178 |
+
if self.args.freeze_lm:
|
179 |
+
self.lm.eval()
|
180 |
+
|
181 |
+
def forward(
|
182 |
+
self,
|
183 |
+
labels: torch.LongTensor,
|
184 |
+
):
|
185 |
+
batch_size = labels.shape[0]
|
186 |
+
full_labels = labels.detach().clone()
|
187 |
+
|
188 |
+
input_embs = self.input_embeddings(labels) # (N, T, D)
|
189 |
+
input_embs_norm = ((input_embs ** 2).sum(dim=-1) ** 0.5).mean()
|
190 |
+
|
191 |
+
for ignore_id in self.config.ignore_ids:
|
192 |
+
full_labels[full_labels == ignore_id] = -100
|
193 |
+
|
194 |
+
pad_idx = []
|
195 |
+
# 获取每一个batch的 seq 长度,取值为 max_len or padding_position,记录在pad_idx
|
196 |
+
# -100 is the ignore index for cross entropy loss. https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html
|
197 |
+
for label in full_labels:
|
198 |
+
for k, token in enumerate(label):
|
199 |
+
# Mask out pad tokens if they exist.
|
200 |
+
if token in [self.pad_token_id]:
|
201 |
+
label[k:] = -100 # 将后面的token都mask掉
|
202 |
+
pad_idx.append(k)
|
203 |
+
break
|
204 |
+
if k == len(label) - 1: # No padding found.
|
205 |
+
pad_idx.append(k + 1)
|
206 |
+
assert len(pad_idx) == batch_size, (len(pad_idx), batch_size)
|
207 |
+
|
208 |
+
output = self.lm( inputs_embeds=input_embs,
|
209 |
+
# input_ids=labels,
|
210 |
+
labels=full_labels,
|
211 |
+
output_hidden_states=True)
|
212 |
+
|
213 |
+
return output, full_labels, input_embs_norm
|
214 |
+
|
215 |
+
if __name__=="__main__":
|
216 |
+
config = CrelloModelConfig(
|
217 |
+
vocab_size=50265,
|
218 |
+
image_reg_token=50264,
|
219 |
+
image_gt_token=50263,
|
220 |
+
)
|
221 |
+
print("config: ",config)
|
222 |
+
model1 = CrelloModel(config)
|
223 |
+
print("\nmodel1: ",model1)
|
224 |
+
model1.save_pretrained('test')
|
225 |
+
model2 = CrelloModel.from_pretrained('test')
|
226 |
+
print("\nmodel2: ",model2)
|
227 |
+
# compare model1 and model2
|
228 |
+
|
229 |
+
state_dict1 = model1.state_dict()
|
230 |
+
state_dict2 = model2.state_dict()
|
231 |
+
assert set(state_dict1.keys()) == set(state_dict2.keys())
|
232 |
+
for k in state_dict1.keys():
|
233 |
+
assert torch.equal(state_dict1[k], state_dict2[k])
|
234 |
+
print('all parameters are equal')
|
235 |
+
|
quantizer.py
ADDED
@@ -0,0 +1,552 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
import copy
|
4 |
+
from collections import OrderedDict
|
5 |
+
import json
|
6 |
+
from datasets import ClassLabel
|
7 |
+
import random
|
8 |
+
import math
|
9 |
+
from functools import lru_cache
|
10 |
+
from matplotlib import font_manager
|
11 |
+
from colorama import Fore, Style, init
|
12 |
+
|
13 |
+
|
14 |
+
class BaseQuantizer:
|
15 |
+
@property
|
16 |
+
def ignore_tokens(self):
|
17 |
+
if self.num_mask_tokens > 0:
|
18 |
+
if self.mask_type == 'cm3':
|
19 |
+
return [self.predict_start_token] + self.mask_tokens
|
20 |
+
elif self.mask_type == 'mask_aug':
|
21 |
+
return [self.mask_aug_token]
|
22 |
+
else:
|
23 |
+
raise ValueError(f'Invalid mask type {self.mask_type}')
|
24 |
+
else:
|
25 |
+
return []
|
26 |
+
|
27 |
+
def __init__(self, simplify_json=False, mask_all=False,
|
28 |
+
num_mask_tokens=0, mask_type='cm3', **kwargs):
|
29 |
+
self.simplify_json=simplify_json
|
30 |
+
self.io_ignore_replace_tokens = ['<split-text>']
|
31 |
+
self.mask_all = mask_all
|
32 |
+
self.num_mask_tokens = num_mask_tokens
|
33 |
+
self.mask_type = mask_type
|
34 |
+
if self.mask_type == 'mask_aug':
|
35 |
+
self.mask_aug_token = '<mask-aug>'
|
36 |
+
elif self.mask_type == 'cm3':
|
37 |
+
self.predict_start_token = '<pred-start>'
|
38 |
+
else:
|
39 |
+
raise ValueError(f'Invalid mask type {self.mask_type}')
|
40 |
+
|
41 |
+
def get_additional_mask_tokens(self):
|
42 |
+
if self.mask_type == 'cm3': # 两种配置:1. ['<pred-start>'] + '<mask-%d>',数量和self.num_mask_tokens相关 2. ['<mask-aug>']
|
43 |
+
self.mask_tokens = ['<mask-%d>' % i for i in range(self.num_mask_tokens)]
|
44 |
+
return [self.predict_start_token] + self.mask_tokens
|
45 |
+
elif self.mask_type == 'mask_aug':
|
46 |
+
return [self.mask_aug_token]
|
47 |
+
else:
|
48 |
+
raise ValueError(f'Invalid mask type {self.mask_type}')
|
49 |
+
|
50 |
+
def dump2json(self, json_example):
|
51 |
+
if self.simplify_json: # 将 dict 转化为 str, 如果simplify_json is True,那么缩减空格和换行,删除token的双引号
|
52 |
+
content = json.dumps(json_example, separators=(',',':'))
|
53 |
+
for token in self.additional_special_tokens:
|
54 |
+
content = content.replace(f'"{token}"', token)
|
55 |
+
else:
|
56 |
+
content = json.dumps(json_example)
|
57 |
+
return content
|
58 |
+
|
59 |
+
def load_json(self, content): # 将str转化为json
|
60 |
+
replace_tokens = set(self.additional_special_tokens) - set(self.io_ignore_replace_tokens) # sirui change
|
61 |
+
if self.simplify_json:
|
62 |
+
for token in replace_tokens: # 如果simplify_json is True,那么为 token 添加双引号
|
63 |
+
content = content.replace(token, f'"{token}"')
|
64 |
+
return json.loads(content)
|
65 |
+
|
66 |
+
def apply_masking(self,
|
67 |
+
json_example,
|
68 |
+
mask_all=None,
|
69 |
+
return_meta=False,
|
70 |
+
target_keys=['width', 'height', 'left', 'top'],
|
71 |
+
target_element_types=None
|
72 |
+
):
|
73 |
+
if mask_all is None:
|
74 |
+
mask_all = self.mask_all
|
75 |
+
json_example = copy.deepcopy(json_example)
|
76 |
+
target_keys = set(target_keys)
|
77 |
+
target_tokens = []
|
78 |
+
for shape_i, shape in enumerate(json_example['layers']['textlayer']):
|
79 |
+
# element_type = self.general_dequantize(shape['type'],'type',to_float=False)
|
80 |
+
# if target_element_types is not None:
|
81 |
+
# if element_type not in target_element_types:
|
82 |
+
# continue
|
83 |
+
for key_i, key in enumerate(shape.keys()):
|
84 |
+
if key in target_keys:
|
85 |
+
target_tokens.append((shape_i, key_i, key, shape[key]))
|
86 |
+
if not mask_all:
|
87 |
+
target_num_mask_tokens = random.randint(1, self.num_mask_tokens)
|
88 |
+
if len(target_tokens) > target_num_mask_tokens:
|
89 |
+
random.shuffle(target_tokens)
|
90 |
+
target_tokens = target_tokens[:target_num_mask_tokens]
|
91 |
+
# sort by shape_i and key_i
|
92 |
+
target_tokens = sorted(target_tokens, key=lambda x: x[0]*100+x[1])
|
93 |
+
else:
|
94 |
+
if len(target_tokens) > self.num_mask_tokens:
|
95 |
+
# 取最后面几个
|
96 |
+
target_tokens = target_tokens[-self.num_mask_tokens:]
|
97 |
+
|
98 |
+
tuples = []
|
99 |
+
meta_infos = []
|
100 |
+
for mask_i, (shape_i, key_i, key, value) in enumerate(target_tokens):
|
101 |
+
if self.mask_type == 'cm3':
|
102 |
+
mask_token = self.mask_tokens[mask_i]
|
103 |
+
elif self.mask_type == 'mask_aug':
|
104 |
+
mask_token = self.mask_aug_token
|
105 |
+
else:
|
106 |
+
raise ValueError(f'Invalid mask type {self.mask_type}')
|
107 |
+
# <one-1><decimal0-1><decimal1-2>
|
108 |
+
if '<' in value:
|
109 |
+
num_token = value.count('<')
|
110 |
+
else:
|
111 |
+
num_token = value.count(' ')
|
112 |
+
json_example['layers']['textlayer'][shape_i][key] = mask_token
|
113 |
+
tuples.append((mask_token, value, num_token))
|
114 |
+
meta_infos.append((shape_i,key))
|
115 |
+
if return_meta:
|
116 |
+
return json_example, tuples, meta_infos
|
117 |
+
else:
|
118 |
+
return json_example, tuples
|
119 |
+
|
120 |
+
def make_prediction_postfix(self, tuples):
|
121 |
+
postfix = self.predict_start_token
|
122 |
+
for mask_token, value, num_token in tuples:
|
123 |
+
postfix = postfix+ f'{mask_token}{value}'
|
124 |
+
return postfix
|
125 |
+
|
126 |
+
# specs={
|
127 |
+
# "width":"size",
|
128 |
+
# "height":"size",
|
129 |
+
# "left":"pos",
|
130 |
+
# "top":"pos",
|
131 |
+
# "x":"pos", # center x
|
132 |
+
# "y":"pos", # center y
|
133 |
+
# "opacity":"opacity",
|
134 |
+
# "color":"color",
|
135 |
+
# "angle":"angle",
|
136 |
+
# "font_size":"font_size",
|
137 |
+
# 'ratio':'ratio',
|
138 |
+
# 'letter_spacing': 'spacing',
|
139 |
+
# 'textlen': 'textlen'
|
140 |
+
# }
|
141 |
+
|
142 |
+
specs={
|
143 |
+
"width":"size",
|
144 |
+
"height":"size",
|
145 |
+
"x":"pos", # center x
|
146 |
+
"y":"pos", # center y
|
147 |
+
"color":"color",
|
148 |
+
"font":"font"
|
149 |
+
}
|
150 |
+
|
151 |
+
# TODO change min_max_bins
|
152 |
+
# min_max_bins = {
|
153 |
+
# 'size':(0,2,256),
|
154 |
+
# 'pos':(-1,1,256),
|
155 |
+
# # 'opacity':(0,1,8),
|
156 |
+
# 'opacity':(0,255,8),
|
157 |
+
# 'color':(0,255,32),
|
158 |
+
# 'angle':(0,2*np.pi,64),
|
159 |
+
# 'font_size':(2,200,100),
|
160 |
+
# 'spacing': (0,1,40),
|
161 |
+
# 'textlen': (1,20,20)
|
162 |
+
# }
|
163 |
+
min_max_bins = {
|
164 |
+
'size': (0,1,256),
|
165 |
+
'pos': (0,1,256),
|
166 |
+
'color': (0,137,138),
|
167 |
+
'font': (0,511,512)
|
168 |
+
}
|
169 |
+
|
170 |
+
import numpy as np
|
171 |
+
|
172 |
+
# pre 和 post 分别代表 10 的幂,分别对应大数和小数部分,参数代表位数
|
173 |
+
def get_keys_and_multipliers(pre_decimal=3, post_decimal=2):
|
174 |
+
pre_keys = ['one', 'ten', 'hundred', 'thousand']
|
175 |
+
pre_multiplers = [1, 10, 100, 1000]
|
176 |
+
assert pre_decimal <= len(pre_keys)
|
177 |
+
pre_keys = pre_keys[:pre_decimal][::-1]
|
178 |
+
pre_multiplers = pre_multiplers[:pre_decimal][::-1]
|
179 |
+
|
180 |
+
post_keys = [f'decimal{x}' for x in range(post_decimal)]
|
181 |
+
post_multiplers = [10 ** -(x+1) for x in range(post_decimal)]
|
182 |
+
|
183 |
+
keys = pre_keys + post_keys
|
184 |
+
multiplers = pre_multiplers + post_multiplers
|
185 |
+
return keys, multiplers
|
186 |
+
|
187 |
+
class DecimalQuantizer:
|
188 |
+
def __init__(self, max_pre_decimal=3, max_post_decimal=2):
|
189 |
+
self.max_pre_decimal = max_pre_decimal
|
190 |
+
self.max_post_decimal = max_post_decimal
|
191 |
+
self.keys, self.multiplers = get_keys_and_multipliers(max_pre_decimal, max_post_decimal)
|
192 |
+
self.symbols = {
|
193 |
+
-1: '<symbol-1>',
|
194 |
+
1: '<symbol-0>',
|
195 |
+
}
|
196 |
+
|
197 |
+
def get_vocab(self):
|
198 |
+
special_tokens = [*self.symbols.values()] # ['<symbol-1>', '<symbol-0>']
|
199 |
+
for key in self.keys: # ['one', 'ten', 'hundred', 'thousand'] + ['decimal0', 'decimal1]
|
200 |
+
special_tokens.extend([f'<{key}-{i}>' for i in range(10)])
|
201 |
+
return special_tokens
|
202 |
+
|
203 |
+
def check_valid(self, token):
|
204 |
+
prefix = token.lstrip('<').split('-')[0] # '<symbol-1>' -> 'symbol-1>' -> ['symbol', '1>']
|
205 |
+
if prefix =='symbol' or prefix in self.keys:
|
206 |
+
return True
|
207 |
+
else:
|
208 |
+
return False
|
209 |
+
|
210 |
+
# 小数点后保留两位
|
211 |
+
def __call__(self, val, pre_decimal=None, post_decimal=None, need_symbol=False): # 100.00
|
212 |
+
if pre_decimal is None:
|
213 |
+
pre_decimal = self.max_pre_decimal
|
214 |
+
if post_decimal is None:
|
215 |
+
post_decimal = self.max_post_decimal
|
216 |
+
|
217 |
+
assert pre_decimal <= self.max_pre_decimal
|
218 |
+
assert post_decimal <= self.max_post_decimal
|
219 |
+
|
220 |
+
keys, multiplers = get_keys_and_multipliers(pre_decimal, post_decimal)
|
221 |
+
|
222 |
+
symbol = int(np.sign(val)) # 返回一个浮点数(1.0, -1.0 或 0.0),代表正负和0
|
223 |
+
if symbol == 0: # 两类:>= 0 & < 0
|
224 |
+
symbol = 1
|
225 |
+
val = round(abs(val), post_decimal) # 将 val 的绝对值四舍五入到 post_decimal 位小数
|
226 |
+
|
227 |
+
tokens = []
|
228 |
+
if need_symbol: # self.symbols = {-1: '<symbol-1>', 1: '<symbol-0>',}
|
229 |
+
symbol_type = self.symbols[symbol]
|
230 |
+
tokens.append(symbol_type)
|
231 |
+
else:
|
232 |
+
assert symbol >= 0
|
233 |
+
|
234 |
+
for key, multipler in zip(keys, multiplers):
|
235 |
+
# 用于获取对于给定数值 val,每一位的数字,并且生成为'<one-7>'这样的token
|
236 |
+
v = math.floor(val / multipler)
|
237 |
+
if v > 9:
|
238 |
+
raise ValueError(f'Invalid value {val} for {pre_decimal} pre_decimal and {post_decimal} post_decimal')
|
239 |
+
val = val - v * multipler
|
240 |
+
tokens.append(f'<{key}-{v}>')
|
241 |
+
|
242 |
+
# 对于val,生成每一位数字对应的token,如果need_symbol = True,还会在前面加上 标识 >= 0 和 < 0 的 symbol-1 和 symbol-0
|
243 |
+
return ''.join(tokens)
|
244 |
+
|
245 |
+
def parse_token(self, token):
|
246 |
+
# <hundred-1> -> hundred, 1
|
247 |
+
key, val = token[1:-1].split('-')
|
248 |
+
return key, int(val)
|
249 |
+
|
250 |
+
def decode(self, tokens_str): # 将token_str用 > 先拆开,再添上 > ,然后转化为 list
|
251 |
+
tokens = tokens_str.split('>')
|
252 |
+
tokens = [x+'>' for x in tokens if x != '']
|
253 |
+
if tokens[0].startswith('<symbol'):
|
254 |
+
symbol_type = tokens[0]
|
255 |
+
tokens = tokens[1:]
|
256 |
+
inv_map = {v: k for k, v in self.symbols.items()} # 和 原字典 键、值 对调
|
257 |
+
symbol = inv_map[symbol_type]
|
258 |
+
else:
|
259 |
+
symbol = 1
|
260 |
+
|
261 |
+
accumulater = 0
|
262 |
+
for token in tokens:
|
263 |
+
key, val = self.parse_token(token)
|
264 |
+
multipler_index = self.keys.index(key)
|
265 |
+
multipler = self.multiplers[multipler_index]
|
266 |
+
actual_val = val * multipler
|
267 |
+
# print(key, val, multipler, actual_val)
|
268 |
+
accumulater += actual_val
|
269 |
+
accumulater = accumulater * symbol
|
270 |
+
|
271 |
+
# 还原出原来的整数,带有符号,并且精度 由 pre/post_decimal位数控制
|
272 |
+
return accumulater
|
273 |
+
|
274 |
+
# min_max_bins = {
|
275 |
+
# 'size': (0,1,256),
|
276 |
+
# 'pos': (0,1,256),
|
277 |
+
# 'color': (0,137,138),
|
278 |
+
# 'font': (0,511,512)
|
279 |
+
# }
|
280 |
+
pre_post_decimals={
|
281 |
+
'size': {
|
282 |
+
'pre_decimal': 1,
|
283 |
+
'post_decimal': 2,
|
284 |
+
'need_symbol': False
|
285 |
+
},
|
286 |
+
'pos': {
|
287 |
+
'pre_decimal': 1,
|
288 |
+
'post_decimal': 2,
|
289 |
+
'need_symbol': True
|
290 |
+
},
|
291 |
+
'opacity': {
|
292 |
+
'pre_decimal': 1,
|
293 |
+
'post_decimal': 1,
|
294 |
+
'need_symbol': False
|
295 |
+
},
|
296 |
+
'color':{
|
297 |
+
'pre_decimal': 3,
|
298 |
+
'post_decimal': 0,
|
299 |
+
'need_symbol': False
|
300 |
+
},
|
301 |
+
'angle':{
|
302 |
+
'pre_decimal': 1,
|
303 |
+
'post_decimal': 2,
|
304 |
+
'need_symbol': False
|
305 |
+
},
|
306 |
+
'font_size':{
|
307 |
+
'pre_decimal': 3,
|
308 |
+
'post_decimal': 0,
|
309 |
+
'need_symbol': False
|
310 |
+
},
|
311 |
+
}
|
312 |
+
|
313 |
+
class QuantizerV4(BaseQuantizer):
|
314 |
+
def __init__(self, quant=True,
|
315 |
+
decimal_quantize_types = [],
|
316 |
+
decimal_quantize_kwargs = {'max_pre_decimal':3, 'max_post_decimal':2},
|
317 |
+
mask_values=False,
|
318 |
+
**kwargs):
|
319 |
+
super().__init__(**kwargs)
|
320 |
+
self.min = min
|
321 |
+
self.max = max
|
322 |
+
self.quant = quant
|
323 |
+
self.mask_values = mask_values
|
324 |
+
self.text_split_token = '<split-text>'
|
325 |
+
self.decimal_quantize_types = decimal_quantize_types
|
326 |
+
self.decimal_quantize = len(decimal_quantize_types) > 0
|
327 |
+
if len(decimal_quantize_types) > 0:
|
328 |
+
print('decimal quantize types', decimal_quantize_types)
|
329 |
+
self.decimal_quantizer = DecimalQuantizer(**decimal_quantize_kwargs)
|
330 |
+
else:
|
331 |
+
self.decimal_quantizer = None
|
332 |
+
|
333 |
+
self.set_min_max_bins(min_max_bins)
|
334 |
+
# min_max_bins = {
|
335 |
+
# 'size': (0,1,256),
|
336 |
+
# 'pos': (0,1,256),
|
337 |
+
# 'color': (0,137,138),
|
338 |
+
# 'font': (0,511,512)
|
339 |
+
# }
|
340 |
+
self.width = kwargs.get('width', 1456)
|
341 |
+
self.height = kwargs.get('height', 1457)
|
342 |
+
self.width = int(self.width)
|
343 |
+
self.height = int(self.height)
|
344 |
+
|
345 |
+
def set_min_max_bins(self, min_max_bins): # 检查 n_bins是否是偶数,然后将其 +1
|
346 |
+
min_max_bins = copy.deepcopy(min_max_bins)
|
347 |
+
# adjust the bins to plus one
|
348 |
+
for type_name, (min_val, max_val, n_bins) in min_max_bins.items():
|
349 |
+
assert n_bins % 2 == 0 # must be even
|
350 |
+
min_max_bins[type_name] = (min_val, max_val, n_bins+1)
|
351 |
+
self.min_max_bins = min_max_bins
|
352 |
+
|
353 |
+
def setup_tokenizer(self, tokenizer):
|
354 |
+
# 整个函数生成additional_special_tokens:1. '<split-text>' 2.<one-1> <symbol-1> : decimal quantizer 3. <size-255> quantizerV4 4.self.get_additional_mask_tokens()
|
355 |
+
# 然后tokenizer.add_special_tokens({'additional_special_tokens': additional_special_tokens})
|
356 |
+
additional_special_tokens = [self.text_split_token] # self.text_split_token = '<split-text>'
|
357 |
+
if self.decimal_quantize:
|
358 |
+
special_tokens = self.decimal_quantizer.get_vocab() # <one-1> <symbol-1>
|
359 |
+
self.io_ignore_replace_tokens += special_tokens # self.io_ignore_replace_tokens = ['<split-text>'] 在BaseQuantizer中声明
|
360 |
+
additional_special_tokens += special_tokens
|
361 |
+
# the order must be preserved, other wise the tokenizer will be wrong
|
362 |
+
rest_types = [key for key in self.min_max_bins.keys() if key not in self.decimal_quantize_types]
|
363 |
+
for type_name in rest_types:
|
364 |
+
min_val, max_val, n_bins = self.min_max_bins[type_name]
|
365 |
+
additional_special_tokens += [f'<{type_name}-{i}>' for i in range(n_bins)] # <size-256>
|
366 |
+
|
367 |
+
if self.num_mask_tokens > 0:
|
368 |
+
additional_special_tokens.extend(self.get_additional_mask_tokens())
|
369 |
+
|
370 |
+
print('additional_special_tokens', additional_special_tokens)
|
371 |
+
|
372 |
+
tokenizer.add_special_tokens({'additional_special_tokens': additional_special_tokens})
|
373 |
+
self.additional_special_tokens = set(additional_special_tokens)
|
374 |
+
return tokenizer
|
375 |
+
|
376 |
+
@lru_cache(maxsize=128) # 缓存函数的返回值,以提高性能。maxsize=128 表示缓存最多存储 128 个不同的输入结果
|
377 |
+
def get_bins(self, real_type): # real_type: size, pos, font, color
|
378 |
+
# 返回 最小值,最大值,等距数组
|
379 |
+
min_val, max_val, n_bins = self.min_max_bins[real_type]
|
380 |
+
return min_val, max_val, np.linspace(min_val, max_val, n_bins)
|
381 |
+
|
382 |
+
def quantize(self, x, type): # (0.25, 'y') -> (<size-50>)
|
383 |
+
if not self.quant:
|
384 |
+
return x
|
385 |
+
"""Quantize a float array x into n_bins discrete values."""
|
386 |
+
real_type = specs[type] # x, y, width, height, color, font -> size, pos, font, color
|
387 |
+
min_val, max_val, bins = self.get_bins(real_type)
|
388 |
+
x = np.clip(float(x), min_val, max_val) # 确保 x 的值在 [min_val, max_val] 范围内,否则截断
|
389 |
+
if self.decimal_quantize and real_type in self.decimal_quantize_types:
|
390 |
+
return self.decimal_quantizer(x, **pre_post_decimals[real_type])
|
391 |
+
val = np.digitize(x, bins) - 1 # val是一个整数,取值范围在[0, len(bins)],换句话说就是bins数组的索引
|
392 |
+
n_bins = len(bins)
|
393 |
+
assert val >= 0 and val < n_bins
|
394 |
+
return f'<{real_type}-{val}>' # <size-255>
|
395 |
+
|
396 |
+
def dequantize(self, x): # (<size-255> -> 0.99?)
|
397 |
+
# <pos-1>->1
|
398 |
+
val = x.split('-')[1].strip('>')
|
399 |
+
# <pos-1>->pos
|
400 |
+
real_type = x.split('-')[0][1:]
|
401 |
+
if self.decimal_quantize and self.decimal_quantizer.check_valid(x):
|
402 |
+
return self.decimal_quantizer.decode(x)
|
403 |
+
min_val, max_val, bins = self.get_bins(real_type)
|
404 |
+
return bins[int(val)]
|
405 |
+
|
406 |
+
def construct_map_dict(self):
|
407 |
+
map_dict = {}
|
408 |
+
for i in range(self.min_max_bins['size'][2]): # 'size': (0, 1, 256),
|
409 |
+
name = "<size-%d>" % i
|
410 |
+
value = self.dequantize(name)
|
411 |
+
map_dict[name] = str(value) # 255 -> 0.99?
|
412 |
+
for i in range(self.min_max_bins['pos'][2]):
|
413 |
+
name = "<pos-%d>" % i
|
414 |
+
value = self.dequantize(name)
|
415 |
+
map_dict[name] = str(value)
|
416 |
+
return map_dict
|
417 |
+
|
418 |
+
def postprocess_colorandfont(self, json_example):
|
419 |
+
# 将其中的 正则 匹配部分 用双引号包裹
|
420 |
+
import re
|
421 |
+
json_example = re.sub(r'(<font-\d+>)', r'"\1"', json_example)
|
422 |
+
json_example = re.sub(r'(<color-\d+>)', r'"\1"', json_example)
|
423 |
+
return json_example
|
424 |
+
|
425 |
+
def to_str(self, x, type):
|
426 |
+
feature = self.get_feature(type)
|
427 |
+
return feature.int2str(x)
|
428 |
+
|
429 |
+
def convert2layout(self, example): # 将原始的数据转化为 <size-255> 的 token形式
|
430 |
+
new_example = OrderedDict()
|
431 |
+
new_example['wholecaption'] = example['wholecaption']
|
432 |
+
new_layout = []
|
433 |
+
for meta_layer in example['layout']:
|
434 |
+
new_layout.append({
|
435 |
+
"layer": meta_layer["layer"],
|
436 |
+
"x": self.quantize(meta_layer["x"]/self.width, 'x'),
|
437 |
+
"y": self.quantize(meta_layer["y"]/self.height, 'y'),
|
438 |
+
"width": self.quantize(meta_layer["width"]/self.width, 'width'),
|
439 |
+
"height": self.quantize(meta_layer["height"]/self.height, 'height')
|
440 |
+
})
|
441 |
+
new_example['layout'] = new_layout
|
442 |
+
return new_example
|
443 |
+
|
444 |
+
def apply_masking(self,
|
445 |
+
json_example,
|
446 |
+
mask_all=None,
|
447 |
+
return_meta=False,
|
448 |
+
# target_keys=['width', 'height', 'left', 'top'], # useless
|
449 |
+
# target_element_types=None, # useless
|
450 |
+
mask_values = True
|
451 |
+
):
|
452 |
+
if mask_all is None:
|
453 |
+
mask_all = self.mask_all
|
454 |
+
|
455 |
+
json_example = copy.deepcopy(json_example)
|
456 |
+
|
457 |
+
# 这段内容对json中的一些 value 替换为 <mask-i>,并用self.num_mask_tokens限制mask的数量,根据参数还可能进行随机mask
|
458 |
+
# 并记录 <mask-i> & value & num_token = value.count('<') 的 三元tuple
|
459 |
+
target_tokens = []
|
460 |
+
if self.mask_values and mask_values:
|
461 |
+
target_tokens.append((-1,-1,'globalcaption', json_example['globalcaption']))
|
462 |
+
target_tokens.append((-1,-1,'canvas_width', json_example['canvas_width']))
|
463 |
+
target_tokens.append((-1,-1,'canvas_height', json_example['canvas_height']))
|
464 |
+
target_tokens.append((-1,-1,'category', json_example['category']))
|
465 |
+
target_tokens.append((-1,-1,'keywords', json_example['keywords']))
|
466 |
+
target_tokens.append((-1,-1,'bgcaption', json_example['layers']['bglayer']['bgcaption']))
|
467 |
+
target_tokens.append((-1,-1,'flag', json_example['layers']['objlayer']['flag']))
|
468 |
+
target_tokens.append((-1,-1,'objcaption', json_example['layers']['objlayer']['objcaption']))
|
469 |
+
for layer_i, textlayer in enumerate(json_example['layers']['textlayer']):
|
470 |
+
target_tokens.append((layer_i, -1, 'text', json_example['layers']['textlayer'][textlayer]))
|
471 |
+
if not mask_all: # 随机取值 target_num_mask_tokens, 上界是self.num_mask_tokens
|
472 |
+
target_num_mask_tokens = random.randint(1, self.num_mask_tokens)
|
473 |
+
if len(target_tokens) > target_num_mask_tokens:
|
474 |
+
random.shuffle(target_tokens)
|
475 |
+
target_tokens = target_tokens[:target_num_mask_tokens]
|
476 |
+
# sort by shape_i and key_i
|
477 |
+
target_tokens = sorted(target_tokens, key=lambda x: x[0]*100+x[1])
|
478 |
+
else: # 取定值 num_mask_tokens
|
479 |
+
if len(target_tokens) > self.num_mask_tokens:
|
480 |
+
# 取最后面几个
|
481 |
+
target_tokens = target_tokens[-self.num_mask_tokens:]
|
482 |
+
|
483 |
+
tuples = []
|
484 |
+
meta_infos = []
|
485 |
+
layer_list = ['heading', 'subheading', 'body']
|
486 |
+
for mask_i, (shape_i, key_i, key, value) in enumerate(target_tokens):
|
487 |
+
if self.mask_type == 'cm3':
|
488 |
+
mask_token = self.mask_tokens[mask_i]
|
489 |
+
elif self.mask_type == 'mask_aug':
|
490 |
+
mask_token = self.mask_aug_token
|
491 |
+
else:
|
492 |
+
raise ValueError(f'Invalid mask type {self.mask_type}')
|
493 |
+
# <one-1><decimal0-1><decimal1-2>
|
494 |
+
if '<' in value:
|
495 |
+
num_token = value.count('<')
|
496 |
+
else:
|
497 |
+
num_token = value.count(' ') + 1
|
498 |
+
if shape_i == -1:
|
499 |
+
if key in ['bgcaption']:
|
500 |
+
json_example['layers']['bglayer']['bgcaption'] = mask_token
|
501 |
+
elif key in ['objcaption']:
|
502 |
+
json_example['layers']['objlayer']['objcaption'] = mask_token
|
503 |
+
elif key in ['flag']:
|
504 |
+
json_example['layers']['objlayer']['flag'] = mask_token
|
505 |
+
else:
|
506 |
+
json_example[key] = mask_token
|
507 |
+
else:
|
508 |
+
curlayer = layer_list[shape_i]
|
509 |
+
json_example['layers']['textlayer'][curlayer] = mask_token
|
510 |
+
tuples.append((mask_token, value, num_token))
|
511 |
+
meta_infos.append((shape_i,key))
|
512 |
+
if return_meta:
|
513 |
+
return json_example, tuples, meta_infos
|
514 |
+
else:
|
515 |
+
return json_example, tuples
|
516 |
+
|
517 |
+
|
518 |
+
# useless orginally used for render
|
519 |
+
def is_font_exists(font_name):
|
520 |
+
font_list = font_manager.findSystemFonts()
|
521 |
+
# print("\nfont_list: ",font_list)
|
522 |
+
for font in font_list:
|
523 |
+
if font_name.lower() in font.lower():
|
524 |
+
return True
|
525 |
+
return False
|
526 |
+
|
527 |
+
def print_info(msg):
|
528 |
+
print(Fore.GREEN + "[INFO] " + msg)
|
529 |
+
|
530 |
+
def print_warning(msg):
|
531 |
+
print(Fore.YELLOW + "[WARNING] " + msg)
|
532 |
+
|
533 |
+
def print_error(msg):
|
534 |
+
print(Fore.RED + "[ERROR] " + msg)
|
535 |
+
|
536 |
+
def load_feature(path):
|
537 |
+
with open(path) as f:
|
538 |
+
content = f.read()
|
539 |
+
content = json.loads(content)
|
540 |
+
names = [content[str(i)] for i in range(len(content))]
|
541 |
+
return ClassLabel(num_classes= len(names), names=names)
|
542 |
+
|
543 |
+
def get_quantizer(version='v1', update_vocab=False, **kwargs):
|
544 |
+
""" if kwargs.pop('separate_alpha', False): # useless
|
545 |
+
kwargs['n_visual_tokens'] *= 2 """
|
546 |
+
if version == 'v4':
|
547 |
+
quantizer = QuantizerV4(**kwargs)
|
548 |
+
else:
|
549 |
+
raise NotImplementedError
|
550 |
+
|
551 |
+
return quantizer
|
552 |
+
|
requirements.txt
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# 核心框架
|
2 |
+
torch==2.4.0 # 保持与conda环境一致(pypi安装)
|
3 |
+
torchvision==0.19.0 # 保持与conda环境一致(pypi安装)
|
4 |
+
|
5 |
+
# Hugging Face 生态
|
6 |
+
transformers==4.44.0 # 原4.39.1 → conda实际安装4.44.0
|
7 |
+
diffusers==0.31.0 # 保持与conda环境一致
|
8 |
+
accelerate==0.34.2 # 原0.27.2 → conda实际安装0.34.2
|
9 |
+
peft==0.12.0 # 原git提交 → conda实际安装0.12.0(pypi)
|
10 |
+
datasets==2.20.0 # 保持与conda环境一致
|
11 |
+
|
12 |
+
# 工具链
|
13 |
+
deepspeed==0.15.4 # 原0.14.2 → conda实际安装0.15.4
|
14 |
+
# bitsandbytes==0.44.1 # 原0.43.0 → conda实际安装0.44.1
|
15 |
+
protobuf==3.20.0 # 原3.20.3 → conda实际安装3.20.0(需验证tensorboard兼容性)
|
16 |
+
tensorboard==2.18.0 # 新增明确版本(conda实际安装2.18.0)
|
17 |
+
tensorboardx==2.6.2.2 # 新增明确版本(conda实际安装2.6.2.2)
|
18 |
+
webdataset==0.2.100 # 新增明确版本(conda实际安装0.2.100)
|
19 |
+
|
20 |
+
# 训练辅助
|
21 |
+
warmup_scheduler==0.3 # 新增明确版本(conda实际安装0.3)
|
22 |
+
torchmetrics==1.6.0 # 新增明确版本(conda实际安装1.6.0)
|
23 |
+
open_clip_torch==2.29.0 # 新增明确版本(conda实际安装2.29.0)
|
24 |
+
evaluate==0.4.3 # 新增明确版本(conda实际安装0.4.3)
|
25 |
+
bert_score==0.3.13 # 新增明确版本(conda实际安装0.3.13)
|
26 |
+
einops==0.8.0 # 保持与conda环境一致
|
27 |
+
wandb==0.17.7 # 保持与conda环境一致
|
28 |
+
|
29 |
+
# 图像处理
|
30 |
+
matplotlib==3.9.2 # 新增明确版本(conda实际安装3.9.2)
|
31 |
+
opencv-python==4.10.0.84 # 新增明确版本(conda实际安装4.10.0.84)
|
32 |
+
clean-fid==0.1.35 # 新增明确版本(conda实际安装0.1.35)
|
33 |
+
skia-python==87.6 # 新增明确版本(conda实际安装87.6)
|
34 |
+
|
35 |
+
# 部署与接口
|
36 |
+
# gradio==5.5.0 # 新增明确版本(conda实际安装5.5.0)
|
37 |
+
langchain>=0.0.139 # 保持约束(conda实际安装0.3.7符合要求)
|
38 |
+
tiktoken==0.8.0 # 新增明确版本(conda实际安装0.8.0)
|
39 |
+
|
40 |
+
# 系统工具
|
41 |
+
ninja==1.11.1.1 # 新增明确版本(conda实际安装1.11.1.1)
|
42 |
+
pynvml==11.5.3 # 新增明确版本(conda实际安装11.5.3)
|
43 |
+
colorama==0.4.6 # 新增明确版本(conda实际安装0.4.6)
|
44 |
+
click>=8.0.4,<9 # 保持约束(conda实际安装8.1.7符合要求)\
|
45 |
+
|
46 |
+
sentencepiece
|