Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
1d5bb62
0
Parent(s):
init space
Browse files- .gitattributes +1 -0
- app.py +231 -0
- examples/birdhouse.glb +3 -0
- examples/mario.glb +3 -0
- utils/__init__.py +0 -0
- utils/controlnet_union.py +957 -0
- utils/image_generation.py +299 -0
- utils/mesh_utils.py +500 -0
- utils/pipeline_controlnet_union_sd_xl.py +1397 -0
- utils/pipeline_stable_diffusion_switcher.py +1240 -0
- utils/rasterize.py +166 -0
- utils/render_utils.py +352 -0
- utils/texture_generation.py +309 -0
- wan/__init__.py +0 -0
- wan/pipeline_wan_t2tex_extra.py +366 -0
- wan/wan_t2tex_transformer_3d_extra.py +634 -0
.gitattributes
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
*.glb filter=lfs diff=lfs merge=lfs -text
|
app.py
ADDED
@@ -0,0 +1,231 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
from einops import rearrange
|
4 |
+
from PIL import Image
|
5 |
+
from utils.image_generation import generate_image_condition
|
6 |
+
from utils.mesh_utils import Mesh
|
7 |
+
from utils.render_utils import render_views
|
8 |
+
from utils.texture_generation import generate_texture
|
9 |
+
|
10 |
+
import gradio as gr
|
11 |
+
from gradio_litmodel3d import LitModel3D
|
12 |
+
|
13 |
+
EXAMPLES = [
|
14 |
+
["examples/birdhouse.glb", True, False, False, False, 42, "First View", "SDXL", False, "A rustic birdhouse featuring a snow-covered roof, wood textures, and two decorative cardinal birds. It has a circular entryway and conveys a winter-themed aesthetic."],
|
15 |
+
["examples/mario.glb", False, False, False, True, 6666, "Third View", "FLUX", True, "Mario, a cartoon character wearing a red cap and blue overalls, with brown hair and a mustache, and white gloves, in a fighting pose. The clothes he wears are not in a reflection mode."],
|
16 |
+
]
|
17 |
+
|
18 |
+
def tensor_to_pil(tensor, mask=None, normalize: bool = True):
|
19 |
+
"""
|
20 |
+
Convert tensor to PIL Image.
|
21 |
+
:param tensor: torch.Tensor, shape can be (Nv, H, W, C), (Nv, C, H, W), (H, W, C), (C, H, W)
|
22 |
+
:param mask: torch.Tensor, shape same as tensor, effective when C=3
|
23 |
+
:return: PIL.Image
|
24 |
+
"""
|
25 |
+
# Move to cpu
|
26 |
+
tensor = tensor.detach()
|
27 |
+
if tensor.is_cuda:
|
28 |
+
tensor = tensor.cpu()
|
29 |
+
if mask is not None and mask.is_cuda:
|
30 |
+
mask = mask.cpu()
|
31 |
+
|
32 |
+
# Convert to float32
|
33 |
+
tensor = tensor.float()
|
34 |
+
if mask is not None:
|
35 |
+
mask = mask.float()
|
36 |
+
|
37 |
+
if normalize:
|
38 |
+
tensor = (tensor + 1.0) / 2.0
|
39 |
+
tensor = torch.clamp(tensor, 0.0, 1.0)
|
40 |
+
if mask is not None:
|
41 |
+
if mask.shape[-1] not in [1, 3]:
|
42 |
+
mask = mask.unsqueeze(-1)
|
43 |
+
tensor = torch.cat([tensor, mask], dim=-1)
|
44 |
+
|
45 |
+
shape = tensor.shape
|
46 |
+
# 4D: (Nv, H, W, C) or (Nv, C, H, W)
|
47 |
+
if len(shape) == 4:
|
48 |
+
Nv = shape[0]
|
49 |
+
if shape[-1] in [3, 4]: # (Nv, H, W, C)
|
50 |
+
tensor = rearrange(tensor, 'nv h w c -> h (nv w) c')
|
51 |
+
else: # (Nv, C, H, W)
|
52 |
+
tensor = rearrange(tensor, 'nv c h w -> h (nv w) c')
|
53 |
+
# 3D: (H, W, C) or (C, H, W)
|
54 |
+
elif len(shape) == 3:
|
55 |
+
if shape[-1] in [3, 4]: # (H, W, C)
|
56 |
+
tensor = rearrange(tensor, 'h w c -> h w c')
|
57 |
+
else: # (C, H, W)
|
58 |
+
tensor = rearrange(tensor, 'c h w -> h w c')
|
59 |
+
else:
|
60 |
+
raise ValueError(f"Unsupported tensor shape: {shape}")
|
61 |
+
|
62 |
+
# Convert to numpy
|
63 |
+
np_img = (tensor.numpy() * 255).round().astype(np.uint8)
|
64 |
+
|
65 |
+
# Create PIL Image
|
66 |
+
if np_img.shape[2] == 3:
|
67 |
+
return Image.fromarray(np_img, mode="RGB")
|
68 |
+
elif np_img.shape[2] == 4:
|
69 |
+
return Image.fromarray(np_img, mode="RGBA")
|
70 |
+
else:
|
71 |
+
raise ValueError("Only support 3 or 4 channel images.")
|
72 |
+
|
73 |
+
if __name__ == '__main__':
|
74 |
+
with gr.Blocks() as demo:
|
75 |
+
gr.Markdown("# 🎨 SeqTex: Generate Mesh Textures in Video Sequence")
|
76 |
+
|
77 |
+
gr.Markdown("""
|
78 |
+
## 🚀 Welcome to SeqTex!
|
79 |
+
**SeqTex** is a cutting-edge AI system that generates high-quality textures for 3D meshes using image prompts (here we use image generator to get them from textual prompts).
|
80 |
+
|
81 |
+
Choose to either **try our example models** below or **upload your own 3D mesh** to create stunning textures.
|
82 |
+
""")
|
83 |
+
|
84 |
+
gr.Markdown("---")
|
85 |
+
|
86 |
+
gr.Markdown("## 🔧 Step 1: Upload & Process 3D Mesh")
|
87 |
+
gr.Markdown("""
|
88 |
+
**📋 How to prepare your 3D mesh:**
|
89 |
+
- Upload your 3D mesh in **.obj** or **.glb** format
|
90 |
+
- **💡 Pro Tip**:
|
91 |
+
- For optimal results, ensure your mesh includes only one part with <span style="color:#e74c3c; font-weight:bold;">UV parameterization</span>
|
92 |
+
- Otherwise, we'll combine all parts and generate UV parameterization using *xAtlas* (may take longer for high-poly meshes; may also fail for certain meshes)
|
93 |
+
- **⚠️ Important**: We recommend adjusting your model using *Mesh Orientation Adjustments* to be **Z-UP oriented** for best results
|
94 |
+
""")
|
95 |
+
position_map_tensor, normal_map_tensor, position_images_tensor, normal_images_tensor, mask_images_tensor, w2cs, mesh, mvp_matrix = gr.State(), gr.State(), gr.State(), gr.State(), gr.State(), gr.State(), gr.State(), gr.State()
|
96 |
+
|
97 |
+
# fixed_texture_map = Image.open("image.webp").convert("RGB")
|
98 |
+
# Step 1
|
99 |
+
with gr.Row():
|
100 |
+
with gr.Column():
|
101 |
+
mesh_upload = gr.File(label="📁 Upload 3D Mesh", file_types=[".obj", ".glb"])
|
102 |
+
# uv_tool = gr.Radio(["xAtlas", "UVAtlas"], label="UV parameterizer", value="xAtlas")
|
103 |
+
|
104 |
+
gr.Markdown("**🔄 Mesh Orientation Adjustments** (if needed):")
|
105 |
+
y2z = gr.Checkbox(label="Y → Z Transform", value=False, info="Rotate: Y becomes Z, -Z becomes Y")
|
106 |
+
y2x = gr.Checkbox(label="Y → X Transform", value=False, info="Rotate: Y becomes X, -X becomes Y")
|
107 |
+
z2x = gr.Checkbox(label="Z → X Transform", value=False, info="Rotate: Z becomes X, -X becomes Z")
|
108 |
+
upside_down = gr.Checkbox(label="🔃 Flip Vertically", value=False, info="Fix upside-down mesh orientation")
|
109 |
+
|
110 |
+
with gr.Column():
|
111 |
+
step1_button = gr.Button("🔄 Process Mesh & Generate Views", variant="primary")
|
112 |
+
step1_progress = gr.Textbox(label="📊 Processing Status", interactive=False)
|
113 |
+
model_input = gr.Model3D(label="📐 Processed 3D Model", height=500)
|
114 |
+
|
115 |
+
with gr.Row(equal_height=True):
|
116 |
+
rgb_views = gr.Image(label="📷 Generated Views (Front, Back, Left, Right)", type="pil", scale=3)
|
117 |
+
position_map = gr.Image(label="🗺️ Position Map", type="pil", scale=1)
|
118 |
+
normal_map = gr.Image(label="🧭 Normal Map", type="pil", scale=1)
|
119 |
+
|
120 |
+
step1_button.click(
|
121 |
+
Mesh.process,
|
122 |
+
inputs=[mesh_upload, gr.State("xAtlas"), y2z, y2x, z2x, upside_down],
|
123 |
+
outputs=[position_map_tensor, normal_map_tensor, position_images_tensor, normal_images_tensor, mask_images_tensor, w2cs, mesh, mvp_matrix, step1_progress]
|
124 |
+
).then(
|
125 |
+
tensor_to_pil,
|
126 |
+
inputs=[normal_images_tensor, mask_images_tensor],
|
127 |
+
outputs=[rgb_views]
|
128 |
+
).then(
|
129 |
+
tensor_to_pil,
|
130 |
+
inputs=[position_map_tensor],
|
131 |
+
outputs=[position_map]
|
132 |
+
).then(
|
133 |
+
tensor_to_pil,
|
134 |
+
inputs=[normal_map_tensor],
|
135 |
+
outputs=[normal_map]
|
136 |
+
).then(
|
137 |
+
Mesh.export,
|
138 |
+
inputs=[mesh],
|
139 |
+
outputs=[model_input]
|
140 |
+
)
|
141 |
+
|
142 |
+
# Step 2
|
143 |
+
gr.Markdown("---")
|
144 |
+
gr.Markdown("## 👁️ Step 2: Select View & Generate Image Condition")
|
145 |
+
gr.Markdown("""
|
146 |
+
**📋 How to generate image condition:**
|
147 |
+
- Your mesh will be rendered from **four viewpoints** (front, back, left, right)
|
148 |
+
- Choose **one view** as your image condition
|
149 |
+
- Enter a **descriptive text prompt** for the desired texture
|
150 |
+
- Select your preferred AI model:
|
151 |
+
- <span style="color:#27ae60; font-weight:bold;">🎯 SDXL</span>: Fast generation with depth + normal control, better details
|
152 |
+
- <span style="color:#3498db; font-weight:bold;">⚡ FLUX</span>: High-quality generation with depth control (slower due to CPU offloading). Better work with **Edge Refinement**
|
153 |
+
""")
|
154 |
+
with gr.Row():
|
155 |
+
with gr.Column():
|
156 |
+
img_condition_seed = gr.Number(label="🎲 Random Seed", minimum=0, maximum=9999, step=1, value=42, info="Change for different results")
|
157 |
+
selected_view = gr.Radio(["First View", "Second View", "Third View", "Fourth View"], label="📐 Camera View", value="First View", info="Choose which viewpoint to use as reference")
|
158 |
+
with gr.Row():
|
159 |
+
model_choice = gr.Radio(["SDXL", "FLUX"], label="🤖 AI Model", value="SDXL", info="SDXL: Fast, depth+normal control | FLUX: High-quality, slower processing")
|
160 |
+
edge_refinement = gr.Checkbox(label="✨ Edge Refinement", value=True, info="Smooth boundary artifacts (recommended for cleaner results)")
|
161 |
+
text_prompt = gr.Textbox(label="💬 Texture Description", placeholder="Describe the desired texture appearance (e.g., 'rustic wooden surface with weathered paint')", lines=2)
|
162 |
+
step2_button = gr.Button("🎯 Generate Image Condition", variant="primary")
|
163 |
+
step2_progress = gr.Textbox(label="📊 Generation Status", interactive=False)
|
164 |
+
|
165 |
+
with gr.Column():
|
166 |
+
condition_image = gr.Image(label="🖼️ Generated Image Condition", type="pil") # , interactive=False
|
167 |
+
|
168 |
+
step2_button.click(
|
169 |
+
generate_image_condition,
|
170 |
+
inputs=[position_images_tensor, normal_images_tensor, mask_images_tensor, w2cs, text_prompt, selected_view, img_condition_seed, model_choice, edge_refinement],
|
171 |
+
outputs=[condition_image, step2_progress],
|
172 |
+
concurrency_id="gpu_intensive"
|
173 |
+
)
|
174 |
+
|
175 |
+
# Step 3
|
176 |
+
gr.Markdown("---")
|
177 |
+
gr.Markdown("## 🎨 Step 3: Generate Final Texture")
|
178 |
+
gr.Markdown("""
|
179 |
+
**📋 How to generate final texture:**
|
180 |
+
- The **SeqTex pipeline** will create a complete texture map for your model
|
181 |
+
- View the results from multiple angles and download your textured 3D model (the viewport is a little bit dark)
|
182 |
+
""")
|
183 |
+
texture_map_tensor, mv_out_tensor = gr.State(), gr.State()
|
184 |
+
with gr.Row():
|
185 |
+
with gr.Column(scale=1):
|
186 |
+
step3_button = gr.Button("🎨 Generate Final Texture", variant="primary")
|
187 |
+
step3_progress = gr.Textbox(label="📊 Texture Generation Status", interactive=False)
|
188 |
+
texture_map = gr.Image(label="🏆 Generated Texture Map", interactive=False)
|
189 |
+
with gr.Column(scale=2):
|
190 |
+
rendered_imgs = gr.Image(label="🖼️ Final Rendered Views")
|
191 |
+
mv_branch_imgs = gr.Image(label="🖼️ SeqTex Direct Output")
|
192 |
+
with gr.Column(scale=1.5):
|
193 |
+
# model_display = gr.Model3D(label="🏆 Final Textured Model", height=500)
|
194 |
+
model_display = LitModel3D(label="Model with Texture",
|
195 |
+
exposure=30.0,
|
196 |
+
height=500)
|
197 |
+
|
198 |
+
step3_button.click(
|
199 |
+
generate_texture,
|
200 |
+
inputs=[position_map_tensor, normal_map_tensor, position_images_tensor, normal_images_tensor, condition_image, text_prompt, selected_view],
|
201 |
+
outputs=[texture_map_tensor, mv_out_tensor, step3_progress],
|
202 |
+
concurrency_id="gpu_intensive"
|
203 |
+
).then(
|
204 |
+
tensor_to_pil,
|
205 |
+
inputs=[texture_map_tensor, gr.State(None), gr.State(False)],
|
206 |
+
outputs=[texture_map]
|
207 |
+
).then(
|
208 |
+
tensor_to_pil,
|
209 |
+
inputs=[mv_out_tensor, gr.State(None), gr.State(False)],
|
210 |
+
outputs=[mv_branch_imgs]
|
211 |
+
).then(
|
212 |
+
render_views,
|
213 |
+
inputs=[mesh, texture_map_tensor, mvp_matrix],
|
214 |
+
outputs=[rendered_imgs]
|
215 |
+
).then(
|
216 |
+
Mesh.export,
|
217 |
+
inputs=[mesh, gr.State(None), texture_map],
|
218 |
+
outputs=[model_display]
|
219 |
+
)
|
220 |
+
|
221 |
+
# Add example inputs for user convenience
|
222 |
+
gr.Markdown("---")
|
223 |
+
gr.Markdown("## 🚀 Try Our Examples")
|
224 |
+
gr.Markdown("**Quick Start**: Click on any example below to see SeqTex in action with pre-configured settings!")
|
225 |
+
gr.Examples(
|
226 |
+
examples=EXAMPLES,
|
227 |
+
inputs=[mesh_upload, y2z, y2x, z2x, upside_down, img_condition_seed, selected_view, model_choice, edge_refinement, text_prompt],
|
228 |
+
cache_examples=False
|
229 |
+
)
|
230 |
+
|
231 |
+
demo.launch(server_name="0.0.0.0", server_port=52424)
|
examples/birdhouse.glb
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:30a006774b35531831aaf4ba0dd1c7b8a5b5b58433af17ebc52c816cfbd654b9
|
3 |
+
size 10043504
|
examples/mario.glb
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:cbe06e0ad2fc52811ba343dcaeccacb0b9cee1705b6f33bcd222d20de770b80c
|
3 |
+
size 1970408
|
utils/__init__.py
ADDED
File without changes
|
utils/controlnet_union.py
ADDED
@@ -0,0 +1,957 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
from dataclasses import dataclass
|
15 |
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
16 |
+
|
17 |
+
import torch
|
18 |
+
from torch import nn
|
19 |
+
from torch.nn import functional as F
|
20 |
+
|
21 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
22 |
+
from diffusers.loaders.single_file_model import FromOriginalModelMixin
|
23 |
+
from diffusers.utils import BaseOutput, logging
|
24 |
+
from diffusers.models.attention_processor import (
|
25 |
+
ADDED_KV_ATTENTION_PROCESSORS,
|
26 |
+
CROSS_ATTENTION_PROCESSORS,
|
27 |
+
AttentionProcessor,
|
28 |
+
AttnAddedKVProcessor,
|
29 |
+
AttnProcessor,
|
30 |
+
)
|
31 |
+
from diffusers.models.embeddings import TextImageProjection, TextImageTimeEmbedding, TextTimeEmbedding, TimestepEmbedding, Timesteps
|
32 |
+
from diffusers.models.modeling_utils import ModelMixin
|
33 |
+
from diffusers.models.unets.unet_2d_blocks import (
|
34 |
+
CrossAttnDownBlock2D,
|
35 |
+
DownBlock2D,
|
36 |
+
UNetMidBlock2DCrossAttn,
|
37 |
+
get_down_block,
|
38 |
+
)
|
39 |
+
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
|
40 |
+
|
41 |
+
|
42 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
43 |
+
|
44 |
+
|
45 |
+
from collections import OrderedDict
|
46 |
+
|
47 |
+
# Transformer Block
|
48 |
+
# Used to exchange info between different conditions and input image
|
49 |
+
# With reference to https://github.com/TencentARC/T2I-Adapter/blob/SD/ldm/modules/encoders/adapter.py#L147
|
50 |
+
class QuickGELU(nn.Module):
|
51 |
+
|
52 |
+
def forward(self, x: torch.Tensor):
|
53 |
+
return x * torch.sigmoid(1.702 * x)
|
54 |
+
|
55 |
+
class LayerNorm(nn.LayerNorm):
|
56 |
+
"""Subclass torch's LayerNorm to handle fp16."""
|
57 |
+
|
58 |
+
def forward(self, x: torch.Tensor):
|
59 |
+
orig_type = x.dtype
|
60 |
+
ret = super().forward(x)
|
61 |
+
return ret.type(orig_type)
|
62 |
+
|
63 |
+
class ResidualAttentionBlock(nn.Module):
|
64 |
+
|
65 |
+
def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
|
66 |
+
super().__init__()
|
67 |
+
|
68 |
+
self.attn = nn.MultiheadAttention(d_model, n_head)
|
69 |
+
self.ln_1 = LayerNorm(d_model)
|
70 |
+
self.mlp = nn.Sequential(
|
71 |
+
OrderedDict([("c_fc", nn.Linear(d_model, d_model * 4)), ("gelu", QuickGELU()),
|
72 |
+
("c_proj", nn.Linear(d_model * 4, d_model))]))
|
73 |
+
self.ln_2 = LayerNorm(d_model)
|
74 |
+
self.attn_mask = attn_mask
|
75 |
+
|
76 |
+
def attention(self, x: torch.Tensor):
|
77 |
+
self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
|
78 |
+
return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
|
79 |
+
|
80 |
+
def forward(self, x: torch.Tensor):
|
81 |
+
x = x + self.attention(self.ln_1(x))
|
82 |
+
x = x + self.mlp(self.ln_2(x))
|
83 |
+
return x
|
84 |
+
#-----------------------------------------------------------------------------------------------------
|
85 |
+
|
86 |
+
@dataclass
|
87 |
+
class ControlNetOutput(BaseOutput):
|
88 |
+
"""
|
89 |
+
The output of [`ControlNetModel`].
|
90 |
+
|
91 |
+
Args:
|
92 |
+
down_block_res_samples (`tuple[torch.Tensor]`):
|
93 |
+
A tuple of downsample activations at different resolutions for each downsampling block. Each tensor should
|
94 |
+
be of shape `(batch_size, channel * resolution, height //resolution, width // resolution)`. Output can be
|
95 |
+
used to condition the original UNet's downsampling activations.
|
96 |
+
mid_down_block_re_sample (`torch.Tensor`):
|
97 |
+
The activation of the midde block (the lowest sample resolution). Each tensor should be of shape
|
98 |
+
`(batch_size, channel * lowest_resolution, height // lowest_resolution, width // lowest_resolution)`.
|
99 |
+
Output can be used to condition the original UNet's middle block activation.
|
100 |
+
"""
|
101 |
+
|
102 |
+
down_block_res_samples: Tuple[torch.Tensor]
|
103 |
+
mid_block_res_sample: torch.Tensor
|
104 |
+
|
105 |
+
|
106 |
+
class ControlNetConditioningEmbedding(nn.Module):
|
107 |
+
"""
|
108 |
+
Quoting from https://arxiv.org/abs/2302.05543: "Stable Diffusion uses a pre-processing method similar to VQ-GAN
|
109 |
+
[11] to convert the entire dataset of 512 × 512 images into smaller 64 × 64 “latent images” for stabilized
|
110 |
+
training. This requires ControlNets to convert image-based conditions to 64 × 64 feature space to match the
|
111 |
+
convolution size. We use a tiny network E(·) of four convolution layers with 4 × 4 kernels and 2 × 2 strides
|
112 |
+
(activated by ReLU, channels are 16, 32, 64, 128, initialized with Gaussian weights, trained jointly with the full
|
113 |
+
model) to encode image-space conditions ... into feature maps ..."
|
114 |
+
"""
|
115 |
+
|
116 |
+
# original setting is (16, 32, 96, 256)
|
117 |
+
def __init__(
|
118 |
+
self,
|
119 |
+
conditioning_embedding_channels: int,
|
120 |
+
conditioning_channels: int = 3,
|
121 |
+
block_out_channels: Tuple[int] = (48, 96, 192, 384),
|
122 |
+
):
|
123 |
+
super().__init__()
|
124 |
+
|
125 |
+
self.conv_in = nn.Conv2d(conditioning_channels, block_out_channels[0], kernel_size=3, padding=1)
|
126 |
+
|
127 |
+
self.blocks = nn.ModuleList([])
|
128 |
+
|
129 |
+
for i in range(len(block_out_channels) - 1):
|
130 |
+
channel_in = block_out_channels[i]
|
131 |
+
channel_out = block_out_channels[i + 1]
|
132 |
+
self.blocks.append(nn.Conv2d(channel_in, channel_in, kernel_size=3, padding=1))
|
133 |
+
self.blocks.append(nn.Conv2d(channel_in, channel_out, kernel_size=3, padding=1, stride=2))
|
134 |
+
|
135 |
+
self.conv_out = zero_module(
|
136 |
+
nn.Conv2d(block_out_channels[-1], conditioning_embedding_channels, kernel_size=3, padding=1)
|
137 |
+
)
|
138 |
+
|
139 |
+
def forward(self, conditioning):
|
140 |
+
embedding = self.conv_in(conditioning)
|
141 |
+
embedding = F.silu(embedding)
|
142 |
+
|
143 |
+
for block in self.blocks:
|
144 |
+
embedding = block(embedding)
|
145 |
+
embedding = F.silu(embedding)
|
146 |
+
|
147 |
+
embedding = self.conv_out(embedding)
|
148 |
+
|
149 |
+
return embedding
|
150 |
+
|
151 |
+
|
152 |
+
class ControlNetModel_Union(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
153 |
+
"""
|
154 |
+
A ControlNet model.
|
155 |
+
|
156 |
+
Args:
|
157 |
+
in_channels (`int`, defaults to 4):
|
158 |
+
The number of channels in the input sample.
|
159 |
+
flip_sin_to_cos (`bool`, defaults to `True`):
|
160 |
+
Whether to flip the sin to cos in the time embedding.
|
161 |
+
freq_shift (`int`, defaults to 0):
|
162 |
+
The frequency shift to apply to the time embedding.
|
163 |
+
down_block_types (`tuple[str]`, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
|
164 |
+
The tuple of downsample blocks to use.
|
165 |
+
only_cross_attention (`Union[bool, Tuple[bool]]`, defaults to `False`):
|
166 |
+
block_out_channels (`tuple[int]`, defaults to `(320, 640, 1280, 1280)`):
|
167 |
+
The tuple of output channels for each block.
|
168 |
+
layers_per_block (`int`, defaults to 2):
|
169 |
+
The number of layers per block.
|
170 |
+
downsample_padding (`int`, defaults to 1):
|
171 |
+
The padding to use for the downsampling convolution.
|
172 |
+
mid_block_scale_factor (`float`, defaults to 1):
|
173 |
+
The scale factor to use for the mid block.
|
174 |
+
act_fn (`str`, defaults to "silu"):
|
175 |
+
The activation function to use.
|
176 |
+
norm_num_groups (`int`, *optional*, defaults to 32):
|
177 |
+
The number of groups to use for the normalization. If None, normalization and activation layers is skipped
|
178 |
+
in post-processing.
|
179 |
+
norm_eps (`float`, defaults to 1e-5):
|
180 |
+
The epsilon to use for the normalization.
|
181 |
+
cross_attention_dim (`int`, defaults to 1280):
|
182 |
+
The dimension of the cross attention features.
|
183 |
+
transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1):
|
184 |
+
The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
|
185 |
+
[`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
|
186 |
+
[`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
|
187 |
+
encoder_hid_dim (`int`, *optional*, defaults to None):
|
188 |
+
If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
|
189 |
+
dimension to `cross_attention_dim`.
|
190 |
+
encoder_hid_dim_type (`str`, *optional*, defaults to `None`):
|
191 |
+
If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text
|
192 |
+
embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
|
193 |
+
attention_head_dim (`Union[int, Tuple[int]]`, defaults to 8):
|
194 |
+
The dimension of the attention heads.
|
195 |
+
use_linear_projection (`bool`, defaults to `False`):
|
196 |
+
class_embed_type (`str`, *optional*, defaults to `None`):
|
197 |
+
The type of class embedding to use which is ultimately summed with the time embeddings. Choose from None,
|
198 |
+
`"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
|
199 |
+
addition_embed_type (`str`, *optional*, defaults to `None`):
|
200 |
+
Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
|
201 |
+
"text". "text" will use the `TextTimeEmbedding` layer.
|
202 |
+
num_class_embeds (`int`, *optional*, defaults to 0):
|
203 |
+
Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
|
204 |
+
class conditioning with `class_embed_type` equal to `None`.
|
205 |
+
upcast_attention (`bool`, defaults to `False`):
|
206 |
+
resnet_time_scale_shift (`str`, defaults to `"default"`):
|
207 |
+
Time scale shift config for ResNet blocks (see `ResnetBlock2D`). Choose from `default` or `scale_shift`.
|
208 |
+
projection_class_embeddings_input_dim (`int`, *optional*, defaults to `None`):
|
209 |
+
The dimension of the `class_labels` input when `class_embed_type="projection"`. Required when
|
210 |
+
`class_embed_type="projection"`.
|
211 |
+
controlnet_conditioning_channel_order (`str`, defaults to `"rgb"`):
|
212 |
+
The channel order of conditional image. Will convert to `rgb` if it's `bgr`.
|
213 |
+
conditioning_embedding_out_channels (`tuple[int]`, *optional*, defaults to `(16, 32, 96, 256)`):
|
214 |
+
The tuple of output channel for each block in the `conditioning_embedding` layer.
|
215 |
+
global_pool_conditions (`bool`, defaults to `False`):
|
216 |
+
"""
|
217 |
+
|
218 |
+
_supports_gradient_checkpointing = True
|
219 |
+
|
220 |
+
@register_to_config
|
221 |
+
def __init__(
|
222 |
+
self,
|
223 |
+
in_channels: int = 4,
|
224 |
+
conditioning_channels: int = 3,
|
225 |
+
flip_sin_to_cos: bool = True,
|
226 |
+
freq_shift: int = 0,
|
227 |
+
down_block_types: Tuple[str] = (
|
228 |
+
"CrossAttnDownBlock2D",
|
229 |
+
"CrossAttnDownBlock2D",
|
230 |
+
"CrossAttnDownBlock2D",
|
231 |
+
"DownBlock2D",
|
232 |
+
),
|
233 |
+
only_cross_attention: Union[bool, Tuple[bool]] = False,
|
234 |
+
block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
|
235 |
+
layers_per_block: int = 2,
|
236 |
+
downsample_padding: int = 1,
|
237 |
+
mid_block_scale_factor: float = 1,
|
238 |
+
act_fn: str = "silu",
|
239 |
+
norm_num_groups: Optional[int] = 32,
|
240 |
+
norm_eps: float = 1e-5,
|
241 |
+
cross_attention_dim: int = 1280,
|
242 |
+
transformer_layers_per_block: Union[int, Tuple[int]] = 1,
|
243 |
+
encoder_hid_dim: Optional[int] = None,
|
244 |
+
encoder_hid_dim_type: Optional[str] = None,
|
245 |
+
attention_head_dim: Union[int, Tuple[int]] = 8,
|
246 |
+
num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
|
247 |
+
use_linear_projection: bool = False,
|
248 |
+
class_embed_type: Optional[str] = None,
|
249 |
+
addition_embed_type: Optional[str] = None,
|
250 |
+
addition_time_embed_dim: Optional[int] = None,
|
251 |
+
num_class_embeds: Optional[int] = None,
|
252 |
+
upcast_attention: bool = False,
|
253 |
+
resnet_time_scale_shift: str = "default",
|
254 |
+
projection_class_embeddings_input_dim: Optional[int] = None,
|
255 |
+
controlnet_conditioning_channel_order: str = "rgb",
|
256 |
+
conditioning_embedding_out_channels: Optional[Tuple[int]] = (16, 32, 96, 256),
|
257 |
+
global_pool_conditions: bool = False,
|
258 |
+
addition_embed_type_num_heads=64,
|
259 |
+
num_control_type = 6,
|
260 |
+
):
|
261 |
+
super().__init__()
|
262 |
+
|
263 |
+
# If `num_attention_heads` is not defined (which is the case for most models)
|
264 |
+
# it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
|
265 |
+
# The reason for this behavior is to correct for incorrectly named variables that were introduced
|
266 |
+
# when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
|
267 |
+
# Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
|
268 |
+
# which is why we correct for the naming here.
|
269 |
+
num_attention_heads = num_attention_heads or attention_head_dim
|
270 |
+
|
271 |
+
# Check inputs
|
272 |
+
if len(block_out_channels) != len(down_block_types):
|
273 |
+
raise ValueError(
|
274 |
+
f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
|
275 |
+
)
|
276 |
+
|
277 |
+
if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
|
278 |
+
raise ValueError(
|
279 |
+
f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
|
280 |
+
)
|
281 |
+
|
282 |
+
if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
|
283 |
+
raise ValueError(
|
284 |
+
f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
|
285 |
+
)
|
286 |
+
|
287 |
+
if isinstance(transformer_layers_per_block, int):
|
288 |
+
transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
|
289 |
+
|
290 |
+
# input
|
291 |
+
conv_in_kernel = 3
|
292 |
+
conv_in_padding = (conv_in_kernel - 1) // 2
|
293 |
+
self.conv_in = nn.Conv2d(
|
294 |
+
in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
|
295 |
+
)
|
296 |
+
|
297 |
+
# time
|
298 |
+
time_embed_dim = block_out_channels[0] * 4
|
299 |
+
self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
|
300 |
+
timestep_input_dim = block_out_channels[0]
|
301 |
+
self.time_embedding = TimestepEmbedding(
|
302 |
+
timestep_input_dim,
|
303 |
+
time_embed_dim,
|
304 |
+
act_fn=act_fn,
|
305 |
+
)
|
306 |
+
|
307 |
+
if encoder_hid_dim_type is None and encoder_hid_dim is not None:
|
308 |
+
encoder_hid_dim_type = "text_proj"
|
309 |
+
self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
|
310 |
+
logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.")
|
311 |
+
|
312 |
+
if encoder_hid_dim is None and encoder_hid_dim_type is not None:
|
313 |
+
raise ValueError(
|
314 |
+
f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}."
|
315 |
+
)
|
316 |
+
|
317 |
+
if encoder_hid_dim_type == "text_proj":
|
318 |
+
self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
|
319 |
+
elif encoder_hid_dim_type == "text_image_proj":
|
320 |
+
# image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
|
321 |
+
# they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
|
322 |
+
# case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)`
|
323 |
+
self.encoder_hid_proj = TextImageProjection(
|
324 |
+
text_embed_dim=encoder_hid_dim,
|
325 |
+
image_embed_dim=cross_attention_dim,
|
326 |
+
cross_attention_dim=cross_attention_dim,
|
327 |
+
)
|
328 |
+
|
329 |
+
elif encoder_hid_dim_type is not None:
|
330 |
+
raise ValueError(
|
331 |
+
f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
|
332 |
+
)
|
333 |
+
else:
|
334 |
+
self.encoder_hid_proj = None
|
335 |
+
|
336 |
+
# class embedding
|
337 |
+
if class_embed_type is None and num_class_embeds is not None:
|
338 |
+
self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
|
339 |
+
elif class_embed_type == "timestep":
|
340 |
+
self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
|
341 |
+
elif class_embed_type == "identity":
|
342 |
+
self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
|
343 |
+
elif class_embed_type == "projection":
|
344 |
+
if projection_class_embeddings_input_dim is None:
|
345 |
+
raise ValueError(
|
346 |
+
"`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
|
347 |
+
)
|
348 |
+
# The projection `class_embed_type` is the same as the timestep `class_embed_type` except
|
349 |
+
# 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
|
350 |
+
# 2. it projects from an arbitrary input dimension.
|
351 |
+
#
|
352 |
+
# Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
|
353 |
+
# When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
|
354 |
+
# As a result, `TimestepEmbedding` can be passed arbitrary vectors.
|
355 |
+
self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
|
356 |
+
else:
|
357 |
+
self.class_embedding = None
|
358 |
+
|
359 |
+
if addition_embed_type == "text":
|
360 |
+
if encoder_hid_dim is not None:
|
361 |
+
text_time_embedding_from_dim = encoder_hid_dim
|
362 |
+
else:
|
363 |
+
text_time_embedding_from_dim = cross_attention_dim
|
364 |
+
|
365 |
+
self.add_embedding = TextTimeEmbedding(
|
366 |
+
text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads
|
367 |
+
)
|
368 |
+
elif addition_embed_type == "text_image":
|
369 |
+
# text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
|
370 |
+
# they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
|
371 |
+
# case when `addition_embed_type == "text_image"` (Kadinsky 2.1)`
|
372 |
+
self.add_embedding = TextImageTimeEmbedding(
|
373 |
+
text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim
|
374 |
+
)
|
375 |
+
elif addition_embed_type == "text_time":
|
376 |
+
self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)
|
377 |
+
self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
|
378 |
+
|
379 |
+
elif addition_embed_type is not None:
|
380 |
+
raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.")
|
381 |
+
|
382 |
+
# control net conditioning embedding
|
383 |
+
self.controlnet_cond_embedding = ControlNetConditioningEmbedding(
|
384 |
+
conditioning_embedding_channels=block_out_channels[0],
|
385 |
+
block_out_channels=conditioning_embedding_out_channels,
|
386 |
+
conditioning_channels=conditioning_channels,
|
387 |
+
)
|
388 |
+
|
389 |
+
# Copyright by Qi Xin(2024/07/06)
|
390 |
+
# Condition Transformer(fuse single/multi conditions with input image)
|
391 |
+
# The Condition Transformer augment the feature representation of conditions
|
392 |
+
# The overall design is somewhat like resnet. The output of Condition Transformer is used to predict a condition bias adding to the original condition feature.
|
393 |
+
# num_control_type = 6
|
394 |
+
num_trans_channel = 320
|
395 |
+
num_trans_head = 8
|
396 |
+
num_trans_layer = 1
|
397 |
+
num_proj_channel = 320
|
398 |
+
task_scale_factor = num_trans_channel ** 0.5
|
399 |
+
|
400 |
+
self.task_embedding = nn.Parameter(task_scale_factor * torch.randn(num_control_type, num_trans_channel))
|
401 |
+
self.transformer_layes = nn.Sequential(*[ResidualAttentionBlock(num_trans_channel, num_trans_head) for _ in range(num_trans_layer)])
|
402 |
+
self.spatial_ch_projs = zero_module(nn.Linear(num_trans_channel, num_proj_channel))
|
403 |
+
#-----------------------------------------------------------------------------------------------------
|
404 |
+
|
405 |
+
# Copyright by Qi Xin(2024/07/06)
|
406 |
+
# Control Encoder to distinguish different control conditions
|
407 |
+
# A simple but effective module, consists of an embedding layer and a linear layer, to inject the control info to time embedding.
|
408 |
+
self.control_type_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)
|
409 |
+
self.control_add_embedding = TimestepEmbedding(addition_time_embed_dim * num_control_type, time_embed_dim)
|
410 |
+
#-----------------------------------------------------------------------------------------------------
|
411 |
+
|
412 |
+
self.down_blocks = nn.ModuleList([])
|
413 |
+
self.controlnet_down_blocks = nn.ModuleList([])
|
414 |
+
|
415 |
+
if isinstance(only_cross_attention, bool):
|
416 |
+
only_cross_attention = [only_cross_attention] * len(down_block_types)
|
417 |
+
|
418 |
+
if isinstance(attention_head_dim, int):
|
419 |
+
attention_head_dim = (attention_head_dim,) * len(down_block_types)
|
420 |
+
|
421 |
+
if isinstance(num_attention_heads, int):
|
422 |
+
num_attention_heads = (num_attention_heads,) * len(down_block_types)
|
423 |
+
|
424 |
+
# down
|
425 |
+
output_channel = block_out_channels[0]
|
426 |
+
|
427 |
+
controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
|
428 |
+
controlnet_block = zero_module(controlnet_block)
|
429 |
+
self.controlnet_down_blocks.append(controlnet_block)
|
430 |
+
|
431 |
+
for i, down_block_type in enumerate(down_block_types):
|
432 |
+
input_channel = output_channel
|
433 |
+
output_channel = block_out_channels[i]
|
434 |
+
is_final_block = i == len(block_out_channels) - 1
|
435 |
+
|
436 |
+
down_block = get_down_block(
|
437 |
+
down_block_type,
|
438 |
+
num_layers=layers_per_block,
|
439 |
+
transformer_layers_per_block=transformer_layers_per_block[i],
|
440 |
+
in_channels=input_channel,
|
441 |
+
out_channels=output_channel,
|
442 |
+
temb_channels=time_embed_dim,
|
443 |
+
add_downsample=not is_final_block,
|
444 |
+
resnet_eps=norm_eps,
|
445 |
+
resnet_act_fn=act_fn,
|
446 |
+
resnet_groups=norm_num_groups,
|
447 |
+
cross_attention_dim=cross_attention_dim,
|
448 |
+
num_attention_heads=num_attention_heads[i],
|
449 |
+
attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
|
450 |
+
downsample_padding=downsample_padding,
|
451 |
+
use_linear_projection=use_linear_projection,
|
452 |
+
only_cross_attention=only_cross_attention[i],
|
453 |
+
upcast_attention=upcast_attention,
|
454 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
455 |
+
)
|
456 |
+
self.down_blocks.append(down_block)
|
457 |
+
|
458 |
+
for _ in range(layers_per_block):
|
459 |
+
controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
|
460 |
+
controlnet_block = zero_module(controlnet_block)
|
461 |
+
self.controlnet_down_blocks.append(controlnet_block)
|
462 |
+
|
463 |
+
if not is_final_block:
|
464 |
+
controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
|
465 |
+
controlnet_block = zero_module(controlnet_block)
|
466 |
+
self.controlnet_down_blocks.append(controlnet_block)
|
467 |
+
|
468 |
+
# mid
|
469 |
+
mid_block_channel = block_out_channels[-1]
|
470 |
+
|
471 |
+
controlnet_block = nn.Conv2d(mid_block_channel, mid_block_channel, kernel_size=1)
|
472 |
+
controlnet_block = zero_module(controlnet_block)
|
473 |
+
self.controlnet_mid_block = controlnet_block
|
474 |
+
|
475 |
+
self.mid_block = UNetMidBlock2DCrossAttn(
|
476 |
+
transformer_layers_per_block=transformer_layers_per_block[-1],
|
477 |
+
in_channels=mid_block_channel,
|
478 |
+
temb_channels=time_embed_dim,
|
479 |
+
resnet_eps=norm_eps,
|
480 |
+
resnet_act_fn=act_fn,
|
481 |
+
output_scale_factor=mid_block_scale_factor,
|
482 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
483 |
+
cross_attention_dim=cross_attention_dim,
|
484 |
+
num_attention_heads=num_attention_heads[-1],
|
485 |
+
resnet_groups=norm_num_groups,
|
486 |
+
use_linear_projection=use_linear_projection,
|
487 |
+
upcast_attention=upcast_attention,
|
488 |
+
)
|
489 |
+
|
490 |
+
@classmethod
|
491 |
+
def from_unet(
|
492 |
+
cls,
|
493 |
+
unet: UNet2DConditionModel,
|
494 |
+
controlnet_conditioning_channel_order: str = "rgb",
|
495 |
+
conditioning_embedding_out_channels: Optional[Tuple[int]] = (16, 32, 96, 256),
|
496 |
+
load_weights_from_unet: bool = True,
|
497 |
+
):
|
498 |
+
r"""
|
499 |
+
Instantiate a [`ControlNetModel`] from [`UNet2DConditionModel`].
|
500 |
+
|
501 |
+
Parameters:
|
502 |
+
unet (`UNet2DConditionModel`):
|
503 |
+
The UNet model weights to copy to the [`ControlNetModel`]. All configuration options are also copied
|
504 |
+
where applicable.
|
505 |
+
"""
|
506 |
+
transformer_layers_per_block = (
|
507 |
+
unet.config.transformer_layers_per_block if "transformer_layers_per_block" in unet.config else 1
|
508 |
+
)
|
509 |
+
encoder_hid_dim = unet.config.encoder_hid_dim if "encoder_hid_dim" in unet.config else None
|
510 |
+
encoder_hid_dim_type = unet.config.encoder_hid_dim_type if "encoder_hid_dim_type" in unet.config else None
|
511 |
+
addition_embed_type = unet.config.addition_embed_type if "addition_embed_type" in unet.config else None
|
512 |
+
addition_time_embed_dim = (
|
513 |
+
unet.config.addition_time_embed_dim if "addition_time_embed_dim" in unet.config else None
|
514 |
+
)
|
515 |
+
|
516 |
+
controlnet = cls(
|
517 |
+
encoder_hid_dim=encoder_hid_dim,
|
518 |
+
encoder_hid_dim_type=encoder_hid_dim_type,
|
519 |
+
addition_embed_type=addition_embed_type,
|
520 |
+
addition_time_embed_dim=addition_time_embed_dim,
|
521 |
+
transformer_layers_per_block=transformer_layers_per_block,
|
522 |
+
# transformer_layers_per_block=[1, 2, 5],
|
523 |
+
in_channels=unet.config.in_channels,
|
524 |
+
flip_sin_to_cos=unet.config.flip_sin_to_cos,
|
525 |
+
freq_shift=unet.config.freq_shift,
|
526 |
+
down_block_types=unet.config.down_block_types,
|
527 |
+
only_cross_attention=unet.config.only_cross_attention,
|
528 |
+
block_out_channels=unet.config.block_out_channels,
|
529 |
+
layers_per_block=unet.config.layers_per_block,
|
530 |
+
downsample_padding=unet.config.downsample_padding,
|
531 |
+
mid_block_scale_factor=unet.config.mid_block_scale_factor,
|
532 |
+
act_fn=unet.config.act_fn,
|
533 |
+
norm_num_groups=unet.config.norm_num_groups,
|
534 |
+
norm_eps=unet.config.norm_eps,
|
535 |
+
cross_attention_dim=unet.config.cross_attention_dim,
|
536 |
+
attention_head_dim=unet.config.attention_head_dim,
|
537 |
+
num_attention_heads=unet.config.num_attention_heads,
|
538 |
+
use_linear_projection=unet.config.use_linear_projection,
|
539 |
+
class_embed_type=unet.config.class_embed_type,
|
540 |
+
num_class_embeds=unet.config.num_class_embeds,
|
541 |
+
upcast_attention=unet.config.upcast_attention,
|
542 |
+
resnet_time_scale_shift=unet.config.resnet_time_scale_shift,
|
543 |
+
projection_class_embeddings_input_dim=unet.config.projection_class_embeddings_input_dim,
|
544 |
+
controlnet_conditioning_channel_order=controlnet_conditioning_channel_order,
|
545 |
+
conditioning_embedding_out_channels=conditioning_embedding_out_channels,
|
546 |
+
)
|
547 |
+
|
548 |
+
if load_weights_from_unet:
|
549 |
+
controlnet.conv_in.load_state_dict(unet.conv_in.state_dict())
|
550 |
+
controlnet.time_proj.load_state_dict(unet.time_proj.state_dict())
|
551 |
+
controlnet.time_embedding.load_state_dict(unet.time_embedding.state_dict())
|
552 |
+
|
553 |
+
if controlnet.class_embedding:
|
554 |
+
controlnet.class_embedding.load_state_dict(unet.class_embedding.state_dict())
|
555 |
+
|
556 |
+
controlnet.down_blocks.load_state_dict(unet.down_blocks.state_dict(), strict=False)
|
557 |
+
controlnet.mid_block.load_state_dict(unet.mid_block.state_dict(), strict=False)
|
558 |
+
|
559 |
+
return controlnet
|
560 |
+
|
561 |
+
@property
|
562 |
+
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
|
563 |
+
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
564 |
+
r"""
|
565 |
+
Returns:
|
566 |
+
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
567 |
+
indexed by its weight name.
|
568 |
+
"""
|
569 |
+
# set recursively
|
570 |
+
processors = {}
|
571 |
+
|
572 |
+
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
573 |
+
if hasattr(module, "get_processor"):
|
574 |
+
processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
|
575 |
+
|
576 |
+
for sub_name, child in module.named_children():
|
577 |
+
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
578 |
+
|
579 |
+
return processors
|
580 |
+
|
581 |
+
for name, module in self.named_children():
|
582 |
+
fn_recursive_add_processors(name, module, processors)
|
583 |
+
|
584 |
+
return processors
|
585 |
+
|
586 |
+
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
587 |
+
def set_attn_processor(
|
588 |
+
self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False
|
589 |
+
):
|
590 |
+
r"""
|
591 |
+
Sets the attention processor to use to compute attention.
|
592 |
+
|
593 |
+
Parameters:
|
594 |
+
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
595 |
+
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
596 |
+
for **all** `Attention` layers.
|
597 |
+
|
598 |
+
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
599 |
+
processor. This is strongly recommended when setting trainable attention processors.
|
600 |
+
|
601 |
+
"""
|
602 |
+
count = len(self.attn_processors.keys())
|
603 |
+
|
604 |
+
if isinstance(processor, dict) and len(processor) != count:
|
605 |
+
raise ValueError(
|
606 |
+
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
607 |
+
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
608 |
+
)
|
609 |
+
|
610 |
+
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
611 |
+
if hasattr(module, "set_processor"):
|
612 |
+
if not isinstance(processor, dict):
|
613 |
+
module.set_processor(processor, _remove_lora=_remove_lora)
|
614 |
+
else:
|
615 |
+
module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora)
|
616 |
+
|
617 |
+
for sub_name, child in module.named_children():
|
618 |
+
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
619 |
+
|
620 |
+
for name, module in self.named_children():
|
621 |
+
fn_recursive_attn_processor(name, module, processor)
|
622 |
+
|
623 |
+
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
|
624 |
+
def set_default_attn_processor(self):
|
625 |
+
"""
|
626 |
+
Disables custom attention processors and sets the default attention implementation.
|
627 |
+
"""
|
628 |
+
if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
|
629 |
+
processor = AttnAddedKVProcessor()
|
630 |
+
elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
|
631 |
+
processor = AttnProcessor()
|
632 |
+
else:
|
633 |
+
raise ValueError(
|
634 |
+
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
|
635 |
+
)
|
636 |
+
|
637 |
+
self.set_attn_processor(processor, _remove_lora=True)
|
638 |
+
|
639 |
+
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attention_slice
|
640 |
+
def set_attention_slice(self, slice_size):
|
641 |
+
r"""
|
642 |
+
Enable sliced attention computation.
|
643 |
+
|
644 |
+
When this option is enabled, the attention module splits the input tensor in slices to compute attention in
|
645 |
+
several steps. This is useful for saving some memory in exchange for a small decrease in speed.
|
646 |
+
|
647 |
+
Args:
|
648 |
+
slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
|
649 |
+
When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
|
650 |
+
`"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
|
651 |
+
provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
|
652 |
+
must be a multiple of `slice_size`.
|
653 |
+
"""
|
654 |
+
sliceable_head_dims = []
|
655 |
+
|
656 |
+
def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
|
657 |
+
if hasattr(module, "set_attention_slice"):
|
658 |
+
sliceable_head_dims.append(module.sliceable_head_dim)
|
659 |
+
|
660 |
+
for child in module.children():
|
661 |
+
fn_recursive_retrieve_sliceable_dims(child)
|
662 |
+
|
663 |
+
# retrieve number of attention layers
|
664 |
+
for module in self.children():
|
665 |
+
fn_recursive_retrieve_sliceable_dims(module)
|
666 |
+
|
667 |
+
num_sliceable_layers = len(sliceable_head_dims)
|
668 |
+
|
669 |
+
if slice_size == "auto":
|
670 |
+
# half the attention head size is usually a good trade-off between
|
671 |
+
# speed and memory
|
672 |
+
slice_size = [dim // 2 for dim in sliceable_head_dims]
|
673 |
+
elif slice_size == "max":
|
674 |
+
# make smallest slice possible
|
675 |
+
slice_size = num_sliceable_layers * [1]
|
676 |
+
|
677 |
+
slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
|
678 |
+
|
679 |
+
if len(slice_size) != len(sliceable_head_dims):
|
680 |
+
raise ValueError(
|
681 |
+
f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
|
682 |
+
f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
|
683 |
+
)
|
684 |
+
|
685 |
+
for i in range(len(slice_size)):
|
686 |
+
size = slice_size[i]
|
687 |
+
dim = sliceable_head_dims[i]
|
688 |
+
if size is not None and size > dim:
|
689 |
+
raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
|
690 |
+
|
691 |
+
# Recursively walk through all the children.
|
692 |
+
# Any children which exposes the set_attention_slice method
|
693 |
+
# gets the message
|
694 |
+
def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
|
695 |
+
if hasattr(module, "set_attention_slice"):
|
696 |
+
module.set_attention_slice(slice_size.pop())
|
697 |
+
|
698 |
+
for child in module.children():
|
699 |
+
fn_recursive_set_attention_slice(child, slice_size)
|
700 |
+
|
701 |
+
reversed_slice_size = list(reversed(slice_size))
|
702 |
+
for module in self.children():
|
703 |
+
fn_recursive_set_attention_slice(module, reversed_slice_size)
|
704 |
+
|
705 |
+
|
706 |
+
def _set_gradient_checkpointing(self, module, value=False):
|
707 |
+
if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D)):
|
708 |
+
module.gradient_checkpointing = value
|
709 |
+
|
710 |
+
|
711 |
+
def forward(
|
712 |
+
self,
|
713 |
+
sample: torch.FloatTensor,
|
714 |
+
timestep: Union[torch.Tensor, float, int],
|
715 |
+
encoder_hidden_states: torch.Tensor,
|
716 |
+
controlnet_cond_list: torch.FloatTensor,
|
717 |
+
conditioning_scale: float = 1.0,
|
718 |
+
class_labels: Optional[torch.Tensor] = None,
|
719 |
+
timestep_cond: Optional[torch.Tensor] = None,
|
720 |
+
attention_mask: Optional[torch.Tensor] = None,
|
721 |
+
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
|
722 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
723 |
+
guess_mode: bool = False,
|
724 |
+
return_dict: bool = True,
|
725 |
+
) -> Union[ControlNetOutput, Tuple]:
|
726 |
+
"""
|
727 |
+
The [`ControlNetModel`] forward method.
|
728 |
+
|
729 |
+
Args:
|
730 |
+
sample (`torch.FloatTensor`):
|
731 |
+
The noisy input tensor.
|
732 |
+
timestep (`Union[torch.Tensor, float, int]`):
|
733 |
+
The number of timesteps to denoise an input.
|
734 |
+
encoder_hidden_states (`torch.Tensor`):
|
735 |
+
The encoder hidden states.
|
736 |
+
controlnet_cond (`torch.FloatTensor`):
|
737 |
+
The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`.
|
738 |
+
conditioning_scale (`float`, defaults to `1.0`):
|
739 |
+
The scale factor for ControlNet outputs.
|
740 |
+
class_labels (`torch.Tensor`, *optional*, defaults to `None`):
|
741 |
+
Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
|
742 |
+
timestep_cond (`torch.Tensor`, *optional*, defaults to `None`):
|
743 |
+
Additional conditional embeddings for timestep. If provided, the embeddings will be summed with the
|
744 |
+
timestep_embedding passed through the `self.time_embedding` layer to obtain the final timestep
|
745 |
+
embeddings.
|
746 |
+
attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
|
747 |
+
An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
|
748 |
+
is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
|
749 |
+
negative values to the attention scores corresponding to "discard" tokens.
|
750 |
+
added_cond_kwargs (`dict`):
|
751 |
+
Additional conditions for the Stable Diffusion XL UNet.
|
752 |
+
cross_attention_kwargs (`dict[str]`, *optional*, defaults to `None`):
|
753 |
+
A kwargs dictionary that if specified is passed along to the `AttnProcessor`.
|
754 |
+
guess_mode (`bool`, defaults to `False`):
|
755 |
+
In this mode, the ControlNet encoder tries its best to recognize the input content of the input even if
|
756 |
+
you remove all prompts. A `guidance_scale` between 3.0 and 5.0 is recommended.
|
757 |
+
return_dict (`bool`, defaults to `True`):
|
758 |
+
Whether or not to return a [`~models.controlnet.ControlNetOutput`] instead of a plain tuple.
|
759 |
+
|
760 |
+
Returns:
|
761 |
+
[`~models.controlnet.ControlNetOutput`] **or** `tuple`:
|
762 |
+
If `return_dict` is `True`, a [`~models.controlnet.ControlNetOutput`] is returned, otherwise a tuple is
|
763 |
+
returned where the first element is the sample tensor.
|
764 |
+
"""
|
765 |
+
# check channel order
|
766 |
+
channel_order = self.config.controlnet_conditioning_channel_order
|
767 |
+
|
768 |
+
if channel_order == "rgb":
|
769 |
+
# in rgb order by default
|
770 |
+
...
|
771 |
+
# elif channel_order == "bgr":
|
772 |
+
# controlnet_cond = torch.flip(controlnet_cond, dims=[1])
|
773 |
+
else:
|
774 |
+
raise ValueError(f"unknown `controlnet_conditioning_channel_order`: {channel_order}")
|
775 |
+
|
776 |
+
# prepare attention_mask
|
777 |
+
if attention_mask is not None:
|
778 |
+
attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
|
779 |
+
attention_mask = attention_mask.unsqueeze(1)
|
780 |
+
|
781 |
+
# 1. time
|
782 |
+
timesteps = timestep
|
783 |
+
if not torch.is_tensor(timesteps):
|
784 |
+
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
|
785 |
+
# This would be a good case for the `match` statement (Python 3.10+)
|
786 |
+
is_mps = sample.device.type == "mps"
|
787 |
+
if isinstance(timestep, float):
|
788 |
+
dtype = torch.float32 if is_mps else torch.float64
|
789 |
+
else:
|
790 |
+
dtype = torch.int32 if is_mps else torch.int64
|
791 |
+
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
|
792 |
+
elif len(timesteps.shape) == 0:
|
793 |
+
timesteps = timesteps[None].to(sample.device)
|
794 |
+
|
795 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
796 |
+
timesteps = timesteps.expand(sample.shape[0])
|
797 |
+
|
798 |
+
t_emb = self.time_proj(timesteps)
|
799 |
+
|
800 |
+
# timesteps does not contain any weights and will always return f32 tensors
|
801 |
+
# but time_embedding might actually be running in fp16. so we need to cast here.
|
802 |
+
# there might be better ways to encapsulate this.
|
803 |
+
t_emb = t_emb.to(dtype=sample.dtype)
|
804 |
+
|
805 |
+
emb = self.time_embedding(t_emb, timestep_cond)
|
806 |
+
aug_emb = None
|
807 |
+
|
808 |
+
if self.class_embedding is not None:
|
809 |
+
if class_labels is None:
|
810 |
+
raise ValueError("class_labels should be provided when num_class_embeds > 0")
|
811 |
+
|
812 |
+
if self.config.class_embed_type == "timestep":
|
813 |
+
class_labels = self.time_proj(class_labels)
|
814 |
+
|
815 |
+
class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
|
816 |
+
emb = emb + class_emb
|
817 |
+
|
818 |
+
if self.config.addition_embed_type is not None:
|
819 |
+
if self.config.addition_embed_type == "text":
|
820 |
+
aug_emb = self.add_embedding(encoder_hidden_states)
|
821 |
+
|
822 |
+
elif self.config.addition_embed_type == "text_time":
|
823 |
+
if "text_embeds" not in added_cond_kwargs:
|
824 |
+
raise ValueError(
|
825 |
+
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
|
826 |
+
)
|
827 |
+
text_embeds = added_cond_kwargs.get("text_embeds")
|
828 |
+
if "time_ids" not in added_cond_kwargs:
|
829 |
+
raise ValueError(
|
830 |
+
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
|
831 |
+
)
|
832 |
+
time_ids = added_cond_kwargs.get("time_ids")
|
833 |
+
time_embeds = self.add_time_proj(time_ids.flatten())
|
834 |
+
time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
|
835 |
+
|
836 |
+
add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
|
837 |
+
add_embeds = add_embeds.to(emb.dtype)
|
838 |
+
aug_emb = self.add_embedding(add_embeds)
|
839 |
+
|
840 |
+
# Copyright by Qi Xin(2024/07/06)
|
841 |
+
# inject control type info to time embedding to distinguish different control conditions
|
842 |
+
control_type = added_cond_kwargs.get('control_type')
|
843 |
+
control_embeds = self.control_type_proj(control_type.flatten())
|
844 |
+
control_embeds = control_embeds.reshape((t_emb.shape[0], -1))
|
845 |
+
control_embeds = control_embeds.to(emb.dtype)
|
846 |
+
control_emb = self.control_add_embedding(control_embeds)
|
847 |
+
emb = emb + control_emb
|
848 |
+
#---------------------------------------------------------------------------------
|
849 |
+
|
850 |
+
emb = emb + aug_emb if aug_emb is not None else emb
|
851 |
+
|
852 |
+
# 2. pre-process
|
853 |
+
sample = self.conv_in(sample)
|
854 |
+
indices = torch.nonzero(control_type[0])
|
855 |
+
|
856 |
+
# Copyright by Qi Xin(2024/07/06)
|
857 |
+
# add single/multi conditons to input image.
|
858 |
+
# Condition Transformer provides an easy and effective way to fuse different features naturally
|
859 |
+
inputs = []
|
860 |
+
condition_list = []
|
861 |
+
|
862 |
+
for idx in range(indices.shape[0] + 1):
|
863 |
+
if idx == indices.shape[0]:
|
864 |
+
controlnet_cond = sample
|
865 |
+
feat_seq = torch.mean(controlnet_cond, dim=(2, 3)) # N * C
|
866 |
+
else:
|
867 |
+
controlnet_cond = self.controlnet_cond_embedding(controlnet_cond_list[indices[idx][0]])
|
868 |
+
feat_seq = torch.mean(controlnet_cond, dim=(2, 3)) # N * C
|
869 |
+
feat_seq = feat_seq + self.task_embedding[indices[idx][0]]
|
870 |
+
|
871 |
+
inputs.append(feat_seq.unsqueeze(1))
|
872 |
+
condition_list.append(controlnet_cond)
|
873 |
+
|
874 |
+
x = torch.cat(inputs, dim=1) # NxLxC
|
875 |
+
x = self.transformer_layes(x)
|
876 |
+
|
877 |
+
controlnet_cond_fuser = sample * 0.0
|
878 |
+
for idx in range(indices.shape[0]):
|
879 |
+
alpha = self.spatial_ch_projs(x[:, idx])
|
880 |
+
alpha = alpha.unsqueeze(-1).unsqueeze(-1)
|
881 |
+
controlnet_cond_fuser += condition_list[idx] + alpha
|
882 |
+
|
883 |
+
sample = sample + controlnet_cond_fuser
|
884 |
+
#-------------------------------------------------------------------------------------------
|
885 |
+
|
886 |
+
# 3. down
|
887 |
+
down_block_res_samples = (sample,)
|
888 |
+
for downsample_block in self.down_blocks:
|
889 |
+
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
|
890 |
+
sample, res_samples = downsample_block(
|
891 |
+
hidden_states=sample,
|
892 |
+
temb=emb,
|
893 |
+
encoder_hidden_states=encoder_hidden_states,
|
894 |
+
attention_mask=attention_mask,
|
895 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
896 |
+
)
|
897 |
+
else:
|
898 |
+
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
|
899 |
+
|
900 |
+
down_block_res_samples += res_samples
|
901 |
+
|
902 |
+
# 4. mid
|
903 |
+
if self.mid_block is not None:
|
904 |
+
sample = self.mid_block(
|
905 |
+
sample,
|
906 |
+
emb,
|
907 |
+
encoder_hidden_states=encoder_hidden_states,
|
908 |
+
attention_mask=attention_mask,
|
909 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
910 |
+
)
|
911 |
+
|
912 |
+
# 5. Control net blocks
|
913 |
+
|
914 |
+
controlnet_down_block_res_samples = ()
|
915 |
+
|
916 |
+
for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks):
|
917 |
+
down_block_res_sample = controlnet_block(down_block_res_sample)
|
918 |
+
controlnet_down_block_res_samples = controlnet_down_block_res_samples + (down_block_res_sample,)
|
919 |
+
|
920 |
+
down_block_res_samples = controlnet_down_block_res_samples
|
921 |
+
|
922 |
+
mid_block_res_sample = self.controlnet_mid_block(sample)
|
923 |
+
|
924 |
+
# 6. scaling
|
925 |
+
if guess_mode and not self.config.global_pool_conditions:
|
926 |
+
scales = torch.logspace(-1, 0, len(down_block_res_samples) + 1, device=sample.device) # 0.1 to 1.0
|
927 |
+
scales = scales * conditioning_scale
|
928 |
+
down_block_res_samples = [sample * scale for sample, scale in zip(down_block_res_samples, scales)]
|
929 |
+
mid_block_res_sample = mid_block_res_sample * scales[-1] # last one
|
930 |
+
else:
|
931 |
+
down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples]
|
932 |
+
mid_block_res_sample = mid_block_res_sample * conditioning_scale
|
933 |
+
|
934 |
+
if self.config.global_pool_conditions:
|
935 |
+
down_block_res_samples = [
|
936 |
+
torch.mean(sample, dim=(2, 3), keepdim=True) for sample in down_block_res_samples
|
937 |
+
]
|
938 |
+
mid_block_res_sample = torch.mean(mid_block_res_sample, dim=(2, 3), keepdim=True)
|
939 |
+
|
940 |
+
if not return_dict:
|
941 |
+
return (down_block_res_samples, mid_block_res_sample)
|
942 |
+
|
943 |
+
return ControlNetOutput(
|
944 |
+
down_block_res_samples=down_block_res_samples, mid_block_res_sample=mid_block_res_sample
|
945 |
+
)
|
946 |
+
|
947 |
+
|
948 |
+
|
949 |
+
def zero_module(module):
|
950 |
+
for p in module.parameters():
|
951 |
+
nn.init.zeros_(p)
|
952 |
+
return module
|
953 |
+
|
954 |
+
|
955 |
+
|
956 |
+
|
957 |
+
|
utils/image_generation.py
ADDED
@@ -0,0 +1,299 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import threading
|
2 |
+
|
3 |
+
import cv2
|
4 |
+
import numpy as np
|
5 |
+
import spaces
|
6 |
+
import torch
|
7 |
+
import torch.nn.functional as F
|
8 |
+
# Add FLUX imports
|
9 |
+
from diffusers import (AutoencoderKL, EulerAncestralDiscreteScheduler,
|
10 |
+
FluxControlNetModel, FluxControlNetPipeline)
|
11 |
+
from einops import rearrange
|
12 |
+
from PIL import Image
|
13 |
+
from torchvision.transforms import ToPILImage
|
14 |
+
|
15 |
+
import gradio as gr
|
16 |
+
|
17 |
+
from .controlnet_union import ControlNetModel_Union
|
18 |
+
from .pipeline_controlnet_union_sd_xl import \
|
19 |
+
StableDiffusionXLControlNetUnionPipeline
|
20 |
+
from .render_utils import get_silhouette_image
|
21 |
+
|
22 |
+
IMG_PIPE = None
|
23 |
+
IMG_PIPE_LOCK = threading.Lock()
|
24 |
+
# Add FLUX pipeline variables
|
25 |
+
FLUX_PIPE = None
|
26 |
+
FLUX_PIPE_LOCK = threading.Lock()
|
27 |
+
FLUX_SUFFIX = None
|
28 |
+
FLUX_NEGATIVE = None
|
29 |
+
|
30 |
+
def lazy_get_flux_pipe():
|
31 |
+
"""
|
32 |
+
Lazy load the FLUX pipeline with ControlNet for image generation.
|
33 |
+
"""
|
34 |
+
global FLUX_PIPE, FLUX_SUFFIX, FLUX_NEGATIVE
|
35 |
+
if FLUX_PIPE is not None:
|
36 |
+
return FLUX_PIPE
|
37 |
+
gr.Info("First called, loading FLUX pipeline... It may take about 1 minute.")
|
38 |
+
with FLUX_PIPE_LOCK:
|
39 |
+
if FLUX_PIPE is not None:
|
40 |
+
return FLUX_PIPE
|
41 |
+
FLUX_SUFFIX = ", albedo texture, high-quality, 8K, flat shaded, diffuse color only, orthographic view, seamless texture pattern, detailed surface texture."
|
42 |
+
FLUX_NEGATIVE = "ugly, PBR, lighting, shadows, highlights, specular, reflections, ambient occlusion, global illumination, bloom, glare, lens flare, glow, shiny, glossy, noise, grain, blurry, bokeh, depth of field."
|
43 |
+
base_model = 'black-forest-labs/FLUX.1-dev'
|
44 |
+
controlnet_model_union = 'Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro-2.0'
|
45 |
+
|
46 |
+
controlnet = FluxControlNetModel.from_pretrained(controlnet_model_union, torch_dtype=torch.bfloat16)
|
47 |
+
FLUX_PIPE = FluxControlNetPipeline.from_pretrained(
|
48 |
+
base_model,
|
49 |
+
controlnet=controlnet,
|
50 |
+
torch_dtype=torch.bfloat16
|
51 |
+
)
|
52 |
+
# Use model CPU offload for better GPU utilization during inference
|
53 |
+
FLUX_PIPE.enable_model_cpu_offload()
|
54 |
+
return FLUX_PIPE
|
55 |
+
|
56 |
+
def lazy_get_sdxl_pipe():
|
57 |
+
"""
|
58 |
+
Lazy load the SDXL pipeline with ControlNet for image generation.
|
59 |
+
"""
|
60 |
+
global IMG_PIPE
|
61 |
+
if IMG_PIPE is not None:
|
62 |
+
return IMG_PIPE
|
63 |
+
gr.Info("First called, loading SDXL pipeline... It may take about 20 seconds.")
|
64 |
+
with IMG_PIPE_LOCK:
|
65 |
+
if IMG_PIPE is not None:
|
66 |
+
return IMG_PIPE
|
67 |
+
eulera_scheduler = EulerAncestralDiscreteScheduler.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", subfolder="scheduler")
|
68 |
+
# when test with other base model, you need to change the vae also.
|
69 |
+
vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)
|
70 |
+
controlnet_model = ControlNetModel_Union.from_pretrained("xinsir/controlnet-union-sdxl-1.0", torch_dtype=torch.float16, use_safetensors=True)
|
71 |
+
IMG_PIPE = StableDiffusionXLControlNetUnionPipeline.from_pretrained(
|
72 |
+
"stabilityai/stable-diffusion-xl-base-1.0", controlnet=controlnet_model,
|
73 |
+
vae=vae,
|
74 |
+
torch_dtype=torch.float16,
|
75 |
+
scheduler=eulera_scheduler,
|
76 |
+
)
|
77 |
+
# Move pipeline to CUDA device
|
78 |
+
IMG_PIPE = IMG_PIPE.to("cuda")
|
79 |
+
return IMG_PIPE
|
80 |
+
|
81 |
+
|
82 |
+
def generate_sdxl_condition(depth_img, normal_img, text_prompt, mask, seed=42, edge_refinement=False, image_height=1024, image_width=1024, progress=gr.Progress()) -> Image.Image:
|
83 |
+
"""
|
84 |
+
Generate image condition using SDXL model with ControlNet based on depth and normal images.
|
85 |
+
:param depth_img: Depth image from the selected view.
|
86 |
+
:param normal_img: Normal image (Camera Coordinate System) from the selected view.
|
87 |
+
:param text_prompt: Text prompt for image generation.
|
88 |
+
:param mask: A mask image to apply to guide the subsequent pipeline to focus on the foreground.
|
89 |
+
:param seed: Random seed for image generation.
|
90 |
+
:param edge_refinement: Whether to apply edge refinement to smooth mask boundaries (default: False).
|
91 |
+
:param image_height: Height of the output image.
|
92 |
+
:param image_width: Width of the output image.
|
93 |
+
:param progress: Progress callback for Gradio.
|
94 |
+
:return: Generated image condition (e.g., PIL Image).
|
95 |
+
"""
|
96 |
+
progress(0.1, desc="Loading SDXL pipeline...")
|
97 |
+
pipeline = lazy_get_sdxl_pipe()
|
98 |
+
progress(0.3, desc="SDXL pipeline loaded successfully.")
|
99 |
+
|
100 |
+
positive_prompt = text_prompt + ", photo-realistic style, high quality, 8K, highly detailed texture, soft lightning, uniform color, foreground"
|
101 |
+
negative_prompt = 'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality'
|
102 |
+
|
103 |
+
img_generation_resolution = 1024 # SDXL performs better at 1024x1024
|
104 |
+
image = pipeline(prompt=[positive_prompt]*1,
|
105 |
+
image_list=[0, depth_img, 0, 0, normal_img, 0],
|
106 |
+
negative_prompt=[negative_prompt]*1,
|
107 |
+
generator=torch.Generator(device="cuda").manual_seed(seed),
|
108 |
+
width=img_generation_resolution,
|
109 |
+
height=img_generation_resolution,
|
110 |
+
num_inference_steps=50,
|
111 |
+
union_control=True,
|
112 |
+
union_control_type=torch.Tensor([0, 1, 0, 0, 1, 0]).to("cuda"), # use depth and normal images
|
113 |
+
progress=progress,
|
114 |
+
).images[0]
|
115 |
+
progress(0.9, desc="Condition tensor generated successfully.")
|
116 |
+
|
117 |
+
rgb_tensor = torch.from_numpy(np.array(image)).float().permute(2, 0, 1).unsqueeze(0).to(pipeline.device)
|
118 |
+
mask_tensor = torch.from_numpy(np.array(mask)).float().unsqueeze(0).unsqueeze(0).to(pipeline.device) # Ensure mask is in the correct shape
|
119 |
+
mask_tensor = mask_tensor / 255.0 # Normalize mask to [0, 1]
|
120 |
+
|
121 |
+
rgb_tensor = F.interpolate(rgb_tensor, (image_height, image_width), mode="bilinear", align_corners=False)
|
122 |
+
mask_tensor = F.interpolate(mask_tensor, (image_height, image_width), mode="bilinear", align_corners=False)
|
123 |
+
|
124 |
+
# Apply edge refinement if enabled
|
125 |
+
if edge_refinement:
|
126 |
+
# Convert to CUDA device for edge refinement
|
127 |
+
rgb_tensor_cuda = rgb_tensor.to("cuda")
|
128 |
+
mask_tensor_cuda = mask_tensor.to("cuda")
|
129 |
+
rgb_tensor_cuda = refine_image_edges(rgb_tensor_cuda, mask_tensor_cuda)
|
130 |
+
rgb_tensor = rgb_tensor_cuda.to(pipeline.device)
|
131 |
+
|
132 |
+
background_tensor = torch.zeros_like(rgb_tensor)
|
133 |
+
rgb_tensor = torch.lerp(background_tensor, rgb_tensor, mask_tensor)
|
134 |
+
rgb_tensor = rearrange(rgb_tensor, "1 C H W -> C H W")
|
135 |
+
rgb_tensor = rgb_tensor / 255.
|
136 |
+
to_img = ToPILImage()
|
137 |
+
condition_image = to_img(rgb_tensor.cpu())
|
138 |
+
|
139 |
+
progress(1, desc="Condition image generated successfully.")
|
140 |
+
return condition_image
|
141 |
+
|
142 |
+
def generate_flux_condition(depth_img, text_prompt, mask, seed=42, edge_refinement=False, image_height=1024, image_width=1024, progress=gr.Progress()) -> Image.Image:
|
143 |
+
"""
|
144 |
+
Generate image condition using FLUX model with ControlNet based on depth image only.
|
145 |
+
Note: FLUX.1-dev-ControlNet-Union-Pro-2.0 does not support normal control, only depth.
|
146 |
+
:param depth_img: Depth image from the selected view.
|
147 |
+
:param text_prompt: Text prompt for image generation.
|
148 |
+
:param mask: A mask image to apply to guide the subsequent pipeline to focus on the foreground.
|
149 |
+
:param seed: Random seed for image generation.
|
150 |
+
:param image_height: Height of the output image.
|
151 |
+
:param image_width: Width of the output image.
|
152 |
+
:param progress: Progress callback for Gradio.
|
153 |
+
:param edge_refinement: Whether to apply edge refinement to smooth mask boundaries (default: False).
|
154 |
+
:return: Generated image condition (PIL Image).
|
155 |
+
"""
|
156 |
+
progress(0.1, desc="Loading FLUX pipeline...")
|
157 |
+
pipeline = lazy_get_flux_pipe()
|
158 |
+
progress(0.3, desc="FLUX pipeline loaded successfully.")
|
159 |
+
|
160 |
+
# Enhanced prompt for better results
|
161 |
+
positive_prompt = text_prompt + FLUX_SUFFIX
|
162 |
+
negative_prompt = FLUX_NEGATIVE
|
163 |
+
|
164 |
+
# Get image dimensions
|
165 |
+
width, height = depth_img.size
|
166 |
+
|
167 |
+
progress(0.5, desc="Generating image with FLUX (including onload and cpu offload)...")
|
168 |
+
|
169 |
+
# Generate image using FLUX ControlNet with depth control
|
170 |
+
# model_cpu_offload handles GPU loading automatically
|
171 |
+
image = pipeline(
|
172 |
+
prompt=positive_prompt,
|
173 |
+
negative_prompt=negative_prompt,
|
174 |
+
control_image=depth_img,
|
175 |
+
width=width,
|
176 |
+
height=height,
|
177 |
+
controlnet_conditioning_scale=0.8, # Recommended for depth
|
178 |
+
control_guidance_end=0.8,
|
179 |
+
num_inference_steps=30,
|
180 |
+
guidance_scale=3.5,
|
181 |
+
generator=torch.Generator(device="cuda").manual_seed(seed),
|
182 |
+
).images[0]
|
183 |
+
|
184 |
+
progress(0.9, desc="Applying mask and resizing...")
|
185 |
+
|
186 |
+
# Convert to tensor and apply mask
|
187 |
+
rgb_tensor = torch.from_numpy(np.array(image)).float().permute(2, 0, 1).unsqueeze(0).to("cuda")
|
188 |
+
mask_tensor = torch.from_numpy(np.array(mask)).float().unsqueeze(0).unsqueeze(0).to("cuda")
|
189 |
+
mask_tensor = mask_tensor / 255.0 # Normalize mask to [0, 1]
|
190 |
+
|
191 |
+
# Resize to target dimensions
|
192 |
+
rgb_tensor = F.interpolate(rgb_tensor, (image_height, image_width), mode="bilinear", align_corners=False)
|
193 |
+
mask_tensor = F.interpolate(mask_tensor, (image_height, image_width), mode="bilinear", align_corners=False)
|
194 |
+
|
195 |
+
# Apply mask (blend with black background)
|
196 |
+
background_tensor = torch.zeros_like(rgb_tensor)
|
197 |
+
if edge_refinement:
|
198 |
+
# replace edge with inner values
|
199 |
+
rgb_tensor = refine_image_edges(rgb_tensor, mask_tensor)
|
200 |
+
|
201 |
+
rgb_tensor = torch.lerp(background_tensor, rgb_tensor, mask_tensor)
|
202 |
+
|
203 |
+
# Convert back to PIL Image
|
204 |
+
rgb_tensor = rearrange(rgb_tensor, "1 C H W -> C H W")
|
205 |
+
rgb_tensor = rgb_tensor / 255.0
|
206 |
+
to_img = ToPILImage()
|
207 |
+
condition_image = to_img(rgb_tensor.cpu())
|
208 |
+
|
209 |
+
progress(1, desc="FLUX condition image generated successfully.")
|
210 |
+
return condition_image
|
211 |
+
|
212 |
+
def refine_image_edges(rgb_tensor, mask_tensor):
|
213 |
+
"""
|
214 |
+
Refine image edges using advanced morphological operations to remove white edges while preserving object boundaries.
|
215 |
+
|
216 |
+
Algorithm:
|
217 |
+
1. Erode mask to get eroded_mask
|
218 |
+
2. Double erode mask to get double_eroded_mask
|
219 |
+
3. XOR eroded_mask and double_eroded_mask to get circle_valid_mask
|
220 |
+
4. Use circle_valid_mask to extract circle_rgb (clean edge values)
|
221 |
+
5. Dilate circle_rgb to cover the edge region
|
222 |
+
6. Final result: use double_eroded_mask for original RGB foreground, dilated_circle_rgb for background
|
223 |
+
|
224 |
+
:param rgb_tensor: RGB image tensor of shape (1, C, H, W) on CUDA device
|
225 |
+
:param mask_tensor: Mask tensor of shape (1, 1, H, W) on CUDA device, normalized to [0, 1]
|
226 |
+
:return: refined_rgb_tensor
|
227 |
+
"""
|
228 |
+
# Convert tensors to numpy for OpenCV processing
|
229 |
+
rgb_np = rgb_tensor.squeeze().permute(1, 2, 0).cpu().numpy().astype(np.uint8) # (H, W, C)
|
230 |
+
mask_np = mask_tensor.squeeze().cpu().numpy() # Remove batch and channel dimensions
|
231 |
+
original_mask_np = (mask_np * 255).astype(np.uint8) # Convert to 0-255 range
|
232 |
+
|
233 |
+
# Create morphological kernel (3x3 as requested)
|
234 |
+
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3))
|
235 |
+
|
236 |
+
# Step 1: Erode mask to get eroded_mask
|
237 |
+
eroded_mask_np = cv2.erode(original_mask_np, kernel, iterations=3)
|
238 |
+
|
239 |
+
# Step 2: Double erode mask to get double_eroded_mask
|
240 |
+
double_eroded_mask_np = cv2.erode(eroded_mask_np, kernel, iterations=5)
|
241 |
+
|
242 |
+
# Step 3: XOR eroded_mask and double_eroded_mask to get circle_valid_mask
|
243 |
+
circle_valid_mask_np = cv2.bitwise_xor(eroded_mask_np, double_eroded_mask_np)
|
244 |
+
|
245 |
+
# Step 4: Use circle_valid_mask to extract circle_rgb (clean edge values)
|
246 |
+
circle_valid_mask_3c = cv2.cvtColor(circle_valid_mask_np, cv2.COLOR_GRAY2BGR) / 255.0
|
247 |
+
circle_rgb_np = (rgb_np * circle_valid_mask_3c).astype(np.uint8)
|
248 |
+
|
249 |
+
# Step 5: Dilate circle_rgb to cover the edge region (using iterations=6 directly)
|
250 |
+
dilated_circle_rgb_np = cv2.dilate(circle_rgb_np, kernel, iterations=8)
|
251 |
+
|
252 |
+
# Step 6: Final composition
|
253 |
+
# Use double_eroded_mask for original RGB foreground, dilated_circle_rgb for background
|
254 |
+
double_eroded_mask_3c = cv2.cvtColor(double_eroded_mask_np, cv2.COLOR_GRAY2BGR) / 255.0
|
255 |
+
|
256 |
+
# Final result: original RGB where double_eroded_mask is valid, dilated_circle_rgb elsewhere
|
257 |
+
refined_rgb_np = (rgb_np * double_eroded_mask_3c +
|
258 |
+
dilated_circle_rgb_np * (1 - double_eroded_mask_3c)).astype(np.uint8)
|
259 |
+
|
260 |
+
# Convert refined RGB back to tensor
|
261 |
+
refined_rgb_tensor = torch.from_numpy(refined_rgb_np).float().permute(2, 0, 1).unsqueeze(0).to("cuda")
|
262 |
+
|
263 |
+
return refined_rgb_tensor
|
264 |
+
|
265 |
+
@spaces.GPU(duration=120)
|
266 |
+
def generate_image_condition(position_imgs, normal_imgs, mask_imgs, w2c, text_prompt, selected_view="First View", seed=42, model="SDXL", edge_refinement=True, progress=gr.Progress()):
|
267 |
+
"""
|
268 |
+
Generate the image condition based on the selected view's silhouette and text prompt.
|
269 |
+
:param position_imgs: Position images from different views.
|
270 |
+
:param normal_imgs: Normal images from different views.
|
271 |
+
:param mask_imgs: Mask images from different views.
|
272 |
+
:param w2c: World-to-camera transformation matrices.
|
273 |
+
:param text_prompt: The text prompt for image generation.
|
274 |
+
:param selected_view: The selected view for image generation.
|
275 |
+
:param seed: Random seed for image generation.
|
276 |
+
:param model: The image generation model type, supports "SDXL" and "FLUX".
|
277 |
+
:param progress: Progress callback for Gradio.
|
278 |
+
:param edge_refinement: Whether to apply edge refinement to smooth mask boundaries (default: True).
|
279 |
+
:return: Generated condition image and status message.
|
280 |
+
"""
|
281 |
+
|
282 |
+
progress(0, desc="Handling geometry information...")
|
283 |
+
silhouette = get_silhouette_image(position_imgs, normal_imgs, mask_imgs=mask_imgs, w2c=w2c, selected_view=selected_view)
|
284 |
+
depth_img = silhouette[0]
|
285 |
+
normal_img = silhouette[1]
|
286 |
+
mask = silhouette[2]
|
287 |
+
|
288 |
+
try:
|
289 |
+
if model == "SDXL":
|
290 |
+
condition = generate_sdxl_condition(depth_img, normal_img, text_prompt, mask, seed, edge_refinement=edge_refinement, progress=progress)
|
291 |
+
return condition, "SDXL condition generated successfully."
|
292 |
+
elif model == "FLUX":
|
293 |
+
# FLUX only supports depth control, not normal
|
294 |
+
condition = generate_flux_condition(depth_img, text_prompt, mask, seed, edge_refinement=edge_refinement, progress=progress)
|
295 |
+
return condition, "FLUX condition generated successfully (depth-only control)."
|
296 |
+
else:
|
297 |
+
raise ValueError(f"Unsupported image generation model type: {model}. Supported models: 'SDXL', 'FLUX'.")
|
298 |
+
finally:
|
299 |
+
torch.cuda.empty_cache()
|
utils/mesh_utils.py
ADDED
@@ -0,0 +1,500 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import tempfile
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
import trimesh
|
7 |
+
import xatlas
|
8 |
+
from PIL import Image
|
9 |
+
|
10 |
+
import gradio as gr
|
11 |
+
|
12 |
+
from .render_utils import (get_mvp_matrix, get_pure_texture, render_geo_map,
|
13 |
+
render_geo_views_tensor, render_views, setup_lights)
|
14 |
+
|
15 |
+
|
16 |
+
class Mesh:
|
17 |
+
def __init__(self, mesh_path=None, uv_tool="xAtlas", device='cuda', progress=gr.Progress()):
|
18 |
+
"""
|
19 |
+
Initialize the Mesh object with a mesh file path.
|
20 |
+
:param mesh_path: Path to the mesh file (e.g., .obj or .glb).
|
21 |
+
"""
|
22 |
+
self.device = device
|
23 |
+
if mesh_path is not None:
|
24 |
+
# Initialize _parts dictionary to store all parts
|
25 |
+
self._parts = {}
|
26 |
+
|
27 |
+
if mesh_path.endswith('.obj'):
|
28 |
+
progress(0., f"Loading mesh in .obj format...")
|
29 |
+
mesh_data = trimesh.load(mesh_path, process=False)
|
30 |
+
|
31 |
+
# Check if it's a mesh list (multi-part obj)
|
32 |
+
if isinstance(mesh_data, list):
|
33 |
+
progress(0.1, f"Handling part list...")
|
34 |
+
for i, mesh_part in enumerate(mesh_data):
|
35 |
+
self._add_part_to_parts(f"part_{i}", mesh_part)
|
36 |
+
# Check if it's a Scene (another multi-part format)
|
37 |
+
elif isinstance(mesh_data, trimesh.Scene):
|
38 |
+
progress(0.1, f"Handling Scenes...")
|
39 |
+
geometry = mesh_data.geometry
|
40 |
+
if len(geometry) > 0:
|
41 |
+
for key, mesh_part in geometry.items():
|
42 |
+
self._add_part_to_parts(key, mesh_part)
|
43 |
+
else:
|
44 |
+
raise ValueError("Empty scene, no mesh data found.")
|
45 |
+
else:
|
46 |
+
# Single part obj
|
47 |
+
progress(0.1, f"Handling single part...")
|
48 |
+
self._add_part_to_parts("part_0", mesh_data)
|
49 |
+
|
50 |
+
elif mesh_path.endswith('.glb'):
|
51 |
+
progress(0., f"Loading mesh in .glb format...")
|
52 |
+
mesh_loaded = trimesh.load(mesh_path)
|
53 |
+
|
54 |
+
# Check if it's a Scene (multi-part glb)
|
55 |
+
if isinstance(mesh_loaded, trimesh.Scene):
|
56 |
+
progress(0.1, f"Handling Scenes...")
|
57 |
+
geometry = mesh_loaded.geometry
|
58 |
+
if len(geometry) > 0:
|
59 |
+
for key, mesh_part in geometry.items():
|
60 |
+
self._add_part_to_parts(key, mesh_part)
|
61 |
+
else:
|
62 |
+
raise ValueError("Empty scene, no mesh data found.")
|
63 |
+
else:
|
64 |
+
# Single part glb
|
65 |
+
progress(0.1, f"Handling single part...")
|
66 |
+
self._add_part_to_parts("part_0", mesh_loaded)
|
67 |
+
else:
|
68 |
+
raise ValueError(f"Unsupported file format: {mesh_path}")
|
69 |
+
|
70 |
+
# Automatically merge all parts during initialization
|
71 |
+
progress(0.2, f"Merging if the mesh have multiple parts.")
|
72 |
+
self._merge_parts_internal()
|
73 |
+
else:
|
74 |
+
raise ValueError("Mesh path cannot be None.")
|
75 |
+
self.to(self.device) # Move to the specified device
|
76 |
+
|
77 |
+
# Initialize transformation flags
|
78 |
+
self._upside_down_applied = False
|
79 |
+
|
80 |
+
# UV parameterization
|
81 |
+
if self.has_multi_parts or not self.has_uv:
|
82 |
+
progress(0.4, f"Using {uv_tool} for UV parameterization. It may take quite a while (several minutes), if there are many faces. We STRONLY recommend using a mesh with UV parameterization.")
|
83 |
+
if uv_tool == "xAtlas":
|
84 |
+
self.uv_xatlas_mapping() # Use default parameters
|
85 |
+
elif uv_tool == "UVAtlas":
|
86 |
+
raise NotImplementedError("UVAtlas parameterization is not implemented yet.")
|
87 |
+
else:
|
88 |
+
raise ValueError("Unsupported UV parameterization tool.")
|
89 |
+
print("UV parameterization completed.")
|
90 |
+
else:
|
91 |
+
progress(0.4, f"The model has SINGLE UV parameterization, no need to reparameterize.")
|
92 |
+
self._vmapping = None # No vmapping needed when not reparameterizing
|
93 |
+
|
94 |
+
def to(self, device):
|
95 |
+
"""
|
96 |
+
Move the mesh data to the specified device.
|
97 |
+
:param device: The target device (e.g., 'cuda' or 'cpu').
|
98 |
+
"""
|
99 |
+
self._v_pos = self._v_pos.to(device)
|
100 |
+
self._t_pos_idx = self._t_pos_idx.to(device)
|
101 |
+
if self._v_tex is not None:
|
102 |
+
self._v_tex = self._v_tex.to(device)
|
103 |
+
self._t_tex_idx = self._t_tex_idx.to(device)
|
104 |
+
if hasattr(self, '_vmapping') and self._vmapping is not None:
|
105 |
+
self._vmapping = self._vmapping.to(device)
|
106 |
+
self._v_normal = self._v_normal.to(device)
|
107 |
+
|
108 |
+
@property
|
109 |
+
def has_multi_parts(self):
|
110 |
+
"""
|
111 |
+
Check if the mesh has multiple parts.
|
112 |
+
:return: Boolean indicating whether the mesh has multiple parts.
|
113 |
+
"""
|
114 |
+
# If _parts is None, it means already merged, not multi-part
|
115 |
+
if self._parts is None:
|
116 |
+
return False
|
117 |
+
return len(self._parts) > 1
|
118 |
+
|
119 |
+
@property
|
120 |
+
def v_pos(self):
|
121 |
+
"""Vertex positions property."""
|
122 |
+
return self._v_pos
|
123 |
+
|
124 |
+
@v_pos.setter
|
125 |
+
def v_pos(self, value):
|
126 |
+
self._v_pos = value
|
127 |
+
|
128 |
+
@property
|
129 |
+
def t_pos_idx(self):
|
130 |
+
"""Triangle position indices property."""
|
131 |
+
return self._t_pos_idx
|
132 |
+
|
133 |
+
@t_pos_idx.setter
|
134 |
+
def t_pos_idx(self, value):
|
135 |
+
self._t_pos_idx = value
|
136 |
+
|
137 |
+
@property
|
138 |
+
def v_tex(self):
|
139 |
+
"""Vertex texture coordinates property."""
|
140 |
+
return self._v_tex
|
141 |
+
|
142 |
+
@v_tex.setter
|
143 |
+
def v_tex(self, value):
|
144 |
+
self._v_tex = value
|
145 |
+
|
146 |
+
@property
|
147 |
+
def t_tex_idx(self):
|
148 |
+
"""Triangle texture indices property."""
|
149 |
+
return self._t_tex_idx
|
150 |
+
|
151 |
+
@t_tex_idx.setter
|
152 |
+
def t_tex_idx(self, value):
|
153 |
+
self._t_tex_idx = value
|
154 |
+
|
155 |
+
@property
|
156 |
+
def v_normal(self):
|
157 |
+
"""Vertex normals property."""
|
158 |
+
return self._v_normal
|
159 |
+
|
160 |
+
@v_normal.setter
|
161 |
+
def v_normal(self, value):
|
162 |
+
self._v_normal = value
|
163 |
+
|
164 |
+
@property
|
165 |
+
def has_uv(self):
|
166 |
+
"""
|
167 |
+
Check if the mesh has a valid UV mapping.
|
168 |
+
:return: Boolean indicating whether the mesh has UV mapping.
|
169 |
+
"""
|
170 |
+
return self.v_tex is not None
|
171 |
+
|
172 |
+
def uv_xatlas_mapping(self, xatlas_chart_options: dict = {}, xatlas_pack_options: dict = {}):
|
173 |
+
# Merged mesh, directly add_mesh as a whole
|
174 |
+
atlas = xatlas.Atlas()
|
175 |
+
v_pos_np = self.v_pos.detach().cpu().numpy()
|
176 |
+
t_pos_idx_np = self.t_pos_idx.cpu().numpy()
|
177 |
+
atlas.add_mesh(v_pos_np, t_pos_idx_np)
|
178 |
+
|
179 |
+
# Set reasonable pack parameters to avoid overlap
|
180 |
+
co = xatlas.ChartOptions()
|
181 |
+
po = xatlas.PackOptions()
|
182 |
+
# Recommended default parameters
|
183 |
+
if 'resolution' not in xatlas_pack_options:
|
184 |
+
po.resolution = 1024 # or larger
|
185 |
+
if 'padding' not in xatlas_pack_options:
|
186 |
+
po.padding = 2
|
187 |
+
for k, v in xatlas_chart_options.items():
|
188 |
+
setattr(co, k, v)
|
189 |
+
for k, v in xatlas_pack_options.items():
|
190 |
+
setattr(po, k, v)
|
191 |
+
atlas.generate(co, po)
|
192 |
+
|
193 |
+
# Get unpacked data
|
194 |
+
vmapping, indices, uvs = atlas.get_mesh(0)
|
195 |
+
# vmapping: new UV vertex -> original mesh vertex
|
196 |
+
# indices: new triangle face indices (based on new UV vertices)
|
197 |
+
# uvs: new UV vertex coordinates
|
198 |
+
device = self.v_pos.device
|
199 |
+
vmapping = torch.from_numpy(vmapping.astype(np.uint64, casting="same_kind").view(np.int64)).to(device).long()
|
200 |
+
uvs = torch.from_numpy(uvs).to(device).float()
|
201 |
+
indices = torch.from_numpy(indices.astype(np.uint64, casting="same_kind").view(np.int64)).to(device).long()
|
202 |
+
|
203 |
+
self.v_tex = uvs # new UV vertices
|
204 |
+
self.t_tex_idx = indices # new triangle face indices (based on UV vertices)
|
205 |
+
self._vmapping = vmapping # save UV vertex to original vertex mapping for export
|
206 |
+
|
207 |
+
def normalize(self):
|
208 |
+
"""
|
209 |
+
Normalize mesh vertices to [-1, 1] range.
|
210 |
+
"""
|
211 |
+
vertices = self.v_pos
|
212 |
+
bounding_box_max = vertices.max(0)[0]
|
213 |
+
bounding_box_min = vertices.min(0)[0]
|
214 |
+
mesh_scale = 2.0 # Scale to [-1, 1]
|
215 |
+
scale = mesh_scale / ((bounding_box_max - bounding_box_min).max() + 1e-6)
|
216 |
+
center_offset = (bounding_box_max + bounding_box_min) * 0.5
|
217 |
+
self.v_pos = (vertices - center_offset) * scale
|
218 |
+
|
219 |
+
def vertex_transform(self):
|
220 |
+
"""
|
221 |
+
Apply coordinate transformation to mesh vertices and normals.
|
222 |
+
"""
|
223 |
+
# Transform normals
|
224 |
+
pre_normals = self.v_normal
|
225 |
+
normals = torch.clone(pre_normals)
|
226 |
+
normals[:, 1] = -pre_normals[:, 2] # -z --> y
|
227 |
+
normals[:, 2] = pre_normals[:, 1] # y --> z
|
228 |
+
|
229 |
+
# Transform vertices
|
230 |
+
pre_vertices = self.v_pos
|
231 |
+
vertices = torch.clone(pre_vertices)
|
232 |
+
vertices[:, 1] = -pre_vertices[:, 2] # -z --> y
|
233 |
+
vertices[:, 2] = pre_vertices[:, 1] # y --> z
|
234 |
+
|
235 |
+
# Update mesh
|
236 |
+
self.v_normal = normals
|
237 |
+
self.v_pos = vertices
|
238 |
+
|
239 |
+
def vertex_transform_y2x(self):
|
240 |
+
"""
|
241 |
+
Apply coordinate transformation to mesh vertices and normals.
|
242 |
+
"""
|
243 |
+
# Transform normals
|
244 |
+
pre_normals = self.v_normal
|
245 |
+
normals = torch.clone(pre_normals)
|
246 |
+
normals[:, 1] = -pre_normals[:, 0] # -x --> y
|
247 |
+
normals[:, 0] = pre_normals[:, 1] # y --> x
|
248 |
+
|
249 |
+
# Transform vertices
|
250 |
+
pre_vertices = self.v_pos
|
251 |
+
vertices = torch.clone(pre_vertices)
|
252 |
+
vertices[:, 1] = -pre_vertices[:, 0] # -z --> y
|
253 |
+
vertices[:, 0] = pre_vertices[:, 1] # y --> z
|
254 |
+
|
255 |
+
# 更新网格
|
256 |
+
self.v_normal = normals
|
257 |
+
self.v_pos = vertices
|
258 |
+
|
259 |
+
def vertex_transform_z2x(self):
|
260 |
+
"""
|
261 |
+
Apply coordinate transformation to mesh vertices and normals.
|
262 |
+
"""
|
263 |
+
# 变换法向量
|
264 |
+
pre_normals = self.v_normal
|
265 |
+
normals = torch.clone(pre_normals)
|
266 |
+
normals[:, 2] = -pre_normals[:, 0] # -x --> z
|
267 |
+
normals[:, 0] = pre_normals[:, 2] # z --> x
|
268 |
+
|
269 |
+
# 变换顶点
|
270 |
+
pre_vertices = self.v_pos
|
271 |
+
vertices = torch.clone(pre_vertices)
|
272 |
+
vertices[:, 2] = -pre_vertices[:, 0] # -z --> y
|
273 |
+
vertices[:, 0] = pre_vertices[:, 2] # y --> z
|
274 |
+
|
275 |
+
# 更新网格
|
276 |
+
self.v_normal = normals
|
277 |
+
self.v_pos = vertices
|
278 |
+
|
279 |
+
def vertex_transform_upsidedown(self):
|
280 |
+
"""
|
281 |
+
Apply upside-down transformation to mesh vertices and normals.
|
282 |
+
"""
|
283 |
+
# 变换法向量
|
284 |
+
pre_normals = self.v_normal
|
285 |
+
normals = torch.clone(pre_normals)
|
286 |
+
normals[:, 2] = -pre_normals[:, 2]
|
287 |
+
|
288 |
+
# 变换顶点
|
289 |
+
pre_vertices = self.v_pos
|
290 |
+
vertices = torch.clone(pre_vertices)
|
291 |
+
vertices[:, 2] = -pre_vertices[:, 2]
|
292 |
+
|
293 |
+
# 更新网格
|
294 |
+
self.v_normal = normals
|
295 |
+
self.v_pos = vertices
|
296 |
+
# self.t_pos_idx = faces
|
297 |
+
|
298 |
+
# 标记已应用上下翻转变换
|
299 |
+
self._upside_down_applied = True
|
300 |
+
|
301 |
+
def _add_part_to_parts(self, key, mesh_part):
|
302 |
+
"""
|
303 |
+
将单个mesh部分添加到_parts字典中
|
304 |
+
:param key: 部分的键名
|
305 |
+
:param mesh_part: trimesh对象
|
306 |
+
"""
|
307 |
+
# exclude PointCloud parts and empty parts
|
308 |
+
if hasattr(mesh_part, 'vertices') and hasattr(mesh_part, 'faces') and len(mesh_part.vertices) > 0 and len(mesh_part.faces) > 0:
|
309 |
+
raw_uv = getattr(mesh_part.visual, 'uv', None)
|
310 |
+
processed_v_tex = None
|
311 |
+
processed_t_tex_idx = None
|
312 |
+
|
313 |
+
# 仅当UV数据存在且不为空时才处理
|
314 |
+
if raw_uv is not None and np.asarray(raw_uv).size > 0 and np.asarray(raw_uv).shape[0] > 0:
|
315 |
+
processed_v_tex = torch.tensor(raw_uv, dtype=torch.float32)
|
316 |
+
# 假设当源数据提供UV时,t_tex_idx 与 t_pos_idx 使用相同的面索引
|
317 |
+
# trimesh 通常提供每个顶点的UV
|
318 |
+
processed_t_tex_idx = torch.tensor(mesh_part.faces, dtype=torch.int32)
|
319 |
+
|
320 |
+
self._parts[key] = {
|
321 |
+
'v_pos': torch.tensor(mesh_part.vertices, dtype=torch.float32),
|
322 |
+
't_pos_idx': torch.tensor(mesh_part.faces, dtype=torch.int32),
|
323 |
+
'v_tex': processed_v_tex,
|
324 |
+
't_tex_idx': processed_t_tex_idx,
|
325 |
+
'v_normal': torch.tensor(mesh_part.vertex_normals, dtype=torch.float32)
|
326 |
+
}
|
327 |
+
|
328 |
+
def _merge_parts_internal(self):
|
329 |
+
"""
|
330 |
+
内部使用的合并函数,在初始化时自动调用
|
331 |
+
将_parts中的所有部分合并为单一的mesh表示
|
332 |
+
"""
|
333 |
+
# 如果没有部分或只有一个部分,简化处理
|
334 |
+
if not self._parts:
|
335 |
+
raise ValueError("No mesh parts.")
|
336 |
+
elif len(self._parts) == 1:
|
337 |
+
key = next(iter(self._parts))
|
338 |
+
part = self._parts[key]
|
339 |
+
self._v_pos = part['v_pos']
|
340 |
+
self._t_pos_idx = part['t_pos_idx']
|
341 |
+
self._v_tex = part['v_tex']
|
342 |
+
self._t_tex_idx = part['t_tex_idx']
|
343 |
+
self._v_normal = part['v_normal']
|
344 |
+
self._parts = None # 清理_parts字典,释放内存
|
345 |
+
return
|
346 |
+
|
347 |
+
# 初始化合并后的数据
|
348 |
+
vertices = []
|
349 |
+
faces = []
|
350 |
+
normals = []
|
351 |
+
|
352 |
+
# Record vertex count for each part, used to adjust face indices
|
353 |
+
v_count = 0
|
354 |
+
|
355 |
+
# Iterate through all parts
|
356 |
+
for key, part in self._parts.items():
|
357 |
+
# Add vertices
|
358 |
+
vertices.append(part['v_pos'])
|
359 |
+
|
360 |
+
# Adjust face indices and add
|
361 |
+
if len(faces) > 0:
|
362 |
+
adjusted_faces = part['t_pos_idx'] + v_count
|
363 |
+
faces.append(adjusted_faces)
|
364 |
+
else:
|
365 |
+
faces.append(part['t_pos_idx'])
|
366 |
+
|
367 |
+
# Add normals
|
368 |
+
normals.append(part['v_normal'])
|
369 |
+
|
370 |
+
# Update vertex count
|
371 |
+
v_count += part['v_pos'].shape[0]
|
372 |
+
|
373 |
+
self._parts = None # Clear _parts dictionary to free memory
|
374 |
+
|
375 |
+
# Merge all data
|
376 |
+
self._v_pos = torch.cat(vertices, dim=0)
|
377 |
+
self._t_pos_idx = torch.cat(faces, dim=0)
|
378 |
+
self._v_normal = torch.cat(normals, dim=0)
|
379 |
+
self._v_tex = None # multi-parts mesh must be reparameterized
|
380 |
+
self._t_tex_idx = None # multi-parts mesh must be reparameterized
|
381 |
+
self._vmapping = None # multi-parts mesh must be reparameterized
|
382 |
+
|
383 |
+
@classmethod
|
384 |
+
def export(cls, mesh, save_path=None, texture_map: Image.Image = None):
|
385 |
+
"""
|
386 |
+
Exports the mesh to a GLB file.
|
387 |
+
:param mesh: Mesh instance to export
|
388 |
+
:param save_path: Optional path to save the GLB file. If None, a temporary file will be created.
|
389 |
+
:param texture_map: Optional PIL.Image to use as the texture. If None, a default texture will be used.
|
390 |
+
:return: Path to the exported GLB file.
|
391 |
+
"""
|
392 |
+
# 由于传入的mesh一定是process过的,所以断言确保是单个part且有UV
|
393 |
+
assert not mesh.has_multi_parts, "Mesh should be processed and merged to single part"
|
394 |
+
assert mesh.has_uv, "Mesh should have UV mapping after processing"
|
395 |
+
|
396 |
+
if save_path is None:
|
397 |
+
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".glb")
|
398 |
+
save_path = temp_file.name
|
399 |
+
temp_file.close()
|
400 |
+
|
401 |
+
# 创建材质
|
402 |
+
if texture_map is not None:
|
403 |
+
if type(texture_map) is np.ndarray:
|
404 |
+
texture_map = Image.fromarray(texture_map)
|
405 |
+
assert type(texture_map) is Image.Image, "texture_map should be a PIL.Image"
|
406 |
+
texture_map = texture_map.transpose(Image.FLIP_TOP_BOTTOM).convert("RGB")
|
407 |
+
material = trimesh.visual.texture.SimpleMaterial(image=texture_map)
|
408 |
+
else:
|
409 |
+
default_texture = Image.new("RGB", (1024, 1024), (200, 200, 200))
|
410 |
+
material = trimesh.visual.texture.SimpleMaterial(image=default_texture)
|
411 |
+
|
412 |
+
# If vmapping exists (processed by xatlas), need to rebuild vertices to match UV layout
|
413 |
+
if hasattr(mesh, '_vmapping') and mesh._vmapping is not None:
|
414 |
+
# Use xatlas-generated UV layout to rebuild mesh
|
415 |
+
vertices = mesh.v_pos[mesh._vmapping].cpu().numpy()
|
416 |
+
faces = mesh.t_tex_idx.cpu().numpy()
|
417 |
+
uvs = mesh.v_tex.cpu().numpy()
|
418 |
+
else:
|
419 |
+
# Original UV mapping, directly use original vertices and faces
|
420 |
+
vertices = mesh.v_pos.cpu().numpy()
|
421 |
+
faces = mesh.t_pos_idx.cpu().numpy()
|
422 |
+
uvs = mesh.v_tex.cpu().numpy()
|
423 |
+
|
424 |
+
# If upside_down transformation was applied, need to apply face orientation correction
|
425 |
+
if hasattr(mesh, '_upside_down_applied') and mesh._upside_down_applied:
|
426 |
+
faces_corrected = faces.copy()
|
427 |
+
faces_corrected[:, [1, 2]] = faces[:, [2, 1]] # (0,1,2) -> (0,2,1)
|
428 |
+
faces = faces_corrected
|
429 |
+
|
430 |
+
# Apply inverse transformation to convert vertices from rendering coordinate system back to GLB coordinate system
|
431 |
+
# This is the inverse of vertex_transform:
|
432 |
+
# vertex_transform: y = -z, z = y
|
433 |
+
# inverse transformation: y = z, z = -y
|
434 |
+
vertices_export = vertices.copy()
|
435 |
+
vertices_export[:, 1] = vertices[:, 2] # z → y
|
436 |
+
vertices_export[:, 2] = -vertices[:, 1] # -y → z
|
437 |
+
|
438 |
+
# Create Trimesh object and set texture
|
439 |
+
mesh_export = trimesh.Trimesh(vertices=vertices_export, faces=faces, process=False)
|
440 |
+
mesh_export.visual = trimesh.visual.TextureVisuals(uv=uvs, material=material)
|
441 |
+
|
442 |
+
# Export GLB file
|
443 |
+
mesh_export.export(file_obj=save_path, file_type='glb')
|
444 |
+
|
445 |
+
return save_path
|
446 |
+
|
447 |
+
@classmethod
|
448 |
+
def process(cls, mesh_file, uv_tool="xAtlas", y2z=True, y2x=False, z2x=False, upside_down=False, img_size=(512, 512), uv_size=(1024, 1024), device='cuda', progress=gr.Progress()):
|
449 |
+
"""
|
450 |
+
Handle the mesh processing, which includes normalization, parts merging, and UV mapping.
|
451 |
+
Then render the untextured mesh from four views.
|
452 |
+
:param mesh_file: uploaded mesh file.
|
453 |
+
:param uv_tool: the UV parameterization tool, default is "xAtlas".
|
454 |
+
:return: rendered clay model images from four views.
|
455 |
+
"""
|
456 |
+
# load mesh (automatically merge multiple parts)
|
457 |
+
mesh: Mesh = cls(mesh_file, uv_tool, device, progress=progress)
|
458 |
+
|
459 |
+
progress(0.7, f"Handling transformation and normalization...")
|
460 |
+
# normalize mesh
|
461 |
+
if y2z:
|
462 |
+
mesh.vertex_transform() # transform vertices and normals
|
463 |
+
if y2x:
|
464 |
+
mesh.vertex_transform_y2x()
|
465 |
+
if z2x:
|
466 |
+
mesh.vertex_transform_z2x()
|
467 |
+
if upside_down:
|
468 |
+
mesh.vertex_transform_upsidedown()
|
469 |
+
mesh.normalize()
|
470 |
+
|
471 |
+
# render preparation
|
472 |
+
texture = get_pure_texture(uv_size).to(device) # tensor of shape (3, height, width)
|
473 |
+
# lights = setup_lights()
|
474 |
+
lights = None
|
475 |
+
mvp_matrix, w2c = get_mvp_matrix(mesh)
|
476 |
+
mvp_matrix = mvp_matrix.to(device)
|
477 |
+
w2c = w2c.to(device)
|
478 |
+
|
479 |
+
# render untextured mesh from four views
|
480 |
+
# images = render_views(mesh, texture, mvp_matrix, lights, img_size) # PIL.Image
|
481 |
+
progress(0.8, f"Rendering clay model views...")
|
482 |
+
print(f"Rendering geometry views...")
|
483 |
+
position_images, normal_images, mask_images = render_geo_views_tensor(mesh, mvp_matrix, img_size) # torch.Tensor # [batch_size, height, width, 3]
|
484 |
+
progress(0.9, f"Rendering geometry maps...")
|
485 |
+
print(f"Rendering geometry maps...")
|
486 |
+
position_map, normal_map = render_geo_map(mesh)
|
487 |
+
|
488 |
+
progress(1, f"Mesh processing completed.")
|
489 |
+
return position_map, normal_map, position_images, normal_images, mask_images.squeeze(-1), w2c, mesh, mvp_matrix, "Mesh processing completed."
|
490 |
+
|
491 |
+
|
492 |
+
if __name__ == '__main__':
|
493 |
+
glb_path = "/mnt/pfs/users/yuanze/projects/clean_seqtex/gradio/examples/multi_parts.glb"
|
494 |
+
position_map, normal_map, position_images, normal_images, w2c = Mesh.process(glb_path)
|
495 |
+
position_map.save("position_map.png")
|
496 |
+
normal_map.save("normal_map.png")
|
497 |
+
|
498 |
+
# 将 [-1, 1] 范围的normal_images save PIL
|
499 |
+
# normal_images = rearrange(normal_images, "B H W C -> B C H W")
|
500 |
+
# save_image(normal_images, "normal_images.png", normalize=True, value_range=(-1, 1))
|
utils/pipeline_controlnet_union_sd_xl.py
ADDED
@@ -0,0 +1,1397 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
|
16 |
+
import inspect
|
17 |
+
import os
|
18 |
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
19 |
+
|
20 |
+
import numpy as np
|
21 |
+
import PIL.Image
|
22 |
+
import torch
|
23 |
+
import torch.nn.functional as F
|
24 |
+
import gradio as gr
|
25 |
+
from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer,CLIPImageProcessor,CLIPVisionModelWithProjection
|
26 |
+
|
27 |
+
from diffusers.utils.import_utils import is_invisible_watermark_available
|
28 |
+
|
29 |
+
from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
|
30 |
+
from diffusers.loaders import FromSingleFileMixin, StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin,IPAdapterMixin
|
31 |
+
from diffusers.models import AutoencoderKL, ControlNetModel, UNet2DConditionModel,ImageProjection
|
32 |
+
from .controlnet_union import ControlNetModel_Union
|
33 |
+
from diffusers.models.attention_processor import (
|
34 |
+
AttnProcessor2_0,
|
35 |
+
LoRAAttnProcessor2_0,
|
36 |
+
LoRAXFormersAttnProcessor,
|
37 |
+
XFormersAttnProcessor,
|
38 |
+
)
|
39 |
+
from diffusers.models.lora import adjust_lora_scale_text_encoder
|
40 |
+
from diffusers.schedulers import KarrasDiffusionSchedulers
|
41 |
+
from diffusers.utils import (
|
42 |
+
is_accelerate_available,
|
43 |
+
is_accelerate_version,
|
44 |
+
logging,
|
45 |
+
replace_example_docstring,
|
46 |
+
)
|
47 |
+
from diffusers.utils.torch_utils import is_compiled_module, randn_tensor
|
48 |
+
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
49 |
+
from diffusers.pipelines.stable_diffusion_xl import StableDiffusionXLPipelineOutput
|
50 |
+
|
51 |
+
if is_invisible_watermark_available():
|
52 |
+
from diffusers.pipelines.stable_diffusion_xl.watermark import StableDiffusionXLWatermarker
|
53 |
+
|
54 |
+
from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel
|
55 |
+
|
56 |
+
|
57 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
58 |
+
|
59 |
+
|
60 |
+
EXAMPLE_DOC_STRING = """
|
61 |
+
Examples:
|
62 |
+
```py
|
63 |
+
>>> # !pip install opencv-python transformers accelerate
|
64 |
+
>>> from diffusers import StableDiffusionXLControlNetPipeline, ControlNetModel, AutoencoderKL
|
65 |
+
>>> from diffusers.utils import load_image
|
66 |
+
>>> import numpy as np
|
67 |
+
>>> import torch
|
68 |
+
|
69 |
+
>>> import cv2
|
70 |
+
>>> from PIL import Image
|
71 |
+
|
72 |
+
>>> prompt = "aerial view, a futuristic research complex in a bright foggy jungle, hard lighting"
|
73 |
+
>>> negative_prompt = "low quality, bad quality, sketches"
|
74 |
+
|
75 |
+
>>> # download an image
|
76 |
+
>>> image = load_image(
|
77 |
+
... "https://hf.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/hf-logo.png"
|
78 |
+
... )
|
79 |
+
|
80 |
+
>>> # initialize the models and pipeline
|
81 |
+
>>> controlnet_conditioning_scale = 0.5 # recommended for good generalization
|
82 |
+
>>> controlnet = ControlNetModel.from_pretrained(
|
83 |
+
... "diffusers/controlnet-canny-sdxl-1.0", torch_dtype=torch.float16
|
84 |
+
... )
|
85 |
+
>>> vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)
|
86 |
+
>>> pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
|
87 |
+
... "stabilityai/stable-diffusion-xl-base-1.0", controlnet=controlnet, vae=vae, torch_dtype=torch.float16
|
88 |
+
... )
|
89 |
+
>>> pipe.enable_model_cpu_offload()
|
90 |
+
|
91 |
+
>>> # get canny image
|
92 |
+
>>> image = np.array(image)
|
93 |
+
>>> image = cv2.Canny(image, 100, 200)
|
94 |
+
>>> image = image[:, :, None]
|
95 |
+
>>> image = np.concatenate([image, image, image], axis=2)
|
96 |
+
>>> canny_image = Image.fromarray(image)
|
97 |
+
|
98 |
+
>>> # generate image
|
99 |
+
>>> image = pipe(
|
100 |
+
... prompt, controlnet_conditioning_scale=controlnet_conditioning_scale, image=canny_image
|
101 |
+
... ).images[0]
|
102 |
+
```
|
103 |
+
"""
|
104 |
+
|
105 |
+
|
106 |
+
class StableDiffusionXLControlNetUnionPipeline(
|
107 |
+
DiffusionPipeline, TextualInversionLoaderMixin, StableDiffusionXLLoraLoaderMixin, FromSingleFileMixin,IPAdapterMixin
|
108 |
+
):
|
109 |
+
r"""
|
110 |
+
Pipeline for text-to-image generation using Stable Diffusion XL with ControlNet guidance.
|
111 |
+
|
112 |
+
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
|
113 |
+
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
|
114 |
+
|
115 |
+
The pipeline also inherits the following loading methods:
|
116 |
+
- [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
|
117 |
+
- [`loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights
|
118 |
+
- [`loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
|
119 |
+
|
120 |
+
Args:
|
121 |
+
vae ([`AutoencoderKL`]):
|
122 |
+
Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
|
123 |
+
text_encoder ([`~transformers.CLIPTextModel`]):
|
124 |
+
Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
|
125 |
+
text_encoder_2 ([`~transformers.CLIPTextModelWithProjection`]):
|
126 |
+
Second frozen text-encoder
|
127 |
+
([laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k)).
|
128 |
+
tokenizer ([`~transformers.CLIPTokenizer`]):
|
129 |
+
A `CLIPTokenizer` to tokenize text.
|
130 |
+
tokenizer_2 ([`~transformers.CLIPTokenizer`]):
|
131 |
+
A `CLIPTokenizer` to tokenize text.
|
132 |
+
unet ([`UNet2DConditionModel`]):
|
133 |
+
A `UNet2DConditionModel` to denoise the encoded image latents.
|
134 |
+
controlnet ([`ControlNetModel`] or `List[ControlNetModel]`):
|
135 |
+
Provides additional conditioning to the `unet` during the denoising process. If you set multiple
|
136 |
+
ControlNets as a list, the outputs from each ControlNet are added together to create one combined
|
137 |
+
additional conditioning.
|
138 |
+
scheduler ([`SchedulerMixin`]):
|
139 |
+
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
|
140 |
+
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
|
141 |
+
force_zeros_for_empty_prompt (`bool`, *optional*, defaults to `"True"`):
|
142 |
+
Whether the negative prompt embeddings should always be set to 0. Also see the config of
|
143 |
+
`stabilityai/stable-diffusion-xl-base-1-0`.
|
144 |
+
add_watermarker (`bool`, *optional*):
|
145 |
+
Whether to use the [invisible_watermark](https://github.com/ShieldMnt/invisible-watermark/) library to
|
146 |
+
watermark output images. If not defined, it defaults to `True` if the package is installed; otherwise no
|
147 |
+
watermarker is used.
|
148 |
+
"""
|
149 |
+
model_cpu_offload_seq = (
|
150 |
+
"text_encoder->text_encoder_2->image_encoder->unet->vae" # leave controlnet out on purpose because it iterates with unet
|
151 |
+
)
|
152 |
+
|
153 |
+
def __init__(
|
154 |
+
self,
|
155 |
+
vae: AutoencoderKL,
|
156 |
+
text_encoder: CLIPTextModel,
|
157 |
+
text_encoder_2: CLIPTextModelWithProjection,
|
158 |
+
tokenizer: CLIPTokenizer,
|
159 |
+
tokenizer_2: CLIPTokenizer,
|
160 |
+
unet: UNet2DConditionModel,
|
161 |
+
controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel],
|
162 |
+
scheduler: KarrasDiffusionSchedulers,
|
163 |
+
feature_extractor: CLIPImageProcessor = None,
|
164 |
+
image_encoder: CLIPVisionModelWithProjection = None,
|
165 |
+
force_zeros_for_empty_prompt: bool = True,
|
166 |
+
add_watermarker: Optional[bool] = None,
|
167 |
+
):
|
168 |
+
super().__init__()
|
169 |
+
|
170 |
+
if isinstance(controlnet, (list, tuple)):
|
171 |
+
controlnet = MultiControlNetModel(controlnet)
|
172 |
+
|
173 |
+
self.register_modules(
|
174 |
+
vae=vae,
|
175 |
+
text_encoder=text_encoder,
|
176 |
+
text_encoder_2=text_encoder_2,
|
177 |
+
tokenizer=tokenizer,
|
178 |
+
tokenizer_2=tokenizer_2,
|
179 |
+
unet=unet,
|
180 |
+
controlnet=controlnet,
|
181 |
+
scheduler=scheduler,
|
182 |
+
feature_extractor=feature_extractor,
|
183 |
+
image_encoder=image_encoder,
|
184 |
+
)
|
185 |
+
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
186 |
+
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True)
|
187 |
+
self.control_image_processor = VaeImageProcessor(
|
188 |
+
vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False
|
189 |
+
)
|
190 |
+
add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available()
|
191 |
+
|
192 |
+
if add_watermarker:
|
193 |
+
self.watermark = StableDiffusionXLWatermarker()
|
194 |
+
else:
|
195 |
+
self.watermark = None
|
196 |
+
|
197 |
+
self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
|
198 |
+
|
199 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing
|
200 |
+
def enable_vae_slicing(self):
|
201 |
+
r"""
|
202 |
+
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
|
203 |
+
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
|
204 |
+
"""
|
205 |
+
self.vae.enable_slicing()
|
206 |
+
|
207 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing
|
208 |
+
def disable_vae_slicing(self):
|
209 |
+
r"""
|
210 |
+
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
|
211 |
+
computing decoding in one step.
|
212 |
+
"""
|
213 |
+
self.vae.disable_slicing()
|
214 |
+
|
215 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_tiling
|
216 |
+
def enable_vae_tiling(self):
|
217 |
+
r"""
|
218 |
+
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
|
219 |
+
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
|
220 |
+
processing larger images.
|
221 |
+
"""
|
222 |
+
self.vae.enable_tiling()
|
223 |
+
|
224 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling
|
225 |
+
def disable_vae_tiling(self):
|
226 |
+
r"""
|
227 |
+
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
|
228 |
+
computing decoding in one step.
|
229 |
+
"""
|
230 |
+
self.vae.disable_tiling()
|
231 |
+
|
232 |
+
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt
|
233 |
+
def encode_prompt(
|
234 |
+
self,
|
235 |
+
prompt: str,
|
236 |
+
prompt_2: Optional[str] = None,
|
237 |
+
device: Optional[torch.device] = None,
|
238 |
+
num_images_per_prompt: int = 1,
|
239 |
+
do_classifier_free_guidance: bool = True,
|
240 |
+
negative_prompt: Optional[str] = None,
|
241 |
+
negative_prompt_2: Optional[str] = None,
|
242 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
243 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
244 |
+
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
245 |
+
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
246 |
+
lora_scale: Optional[float] = None,
|
247 |
+
):
|
248 |
+
r"""
|
249 |
+
Encodes the prompt into text encoder hidden states.
|
250 |
+
|
251 |
+
Args:
|
252 |
+
prompt (`str` or `List[str]`, *optional*):
|
253 |
+
prompt to be encoded
|
254 |
+
prompt_2 (`str` or `List[str]`, *optional*):
|
255 |
+
The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
|
256 |
+
used in both text-encoders
|
257 |
+
device: (`torch.device`):
|
258 |
+
torch device
|
259 |
+
num_images_per_prompt (`int`):
|
260 |
+
number of images that should be generated per prompt
|
261 |
+
do_classifier_free_guidance (`bool`):
|
262 |
+
whether to use classifier free guidance or not
|
263 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
264 |
+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
265 |
+
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
266 |
+
less than `1`).
|
267 |
+
negative_prompt_2 (`str` or `List[str]`, *optional*):
|
268 |
+
The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
|
269 |
+
`text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
|
270 |
+
prompt_embeds (`torch.FloatTensor`, *optional*):
|
271 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
272 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
273 |
+
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
274 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
275 |
+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
276 |
+
argument.
|
277 |
+
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
|
278 |
+
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
|
279 |
+
If not provided, pooled text embeddings will be generated from `prompt` input argument.
|
280 |
+
negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
|
281 |
+
Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
282 |
+
weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
|
283 |
+
input argument.
|
284 |
+
lora_scale (`float`, *optional*):
|
285 |
+
A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
|
286 |
+
"""
|
287 |
+
device = device or self._execution_device
|
288 |
+
|
289 |
+
# set lora scale so that monkey patched LoRA
|
290 |
+
# function of text encoder can correctly access it
|
291 |
+
if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin,):
|
292 |
+
self._lora_scale = lora_scale
|
293 |
+
|
294 |
+
# dynamically adjust the LoRA scale
|
295 |
+
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
|
296 |
+
adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
|
297 |
+
|
298 |
+
if prompt is not None and isinstance(prompt, str):
|
299 |
+
batch_size = 1
|
300 |
+
elif prompt is not None and isinstance(prompt, list):
|
301 |
+
batch_size = len(prompt)
|
302 |
+
else:
|
303 |
+
batch_size = prompt_embeds.shape[0]
|
304 |
+
|
305 |
+
# Define tokenizers and text encoders
|
306 |
+
tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2]
|
307 |
+
text_encoders = (
|
308 |
+
[self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2]
|
309 |
+
)
|
310 |
+
|
311 |
+
if prompt_embeds is None:
|
312 |
+
prompt_2 = prompt_2 or prompt
|
313 |
+
# textual inversion: procecss multi-vector tokens if necessary
|
314 |
+
prompt_embeds_list = []
|
315 |
+
prompts = [prompt, prompt_2]
|
316 |
+
for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders):
|
317 |
+
if isinstance(self, TextualInversionLoaderMixin):
|
318 |
+
prompt = self.maybe_convert_prompt(prompt, tokenizer)
|
319 |
+
|
320 |
+
text_inputs = tokenizer(
|
321 |
+
prompt,
|
322 |
+
padding="max_length",
|
323 |
+
max_length=tokenizer.model_max_length,
|
324 |
+
truncation=True,
|
325 |
+
return_tensors="pt",
|
326 |
+
)
|
327 |
+
|
328 |
+
text_input_ids = text_inputs.input_ids
|
329 |
+
untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
330 |
+
|
331 |
+
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
|
332 |
+
text_input_ids, untruncated_ids
|
333 |
+
):
|
334 |
+
removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1])
|
335 |
+
logger.warning(
|
336 |
+
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
337 |
+
f" {tokenizer.model_max_length} tokens: {removed_text}"
|
338 |
+
)
|
339 |
+
|
340 |
+
prompt_embeds = text_encoder(
|
341 |
+
text_input_ids.to(device),
|
342 |
+
output_hidden_states=True,
|
343 |
+
)
|
344 |
+
|
345 |
+
# We are only ALWAYS interested in the pooled output of the final text encoder
|
346 |
+
pooled_prompt_embeds = prompt_embeds[0]
|
347 |
+
prompt_embeds = prompt_embeds.hidden_states[-2]
|
348 |
+
|
349 |
+
prompt_embeds_list.append(prompt_embeds)
|
350 |
+
|
351 |
+
prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
|
352 |
+
|
353 |
+
# get unconditional embeddings for classifier free guidance
|
354 |
+
zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt
|
355 |
+
if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt:
|
356 |
+
negative_prompt_embeds = torch.zeros_like(prompt_embeds)
|
357 |
+
negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds)
|
358 |
+
elif do_classifier_free_guidance and negative_prompt_embeds is None:
|
359 |
+
negative_prompt = negative_prompt or ""
|
360 |
+
negative_prompt_2 = negative_prompt_2 or negative_prompt
|
361 |
+
|
362 |
+
uncond_tokens: List[str]
|
363 |
+
if prompt is not None and type(prompt) is not type(negative_prompt):
|
364 |
+
raise TypeError(
|
365 |
+
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
366 |
+
f" {type(prompt)}."
|
367 |
+
)
|
368 |
+
elif isinstance(negative_prompt, str):
|
369 |
+
uncond_tokens = [negative_prompt, negative_prompt_2]
|
370 |
+
elif batch_size != len(negative_prompt):
|
371 |
+
raise ValueError(
|
372 |
+
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
373 |
+
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
374 |
+
" the batch size of `prompt`."
|
375 |
+
)
|
376 |
+
else:
|
377 |
+
uncond_tokens = [negative_prompt, negative_prompt_2]
|
378 |
+
|
379 |
+
negative_prompt_embeds_list = []
|
380 |
+
for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders):
|
381 |
+
if isinstance(self, TextualInversionLoaderMixin):
|
382 |
+
negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer)
|
383 |
+
|
384 |
+
max_length = prompt_embeds.shape[1]
|
385 |
+
uncond_input = tokenizer(
|
386 |
+
negative_prompt,
|
387 |
+
padding="max_length",
|
388 |
+
max_length=max_length,
|
389 |
+
truncation=True,
|
390 |
+
return_tensors="pt",
|
391 |
+
)
|
392 |
+
|
393 |
+
negative_prompt_embeds = text_encoder(
|
394 |
+
uncond_input.input_ids.to(device),
|
395 |
+
output_hidden_states=True,
|
396 |
+
)
|
397 |
+
# We are only ALWAYS interested in the pooled output of the final text encoder
|
398 |
+
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
|
399 |
+
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
|
400 |
+
|
401 |
+
negative_prompt_embeds_list.append(negative_prompt_embeds)
|
402 |
+
|
403 |
+
negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)
|
404 |
+
|
405 |
+
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
|
406 |
+
bs_embed, seq_len, _ = prompt_embeds.shape
|
407 |
+
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
408 |
+
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
409 |
+
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
|
410 |
+
|
411 |
+
if do_classifier_free_guidance:
|
412 |
+
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
413 |
+
seq_len = negative_prompt_embeds.shape[1]
|
414 |
+
negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
|
415 |
+
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
416 |
+
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
417 |
+
|
418 |
+
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
|
419 |
+
bs_embed * num_images_per_prompt, -1
|
420 |
+
)
|
421 |
+
if do_classifier_free_guidance:
|
422 |
+
negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
|
423 |
+
bs_embed * num_images_per_prompt, -1
|
424 |
+
)
|
425 |
+
|
426 |
+
return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
|
427 |
+
|
428 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
|
429 |
+
def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
|
430 |
+
dtype = next(self.image_encoder.parameters()).dtype
|
431 |
+
|
432 |
+
if not isinstance(image, torch.Tensor):
|
433 |
+
image = self.feature_extractor(image, return_tensors="pt").pixel_values
|
434 |
+
|
435 |
+
image = image.to(device=device, dtype=dtype)
|
436 |
+
if output_hidden_states:
|
437 |
+
image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
|
438 |
+
image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
|
439 |
+
uncond_image_enc_hidden_states = self.image_encoder(
|
440 |
+
torch.zeros_like(image), output_hidden_states=True
|
441 |
+
).hidden_states[-2]
|
442 |
+
uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
|
443 |
+
num_images_per_prompt, dim=0
|
444 |
+
)
|
445 |
+
return image_enc_hidden_states, uncond_image_enc_hidden_states
|
446 |
+
else:
|
447 |
+
image_embeds = self.image_encoder(image).image_embeds
|
448 |
+
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
|
449 |
+
uncond_image_embeds = torch.zeros_like(image_embeds)
|
450 |
+
|
451 |
+
return image_embeds, uncond_image_embeds
|
452 |
+
|
453 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
|
454 |
+
def prepare_ip_adapter_image_embeds(
|
455 |
+
self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance
|
456 |
+
):
|
457 |
+
image_embeds = []
|
458 |
+
if do_classifier_free_guidance:
|
459 |
+
negative_image_embeds = []
|
460 |
+
if ip_adapter_image_embeds is None:
|
461 |
+
if not isinstance(ip_adapter_image, list):
|
462 |
+
ip_adapter_image = [ip_adapter_image]
|
463 |
+
|
464 |
+
if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):
|
465 |
+
raise ValueError(
|
466 |
+
f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
|
467 |
+
)
|
468 |
+
|
469 |
+
for single_ip_adapter_image, image_proj_layer in zip(
|
470 |
+
ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
|
471 |
+
):
|
472 |
+
output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
|
473 |
+
single_image_embeds, single_negative_image_embeds = self.encode_image(
|
474 |
+
single_ip_adapter_image, device, 1, output_hidden_state
|
475 |
+
)
|
476 |
+
|
477 |
+
image_embeds.append(single_image_embeds[None, :])
|
478 |
+
if do_classifier_free_guidance:
|
479 |
+
negative_image_embeds.append(single_negative_image_embeds[None, :])
|
480 |
+
else:
|
481 |
+
for single_image_embeds in ip_adapter_image_embeds:
|
482 |
+
if do_classifier_free_guidance:
|
483 |
+
single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
|
484 |
+
negative_image_embeds.append(single_negative_image_embeds)
|
485 |
+
image_embeds.append(single_image_embeds)
|
486 |
+
|
487 |
+
ip_adapter_image_embeds = []
|
488 |
+
for i, single_image_embeds in enumerate(image_embeds):
|
489 |
+
single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)
|
490 |
+
if do_classifier_free_guidance:
|
491 |
+
single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0)
|
492 |
+
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0)
|
493 |
+
|
494 |
+
single_image_embeds = single_image_embeds.to(device=device)
|
495 |
+
ip_adapter_image_embeds.append(single_image_embeds)
|
496 |
+
|
497 |
+
return ip_adapter_image_embeds
|
498 |
+
|
499 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
500 |
+
def prepare_extra_step_kwargs(self, generator, eta):
|
501 |
+
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
502 |
+
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
503 |
+
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
504 |
+
# and should be between [0, 1]
|
505 |
+
|
506 |
+
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
507 |
+
extra_step_kwargs = {}
|
508 |
+
if accepts_eta:
|
509 |
+
extra_step_kwargs["eta"] = eta
|
510 |
+
|
511 |
+
# check if the scheduler accepts generator
|
512 |
+
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
513 |
+
if accepts_generator:
|
514 |
+
extra_step_kwargs["generator"] = generator
|
515 |
+
return extra_step_kwargs
|
516 |
+
|
517 |
+
def check_inputs(
|
518 |
+
self,
|
519 |
+
prompt,
|
520 |
+
prompt_2,
|
521 |
+
image,
|
522 |
+
callback_steps,
|
523 |
+
negative_prompt=None,
|
524 |
+
negative_prompt_2=None,
|
525 |
+
prompt_embeds=None,
|
526 |
+
negative_prompt_embeds=None,
|
527 |
+
pooled_prompt_embeds=None,
|
528 |
+
negative_pooled_prompt_embeds=None,
|
529 |
+
controlnet_conditioning_scale=1.0,
|
530 |
+
control_guidance_start=0.0,
|
531 |
+
control_guidance_end=1.0,
|
532 |
+
ip_adapter_image=None,
|
533 |
+
ip_adapter_image_embeds=None,
|
534 |
+
):
|
535 |
+
if (callback_steps is None) or (
|
536 |
+
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
|
537 |
+
):
|
538 |
+
raise ValueError(
|
539 |
+
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
|
540 |
+
f" {type(callback_steps)}."
|
541 |
+
)
|
542 |
+
|
543 |
+
if prompt is not None and prompt_embeds is not None:
|
544 |
+
raise ValueError(
|
545 |
+
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
546 |
+
" only forward one of the two."
|
547 |
+
)
|
548 |
+
elif prompt_2 is not None and prompt_embeds is not None:
|
549 |
+
raise ValueError(
|
550 |
+
f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
551 |
+
" only forward one of the two."
|
552 |
+
)
|
553 |
+
elif prompt is None and prompt_embeds is None:
|
554 |
+
raise ValueError(
|
555 |
+
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
556 |
+
)
|
557 |
+
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
558 |
+
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
559 |
+
elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
|
560 |
+
raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
|
561 |
+
|
562 |
+
if negative_prompt is not None and negative_prompt_embeds is not None:
|
563 |
+
raise ValueError(
|
564 |
+
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
565 |
+
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
566 |
+
)
|
567 |
+
elif negative_prompt_2 is not None and negative_prompt_embeds is not None:
|
568 |
+
raise ValueError(
|
569 |
+
f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:"
|
570 |
+
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
571 |
+
)
|
572 |
+
|
573 |
+
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
574 |
+
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
575 |
+
raise ValueError(
|
576 |
+
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
|
577 |
+
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
578 |
+
f" {negative_prompt_embeds.shape}."
|
579 |
+
)
|
580 |
+
|
581 |
+
if prompt_embeds is not None and pooled_prompt_embeds is None:
|
582 |
+
raise ValueError(
|
583 |
+
"If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
|
584 |
+
)
|
585 |
+
|
586 |
+
if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:
|
587 |
+
raise ValueError(
|
588 |
+
"If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
|
589 |
+
)
|
590 |
+
|
591 |
+
# `prompt` needs more sophisticated handling when there are multiple
|
592 |
+
# conditionings.
|
593 |
+
if isinstance(self.controlnet, MultiControlNetModel):
|
594 |
+
if isinstance(prompt, list):
|
595 |
+
logger.warning(
|
596 |
+
f"You have {len(self.controlnet.nets)} ControlNets and you have passed {len(prompt)}"
|
597 |
+
" prompts. The conditionings will be fixed across the prompts."
|
598 |
+
)
|
599 |
+
|
600 |
+
# Check `image`
|
601 |
+
is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance(
|
602 |
+
self.controlnet, torch._dynamo.eval_frame.OptimizedModule
|
603 |
+
)
|
604 |
+
if (
|
605 |
+
isinstance(self.controlnet, ControlNetModel)
|
606 |
+
or is_compiled
|
607 |
+
and isinstance(self.controlnet._orig_mod, ControlNetModel)
|
608 |
+
):
|
609 |
+
self.check_image(image, prompt, prompt_embeds)
|
610 |
+
elif (
|
611 |
+
isinstance(self.controlnet, ControlNetModel_Union)
|
612 |
+
or is_compiled
|
613 |
+
and isinstance(self.controlnet._orig_mod, ControlNetModel_Union)
|
614 |
+
):
|
615 |
+
self.check_image(image, prompt, prompt_embeds)
|
616 |
+
elif (
|
617 |
+
isinstance(self.controlnet, MultiControlNetModel)
|
618 |
+
or is_compiled
|
619 |
+
and isinstance(self.controlnet._orig_mod, MultiControlNetModel)
|
620 |
+
):
|
621 |
+
if not isinstance(image, list):
|
622 |
+
raise TypeError("For multiple controlnets: `image` must be type `list`")
|
623 |
+
|
624 |
+
# When `image` is a nested list:
|
625 |
+
# (e.g. [[canny_image_1, pose_image_1], [canny_image_2, pose_image_2]])
|
626 |
+
elif any(isinstance(i, list) for i in image):
|
627 |
+
raise ValueError("A single batch of multiple conditionings are supported at the moment.")
|
628 |
+
elif len(image) != len(self.controlnet.nets):
|
629 |
+
raise ValueError(
|
630 |
+
f"For multiple controlnets: `image` must have the same length as the number of controlnets, but got {len(image)} images and {len(self.controlnet.nets)} ControlNets."
|
631 |
+
)
|
632 |
+
|
633 |
+
for image_ in image:
|
634 |
+
self.check_image(image_, prompt, prompt_embeds)
|
635 |
+
else:
|
636 |
+
assert False
|
637 |
+
|
638 |
+
# Check `controlnet_conditioning_scale`
|
639 |
+
if (
|
640 |
+
isinstance(self.controlnet, ControlNetModel)
|
641 |
+
or is_compiled
|
642 |
+
and isinstance(self.controlnet._orig_mod, ControlNetModel)
|
643 |
+
):
|
644 |
+
if not isinstance(controlnet_conditioning_scale, float):
|
645 |
+
raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.")
|
646 |
+
|
647 |
+
elif (
|
648 |
+
isinstance(self.controlnet, ControlNetModel_Union)
|
649 |
+
or is_compiled
|
650 |
+
and isinstance(self.controlnet._orig_mod, ControlNetModel_Union)
|
651 |
+
):
|
652 |
+
if not isinstance(controlnet_conditioning_scale, float):
|
653 |
+
raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.")
|
654 |
+
|
655 |
+
elif (
|
656 |
+
isinstance(self.controlnet, MultiControlNetModel)
|
657 |
+
or is_compiled
|
658 |
+
and isinstance(self.controlnet._orig_mod, MultiControlNetModel)
|
659 |
+
):
|
660 |
+
if isinstance(controlnet_conditioning_scale, list):
|
661 |
+
if any(isinstance(i, list) for i in controlnet_conditioning_scale):
|
662 |
+
raise ValueError("A single batch of multiple conditionings are supported at the moment.")
|
663 |
+
elif isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len(
|
664 |
+
self.controlnet.nets
|
665 |
+
):
|
666 |
+
raise ValueError(
|
667 |
+
"For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have"
|
668 |
+
" the same length as the number of controlnets"
|
669 |
+
)
|
670 |
+
else:
|
671 |
+
assert False
|
672 |
+
|
673 |
+
if not isinstance(control_guidance_start, (tuple, list)):
|
674 |
+
control_guidance_start = [control_guidance_start]
|
675 |
+
|
676 |
+
if not isinstance(control_guidance_end, (tuple, list)):
|
677 |
+
control_guidance_end = [control_guidance_end]
|
678 |
+
|
679 |
+
if len(control_guidance_start) != len(control_guidance_end):
|
680 |
+
raise ValueError(
|
681 |
+
f"`control_guidance_start` has {len(control_guidance_start)} elements, but `control_guidance_end` has {len(control_guidance_end)} elements. Make sure to provide the same number of elements to each list."
|
682 |
+
)
|
683 |
+
|
684 |
+
if isinstance(self.controlnet, MultiControlNetModel):
|
685 |
+
if len(control_guidance_start) != len(self.controlnet.nets):
|
686 |
+
raise ValueError(
|
687 |
+
f"`control_guidance_start`: {control_guidance_start} has {len(control_guidance_start)} elements but there are {len(self.controlnet.nets)} controlnets available. Make sure to provide {len(self.controlnet.nets)}."
|
688 |
+
)
|
689 |
+
|
690 |
+
for start, end in zip(control_guidance_start, control_guidance_end):
|
691 |
+
if start >= end:
|
692 |
+
raise ValueError(
|
693 |
+
f"control guidance start: {start} cannot be larger or equal to control guidance end: {end}."
|
694 |
+
)
|
695 |
+
if start < 0.0:
|
696 |
+
raise ValueError(f"control guidance start: {start} can't be smaller than 0.")
|
697 |
+
if end > 1.0:
|
698 |
+
raise ValueError(f"control guidance end: {end} can't be larger than 1.0.")
|
699 |
+
|
700 |
+
if ip_adapter_image is not None and ip_adapter_image_embeds is not None:
|
701 |
+
raise ValueError(
|
702 |
+
"Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined."
|
703 |
+
)
|
704 |
+
|
705 |
+
if ip_adapter_image_embeds is not None:
|
706 |
+
if not isinstance(ip_adapter_image_embeds, list):
|
707 |
+
raise ValueError(
|
708 |
+
f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}"
|
709 |
+
)
|
710 |
+
elif ip_adapter_image_embeds[0].ndim not in [3, 4]:
|
711 |
+
raise ValueError(
|
712 |
+
f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D"
|
713 |
+
)
|
714 |
+
|
715 |
+
# Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.check_image
|
716 |
+
def check_image(self, image, prompt, prompt_embeds):
|
717 |
+
image_is_pil = isinstance(image, PIL.Image.Image)
|
718 |
+
image_is_tensor = isinstance(image, torch.Tensor)
|
719 |
+
image_is_np = isinstance(image, np.ndarray)
|
720 |
+
image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image)
|
721 |
+
image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor)
|
722 |
+
image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray)
|
723 |
+
|
724 |
+
if (
|
725 |
+
not image_is_pil
|
726 |
+
and not image_is_tensor
|
727 |
+
and not image_is_np
|
728 |
+
and not image_is_pil_list
|
729 |
+
and not image_is_tensor_list
|
730 |
+
and not image_is_np_list
|
731 |
+
):
|
732 |
+
raise TypeError(
|
733 |
+
f"image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is {type(image)}"
|
734 |
+
)
|
735 |
+
|
736 |
+
if image_is_pil:
|
737 |
+
image_batch_size = 1
|
738 |
+
else:
|
739 |
+
image_batch_size = len(image)
|
740 |
+
|
741 |
+
if prompt is not None and isinstance(prompt, str):
|
742 |
+
prompt_batch_size = 1
|
743 |
+
elif prompt is not None and isinstance(prompt, list):
|
744 |
+
prompt_batch_size = len(prompt)
|
745 |
+
elif prompt_embeds is not None:
|
746 |
+
prompt_batch_size = prompt_embeds.shape[0]
|
747 |
+
|
748 |
+
if image_batch_size != 1 and image_batch_size != prompt_batch_size:
|
749 |
+
raise ValueError(
|
750 |
+
f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}"
|
751 |
+
)
|
752 |
+
|
753 |
+
# Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.prepare_image
|
754 |
+
def prepare_image(
|
755 |
+
self,
|
756 |
+
image,
|
757 |
+
width,
|
758 |
+
height,
|
759 |
+
batch_size,
|
760 |
+
num_images_per_prompt,
|
761 |
+
device,
|
762 |
+
dtype,
|
763 |
+
do_classifier_free_guidance=False,
|
764 |
+
guess_mode=False,
|
765 |
+
):
|
766 |
+
image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32)
|
767 |
+
image_batch_size = image.shape[0]
|
768 |
+
|
769 |
+
if image_batch_size == 1:
|
770 |
+
repeat_by = batch_size
|
771 |
+
else:
|
772 |
+
# image batch size is the same as prompt batch size
|
773 |
+
repeat_by = num_images_per_prompt
|
774 |
+
|
775 |
+
image = image.repeat_interleave(repeat_by, dim=0)
|
776 |
+
|
777 |
+
image = image.to(device=device, dtype=dtype)
|
778 |
+
|
779 |
+
if do_classifier_free_guidance and not guess_mode:
|
780 |
+
image = torch.cat([image] * 2)
|
781 |
+
|
782 |
+
return image
|
783 |
+
|
784 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
|
785 |
+
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
|
786 |
+
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
|
787 |
+
if isinstance(generator, list) and len(generator) != batch_size:
|
788 |
+
raise ValueError(
|
789 |
+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
790 |
+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
791 |
+
)
|
792 |
+
|
793 |
+
if latents is None:
|
794 |
+
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
795 |
+
else:
|
796 |
+
latents = latents.to(device)
|
797 |
+
|
798 |
+
# scale the initial noise by the standard deviation required by the scheduler
|
799 |
+
latents = latents * self.scheduler.init_noise_sigma
|
800 |
+
return latents
|
801 |
+
|
802 |
+
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline._get_add_time_ids
|
803 |
+
def _get_add_time_ids(self, original_size, crops_coords_top_left, target_size, dtype):
|
804 |
+
add_time_ids = list(original_size + crops_coords_top_left + target_size)
|
805 |
+
|
806 |
+
passed_add_embed_dim = (
|
807 |
+
self.unet.config.addition_time_embed_dim * len(add_time_ids) + self.text_encoder_2.config.projection_dim
|
808 |
+
)
|
809 |
+
expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features
|
810 |
+
|
811 |
+
if expected_add_embed_dim != passed_add_embed_dim:
|
812 |
+
raise ValueError(
|
813 |
+
f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
|
814 |
+
)
|
815 |
+
|
816 |
+
add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
|
817 |
+
return add_time_ids
|
818 |
+
|
819 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae
|
820 |
+
def upcast_vae(self):
|
821 |
+
dtype = self.vae.dtype
|
822 |
+
self.vae.to(dtype=torch.float32)
|
823 |
+
use_torch_2_0_or_xformers = isinstance(
|
824 |
+
self.vae.decoder.mid_block.attentions[0].processor,
|
825 |
+
(
|
826 |
+
AttnProcessor2_0,
|
827 |
+
XFormersAttnProcessor,
|
828 |
+
LoRAXFormersAttnProcessor,
|
829 |
+
LoRAAttnProcessor2_0,
|
830 |
+
),
|
831 |
+
)
|
832 |
+
# if xformers or torch_2_0 is used attention block does not need
|
833 |
+
# to be in float32 which can save lots of memory
|
834 |
+
if use_torch_2_0_or_xformers:
|
835 |
+
self.vae.post_quant_conv.to(dtype)
|
836 |
+
self.vae.decoder.conv_in.to(dtype)
|
837 |
+
self.vae.decoder.mid_block.to(dtype)
|
838 |
+
|
839 |
+
@torch.no_grad()
|
840 |
+
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
841 |
+
def __call__(
|
842 |
+
self,
|
843 |
+
prompt: Union[str, List[str]] = None,
|
844 |
+
prompt_2: Optional[Union[str, List[str]]] = None,
|
845 |
+
image_list: PipelineImageInput = None,
|
846 |
+
height: Optional[int] = None,
|
847 |
+
width: Optional[int] = None,
|
848 |
+
num_inference_steps: int = 50,
|
849 |
+
guidance_scale: float = 5.0,
|
850 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
851 |
+
negative_prompt_2: Optional[Union[str, List[str]]] = None,
|
852 |
+
num_images_per_prompt: Optional[int] = 1,
|
853 |
+
eta: float = 0.0,
|
854 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
855 |
+
latents: Optional[torch.FloatTensor] = None,
|
856 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
857 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
858 |
+
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
859 |
+
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
860 |
+
ip_adapter_image: Optional[PipelineImageInput] = None,
|
861 |
+
ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
|
862 |
+
output_type: Optional[str] = "pil",
|
863 |
+
return_dict: bool = True,
|
864 |
+
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
865 |
+
callback_steps: int = 1,
|
866 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
867 |
+
controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
|
868 |
+
guess_mode: bool = False,
|
869 |
+
control_guidance_start: Union[float, List[float]] = 0.0,
|
870 |
+
control_guidance_end: Union[float, List[float]] = 1.0,
|
871 |
+
original_size: Tuple[int, int] = None,
|
872 |
+
crops_coords_top_left: Tuple[int, int] = (0, 0),
|
873 |
+
target_size: Tuple[int, int] = None,
|
874 |
+
negative_original_size: Optional[Tuple[int, int]] = None,
|
875 |
+
negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
|
876 |
+
negative_target_size: Optional[Tuple[int, int]] = None,
|
877 |
+
union_control = False,
|
878 |
+
union_control_type = None,
|
879 |
+
progress=gr.Progress(),
|
880 |
+
|
881 |
+
):
|
882 |
+
r"""
|
883 |
+
The call function to the pipeline for generation.
|
884 |
+
|
885 |
+
Args:
|
886 |
+
prompt (`str` or `List[str]`, *optional*):
|
887 |
+
The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
|
888 |
+
prompt_2 (`str` or `List[str]`, *optional*):
|
889 |
+
The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
|
890 |
+
used in both text-encoders.
|
891 |
+
image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,:
|
892 |
+
`List[List[torch.FloatTensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
|
893 |
+
The ControlNet input condition to provide guidance to the `unet` for generation. If the type is
|
894 |
+
specified as `torch.FloatTensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be
|
895 |
+
accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If height
|
896 |
+
and/or width are passed, `image` is resized accordingly. If multiple ControlNets are specified in
|
897 |
+
`init`, images must be passed as a list such that each element of the list can be correctly batched for
|
898 |
+
input to a single ControlNet.
|
899 |
+
height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
|
900 |
+
The height in pixels of the generated image. Anything below 512 pixels won't work well for
|
901 |
+
[stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
|
902 |
+
and checkpoints that are not specifically fine-tuned on low resolutions.
|
903 |
+
width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
|
904 |
+
The width in pixels of the generated image. Anything below 512 pixels won't work well for
|
905 |
+
[stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
|
906 |
+
and checkpoints that are not specifically fine-tuned on low resolutions.
|
907 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
908 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
909 |
+
expense of slower inference.
|
910 |
+
guidance_scale (`float`, *optional*, defaults to 5.0):
|
911 |
+
A higher guidance scale value encourages the model to generate images closely linked to the text
|
912 |
+
`prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
|
913 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
914 |
+
The prompt or prompts to guide what to not include in image generation. If not defined, you need to
|
915 |
+
pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
|
916 |
+
negative_prompt_2 (`str` or `List[str]`, *optional*):
|
917 |
+
The prompt or prompts to guide what to not include in image generation. This is sent to `tokenizer_2`
|
918 |
+
and `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders.
|
919 |
+
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
920 |
+
The number of images to generate per prompt.
|
921 |
+
eta (`float`, *optional*, defaults to 0.0):
|
922 |
+
Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
|
923 |
+
to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
|
924 |
+
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
925 |
+
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
|
926 |
+
generation deterministic.
|
927 |
+
latents (`torch.FloatTensor`, *optional*):
|
928 |
+
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
|
929 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
930 |
+
tensor is generated by sampling using the supplied random `generator`.
|
931 |
+
prompt_embeds (`torch.FloatTensor`, *optional*):
|
932 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
|
933 |
+
provided, text embeddings are generated from the `prompt` input argument.
|
934 |
+
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
935 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
|
936 |
+
not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
|
937 |
+
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
|
938 |
+
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
|
939 |
+
not provided, pooled text embeddings are generated from `prompt` input argument.
|
940 |
+
negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
|
941 |
+
Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs (prompt
|
942 |
+
weighting). If not provided, pooled `negative_prompt_embeds` are generated from `negative_prompt` input
|
943 |
+
argument.
|
944 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
945 |
+
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
|
946 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
947 |
+
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
|
948 |
+
plain tuple.
|
949 |
+
callback (`Callable`, *optional*):
|
950 |
+
A function that calls every `callback_steps` steps during inference. The function is called with the
|
951 |
+
following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
|
952 |
+
callback_steps (`int`, *optional*, defaults to 1):
|
953 |
+
The frequency at which the `callback` function is called. If not specified, the callback is called at
|
954 |
+
every step.
|
955 |
+
cross_attention_kwargs (`dict`, *optional*):
|
956 |
+
A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
|
957 |
+
[`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
958 |
+
controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0):
|
959 |
+
The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added
|
960 |
+
to the residual in the original `unet`. If multiple ControlNets are specified in `init`, you can set
|
961 |
+
the corresponding scale as a list.
|
962 |
+
guess_mode (`bool`, *optional*, defaults to `False`):
|
963 |
+
The ControlNet encoder tries to recognize the content of the input image even if you remove all
|
964 |
+
prompts. A `guidance_scale` value between 3.0 and 5.0 is recommended.
|
965 |
+
control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0):
|
966 |
+
The percentage of total steps at which the ControlNet starts applying.
|
967 |
+
control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0):
|
968 |
+
The percentage of total steps at which the ControlNet stops applying.
|
969 |
+
original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
|
970 |
+
If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
|
971 |
+
`original_size` defaults to `(width, height)` if not specified. Part of SDXL's micro-conditioning as
|
972 |
+
explained in section 2.2 of
|
973 |
+
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
|
974 |
+
crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
|
975 |
+
`crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position
|
976 |
+
`crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting
|
977 |
+
`crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of
|
978 |
+
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
|
979 |
+
target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
|
980 |
+
For most cases, `target_size` should be set to the desired height and width of the generated image. If
|
981 |
+
not specified it will default to `(width, height)`. Part of SDXL's micro-conditioning as explained in
|
982 |
+
section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
|
983 |
+
negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
|
984 |
+
To negatively condition the generation process based on a specific image resolution. Part of SDXL's
|
985 |
+
micro-conditioning as explained in section 2.2 of
|
986 |
+
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
|
987 |
+
information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
|
988 |
+
negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
|
989 |
+
To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's
|
990 |
+
micro-conditioning as explained in section 2.2 of
|
991 |
+
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
|
992 |
+
information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
|
993 |
+
negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
|
994 |
+
To negatively condition the generation process based on a target image resolution. It should be as same
|
995 |
+
as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of
|
996 |
+
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
|
997 |
+
information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
|
998 |
+
|
999 |
+
Examples:
|
1000 |
+
|
1001 |
+
Returns:
|
1002 |
+
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
|
1003 |
+
If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
|
1004 |
+
otherwise a `tuple` is returned containing the output images.
|
1005 |
+
"""
|
1006 |
+
controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
|
1007 |
+
|
1008 |
+
# align format for control guidance
|
1009 |
+
if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
|
1010 |
+
control_guidance_start = len(control_guidance_end) * [control_guidance_start]
|
1011 |
+
elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
|
1012 |
+
control_guidance_end = len(control_guidance_start) * [control_guidance_end]
|
1013 |
+
elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
|
1014 |
+
mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1
|
1015 |
+
control_guidance_start, control_guidance_end = mult * [control_guidance_start], mult * [
|
1016 |
+
control_guidance_end
|
1017 |
+
]
|
1018 |
+
|
1019 |
+
# 1. Check inputs. Raise error if not correct
|
1020 |
+
for image in image_list:
|
1021 |
+
if image:
|
1022 |
+
self.check_inputs(
|
1023 |
+
prompt,
|
1024 |
+
prompt_2,
|
1025 |
+
image,
|
1026 |
+
callback_steps,
|
1027 |
+
negative_prompt,
|
1028 |
+
negative_prompt_2,
|
1029 |
+
prompt_embeds,
|
1030 |
+
negative_prompt_embeds,
|
1031 |
+
pooled_prompt_embeds,
|
1032 |
+
negative_pooled_prompt_embeds,
|
1033 |
+
controlnet_conditioning_scale,
|
1034 |
+
control_guidance_start,
|
1035 |
+
control_guidance_end,
|
1036 |
+
ip_adapter_image,
|
1037 |
+
ip_adapter_image_embeds,
|
1038 |
+
)
|
1039 |
+
# 2. Define call parameters
|
1040 |
+
if prompt is not None and isinstance(prompt, str):
|
1041 |
+
batch_size = 1
|
1042 |
+
elif prompt is not None and isinstance(prompt, list):
|
1043 |
+
batch_size = len(prompt)
|
1044 |
+
else:
|
1045 |
+
batch_size = prompt_embeds.shape[0]
|
1046 |
+
|
1047 |
+
device = self._execution_device
|
1048 |
+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
1049 |
+
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
1050 |
+
# corresponds to doing no classifier free guidance.
|
1051 |
+
do_classifier_free_guidance = guidance_scale > 1.0
|
1052 |
+
|
1053 |
+
global_pool_conditions = (
|
1054 |
+
controlnet.config.global_pool_conditions
|
1055 |
+
)
|
1056 |
+
guess_mode = guess_mode or global_pool_conditions
|
1057 |
+
|
1058 |
+
# 3. Encode input prompt
|
1059 |
+
text_encoder_lora_scale = (
|
1060 |
+
cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
|
1061 |
+
)
|
1062 |
+
(
|
1063 |
+
prompt_embeds,
|
1064 |
+
negative_prompt_embeds,
|
1065 |
+
pooled_prompt_embeds,
|
1066 |
+
negative_pooled_prompt_embeds,
|
1067 |
+
) = self.encode_prompt(
|
1068 |
+
prompt,
|
1069 |
+
prompt_2,
|
1070 |
+
device,
|
1071 |
+
num_images_per_prompt,
|
1072 |
+
do_classifier_free_guidance,
|
1073 |
+
negative_prompt,
|
1074 |
+
negative_prompt_2,
|
1075 |
+
prompt_embeds=prompt_embeds,
|
1076 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
1077 |
+
pooled_prompt_embeds=pooled_prompt_embeds,
|
1078 |
+
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
|
1079 |
+
lora_scale=text_encoder_lora_scale,
|
1080 |
+
)
|
1081 |
+
|
1082 |
+
# 3.2 Encode ip_adapter_image
|
1083 |
+
if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
|
1084 |
+
image_embeds = self.prepare_ip_adapter_image_embeds(
|
1085 |
+
ip_adapter_image,
|
1086 |
+
ip_adapter_image_embeds,
|
1087 |
+
device,
|
1088 |
+
batch_size * num_images_per_prompt,
|
1089 |
+
do_classifier_free_guidance,
|
1090 |
+
)
|
1091 |
+
|
1092 |
+
# 4. Prepare image
|
1093 |
+
assert isinstance(controlnet, ControlNetModel_Union)
|
1094 |
+
|
1095 |
+
|
1096 |
+
for idx in range(len(image_list)):
|
1097 |
+
if image_list[idx]:
|
1098 |
+
image = self.prepare_image(
|
1099 |
+
image=image_list[idx],
|
1100 |
+
width=width,
|
1101 |
+
height=height,
|
1102 |
+
batch_size=batch_size * num_images_per_prompt,
|
1103 |
+
num_images_per_prompt=num_images_per_prompt,
|
1104 |
+
device=device,
|
1105 |
+
dtype=controlnet.dtype,
|
1106 |
+
do_classifier_free_guidance=do_classifier_free_guidance,
|
1107 |
+
guess_mode=guess_mode,
|
1108 |
+
)
|
1109 |
+
height, width = image.shape[-2:]
|
1110 |
+
image_list[idx] = image
|
1111 |
+
|
1112 |
+
|
1113 |
+
# 5. Prepare timesteps
|
1114 |
+
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
1115 |
+
timesteps = self.scheduler.timesteps
|
1116 |
+
|
1117 |
+
# 6. Prepare latent variables
|
1118 |
+
num_channels_latents = self.unet.config.in_channels
|
1119 |
+
latents = self.prepare_latents(
|
1120 |
+
batch_size * num_images_per_prompt,
|
1121 |
+
num_channels_latents,
|
1122 |
+
height,
|
1123 |
+
width,
|
1124 |
+
prompt_embeds.dtype,
|
1125 |
+
device,
|
1126 |
+
generator,
|
1127 |
+
latents,
|
1128 |
+
)
|
1129 |
+
|
1130 |
+
# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
1131 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
1132 |
+
|
1133 |
+
# 7.1 Create tensor stating which controlnets to keep
|
1134 |
+
controlnet_keep = []
|
1135 |
+
for i in range(len(timesteps)):
|
1136 |
+
keeps = [
|
1137 |
+
1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
|
1138 |
+
for s, e in zip(control_guidance_start, control_guidance_end)
|
1139 |
+
]
|
1140 |
+
controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) or isinstance(controlnet, ControlNetModel_Union) else keeps)
|
1141 |
+
|
1142 |
+
# 7.2 Prepare added time ids & embeddings
|
1143 |
+
for image in image_list:
|
1144 |
+
if isinstance(image, torch.Tensor):
|
1145 |
+
original_size = original_size or image.shape[-2:]
|
1146 |
+
|
1147 |
+
target_size = target_size or (height, width)
|
1148 |
+
# print(original_size)
|
1149 |
+
# print(target_size)
|
1150 |
+
add_text_embeds = pooled_prompt_embeds
|
1151 |
+
add_time_ids = self._get_add_time_ids(
|
1152 |
+
original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype
|
1153 |
+
)
|
1154 |
+
|
1155 |
+
if negative_original_size is not None and negative_target_size is not None:
|
1156 |
+
negative_add_time_ids = self._get_add_time_ids(
|
1157 |
+
negative_original_size,
|
1158 |
+
negative_crops_coords_top_left,
|
1159 |
+
negative_target_size,
|
1160 |
+
dtype=prompt_embeds.dtype,
|
1161 |
+
)
|
1162 |
+
else:
|
1163 |
+
negative_add_time_ids = add_time_ids
|
1164 |
+
|
1165 |
+
if do_classifier_free_guidance:
|
1166 |
+
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
1167 |
+
add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
|
1168 |
+
add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)
|
1169 |
+
|
1170 |
+
prompt_embeds = prompt_embeds.to(device)
|
1171 |
+
add_text_embeds = add_text_embeds.to(device)
|
1172 |
+
add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
|
1173 |
+
|
1174 |
+
# 8. Denoising loop
|
1175 |
+
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
1176 |
+
# with self.progress_bar(total=num_inference_steps) as progress_bar:
|
1177 |
+
# with progress.tqdm(range(num_inference_steps), desc="Diffusing...") as progress_bar:
|
1178 |
+
for i, t in progress.tqdm(enumerate(timesteps), desc="Diffusing..."):
|
1179 |
+
# expand the latents if we are doing classifier free guidance
|
1180 |
+
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
1181 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
1182 |
+
|
1183 |
+
added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids, \
|
1184 |
+
"control_type":union_control_type.reshape(1, -1).to(device, dtype=prompt_embeds.dtype).repeat(batch_size * num_images_per_prompt * 2, 1)}
|
1185 |
+
|
1186 |
+
# controlnet(s) inference
|
1187 |
+
if guess_mode and do_classifier_free_guidance:
|
1188 |
+
# Infer ControlNet only for the conditional batch.
|
1189 |
+
control_model_input = latents
|
1190 |
+
control_model_input = self.scheduler.scale_model_input(control_model_input, t)
|
1191 |
+
controlnet_prompt_embeds = prompt_embeds.chunk(2)[1]
|
1192 |
+
controlnet_added_cond_kwargs = {
|
1193 |
+
"text_embeds": add_text_embeds.chunk(2)[1],
|
1194 |
+
"time_ids": add_time_ids.chunk(2)[1],
|
1195 |
+
}
|
1196 |
+
else:
|
1197 |
+
control_model_input = latent_model_input
|
1198 |
+
controlnet_prompt_embeds = prompt_embeds
|
1199 |
+
controlnet_added_cond_kwargs = added_cond_kwargs
|
1200 |
+
|
1201 |
+
if isinstance(controlnet_keep[i], list):
|
1202 |
+
cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])]
|
1203 |
+
else:
|
1204 |
+
controlnet_cond_scale = controlnet_conditioning_scale
|
1205 |
+
if isinstance(controlnet_cond_scale, list):
|
1206 |
+
controlnet_cond_scale = controlnet_cond_scale[0]
|
1207 |
+
cond_scale = controlnet_cond_scale * controlnet_keep[i]
|
1208 |
+
|
1209 |
+
|
1210 |
+
# print(image.shape)
|
1211 |
+
if isinstance(controlnet, ControlNetModel_Union):
|
1212 |
+
down_block_res_samples, mid_block_res_sample = self.controlnet(
|
1213 |
+
control_model_input,
|
1214 |
+
t,
|
1215 |
+
encoder_hidden_states=controlnet_prompt_embeds,
|
1216 |
+
controlnet_cond_list=image_list,
|
1217 |
+
conditioning_scale=cond_scale,
|
1218 |
+
guess_mode=guess_mode,
|
1219 |
+
added_cond_kwargs=controlnet_added_cond_kwargs,
|
1220 |
+
return_dict=False,
|
1221 |
+
)
|
1222 |
+
|
1223 |
+
if guess_mode and do_classifier_free_guidance:
|
1224 |
+
# Infered ControlNet only for the conditional batch.
|
1225 |
+
# To apply the output of ControlNet to both the unconditional and conditional batches,
|
1226 |
+
# add 0 to the unconditional batch to keep it unchanged.
|
1227 |
+
down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples]
|
1228 |
+
mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample])
|
1229 |
+
|
1230 |
+
if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
|
1231 |
+
added_cond_kwargs["image_embeds"] = image_embeds
|
1232 |
+
# predict the noise residual
|
1233 |
+
noise_pred = self.unet(
|
1234 |
+
latent_model_input,
|
1235 |
+
t,
|
1236 |
+
encoder_hidden_states=prompt_embeds,
|
1237 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
1238 |
+
down_block_additional_residuals=down_block_res_samples,
|
1239 |
+
mid_block_additional_residual=mid_block_res_sample,
|
1240 |
+
added_cond_kwargs=added_cond_kwargs,
|
1241 |
+
return_dict=False,
|
1242 |
+
)[0]
|
1243 |
+
|
1244 |
+
# perform guidance
|
1245 |
+
if do_classifier_free_guidance:
|
1246 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
1247 |
+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
1248 |
+
|
1249 |
+
# compute the previous noisy sample x_t -> x_t-1
|
1250 |
+
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
1251 |
+
|
1252 |
+
# call the callback, if provided
|
1253 |
+
# if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
1254 |
+
# progress_bar.update()
|
1255 |
+
# if callback is not None and i % callback_steps == 0:
|
1256 |
+
# callback(i, t, latents)
|
1257 |
+
|
1258 |
+
# manually for max memory savings
|
1259 |
+
if self.vae.dtype == torch.float16 and self.vae.config.force_upcast:
|
1260 |
+
self.upcast_vae()
|
1261 |
+
latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
|
1262 |
+
|
1263 |
+
if not output_type == "latent":
|
1264 |
+
# make sure the VAE is in float32 mode, as it overflows in float16
|
1265 |
+
needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
|
1266 |
+
|
1267 |
+
if needs_upcasting:
|
1268 |
+
self.upcast_vae()
|
1269 |
+
latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
|
1270 |
+
|
1271 |
+
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
|
1272 |
+
|
1273 |
+
# cast back to fp16 if needed
|
1274 |
+
if needs_upcasting:
|
1275 |
+
self.vae.to(dtype=torch.float16)
|
1276 |
+
else:
|
1277 |
+
image = latents
|
1278 |
+
|
1279 |
+
if not output_type == "latent":
|
1280 |
+
# apply watermark if available
|
1281 |
+
if self.watermark is not None:
|
1282 |
+
image = self.watermark.apply_watermark(image)
|
1283 |
+
|
1284 |
+
image = self.image_processor.postprocess(image, output_type=output_type)
|
1285 |
+
|
1286 |
+
# Offload all models
|
1287 |
+
self.maybe_free_model_hooks()
|
1288 |
+
|
1289 |
+
if not return_dict:
|
1290 |
+
return (image,)
|
1291 |
+
|
1292 |
+
return StableDiffusionXLPipelineOutput(images=image)
|
1293 |
+
|
1294 |
+
# Overrride to properly handle the loading and unloading of the additional text encoder.
|
1295 |
+
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.load_lora_weights
|
1296 |
+
def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs):
|
1297 |
+
# We could have accessed the unet config from `lora_state_dict()` too. We pass
|
1298 |
+
# it here explicitly to be able to tell that it's coming from an SDXL
|
1299 |
+
# pipeline.
|
1300 |
+
|
1301 |
+
# Remove any existing hooks.
|
1302 |
+
if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
|
1303 |
+
from accelerate.hooks import AlignDevicesHook, CpuOffload, remove_hook_from_module
|
1304 |
+
else:
|
1305 |
+
raise ImportError("Offloading requires `accelerate v0.17.0` or higher.")
|
1306 |
+
|
1307 |
+
is_model_cpu_offload = False
|
1308 |
+
is_sequential_cpu_offload = False
|
1309 |
+
recursive = False
|
1310 |
+
for _, component in self.components.items():
|
1311 |
+
if isinstance(component, torch.nn.Module):
|
1312 |
+
if hasattr(component, "_hf_hook"):
|
1313 |
+
is_model_cpu_offload = isinstance(getattr(component, "_hf_hook"), CpuOffload)
|
1314 |
+
is_sequential_cpu_offload = isinstance(getattr(component, "_hf_hook"), AlignDevicesHook)
|
1315 |
+
logger.info(
|
1316 |
+
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
|
1317 |
+
)
|
1318 |
+
recursive = is_sequential_cpu_offload
|
1319 |
+
remove_hook_from_module(component, recurse=recursive)
|
1320 |
+
state_dict, network_alphas = self.lora_state_dict(
|
1321 |
+
pretrained_model_name_or_path_or_dict,
|
1322 |
+
unet_config=self.unet.config,
|
1323 |
+
**kwargs,
|
1324 |
+
)
|
1325 |
+
self.load_lora_into_unet(state_dict, network_alphas=network_alphas, unet=self.unet)
|
1326 |
+
|
1327 |
+
text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k}
|
1328 |
+
if len(text_encoder_state_dict) > 0:
|
1329 |
+
self.load_lora_into_text_encoder(
|
1330 |
+
text_encoder_state_dict,
|
1331 |
+
network_alphas=network_alphas,
|
1332 |
+
text_encoder=self.text_encoder,
|
1333 |
+
prefix="text_encoder",
|
1334 |
+
lora_scale=self.lora_scale,
|
1335 |
+
)
|
1336 |
+
|
1337 |
+
text_encoder_2_state_dict = {k: v for k, v in state_dict.items() if "text_encoder_2." in k}
|
1338 |
+
if len(text_encoder_2_state_dict) > 0:
|
1339 |
+
self.load_lora_into_text_encoder(
|
1340 |
+
text_encoder_2_state_dict,
|
1341 |
+
network_alphas=network_alphas,
|
1342 |
+
text_encoder=self.text_encoder_2,
|
1343 |
+
prefix="text_encoder_2",
|
1344 |
+
lora_scale=self.lora_scale,
|
1345 |
+
)
|
1346 |
+
|
1347 |
+
# Offload back.
|
1348 |
+
if is_model_cpu_offload:
|
1349 |
+
self.enable_model_cpu_offload()
|
1350 |
+
elif is_sequential_cpu_offload:
|
1351 |
+
self.enable_sequential_cpu_offload()
|
1352 |
+
|
1353 |
+
@classmethod
|
1354 |
+
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.save_lora_weights
|
1355 |
+
def save_lora_weights(
|
1356 |
+
self,
|
1357 |
+
save_directory: Union[str, os.PathLike],
|
1358 |
+
unet_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
|
1359 |
+
text_encoder_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
|
1360 |
+
text_encoder_2_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
|
1361 |
+
is_main_process: bool = True,
|
1362 |
+
weight_name: str = None,
|
1363 |
+
save_function: Callable = None,
|
1364 |
+
safe_serialization: bool = True,
|
1365 |
+
):
|
1366 |
+
state_dict = {}
|
1367 |
+
|
1368 |
+
def pack_weights(layers, prefix):
|
1369 |
+
layers_weights = layers.state_dict() if isinstance(layers, torch.nn.Module) else layers
|
1370 |
+
layers_state_dict = {f"{prefix}.{module_name}": param for module_name, param in layers_weights.items()}
|
1371 |
+
return layers_state_dict
|
1372 |
+
|
1373 |
+
if not (unet_lora_layers or text_encoder_lora_layers or text_encoder_2_lora_layers):
|
1374 |
+
raise ValueError(
|
1375 |
+
"You must pass at least one of `unet_lora_layers`, `text_encoder_lora_layers` or `text_encoder_2_lora_layers`."
|
1376 |
+
)
|
1377 |
+
|
1378 |
+
if unet_lora_layers:
|
1379 |
+
state_dict.update(pack_weights(unet_lora_layers, "unet"))
|
1380 |
+
|
1381 |
+
if text_encoder_lora_layers and text_encoder_2_lora_layers:
|
1382 |
+
state_dict.update(pack_weights(text_encoder_lora_layers, "text_encoder"))
|
1383 |
+
state_dict.update(pack_weights(text_encoder_2_lora_layers, "text_encoder_2"))
|
1384 |
+
|
1385 |
+
self.write_lora_layers(
|
1386 |
+
state_dict=state_dict,
|
1387 |
+
save_directory=save_directory,
|
1388 |
+
is_main_process=is_main_process,
|
1389 |
+
weight_name=weight_name,
|
1390 |
+
save_function=save_function,
|
1391 |
+
safe_serialization=safe_serialization,
|
1392 |
+
)
|
1393 |
+
|
1394 |
+
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline._remove_text_encoder_monkey_patch
|
1395 |
+
def _remove_text_encoder_monkey_patch(self):
|
1396 |
+
self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder)
|
1397 |
+
self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder_2)
|
utils/pipeline_stable_diffusion_switcher.py
ADDED
@@ -0,0 +1,1240 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import inspect
|
16 |
+
from typing import Any, Callable, Dict, List, Optional, Union
|
17 |
+
import numpy as np
|
18 |
+
from PIL import Image
|
19 |
+
import torch
|
20 |
+
from packaging import version
|
21 |
+
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
|
22 |
+
import torchvision.transforms.functional as TF
|
23 |
+
|
24 |
+
from diffusers.configuration_utils import FrozenDict
|
25 |
+
from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
|
26 |
+
from diffusers.loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin
|
27 |
+
from diffusers.models import AutoencoderKL, ImageProjection, UNet2DConditionModel
|
28 |
+
from diffusers.models.lora import adjust_lora_scale_text_encoder
|
29 |
+
from diffusers.schedulers import KarrasDiffusionSchedulers
|
30 |
+
from diffusers.utils import (
|
31 |
+
USE_PEFT_BACKEND,
|
32 |
+
deprecate,
|
33 |
+
logging,
|
34 |
+
replace_example_docstring,
|
35 |
+
scale_lora_layers,
|
36 |
+
unscale_lora_layers,
|
37 |
+
)
|
38 |
+
from diffusers.utils.torch_utils import randn_tensor
|
39 |
+
from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin
|
40 |
+
from diffusers.pipelines.stable_diffusion.pipeline_output import StableDiffusionPipelineOutput
|
41 |
+
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
42 |
+
|
43 |
+
|
44 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
45 |
+
|
46 |
+
EXAMPLE_DOC_STRING = """
|
47 |
+
Examples:
|
48 |
+
```py
|
49 |
+
>>> import torch
|
50 |
+
>>> from diffusers import StableDiffusionPipeline
|
51 |
+
|
52 |
+
>>> pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16)
|
53 |
+
>>> pipe = pipe.to("cuda")
|
54 |
+
|
55 |
+
>>> prompt = "a photo of an astronaut riding a horse on mars"
|
56 |
+
>>> image = pipe(prompt).images[0]
|
57 |
+
```
|
58 |
+
"""
|
59 |
+
|
60 |
+
|
61 |
+
def scale_latents_rm(latents):
|
62 |
+
latents = latents * 0.9702 - 0.5742
|
63 |
+
return latents
|
64 |
+
|
65 |
+
|
66 |
+
def unscale_latents_rm(latents):
|
67 |
+
latents = (latents + 0.5742) / 0.9702
|
68 |
+
return latents
|
69 |
+
|
70 |
+
|
71 |
+
def scale_latents_bump(latents):
|
72 |
+
latents = latents * 0.9462 + 0.3770
|
73 |
+
return latents
|
74 |
+
|
75 |
+
|
76 |
+
def unscale_latents_bump(latents):
|
77 |
+
latents = (latents - 0.3770) / 0.9462
|
78 |
+
return latents
|
79 |
+
|
80 |
+
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
|
81 |
+
"""
|
82 |
+
Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
|
83 |
+
Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
|
84 |
+
"""
|
85 |
+
std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
|
86 |
+
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
|
87 |
+
# rescale the results from guidance (fixes overexposure)
|
88 |
+
noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
|
89 |
+
# mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
|
90 |
+
noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
|
91 |
+
return noise_cfg
|
92 |
+
|
93 |
+
|
94 |
+
def retrieve_timesteps(
|
95 |
+
scheduler,
|
96 |
+
num_inference_steps: Optional[int] = None,
|
97 |
+
device: Optional[Union[str, torch.device]] = None,
|
98 |
+
timesteps: Optional[List[int]] = None,
|
99 |
+
**kwargs,
|
100 |
+
):
|
101 |
+
"""
|
102 |
+
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
103 |
+
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
104 |
+
|
105 |
+
Args:
|
106 |
+
scheduler (`SchedulerMixin`):
|
107 |
+
The scheduler to get timesteps from.
|
108 |
+
num_inference_steps (`int`):
|
109 |
+
The number of diffusion steps used when generating samples with a pre-trained model. If used,
|
110 |
+
`timesteps` must be `None`.
|
111 |
+
device (`str` or `torch.device`, *optional*):
|
112 |
+
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
113 |
+
timesteps (`List[int]`, *optional*):
|
114 |
+
Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default
|
115 |
+
timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps`
|
116 |
+
must be `None`.
|
117 |
+
|
118 |
+
Returns:
|
119 |
+
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
120 |
+
second element is the number of inference steps.
|
121 |
+
"""
|
122 |
+
if timesteps is not None:
|
123 |
+
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
124 |
+
if not accepts_timesteps:
|
125 |
+
raise ValueError(
|
126 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
127 |
+
f" timestep schedules. Please check whether you are using the correct scheduler."
|
128 |
+
)
|
129 |
+
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
130 |
+
timesteps = scheduler.timesteps
|
131 |
+
num_inference_steps = len(timesteps)
|
132 |
+
else:
|
133 |
+
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
134 |
+
timesteps = scheduler.timesteps
|
135 |
+
return timesteps, num_inference_steps
|
136 |
+
|
137 |
+
|
138 |
+
class StableDiffusionPipeline(
|
139 |
+
DiffusionPipeline,
|
140 |
+
StableDiffusionMixin,
|
141 |
+
TextualInversionLoaderMixin,
|
142 |
+
LoraLoaderMixin,
|
143 |
+
IPAdapterMixin,
|
144 |
+
FromSingleFileMixin,
|
145 |
+
):
|
146 |
+
r"""
|
147 |
+
Pipeline for text-to-image generation using Stable Diffusion.
|
148 |
+
|
149 |
+
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
|
150 |
+
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
|
151 |
+
|
152 |
+
The pipeline also inherits the following loading methods:
|
153 |
+
- [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
|
154 |
+
- [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights
|
155 |
+
- [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights
|
156 |
+
- [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
|
157 |
+
- [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
|
158 |
+
|
159 |
+
Args:
|
160 |
+
vae ([`AutoencoderKL`]):
|
161 |
+
Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
|
162 |
+
text_encoder ([`~transformers.CLIPTextModel`]):
|
163 |
+
Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
|
164 |
+
tokenizer ([`~transformers.CLIPTokenizer`]):
|
165 |
+
A `CLIPTokenizer` to tokenize text.
|
166 |
+
unet ([`UNet2DConditionModel`]):
|
167 |
+
A `UNet2DConditionModel` to denoise the encoded image latents.
|
168 |
+
scheduler ([`SchedulerMixin`]):
|
169 |
+
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
|
170 |
+
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
|
171 |
+
safety_checker ([`StableDiffusionSafetyChecker`]):
|
172 |
+
Classification module that estimates whether generated images could be considered offensive or harmful.
|
173 |
+
Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
|
174 |
+
about a model's potential harms.
|
175 |
+
feature_extractor ([`~transformers.CLIPImageProcessor`]):
|
176 |
+
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
|
177 |
+
"""
|
178 |
+
|
179 |
+
model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae"
|
180 |
+
_optional_components = ["safety_checker", "feature_extractor", "image_encoder"]
|
181 |
+
_exclude_from_cpu_offload = ["safety_checker"]
|
182 |
+
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
|
183 |
+
|
184 |
+
def __init__(
|
185 |
+
self,
|
186 |
+
vae: AutoencoderKL,
|
187 |
+
text_encoder: CLIPTextModel,
|
188 |
+
tokenizer: CLIPTokenizer,
|
189 |
+
unet: UNet2DConditionModel,
|
190 |
+
scheduler: KarrasDiffusionSchedulers,
|
191 |
+
safety_checker: StableDiffusionSafetyChecker,
|
192 |
+
feature_extractor: CLIPImageProcessor,
|
193 |
+
image_encoder: CLIPVisionModelWithProjection = None,
|
194 |
+
requires_safety_checker: bool = True,
|
195 |
+
):
|
196 |
+
super().__init__()
|
197 |
+
|
198 |
+
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
|
199 |
+
deprecation_message = (
|
200 |
+
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
|
201 |
+
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
|
202 |
+
"to update the config accordingly as leaving `steps_offset` might led to incorrect results"
|
203 |
+
" in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
|
204 |
+
" it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
|
205 |
+
" file"
|
206 |
+
)
|
207 |
+
deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
|
208 |
+
new_config = dict(scheduler.config)
|
209 |
+
new_config["steps_offset"] = 1
|
210 |
+
scheduler._internal_dict = FrozenDict(new_config)
|
211 |
+
|
212 |
+
if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
|
213 |
+
deprecation_message = (
|
214 |
+
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
|
215 |
+
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
|
216 |
+
" config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
|
217 |
+
" future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
|
218 |
+
" nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
|
219 |
+
)
|
220 |
+
deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
|
221 |
+
new_config = dict(scheduler.config)
|
222 |
+
new_config["clip_sample"] = False
|
223 |
+
scheduler._internal_dict = FrozenDict(new_config)
|
224 |
+
|
225 |
+
if safety_checker is None and requires_safety_checker:
|
226 |
+
logger.warning(
|
227 |
+
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
|
228 |
+
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
|
229 |
+
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
|
230 |
+
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
|
231 |
+
" it only for use-cases that involve analyzing network behavior or auditing its results. For more"
|
232 |
+
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
|
233 |
+
)
|
234 |
+
|
235 |
+
if safety_checker is not None and feature_extractor is None:
|
236 |
+
raise ValueError(
|
237 |
+
"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
|
238 |
+
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
|
239 |
+
)
|
240 |
+
|
241 |
+
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
|
242 |
+
version.parse(unet.config._diffusers_version).base_version
|
243 |
+
) < version.parse("0.9.0.dev0")
|
244 |
+
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
|
245 |
+
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
|
246 |
+
deprecation_message = (
|
247 |
+
"The configuration file of the unet has set the default `sample_size` to smaller than"
|
248 |
+
" 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the"
|
249 |
+
" following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
|
250 |
+
" CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
|
251 |
+
" \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
|
252 |
+
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
|
253 |
+
" in the config might lead to incorrect results in future versions. If you have downloaded this"
|
254 |
+
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
|
255 |
+
" the `unet/config.json` file"
|
256 |
+
)
|
257 |
+
deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
|
258 |
+
new_config = dict(unet.config)
|
259 |
+
new_config["sample_size"] = 64
|
260 |
+
unet._internal_dict = FrozenDict(new_config)
|
261 |
+
|
262 |
+
self.register_modules(
|
263 |
+
vae=vae,
|
264 |
+
text_encoder=text_encoder,
|
265 |
+
tokenizer=tokenizer,
|
266 |
+
unet=unet,
|
267 |
+
scheduler=scheduler,
|
268 |
+
safety_checker=safety_checker,
|
269 |
+
feature_extractor=feature_extractor,
|
270 |
+
image_encoder=image_encoder,
|
271 |
+
)
|
272 |
+
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
273 |
+
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
274 |
+
self.register_to_config(requires_safety_checker=requires_safety_checker)
|
275 |
+
|
276 |
+
def _encode_prompt(
|
277 |
+
self,
|
278 |
+
prompt,
|
279 |
+
device,
|
280 |
+
num_images_per_prompt,
|
281 |
+
do_classifier_free_guidance,
|
282 |
+
negative_prompt=None,
|
283 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
284 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
285 |
+
lora_scale: Optional[float] = None,
|
286 |
+
**kwargs,
|
287 |
+
):
|
288 |
+
deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple."
|
289 |
+
deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False)
|
290 |
+
|
291 |
+
prompt_embeds_tuple = self.encode_prompt(
|
292 |
+
prompt=prompt,
|
293 |
+
device=device,
|
294 |
+
num_images_per_prompt=num_images_per_prompt,
|
295 |
+
do_classifier_free_guidance=do_classifier_free_guidance,
|
296 |
+
negative_prompt=negative_prompt,
|
297 |
+
prompt_embeds=prompt_embeds,
|
298 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
299 |
+
lora_scale=lora_scale,
|
300 |
+
**kwargs,
|
301 |
+
)
|
302 |
+
|
303 |
+
# concatenate for backwards comp
|
304 |
+
prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]])
|
305 |
+
|
306 |
+
return prompt_embeds
|
307 |
+
|
308 |
+
def encode_prompt(
|
309 |
+
self,
|
310 |
+
prompt,
|
311 |
+
device,
|
312 |
+
num_images_per_prompt,
|
313 |
+
do_classifier_free_guidance,
|
314 |
+
negative_prompt=None,
|
315 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
316 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
317 |
+
lora_scale: Optional[float] = None,
|
318 |
+
clip_skip: Optional[int] = None,
|
319 |
+
):
|
320 |
+
r"""
|
321 |
+
Encodes the prompt into text encoder hidden states.
|
322 |
+
|
323 |
+
Args:
|
324 |
+
prompt (`str` or `List[str]`, *optional*):
|
325 |
+
prompt to be encoded
|
326 |
+
device: (`torch.device`):
|
327 |
+
torch device
|
328 |
+
num_images_per_prompt (`int`):
|
329 |
+
number of images that should be generated per prompt
|
330 |
+
do_classifier_free_guidance (`bool`):
|
331 |
+
whether to use classifier free guidance or not
|
332 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
333 |
+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
334 |
+
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
335 |
+
less than `1`).
|
336 |
+
prompt_embeds (`torch.FloatTensor`, *optional*):
|
337 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
338 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
339 |
+
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
340 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
341 |
+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
342 |
+
argument.
|
343 |
+
lora_scale (`float`, *optional*):
|
344 |
+
A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
|
345 |
+
clip_skip (`int`, *optional*):
|
346 |
+
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
|
347 |
+
the output of the pre-final layer will be used for computing the prompt embeddings.
|
348 |
+
"""
|
349 |
+
# set lora scale so that monkey patched LoRA
|
350 |
+
# function of text encoder can correctly access it
|
351 |
+
if lora_scale is not None and isinstance(self, LoraLoaderMixin):
|
352 |
+
self._lora_scale = lora_scale
|
353 |
+
|
354 |
+
# dynamically adjust the LoRA scale
|
355 |
+
if not USE_PEFT_BACKEND:
|
356 |
+
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
|
357 |
+
else:
|
358 |
+
scale_lora_layers(self.text_encoder, lora_scale)
|
359 |
+
|
360 |
+
if prompt is not None and isinstance(prompt, str):
|
361 |
+
batch_size = 1
|
362 |
+
elif prompt is not None and isinstance(prompt, list):
|
363 |
+
batch_size = len(prompt)
|
364 |
+
else:
|
365 |
+
batch_size = prompt_embeds.shape[0]
|
366 |
+
|
367 |
+
if prompt_embeds is None:
|
368 |
+
# textual inversion: process multi-vector tokens if necessary
|
369 |
+
if isinstance(self, TextualInversionLoaderMixin):
|
370 |
+
prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
|
371 |
+
|
372 |
+
text_inputs = self.tokenizer(
|
373 |
+
prompt,
|
374 |
+
padding="max_length",
|
375 |
+
max_length=self.tokenizer.model_max_length,
|
376 |
+
truncation=True,
|
377 |
+
return_tensors="pt",
|
378 |
+
)
|
379 |
+
text_input_ids = text_inputs.input_ids
|
380 |
+
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
381 |
+
|
382 |
+
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
|
383 |
+
text_input_ids, untruncated_ids
|
384 |
+
):
|
385 |
+
removed_text = self.tokenizer.batch_decode(
|
386 |
+
untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
|
387 |
+
)
|
388 |
+
logger.warning(
|
389 |
+
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
390 |
+
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
|
391 |
+
)
|
392 |
+
|
393 |
+
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
|
394 |
+
attention_mask = text_inputs.attention_mask.to(device)
|
395 |
+
else:
|
396 |
+
attention_mask = None
|
397 |
+
|
398 |
+
if clip_skip is None:
|
399 |
+
prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask)
|
400 |
+
prompt_embeds = prompt_embeds[0]
|
401 |
+
else:
|
402 |
+
prompt_embeds = self.text_encoder(
|
403 |
+
text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True
|
404 |
+
)
|
405 |
+
# Access the `hidden_states` first, that contains a tuple of
|
406 |
+
# all the hidden states from the encoder layers. Then index into
|
407 |
+
# the tuple to access the hidden states from the desired layer.
|
408 |
+
prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)]
|
409 |
+
# We also need to apply the final LayerNorm here to not mess with the
|
410 |
+
# representations. The `last_hidden_states` that we typically use for
|
411 |
+
# obtaining the final prompt representations passes through the LayerNorm
|
412 |
+
# layer.
|
413 |
+
prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds)
|
414 |
+
|
415 |
+
if self.text_encoder is not None:
|
416 |
+
prompt_embeds_dtype = self.text_encoder.dtype
|
417 |
+
elif self.unet is not None:
|
418 |
+
prompt_embeds_dtype = self.unet.dtype
|
419 |
+
else:
|
420 |
+
prompt_embeds_dtype = prompt_embeds.dtype
|
421 |
+
|
422 |
+
prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
|
423 |
+
|
424 |
+
bs_embed, seq_len, _ = prompt_embeds.shape
|
425 |
+
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
426 |
+
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
427 |
+
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
|
428 |
+
|
429 |
+
# get unconditional embeddings for classifier free guidance
|
430 |
+
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
431 |
+
uncond_tokens: List[str]
|
432 |
+
if negative_prompt is None:
|
433 |
+
uncond_tokens = [""] * batch_size
|
434 |
+
elif prompt is not None and type(prompt) is not type(negative_prompt):
|
435 |
+
raise TypeError(
|
436 |
+
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
437 |
+
f" {type(prompt)}."
|
438 |
+
)
|
439 |
+
elif isinstance(negative_prompt, str):
|
440 |
+
uncond_tokens = [negative_prompt]
|
441 |
+
elif batch_size != len(negative_prompt):
|
442 |
+
raise ValueError(
|
443 |
+
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
444 |
+
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
445 |
+
" the batch size of `prompt`."
|
446 |
+
)
|
447 |
+
else:
|
448 |
+
uncond_tokens = negative_prompt
|
449 |
+
|
450 |
+
# textual inversion: process multi-vector tokens if necessary
|
451 |
+
if isinstance(self, TextualInversionLoaderMixin):
|
452 |
+
uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
|
453 |
+
|
454 |
+
max_length = prompt_embeds.shape[1]
|
455 |
+
uncond_input = self.tokenizer(
|
456 |
+
uncond_tokens,
|
457 |
+
padding="max_length",
|
458 |
+
max_length=max_length,
|
459 |
+
truncation=True,
|
460 |
+
return_tensors="pt",
|
461 |
+
)
|
462 |
+
|
463 |
+
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
|
464 |
+
attention_mask = uncond_input.attention_mask.to(device)
|
465 |
+
else:
|
466 |
+
attention_mask = None
|
467 |
+
|
468 |
+
negative_prompt_embeds = self.text_encoder(
|
469 |
+
uncond_input.input_ids.to(device),
|
470 |
+
attention_mask=attention_mask,
|
471 |
+
)
|
472 |
+
negative_prompt_embeds = negative_prompt_embeds[0]
|
473 |
+
|
474 |
+
if do_classifier_free_guidance:
|
475 |
+
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
476 |
+
seq_len = negative_prompt_embeds.shape[1]
|
477 |
+
|
478 |
+
negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
|
479 |
+
|
480 |
+
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
481 |
+
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
482 |
+
|
483 |
+
if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
|
484 |
+
# Retrieve the original scale by scaling back the LoRA layers
|
485 |
+
unscale_lora_layers(self.text_encoder, lora_scale)
|
486 |
+
|
487 |
+
return prompt_embeds, negative_prompt_embeds
|
488 |
+
|
489 |
+
def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
|
490 |
+
dtype = next(self.image_encoder.parameters()).dtype
|
491 |
+
|
492 |
+
if not isinstance(image, torch.Tensor):
|
493 |
+
image = self.feature_extractor(image, return_tensors="pt").pixel_values
|
494 |
+
|
495 |
+
image = image.to(device=device, dtype=dtype)
|
496 |
+
if output_hidden_states:
|
497 |
+
image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
|
498 |
+
image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
|
499 |
+
uncond_image_enc_hidden_states = self.image_encoder(
|
500 |
+
torch.zeros_like(image), output_hidden_states=True
|
501 |
+
).hidden_states[-2]
|
502 |
+
uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
|
503 |
+
num_images_per_prompt, dim=0
|
504 |
+
)
|
505 |
+
return image_enc_hidden_states, uncond_image_enc_hidden_states
|
506 |
+
else:
|
507 |
+
image_embeds = self.image_encoder(image).image_embeds
|
508 |
+
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
|
509 |
+
uncond_image_embeds = torch.zeros_like(image_embeds)
|
510 |
+
|
511 |
+
return image_embeds, uncond_image_embeds
|
512 |
+
|
513 |
+
def prepare_cond_image_latents(self, image, normal, mask, cond_vae, device, num_images_per_prompt, do_classifier_free_guidance):
|
514 |
+
dtype = self.vae.dtype
|
515 |
+
|
516 |
+
if isinstance(image, list):
|
517 |
+
image = torch.stack([TF.to_tensor(img) for img in image], dim=0).to(device=device, dtype=dtype)
|
518 |
+
elif isinstance(image, torch.Tensor):
|
519 |
+
image = image.to(device=device, dtype=dtype)
|
520 |
+
|
521 |
+
if isinstance(normal, list):
|
522 |
+
normal = torch.stack([TF.to_tensor(img) for img in normal], dim=0).to(device=device, dtype=dtype)
|
523 |
+
elif isinstance(normal, torch.Tensor):
|
524 |
+
normal = normal.to(device=device, dtype=dtype)
|
525 |
+
|
526 |
+
if isinstance(mask, list):
|
527 |
+
if isinstance(mask[0], np.ndarray):
|
528 |
+
mask = [Image.fromarray((img*255).astype(np.uint8), mode='L') for img in mask]
|
529 |
+
mask = [img.resize((image.shape[3]//8, image.shape[2]//8), resample=Image.NEAREST) for img in mask]
|
530 |
+
elif isinstance(mask[0], Image.Image):
|
531 |
+
mask = [img.resize((image.shape[3]//8, image.shape[2]//8), resample=Image.NEAREST) for img in mask]
|
532 |
+
mask = torch.stack([TF.to_tensor(img) for img in mask], dim=0).to(device=device, dtype=dtype)
|
533 |
+
elif isinstance(mask, torch.Tensor):
|
534 |
+
mask = Image.fromarray((mask.cpu().numpy()*255).astype(np.uint8), mode='L')
|
535 |
+
mask = mask.resize((image.shape[3]//8, image.shape[2]//8), resample=Image.NEAREST)
|
536 |
+
mask = TF.to_tensor(mask).to(device=device, dtype=dtype)
|
537 |
+
|
538 |
+
if cond_vae is not None:
|
539 |
+
image = image * 2.0 - 1.0
|
540 |
+
if normal is not None:
|
541 |
+
normal = normal * 2.0 - 1.0
|
542 |
+
image = torch.cat([image, normal], dim=1)
|
543 |
+
latents = cond_vae(image) * self.vae.config.scaling_factor
|
544 |
+
else:
|
545 |
+
# vae encoder
|
546 |
+
image = image * 2.0 - 1.0
|
547 |
+
latents = self.vae.encode(image).latent_dist.mode() * self.vae.config.scaling_factor
|
548 |
+
latents = latents.repeat(num_images_per_prompt, 1, 1, 1)
|
549 |
+
|
550 |
+
if normal is not None:
|
551 |
+
normal = normal * 2.0 - 1.0
|
552 |
+
normal_latents = self.vae.encode(normal).latent_dist.mode() * self.vae.config.scaling_factor
|
553 |
+
normal_latents = normal_latents.repeat(num_images_per_prompt, 1, 1, 1)
|
554 |
+
latents = torch.cat([latents, normal_latents], dim=1)
|
555 |
+
|
556 |
+
if mask is not None:
|
557 |
+
# mask = torch.ones_like(mask)
|
558 |
+
mask = mask * 2.0 - 1.0
|
559 |
+
mask_latents = mask.repeat(num_images_per_prompt, 1, 1, 1)
|
560 |
+
latents = torch.cat([latents, mask_latents.to(latents)], dim=1)
|
561 |
+
|
562 |
+
|
563 |
+
if do_classifier_free_guidance:
|
564 |
+
# uncond_latens = self.vae.encode(torch.zeros_like(image)).latent_dist.mode() * self.vae.config.scaling_factor
|
565 |
+
# uncond_latens.repeat(num_images_per_prompt, 1, 1, 1)
|
566 |
+
uncond_latens = torch.zeros_like(latents)
|
567 |
+
latents = torch.cat([latents, latents])
|
568 |
+
|
569 |
+
return latents
|
570 |
+
|
571 |
+
def prepare_init_latents(self, init_materials, device, num_images_per_prompt, do_classifier_free_guidance):
|
572 |
+
dtype = self.vae.dtype
|
573 |
+
|
574 |
+
image = torch.cat([
|
575 |
+
init_materials['albedo'][...,:3].permute(0, 3, 1, 2),
|
576 |
+
init_materials['roughness_metallic'][...,:3].permute(0, 3, 1, 2),
|
577 |
+
init_materials['bump'][...,:3].permute(0, 3, 1, 2),
|
578 |
+
], dim=0).to(device=device, dtype=dtype)
|
579 |
+
|
580 |
+
from einops import rearrange
|
581 |
+
# vae encoder
|
582 |
+
image = image * 2.0 - 1.0
|
583 |
+
latents = self.vae.encode(image).latent_dist.mode() * self.vae.config.scaling_factor
|
584 |
+
latents = rearrange(latents, '(s b) c h w -> b (s c) h w', s=3)
|
585 |
+
latents = latents.repeat(num_images_per_prompt, 1, 1, 1)
|
586 |
+
|
587 |
+
# if do_classifier_free_guidance:
|
588 |
+
# # uncond_latens = self.vae.encode(torch.zeros_like(image)).latent_dist.mode() * self.vae.config.scaling_factor
|
589 |
+
# # uncond_latens.repeat(num_images_per_prompt, 1, 1, 1)
|
590 |
+
# # uncond_latens = torch.zeros_like(latents)
|
591 |
+
# latents = torch.cat([latents, latents])
|
592 |
+
|
593 |
+
return latents
|
594 |
+
|
595 |
+
def prepare_ip_adapter_image_embeds(
|
596 |
+
self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance
|
597 |
+
):
|
598 |
+
if ip_adapter_image_embeds is None:
|
599 |
+
if not isinstance(ip_adapter_image, list):
|
600 |
+
ip_adapter_image = [ip_adapter_image]
|
601 |
+
|
602 |
+
if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):
|
603 |
+
raise ValueError(
|
604 |
+
f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
|
605 |
+
)
|
606 |
+
|
607 |
+
image_embeds = []
|
608 |
+
for single_ip_adapter_image, image_proj_layer in zip(
|
609 |
+
ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
|
610 |
+
):
|
611 |
+
output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
|
612 |
+
single_image_embeds, single_negative_image_embeds = self.encode_image(
|
613 |
+
single_ip_adapter_image, device, 1, output_hidden_state
|
614 |
+
)
|
615 |
+
single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0)
|
616 |
+
single_negative_image_embeds = torch.stack(
|
617 |
+
[single_negative_image_embeds] * num_images_per_prompt, dim=0
|
618 |
+
)
|
619 |
+
|
620 |
+
if do_classifier_free_guidance:
|
621 |
+
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
|
622 |
+
single_image_embeds = single_image_embeds.to(device)
|
623 |
+
|
624 |
+
image_embeds.append(single_image_embeds)
|
625 |
+
else:
|
626 |
+
repeat_dims = [1]
|
627 |
+
image_embeds = []
|
628 |
+
for single_image_embeds in ip_adapter_image_embeds:
|
629 |
+
if do_classifier_free_guidance:
|
630 |
+
single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
|
631 |
+
single_image_embeds = single_image_embeds.repeat(
|
632 |
+
num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
|
633 |
+
)
|
634 |
+
single_negative_image_embeds = single_negative_image_embeds.repeat(
|
635 |
+
num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:]))
|
636 |
+
)
|
637 |
+
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
|
638 |
+
else:
|
639 |
+
single_image_embeds = single_image_embeds.repeat(
|
640 |
+
num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
|
641 |
+
)
|
642 |
+
image_embeds.append(single_image_embeds)
|
643 |
+
|
644 |
+
return image_embeds
|
645 |
+
|
646 |
+
def run_safety_checker(self, image, device, dtype):
|
647 |
+
if self.safety_checker is None:
|
648 |
+
has_nsfw_concept = None
|
649 |
+
else:
|
650 |
+
if torch.is_tensor(image):
|
651 |
+
feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
|
652 |
+
else:
|
653 |
+
feature_extractor_input = self.image_processor.numpy_to_pil(image)
|
654 |
+
safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
|
655 |
+
image, has_nsfw_concept = self.safety_checker(
|
656 |
+
images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
|
657 |
+
)
|
658 |
+
return image, has_nsfw_concept
|
659 |
+
|
660 |
+
def decode_latents(self, latents):
|
661 |
+
deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead"
|
662 |
+
deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False)
|
663 |
+
|
664 |
+
latents = 1 / self.vae.config.scaling_factor * latents
|
665 |
+
image = self.vae.decode(latents, return_dict=False)[0]
|
666 |
+
image = (image / 2 + 0.5).clamp(0, 1)
|
667 |
+
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
|
668 |
+
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
|
669 |
+
return image
|
670 |
+
|
671 |
+
def prepare_extra_step_kwargs(self, generator, eta):
|
672 |
+
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
673 |
+
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
674 |
+
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
675 |
+
# and should be between [0, 1]
|
676 |
+
|
677 |
+
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
678 |
+
extra_step_kwargs = {}
|
679 |
+
if accepts_eta:
|
680 |
+
extra_step_kwargs["eta"] = eta
|
681 |
+
|
682 |
+
# check if the scheduler accepts generator
|
683 |
+
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
684 |
+
if accepts_generator:
|
685 |
+
extra_step_kwargs["generator"] = generator
|
686 |
+
return extra_step_kwargs
|
687 |
+
|
688 |
+
def check_inputs(
|
689 |
+
self,
|
690 |
+
prompt,
|
691 |
+
height,
|
692 |
+
width,
|
693 |
+
callback_steps,
|
694 |
+
negative_prompt=None,
|
695 |
+
prompt_embeds=None,
|
696 |
+
negative_prompt_embeds=None,
|
697 |
+
ip_adapter_image=None,
|
698 |
+
ip_adapter_image_embeds=None,
|
699 |
+
callback_on_step_end_tensor_inputs=None,
|
700 |
+
):
|
701 |
+
if height % 8 != 0 or width % 8 != 0:
|
702 |
+
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
703 |
+
|
704 |
+
if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
|
705 |
+
raise ValueError(
|
706 |
+
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
|
707 |
+
f" {type(callback_steps)}."
|
708 |
+
)
|
709 |
+
if callback_on_step_end_tensor_inputs is not None and not all(
|
710 |
+
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
711 |
+
):
|
712 |
+
raise ValueError(
|
713 |
+
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
714 |
+
)
|
715 |
+
|
716 |
+
if prompt is not None and prompt_embeds is not None:
|
717 |
+
raise ValueError(
|
718 |
+
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
719 |
+
" only forward one of the two."
|
720 |
+
)
|
721 |
+
elif prompt is None and prompt_embeds is None:
|
722 |
+
raise ValueError(
|
723 |
+
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
724 |
+
)
|
725 |
+
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
726 |
+
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
727 |
+
|
728 |
+
if negative_prompt is not None and negative_prompt_embeds is not None:
|
729 |
+
raise ValueError(
|
730 |
+
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
731 |
+
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
732 |
+
)
|
733 |
+
|
734 |
+
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
735 |
+
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
736 |
+
raise ValueError(
|
737 |
+
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
|
738 |
+
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
739 |
+
f" {negative_prompt_embeds.shape}."
|
740 |
+
)
|
741 |
+
|
742 |
+
if ip_adapter_image is not None and ip_adapter_image_embeds is not None:
|
743 |
+
raise ValueError(
|
744 |
+
"Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined."
|
745 |
+
)
|
746 |
+
|
747 |
+
if ip_adapter_image_embeds is not None:
|
748 |
+
if not isinstance(ip_adapter_image_embeds, list):
|
749 |
+
raise ValueError(
|
750 |
+
f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}"
|
751 |
+
)
|
752 |
+
elif ip_adapter_image_embeds[0].ndim not in [3, 4]:
|
753 |
+
raise ValueError(
|
754 |
+
f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D"
|
755 |
+
)
|
756 |
+
|
757 |
+
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None, copy_noise=False):
|
758 |
+
if copy_noise:
|
759 |
+
shape = (batch_size, num_channels_latents//3, height // self.vae_scale_factor, width // self.vae_scale_factor)
|
760 |
+
else:
|
761 |
+
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
|
762 |
+
if isinstance(generator, list) and len(generator) != batch_size:
|
763 |
+
raise ValueError(
|
764 |
+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
765 |
+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
766 |
+
)
|
767 |
+
|
768 |
+
if latents is None:
|
769 |
+
if copy_noise:
|
770 |
+
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
771 |
+
latents = torch.cat([latents, latents, latents], dim=1)
|
772 |
+
else:
|
773 |
+
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
774 |
+
else:
|
775 |
+
latents = latents.to(device)
|
776 |
+
|
777 |
+
# scale the initial noise by the standard deviation required by the scheduler
|
778 |
+
latents = latents * self.scheduler.init_noise_sigma
|
779 |
+
return latents
|
780 |
+
|
781 |
+
# Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
|
782 |
+
def get_guidance_scale_embedding(
|
783 |
+
self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32
|
784 |
+
) -> torch.FloatTensor:
|
785 |
+
"""
|
786 |
+
See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
|
787 |
+
|
788 |
+
Args:
|
789 |
+
w (`torch.Tensor`):
|
790 |
+
Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings.
|
791 |
+
embedding_dim (`int`, *optional*, defaults to 512):
|
792 |
+
Dimension of the embeddings to generate.
|
793 |
+
dtype (`torch.dtype`, *optional*, defaults to `torch.float32`):
|
794 |
+
Data type of the generated embeddings.
|
795 |
+
|
796 |
+
Returns:
|
797 |
+
`torch.FloatTensor`: Embedding vectors with shape `(len(w), embedding_dim)`.
|
798 |
+
"""
|
799 |
+
assert len(w.shape) == 1
|
800 |
+
w = w * 1000.0
|
801 |
+
|
802 |
+
half_dim = embedding_dim // 2
|
803 |
+
emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
|
804 |
+
emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
|
805 |
+
emb = w.to(dtype)[:, None] * emb[None, :]
|
806 |
+
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
|
807 |
+
if embedding_dim % 2 == 1: # zero pad
|
808 |
+
emb = torch.nn.functional.pad(emb, (0, 1))
|
809 |
+
assert emb.shape == (w.shape[0], embedding_dim)
|
810 |
+
return emb
|
811 |
+
|
812 |
+
def _get_add_time_ids(
|
813 |
+
self, albedo_label, rough_meta_label, bump_label, dtype
|
814 |
+
):
|
815 |
+
add_time_ids = list(albedo_label + rough_meta_label + bump_label)
|
816 |
+
|
817 |
+
passed_add_embed_dim = (
|
818 |
+
self.unet.config.addition_time_embed_dim * len(add_time_ids) // 3
|
819 |
+
)
|
820 |
+
expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features
|
821 |
+
|
822 |
+
if expected_add_embed_dim != passed_add_embed_dim:
|
823 |
+
raise ValueError(
|
824 |
+
f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
|
825 |
+
)
|
826 |
+
|
827 |
+
add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
|
828 |
+
return add_time_ids
|
829 |
+
|
830 |
+
@property
|
831 |
+
def guidance_scale(self):
|
832 |
+
return self._guidance_scale
|
833 |
+
|
834 |
+
@property
|
835 |
+
def guidance_rescale(self):
|
836 |
+
return self._guidance_rescale
|
837 |
+
|
838 |
+
@property
|
839 |
+
def clip_skip(self):
|
840 |
+
return self._clip_skip
|
841 |
+
|
842 |
+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
843 |
+
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
844 |
+
# corresponds to doing no classifier free guidance.
|
845 |
+
@property
|
846 |
+
def do_classifier_free_guidance(self):
|
847 |
+
return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None
|
848 |
+
|
849 |
+
@property
|
850 |
+
def cross_attention_kwargs(self):
|
851 |
+
return self._cross_attention_kwargs
|
852 |
+
|
853 |
+
@property
|
854 |
+
def num_timesteps(self):
|
855 |
+
return self._num_timesteps
|
856 |
+
|
857 |
+
@property
|
858 |
+
def interrupt(self):
|
859 |
+
return self._interrupt
|
860 |
+
|
861 |
+
@torch.no_grad()
|
862 |
+
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
863 |
+
def __call__(
|
864 |
+
self,
|
865 |
+
prompt: Union[str, List[str]] = None,
|
866 |
+
cond_image: Optional[PipelineImageInput] = None,
|
867 |
+
normal_image: Optional[PipelineImageInput] = None,
|
868 |
+
mask_image: Optional[PipelineImageInput] = None,
|
869 |
+
init_materials: Optional[dict] = None,
|
870 |
+
masks: Optional[torch.FloatTensor] = None,
|
871 |
+
cond_vae = None,
|
872 |
+
height: Optional[int] = None,
|
873 |
+
width: Optional[int] = None,
|
874 |
+
num_inference_steps: int = 50,
|
875 |
+
timesteps: List[int] = None,
|
876 |
+
guidance_scale: float = 7.5,
|
877 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
878 |
+
num_images_per_prompt: Optional[int] = 1,
|
879 |
+
eta: float = 0.0,
|
880 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
881 |
+
latents: Optional[torch.FloatTensor] = None,
|
882 |
+
unscale_latents: bool = False,
|
883 |
+
copy_noise: bool = False,
|
884 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
885 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
886 |
+
ip_adapter_image: Optional[PipelineImageInput] = None,
|
887 |
+
ip_adapter_image_embeds: Optional[List[torch.FloatTensor]] = None,
|
888 |
+
output_type: Optional[str] = "pil",
|
889 |
+
return_dict: bool = True,
|
890 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
891 |
+
guidance_rescale: float = 0.0,
|
892 |
+
clip_skip: Optional[int] = None,
|
893 |
+
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
894 |
+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
895 |
+
**kwargs,
|
896 |
+
):
|
897 |
+
r"""
|
898 |
+
The call function to the pipeline for generation.
|
899 |
+
|
900 |
+
Args:
|
901 |
+
prompt (`str` or `List[str]`, *optional*):
|
902 |
+
The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
|
903 |
+
height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
|
904 |
+
The height in pixels of the generated image.
|
905 |
+
width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
|
906 |
+
The width in pixels of the generated image.
|
907 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
908 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
909 |
+
expense of slower inference.
|
910 |
+
timesteps (`List[int]`, *optional*):
|
911 |
+
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
|
912 |
+
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
|
913 |
+
passed will be used. Must be in descending order.
|
914 |
+
guidance_scale (`float`, *optional*, defaults to 7.5):
|
915 |
+
A higher guidance scale value encourages the model to generate images closely linked to the text
|
916 |
+
`prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
|
917 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
918 |
+
The prompt or prompts to guide what to not include in image generation. If not defined, you need to
|
919 |
+
pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
|
920 |
+
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
921 |
+
The number of images to generate per prompt.
|
922 |
+
eta (`float`, *optional*, defaults to 0.0):
|
923 |
+
Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
|
924 |
+
to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
|
925 |
+
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
926 |
+
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
|
927 |
+
generation deterministic.
|
928 |
+
latents (`torch.FloatTensor`, *optional*):
|
929 |
+
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
|
930 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
931 |
+
tensor is generated by sampling using the supplied random `generator`.
|
932 |
+
prompt_embeds (`torch.FloatTensor`, *optional*):
|
933 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
|
934 |
+
provided, text embeddings are generated from the `prompt` input argument.
|
935 |
+
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
936 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
|
937 |
+
not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
|
938 |
+
ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
|
939 |
+
ip_adapter_image_embeds (`List[torch.FloatTensor]`, *optional*):
|
940 |
+
Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of IP-adapters.
|
941 |
+
Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should contain the negative image embedding
|
942 |
+
if `do_classifier_free_guidance` is set to `True`.
|
943 |
+
If not provided, embeddings are computed from the `ip_adapter_image` input argument.
|
944 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
945 |
+
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
|
946 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
947 |
+
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
|
948 |
+
plain tuple.
|
949 |
+
cross_attention_kwargs (`dict`, *optional*):
|
950 |
+
A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
|
951 |
+
[`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
952 |
+
guidance_rescale (`float`, *optional*, defaults to 0.0):
|
953 |
+
Guidance rescale factor from [Common Diffusion Noise Schedules and Sample Steps are
|
954 |
+
Flawed](https://arxiv.org/pdf/2305.08891.pdf). Guidance rescale factor should fix overexposure when
|
955 |
+
using zero terminal SNR.
|
956 |
+
clip_skip (`int`, *optional*):
|
957 |
+
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
|
958 |
+
the output of the pre-final layer will be used for computing the prompt embeddings.
|
959 |
+
callback_on_step_end (`Callable`, *optional*):
|
960 |
+
A function that calls at the end of each denoising steps during the inference. The function is called
|
961 |
+
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
962 |
+
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
|
963 |
+
`callback_on_step_end_tensor_inputs`.
|
964 |
+
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
965 |
+
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
966 |
+
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
967 |
+
`._callback_tensor_inputs` attribute of your pipeline class.
|
968 |
+
|
969 |
+
Examples:
|
970 |
+
|
971 |
+
Returns:
|
972 |
+
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
|
973 |
+
If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
|
974 |
+
otherwise a `tuple` is returned where the first element is a list with the generated images and the
|
975 |
+
second element is a list of `bool`s indicating whether the corresponding generated image contains
|
976 |
+
"not-safe-for-work" (nsfw) content.
|
977 |
+
"""
|
978 |
+
|
979 |
+
callback = kwargs.pop("callback", None)
|
980 |
+
callback_steps = kwargs.pop("callback_steps", None)
|
981 |
+
|
982 |
+
if callback is not None:
|
983 |
+
deprecate(
|
984 |
+
"callback",
|
985 |
+
"1.0.0",
|
986 |
+
"Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
|
987 |
+
)
|
988 |
+
if callback_steps is not None:
|
989 |
+
deprecate(
|
990 |
+
"callback_steps",
|
991 |
+
"1.0.0",
|
992 |
+
"Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
|
993 |
+
)
|
994 |
+
|
995 |
+
# 0. Default height and width to unet
|
996 |
+
height = height or self.unet.config.sample_size * self.vae_scale_factor
|
997 |
+
width = width or self.unet.config.sample_size * self.vae_scale_factor
|
998 |
+
# to deal with lora scaling and other possible forward hooks
|
999 |
+
|
1000 |
+
# 1. Check inputs. Raise error if not correct
|
1001 |
+
self.check_inputs(
|
1002 |
+
prompt,
|
1003 |
+
height,
|
1004 |
+
width,
|
1005 |
+
callback_steps,
|
1006 |
+
negative_prompt,
|
1007 |
+
prompt_embeds,
|
1008 |
+
negative_prompt_embeds,
|
1009 |
+
ip_adapter_image,
|
1010 |
+
ip_adapter_image_embeds,
|
1011 |
+
callback_on_step_end_tensor_inputs,
|
1012 |
+
)
|
1013 |
+
|
1014 |
+
self._guidance_scale = guidance_scale
|
1015 |
+
self._guidance_rescale = guidance_rescale
|
1016 |
+
self._clip_skip = clip_skip
|
1017 |
+
self._cross_attention_kwargs = cross_attention_kwargs
|
1018 |
+
self._interrupt = False
|
1019 |
+
|
1020 |
+
# 2. Define call parameters
|
1021 |
+
if prompt is not None and isinstance(prompt, str):
|
1022 |
+
batch_size = 1
|
1023 |
+
elif prompt is not None and isinstance(prompt, list):
|
1024 |
+
batch_size = len(prompt)
|
1025 |
+
else:
|
1026 |
+
batch_size = prompt_embeds.shape[0] // 3
|
1027 |
+
|
1028 |
+
device = self._execution_device
|
1029 |
+
|
1030 |
+
# 3. Encode input prompt
|
1031 |
+
lora_scale = (
|
1032 |
+
self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
|
1033 |
+
)
|
1034 |
+
|
1035 |
+
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
|
1036 |
+
prompt,
|
1037 |
+
device,
|
1038 |
+
num_images_per_prompt,
|
1039 |
+
self.do_classifier_free_guidance,
|
1040 |
+
negative_prompt,
|
1041 |
+
prompt_embeds=prompt_embeds,
|
1042 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
1043 |
+
lora_scale=lora_scale,
|
1044 |
+
clip_skip=self.clip_skip,
|
1045 |
+
)
|
1046 |
+
|
1047 |
+
# For classifier free guidance, we need to do two forward passes.
|
1048 |
+
# Here we concatenate the unconditional and text embeddings into a single batch
|
1049 |
+
# to avoid doing two forward passes
|
1050 |
+
if self.do_classifier_free_guidance:
|
1051 |
+
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
1052 |
+
|
1053 |
+
if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
|
1054 |
+
image_embeds = self.prepare_ip_adapter_image_embeds(
|
1055 |
+
ip_adapter_image,
|
1056 |
+
ip_adapter_image_embeds,
|
1057 |
+
device,
|
1058 |
+
batch_size * num_images_per_prompt,
|
1059 |
+
self.do_classifier_free_guidance,
|
1060 |
+
)
|
1061 |
+
|
1062 |
+
# 4. Prepare timesteps
|
1063 |
+
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
|
1064 |
+
|
1065 |
+
# 4.1 Prepare additional class embedding
|
1066 |
+
if self.unet.config.addition_time_embed_dim is not None:
|
1067 |
+
albedo_label = (1, 0, 0)
|
1068 |
+
rough_meta_label = (0, 1, 0)
|
1069 |
+
nump_label = (0, 0, 1)
|
1070 |
+
add_time_ids = self._get_add_time_ids(
|
1071 |
+
albedo_label,
|
1072 |
+
rough_meta_label,
|
1073 |
+
nump_label,
|
1074 |
+
dtype=prompt_embeds.dtype,
|
1075 |
+
)
|
1076 |
+
negative_add_time_ids = add_time_ids
|
1077 |
+
|
1078 |
+
if self.do_classifier_free_guidance:
|
1079 |
+
add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)
|
1080 |
+
add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
|
1081 |
+
|
1082 |
+
# 5. Prepare latent variables
|
1083 |
+
num_channels_latents = self.unet.config.in_channels_no_cond
|
1084 |
+
latents = self.prepare_latents(
|
1085 |
+
batch_size * num_images_per_prompt,
|
1086 |
+
num_channels_latents,
|
1087 |
+
height,
|
1088 |
+
width,
|
1089 |
+
prompt_embeds.dtype,
|
1090 |
+
device,
|
1091 |
+
generator,
|
1092 |
+
latents,
|
1093 |
+
copy_noise,
|
1094 |
+
)
|
1095 |
+
|
1096 |
+
# 5.1 Prepare conditional image latents
|
1097 |
+
cond_latents = None
|
1098 |
+
mask_image = [mask.cpu().numpy() for mask in masks]
|
1099 |
+
if cond_image is not None:
|
1100 |
+
cond_latents = self.prepare_cond_image_latents(
|
1101 |
+
cond_image,
|
1102 |
+
normal_image,
|
1103 |
+
mask_image,
|
1104 |
+
cond_vae,
|
1105 |
+
device,
|
1106 |
+
num_images_per_prompt,
|
1107 |
+
self.do_classifier_free_guidance
|
1108 |
+
)
|
1109 |
+
|
1110 |
+
init_latents = None
|
1111 |
+
if init_materials is not None:
|
1112 |
+
init_latents = self.prepare_init_latents(
|
1113 |
+
init_materials,
|
1114 |
+
device,
|
1115 |
+
num_images_per_prompt,
|
1116 |
+
self.do_classifier_free_guidance
|
1117 |
+
)
|
1118 |
+
|
1119 |
+
import cv2
|
1120 |
+
import numpy as np
|
1121 |
+
from PIL import Image
|
1122 |
+
masks = cv2.erode((masks[0].cpu().numpy()*255).astype(np.uint8), kernel=np.ones((5, 5), np.uint8), iterations=4)
|
1123 |
+
masks = Image.fromarray(masks.astype(np.uint8)).convert("L")
|
1124 |
+
masks = masks.resize((height // 8, width // 8), Image.NEAREST)
|
1125 |
+
masks = TF.to_tensor(masks).to(init_latents.device, init_latents.dtype).unsqueeze(1)
|
1126 |
+
# masks = torch.zeros_like(masks)
|
1127 |
+
|
1128 |
+
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
1129 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
1130 |
+
|
1131 |
+
# # 6.1 Add image embeds for IP-Adapter
|
1132 |
+
# added_cond_kwargs = (
|
1133 |
+
# {"image_embeds": image_embeds}
|
1134 |
+
# if (ip_adapter_image is not None or ip_adapter_image_embeds is not None)
|
1135 |
+
# else None
|
1136 |
+
# )
|
1137 |
+
|
1138 |
+
# 6.2 Optionally get Guidance Scale Embedding
|
1139 |
+
timestep_cond = None
|
1140 |
+
if self.unet.config.time_cond_proj_dim is not None:
|
1141 |
+
guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
|
1142 |
+
timestep_cond = self.get_guidance_scale_embedding(
|
1143 |
+
guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
|
1144 |
+
).to(device=device, dtype=latents.dtype)
|
1145 |
+
|
1146 |
+
# 7. Denoising loop
|
1147 |
+
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
1148 |
+
self._num_timesteps = len(timesteps)
|
1149 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
1150 |
+
for i, t in enumerate(timesteps):
|
1151 |
+
if self.interrupt:
|
1152 |
+
continue
|
1153 |
+
|
1154 |
+
# expand the latents if we are doing classifier free guidance
|
1155 |
+
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
|
1156 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
1157 |
+
|
1158 |
+
if cond_latents is not None:
|
1159 |
+
latent_model_input = torch.cat([latent_model_input, cond_latents], dim=1)
|
1160 |
+
|
1161 |
+
# predict the noise residual
|
1162 |
+
added_cond_kwargs = {}
|
1163 |
+
if self.unet.config.addition_time_embed_dim is not None:
|
1164 |
+
added_cond_kwargs["time_ids"] = add_time_ids
|
1165 |
+
if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
|
1166 |
+
added_cond_kwargs["image_embeds"] = image_embeds
|
1167 |
+
noise_pred = self.unet(
|
1168 |
+
latent_model_input,
|
1169 |
+
t,
|
1170 |
+
encoder_hidden_states=prompt_embeds,
|
1171 |
+
timestep_cond=timestep_cond,
|
1172 |
+
cross_attention_kwargs=self.cross_attention_kwargs,
|
1173 |
+
added_cond_kwargs=added_cond_kwargs,
|
1174 |
+
return_dict=False,
|
1175 |
+
)[0]
|
1176 |
+
|
1177 |
+
# perform guidance
|
1178 |
+
if self.do_classifier_free_guidance:
|
1179 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
1180 |
+
# only do cfg for roughness, metallic and bump
|
1181 |
+
noise_pred = noise_pred_uncond[:,4:] + self.guidance_scale * (noise_pred_text[:,4:] - noise_pred_uncond[:,4:])
|
1182 |
+
noise_pred = torch.cat([noise_pred_text[:, :4], noise_pred], dim=1)
|
1183 |
+
|
1184 |
+
if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
|
1185 |
+
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
|
1186 |
+
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale)
|
1187 |
+
|
1188 |
+
# compute the previous noisy sample x_t -> x_t-1
|
1189 |
+
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False, init_latents=init_latents, masks=masks)[0]
|
1190 |
+
|
1191 |
+
if callback_on_step_end is not None:
|
1192 |
+
callback_kwargs = {}
|
1193 |
+
for k in callback_on_step_end_tensor_inputs:
|
1194 |
+
callback_kwargs[k] = locals()[k]
|
1195 |
+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
1196 |
+
|
1197 |
+
latents = callback_outputs.pop("latents", latents)
|
1198 |
+
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
1199 |
+
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
1200 |
+
|
1201 |
+
# call the callback, if provided
|
1202 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
1203 |
+
progress_bar.update()
|
1204 |
+
if callback is not None and i % callback_steps == 0:
|
1205 |
+
step_idx = i // getattr(self.scheduler, "order", 1)
|
1206 |
+
callback(step_idx, t, latents)
|
1207 |
+
|
1208 |
+
if not output_type == "latent":
|
1209 |
+
if num_channels_latents == 12:
|
1210 |
+
latents = latents / self.vae.config.scaling_factor
|
1211 |
+
if unscale_latents:
|
1212 |
+
latents[:, 4:8] = unscale_latents_rm(latents[:, 4:8])
|
1213 |
+
latents[:, 8:] = unscale_latents_bump(latents[:, 8:])
|
1214 |
+
latents = torch.cat([latents[:, :4], latents[:, 4:8], latents[:, 8:]], dim=0)
|
1215 |
+
image = self.vae.decode(latents, return_dict=False, generator=generator)[
|
1216 |
+
0
|
1217 |
+
]
|
1218 |
+
else:
|
1219 |
+
image = self.vae.decode(latents/ self.vae.config.scaling_factor, return_dict=False, generator=generator)[
|
1220 |
+
0
|
1221 |
+
]
|
1222 |
+
has_nsfw_concept = None
|
1223 |
+
else:
|
1224 |
+
image = latents
|
1225 |
+
has_nsfw_concept = None
|
1226 |
+
|
1227 |
+
if has_nsfw_concept is None:
|
1228 |
+
do_denormalize = [True] * image.shape[0]
|
1229 |
+
else:
|
1230 |
+
do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
|
1231 |
+
|
1232 |
+
image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
|
1233 |
+
|
1234 |
+
# Offload all models
|
1235 |
+
self.maybe_free_model_hooks()
|
1236 |
+
|
1237 |
+
if not return_dict:
|
1238 |
+
return (image, has_nsfw_concept)
|
1239 |
+
|
1240 |
+
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
|
utils/rasterize.py
ADDED
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import nvdiffrast.torch as dr
|
2 |
+
import torch
|
3 |
+
|
4 |
+
from torch import Tensor
|
5 |
+
from jaxtyping import Float, Integer
|
6 |
+
from typing import Union, Tuple
|
7 |
+
|
8 |
+
class NVDiffRasterizerContext:
|
9 |
+
def __init__(self, context_type: str, device: torch.device) -> None:
|
10 |
+
self.device = device
|
11 |
+
self.ctx = self.initialize_context(context_type, device)
|
12 |
+
|
13 |
+
def initialize_context(
|
14 |
+
self, context_type: str, device: torch.device
|
15 |
+
) -> Union[dr.RasterizeGLContext, dr.RasterizeCudaContext]:
|
16 |
+
if context_type == "gl":
|
17 |
+
return dr.RasterizeGLContext(device=device)
|
18 |
+
elif context_type == "cuda":
|
19 |
+
return dr.RasterizeCudaContext(device=device)
|
20 |
+
else:
|
21 |
+
raise ValueError(f"Unknown rasterizer context type: {context_type}")
|
22 |
+
|
23 |
+
def vertex_transform(
|
24 |
+
self, verts: Float[Tensor, "Nv 3"], mvp_mtx: Float[Tensor, "B 4 4"]
|
25 |
+
) -> Float[Tensor, "B Nv 4"]:
|
26 |
+
with torch.amp.autocast("cuda", enabled=False):
|
27 |
+
verts_homo = torch.cat(
|
28 |
+
[verts, torch.ones([verts.shape[0], 1]).to(verts)], dim=-1
|
29 |
+
)
|
30 |
+
verts_clip = torch.matmul(verts_homo, mvp_mtx.permute(0, 2, 1))
|
31 |
+
return verts_clip
|
32 |
+
|
33 |
+
def rasterize(
|
34 |
+
self,
|
35 |
+
pos: Float[Tensor, "B Nv 4"],
|
36 |
+
tri: Integer[Tensor, "Nf 3"],
|
37 |
+
resolution: Union[int, Tuple[int, int]],
|
38 |
+
):
|
39 |
+
# rasterize in instance mode (single topology)
|
40 |
+
return dr.rasterize(self.ctx, pos.float(), tri.int(), resolution, grad_db=True)
|
41 |
+
|
42 |
+
def rasterize_one(
|
43 |
+
self,
|
44 |
+
pos: Float[Tensor, "Nv 4"],
|
45 |
+
tri: Integer[Tensor, "Nf 3"],
|
46 |
+
resolution: Union[int, Tuple[int, int]],
|
47 |
+
):
|
48 |
+
# rasterize one single mesh under a single viewpoint
|
49 |
+
rast, rast_db = self.rasterize(pos[None, ...], tri, resolution)
|
50 |
+
return rast[0], rast_db[0]
|
51 |
+
|
52 |
+
def antialias(
|
53 |
+
self,
|
54 |
+
color: Float[Tensor, "B H W C"],
|
55 |
+
rast: Float[Tensor, "B H W 4"],
|
56 |
+
pos: Float[Tensor, "B Nv 4"],
|
57 |
+
tri: Integer[Tensor, "Nf 3"],
|
58 |
+
) -> Float[Tensor, "B H W C"]:
|
59 |
+
return dr.antialias(color.float(), rast, pos.float(), tri.int())
|
60 |
+
|
61 |
+
def interpolate(
|
62 |
+
self,
|
63 |
+
attr: Float[Tensor, "B Nv C"],
|
64 |
+
rast: Float[Tensor, "B H W 4"],
|
65 |
+
tri: Integer[Tensor, "Nf 3"],
|
66 |
+
rast_db=None,
|
67 |
+
diff_attrs=None,
|
68 |
+
) -> Float[Tensor, "B H W C"]:
|
69 |
+
return dr.interpolate(
|
70 |
+
attr.float(), rast, tri.int(), rast_db=rast_db, diff_attrs=diff_attrs
|
71 |
+
)
|
72 |
+
|
73 |
+
def interpolate_one(
|
74 |
+
self,
|
75 |
+
attr: Float[Tensor, "Nv C"],
|
76 |
+
rast: Float[Tensor, "B H W 4"],
|
77 |
+
tri: Integer[Tensor, "Nf 3"],
|
78 |
+
rast_db=None,
|
79 |
+
diff_attrs=None,
|
80 |
+
) -> Float[Tensor, "B H W C"]:
|
81 |
+
return self.interpolate(attr[None, ...], rast, tri, rast_db, diff_attrs)
|
82 |
+
|
83 |
+
def texture_map_to_rgb(tex_map, uv_coordinates):
|
84 |
+
return dr.texture(tex_map.float(), uv_coordinates)
|
85 |
+
|
86 |
+
def render_rgb_from_texture_mesh_with_mask(
|
87 |
+
ctx,
|
88 |
+
mesh,
|
89 |
+
tex_map: Float[Tensor, "1 H W C"],
|
90 |
+
mvp_matrix: Float[Tensor, "batch 4 4"],
|
91 |
+
image_height: int,
|
92 |
+
image_width: int,
|
93 |
+
background_color: Tensor = torch.tensor([0.0, 0.0, 0.0]),
|
94 |
+
):
|
95 |
+
batch_size = mvp_matrix.shape[0]
|
96 |
+
tex_map = tex_map.contiguous()
|
97 |
+
if tex_map.dim() == 3:
|
98 |
+
tex_map = tex_map.unsqueeze(0) # Add batch dimension if missing
|
99 |
+
|
100 |
+
vertex_positions_clip = ctx.vertex_transform(mesh.v_pos, mvp_matrix)
|
101 |
+
rasterized_output, _ = ctx.rasterize(vertex_positions_clip, mesh.t_pos_idx, (image_height, image_width))
|
102 |
+
mask = rasterized_output[..., 3:] > 0
|
103 |
+
mask_antialiased = ctx.antialias(mask.float(), rasterized_output, vertex_positions_clip, mesh.t_pos_idx)
|
104 |
+
|
105 |
+
interpolated_texture_coords, _ = ctx.interpolate_one(mesh._v_tex, rasterized_output, mesh._t_tex_idx)
|
106 |
+
rgb_foreground = texture_map_to_rgb(tex_map.float(), interpolated_texture_coords)
|
107 |
+
rgb_foreground_batched = torch.zeros(batch_size, image_height, image_width, 3).to(rgb_foreground)
|
108 |
+
rgb_background_batched = torch.zeros(batch_size, image_height, image_width, 3).to(rgb_foreground)
|
109 |
+
rgb_background_batched += background_color.view(1, 1, 1, 3).to(rgb_foreground)
|
110 |
+
|
111 |
+
selector = mask[..., 0]
|
112 |
+
rgb_foreground_batched[selector] = rgb_foreground[selector]
|
113 |
+
|
114 |
+
# Use the anti-aliased mask for blending
|
115 |
+
final_rgb = torch.lerp(rgb_background_batched, rgb_foreground_batched, mask_antialiased)
|
116 |
+
final_rgb_aa = ctx.antialias(final_rgb, rasterized_output, vertex_positions_clip, mesh.t_pos_idx)
|
117 |
+
|
118 |
+
return final_rgb_aa, selector
|
119 |
+
|
120 |
+
|
121 |
+
def render_geo_from_mesh(ctx, mesh, mvp_matrix, image_height, image_width):
|
122 |
+
device = mvp_matrix.device
|
123 |
+
vertex_positions_clip = ctx.vertex_transform(mesh.v_pos.to(device), mvp_matrix)
|
124 |
+
rasterized_output, _ = ctx.rasterize(vertex_positions_clip, mesh.t_pos_idx.to(device), (image_height, image_width))
|
125 |
+
interpolated_positions, _ = ctx.interpolate_one(mesh.v_pos.to(device), rasterized_output, mesh.t_pos_idx.to(device))
|
126 |
+
interpolated_normals, _ = ctx.interpolate_one(mesh.v_normal.to(device).contiguous(), rasterized_output, mesh.t_pos_idx.to(device))
|
127 |
+
|
128 |
+
mask = rasterized_output[..., 3:] > 0
|
129 |
+
mask_antialiased = ctx.antialias(mask.float(), rasterized_output, vertex_positions_clip, mesh.t_pos_idx.to(device))
|
130 |
+
|
131 |
+
batch_size = mvp_matrix.shape[0]
|
132 |
+
rgb_foreground_pos_batched = torch.zeros(batch_size, image_height, image_width, 3).to(interpolated_positions)
|
133 |
+
rgb_foreground_norm_batched = torch.zeros(batch_size, image_height, image_width, 3).to(interpolated_positions)
|
134 |
+
rgb_background_batched = torch.zeros(batch_size, image_height, image_width, 3).to(interpolated_positions)
|
135 |
+
|
136 |
+
selector = mask[..., 0]
|
137 |
+
rgb_foreground_pos_batched[selector] = interpolated_positions[selector]
|
138 |
+
rgb_foreground_norm_batched[selector] = interpolated_normals[selector]
|
139 |
+
|
140 |
+
final_pos_rgb = torch.lerp(rgb_background_batched, rgb_foreground_pos_batched, mask_antialiased)
|
141 |
+
final_norm_rgb = torch.lerp(rgb_background_batched, rgb_foreground_norm_batched, mask_antialiased)
|
142 |
+
final_pos_rgb_aa = ctx.antialias(final_pos_rgb, rasterized_output, vertex_positions_clip, mesh.t_pos_idx.to(device))
|
143 |
+
final_norm_rgb_aa = ctx.antialias(final_norm_rgb, rasterized_output, vertex_positions_clip, mesh.t_pos_idx.to(device))
|
144 |
+
|
145 |
+
return final_pos_rgb_aa, final_norm_rgb_aa, mask_antialiased
|
146 |
+
|
147 |
+
def rasterize_position_and_normal_maps(ctx, mesh, rasterize_height, rasterize_width):
|
148 |
+
device = ctx.device
|
149 |
+
# Convert mesh data to torch tensors
|
150 |
+
mesh_v = mesh.v_pos.to(device)
|
151 |
+
mesh_f = mesh.t_pos_idx.to(device)
|
152 |
+
uvs_tensor = mesh._v_tex.to(device)
|
153 |
+
indices_tensor = mesh._t_tex_idx.to(device)
|
154 |
+
normal_v = mesh.v_normal.to(device).contiguous()
|
155 |
+
|
156 |
+
# Interpolate mesh data
|
157 |
+
uv_clip = uvs_tensor[None, ...] * 2.0 - 1.0
|
158 |
+
uv_clip_padded = torch.cat((uv_clip, torch.zeros_like(uv_clip[..., :1]), torch.ones_like(uv_clip[..., :1])), dim=-1)
|
159 |
+
rasterized_output, _ = ctx.rasterize(uv_clip_padded, indices_tensor.int(), (rasterize_height, rasterize_width))
|
160 |
+
|
161 |
+
# Interpolate positions.
|
162 |
+
position_map, _ = ctx.interpolate_one(mesh_v, rasterized_output, mesh_f.int())
|
163 |
+
normal_map, _ = ctx.interpolate_one(normal_v, rasterized_output, mesh_f.int())
|
164 |
+
rasterization_mask = rasterized_output[..., 3:4] > 0
|
165 |
+
|
166 |
+
return position_map, normal_map, rasterization_mask
|
utils/render_utils.py
ADDED
@@ -0,0 +1,352 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from functools import cache
|
3 |
+
from typing import Dict, Union
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
import torch.nn.functional as F
|
8 |
+
from einops import rearrange
|
9 |
+
from jaxtyping import Float
|
10 |
+
from PIL import Image
|
11 |
+
from torch import Tensor
|
12 |
+
from torchvision.transforms import ToPILImage
|
13 |
+
|
14 |
+
from .rasterize import (NVDiffRasterizerContext,
|
15 |
+
rasterize_position_and_normal_maps,
|
16 |
+
render_geo_from_mesh,
|
17 |
+
render_rgb_from_texture_mesh_with_mask)
|
18 |
+
|
19 |
+
CTX = NVDiffRasterizerContext('cuda', 'cuda')
|
20 |
+
|
21 |
+
def setup_lights():
|
22 |
+
"""
|
23 |
+
Set three random point lights in the scene.
|
24 |
+
"""
|
25 |
+
raise NotImplementedError("setup_lights function is not implemented yet.")
|
26 |
+
|
27 |
+
def render_views(mesh, texture, mvp_matrix, lights=None, img_size=(512, 512)) -> Image.Image:
|
28 |
+
"""
|
29 |
+
Render the RGB color images of the mesh. The background will be transparent.
|
30 |
+
:param mesh: The mesh to be rendered. Class: Mesh.
|
31 |
+
:param texture: The texture of the mesh, a tensor of shape (H, W, 3).
|
32 |
+
:param mvp_matrix: The Model-View-Projection matrix for rendering, a tensor of shape (n_v, 4, 4).
|
33 |
+
:param lights: The lights in the scene.
|
34 |
+
:param img_size: The size of the output image, a tuple (height, width).
|
35 |
+
:return: A concatenated PIL Image.
|
36 |
+
"""
|
37 |
+
if texture.shape[-1] != 3:
|
38 |
+
texture = texture.permute(1, 2, 0)
|
39 |
+
image_height, image_width = img_size
|
40 |
+
rgb_cond, mask = render_rgb_from_texture_mesh_with_mask(
|
41 |
+
CTX, mesh, texture, mvp_matrix, image_height, image_width, torch.tensor([0.0, 0.0, 0.0], device='cuda'))
|
42 |
+
|
43 |
+
if mvp_matrix.shape[0] == 0:
|
44 |
+
return None
|
45 |
+
|
46 |
+
pil_images = []
|
47 |
+
for i in range(mvp_matrix.shape[0]):
|
48 |
+
rgba_img = torch.cat([rgb_cond[i], mask[i].unsqueeze(-1)], dim=-1) # [H, W, 3] + [H, W, 1] -> [H, W, 4]
|
49 |
+
rgba_img = (rgba_img * 255).to(torch.uint8) # Convert to uint8
|
50 |
+
rgba_img = rgba_img.cpu().numpy() # Convert to numpy array
|
51 |
+
pil_images.append(Image.fromarray(rgba_img, mode='RGBA'))
|
52 |
+
|
53 |
+
if not pil_images:
|
54 |
+
return None
|
55 |
+
|
56 |
+
total_width = sum(img.width for img in pil_images)
|
57 |
+
max_height = max(img.height for img in pil_images)
|
58 |
+
|
59 |
+
concatenated_image = Image.new('RGBA', (total_width, max_height))
|
60 |
+
|
61 |
+
current_x = 0
|
62 |
+
for img in pil_images:
|
63 |
+
concatenated_image.paste(img, (current_x, 0))
|
64 |
+
current_x += img.width
|
65 |
+
|
66 |
+
return concatenated_image
|
67 |
+
|
68 |
+
def render_geo_views_tensor(mesh, mvp_matrix, img_size=(512, 512)) -> tuple[torch.Tensor, torch.Tensor]:
|
69 |
+
"""
|
70 |
+
render the geometry information including position and normal from views that mvp matrix implies.
|
71 |
+
"""
|
72 |
+
image_height, image_width = img_size
|
73 |
+
position_images, normal_images, mask_images = render_geo_from_mesh(CTX, mesh, mvp_matrix, image_height, image_width)
|
74 |
+
return position_images, normal_images, mask_images
|
75 |
+
|
76 |
+
def render_geo_map(mesh, map_size=(1024, 1024)) -> tuple[torch.Tensor, torch.Tensor]:
|
77 |
+
"""
|
78 |
+
Render the geometry information including position and normal from UV parameterization.
|
79 |
+
"""
|
80 |
+
map_height, map_width = map_size
|
81 |
+
position_images, normal_images, mask = rasterize_position_and_normal_maps(CTX, mesh, map_height, map_width)
|
82 |
+
# out_imgs = []
|
83 |
+
# if mask.ndim == 4:
|
84 |
+
# mask = mask[0]
|
85 |
+
# for img_map in [position_images, normal_images]:
|
86 |
+
# if img_map.ndim == 4:
|
87 |
+
# img_map = img_map[0]
|
88 |
+
# # normalize to [0, 1]
|
89 |
+
# img_map = (img_map - img_map.min()) / (img_map.max() - img_map.min() + 1e-6)
|
90 |
+
|
91 |
+
# rgba_img = torch.cat([img_map, mask], dim=-1) # [H, W, 3] + [H, W, 1] -> [H, W, 4]
|
92 |
+
# rgba_img = (rgba_img * 255).to(torch.uint8) # Convert to uint8
|
93 |
+
# rgba_img = rgba_img.cpu().numpy() # Convert to numpy array
|
94 |
+
# out_imgs.append(Image.fromarray(rgba_img, mode='RGBA'))
|
95 |
+
return position_images, normal_images
|
96 |
+
|
97 |
+
@cache
|
98 |
+
def get_pure_texture(uv_size, color=(int("0x55", 16), int("0x55", 16), int("0x55", 16))) -> torch.Tensor:
|
99 |
+
"""
|
100 |
+
get a pure texture image with the specified color.
|
101 |
+
:param uv_size: The size of the UV map (height, width).
|
102 |
+
:param color: The color of the texture, default is "0x555555" (light gray).
|
103 |
+
:return: A texture image tensor of shape (height, width, 3).
|
104 |
+
"""
|
105 |
+
height, width = uv_size
|
106 |
+
|
107 |
+
color = torch.tensor(color, dtype=torch.float32).view(1, 1, 3) / 255.0
|
108 |
+
texture = color.repeat(height, width, 1)
|
109 |
+
|
110 |
+
return texture
|
111 |
+
|
112 |
+
def get_c2w(
|
113 |
+
azimuth_deg,
|
114 |
+
elevation_deg,
|
115 |
+
camera_distances,):
|
116 |
+
assert len(azimuth_deg) == len(elevation_deg) == len(camera_distances)
|
117 |
+
n_views = len(azimuth_deg)
|
118 |
+
#camera_distances = torch.full_like(elevation_deg, dis)
|
119 |
+
elevation = elevation_deg * math.pi / 180
|
120 |
+
azimuth = azimuth_deg * math.pi / 180
|
121 |
+
camera_positions = torch.stack(
|
122 |
+
[
|
123 |
+
camera_distances * torch.cos(elevation) * torch.cos(azimuth),
|
124 |
+
camera_distances * torch.cos(elevation) * torch.sin(azimuth),
|
125 |
+
camera_distances * torch.sin(elevation),
|
126 |
+
],
|
127 |
+
dim=-1,
|
128 |
+
)
|
129 |
+
center = torch.zeros_like(camera_positions)
|
130 |
+
up = torch.as_tensor([0, 0, 1], dtype=torch.float32)[None, :].repeat(n_views, 1)
|
131 |
+
lookat = F.normalize(center - camera_positions, dim=-1)
|
132 |
+
right = F.normalize(torch.cross(lookat, up, dim=-1), dim=-1)
|
133 |
+
up = F.normalize(torch.cross(right, lookat, dim=-1), dim=-1)
|
134 |
+
c2w3x4 = torch.cat(
|
135 |
+
[torch.stack([right, up, -lookat], dim=-1), camera_positions[:, :, None]],
|
136 |
+
dim=-1,
|
137 |
+
)
|
138 |
+
c2w = torch.cat([c2w3x4, torch.zeros_like(c2w3x4[:, :1])], dim=1)
|
139 |
+
c2w[:, 3, 3] = 1.0
|
140 |
+
return c2w
|
141 |
+
|
142 |
+
def camera_strategy_test_4_90deg(
|
143 |
+
mesh: Dict,
|
144 |
+
num_views: int = 4,
|
145 |
+
**kwargs) -> Dict:
|
146 |
+
"""
|
147 |
+
For sup views: Random elevation and azimuth, fixed distance and close fov.
|
148 |
+
:param num_views: number of supervision views
|
149 |
+
:param kwargs: additional arguments
|
150 |
+
"""
|
151 |
+
# Default camera intrinsics
|
152 |
+
default_elevation = 10
|
153 |
+
default_camera_lens = 50
|
154 |
+
default_camera_sensor_width = 36
|
155 |
+
default_fovy = 2 * np.arctan(default_camera_sensor_width / (2 * default_camera_lens))
|
156 |
+
|
157 |
+
bbox_size = mesh.v_pos.max(dim=0)[0] - mesh.v_pos.min(dim=0)[0]
|
158 |
+
distance = default_camera_lens / default_camera_sensor_width * \
|
159 |
+
math.sqrt(bbox_size[0] ** 2 + bbox_size[1] ** 2 + bbox_size[2] ** 2)
|
160 |
+
|
161 |
+
all_azimuth_deg = torch.linspace(0, 360.0, num_views + 1)[:num_views] - 90
|
162 |
+
|
163 |
+
all_elevation_deg = torch.full_like(all_azimuth_deg, default_elevation)
|
164 |
+
|
165 |
+
# Get the corresponding azimuth and elevation
|
166 |
+
view_idxs = torch.arange(0, num_views)
|
167 |
+
azimuth = all_azimuth_deg[view_idxs]
|
168 |
+
elevation = all_elevation_deg[view_idxs]
|
169 |
+
camera_distances = torch.full_like(elevation, distance)
|
170 |
+
c2w = get_c2w(azimuth, elevation, camera_distances)
|
171 |
+
|
172 |
+
if c2w.ndim == 2:
|
173 |
+
w2c: Float[Tensor, "4 4"] = torch.zeros(4, 4).to(c2w)
|
174 |
+
w2c[:3, :3] = c2w[:3, :3].permute(1, 0)
|
175 |
+
w2c[:3, 3:] = -c2w[:3, :3].permute(1, 0) @ c2w[:3, 3:]
|
176 |
+
w2c[3, 3] = 1.0
|
177 |
+
else:
|
178 |
+
w2c: Float[Tensor, "B 4 4"] = torch.zeros(c2w.shape[0], 4, 4).to(c2w)
|
179 |
+
w2c[:, :3, :3] = c2w[:, :3, :3].permute(0, 2, 1)
|
180 |
+
w2c[:, :3, 3:] = -c2w[:, :3, :3].permute(0, 2, 1) @ c2w[:, :3, 3:]
|
181 |
+
w2c[:, 3, 3] = 1.0
|
182 |
+
|
183 |
+
fovy = torch.full_like(azimuth, default_fovy)
|
184 |
+
|
185 |
+
return {
|
186 |
+
'cond_sup_view_idxs': view_idxs,
|
187 |
+
'cond_sup_c2w': c2w,
|
188 |
+
'cond_sup_w2c': w2c,
|
189 |
+
'cond_sup_fovy': fovy,
|
190 |
+
# 'cond_sup_azimuth': azimuth,
|
191 |
+
# 'cond_sup_elevation': elevation,
|
192 |
+
}
|
193 |
+
|
194 |
+
def _get_projection_matrix(
|
195 |
+
fovy: Union[float, Float[Tensor, "B"]], aspect_wh: float, near: float, far: float
|
196 |
+
) -> Float[Tensor, "*B 4 4"]:
|
197 |
+
if isinstance(fovy, float):
|
198 |
+
proj_mtx = torch.zeros(4, 4, dtype=torch.float32)
|
199 |
+
proj_mtx[0, 0] = 1.0 / (math.tan(fovy / 2.0) * aspect_wh)
|
200 |
+
proj_mtx[1, 1] = -1.0 / math.tan(
|
201 |
+
fovy / 2.0
|
202 |
+
) # add a negative sign here as the y axis is flipped in nvdiffrast output
|
203 |
+
proj_mtx[2, 2] = -(far + near) / (far - near)
|
204 |
+
proj_mtx[2, 3] = -2.0 * far * near / (far - near)
|
205 |
+
proj_mtx[3, 2] = -1.0
|
206 |
+
else:
|
207 |
+
batch_size = fovy.shape[0]
|
208 |
+
proj_mtx = torch.zeros(batch_size, 4, 4, dtype=torch.float32)
|
209 |
+
proj_mtx[:, 0, 0] = 1.0 / (torch.tan(fovy / 2.0) * aspect_wh)
|
210 |
+
proj_mtx[:, 1, 1] = -1.0 / torch.tan(
|
211 |
+
fovy / 2.0
|
212 |
+
) # add a negative sign here as the y axis is flipped in nvdiffrast output
|
213 |
+
proj_mtx[:, 2, 2] = -(far + near) / (far - near)
|
214 |
+
proj_mtx[:, 2, 3] = -2.0 * far * near / (far - near)
|
215 |
+
proj_mtx[:, 3, 2] = -1.0
|
216 |
+
return proj_mtx
|
217 |
+
|
218 |
+
def _get_mvp_matrix(
|
219 |
+
c2w: Float[Tensor, "*B 4 4"], proj_mtx: Float[Tensor, "*B 4 4"]
|
220 |
+
) -> Float[Tensor, "*B 4 4"]:
|
221 |
+
# calculate w2c from c2w: R' = Rt, t' = -Rt * t
|
222 |
+
# mathematically equivalent to (c2w)^-1
|
223 |
+
if c2w.ndim == 2:
|
224 |
+
assert proj_mtx.ndim == 2
|
225 |
+
w2c: Float[Tensor, "4 4"] = torch.zeros(4, 4).to(c2w)
|
226 |
+
w2c[:3, :3] = c2w[:3, :3].permute(1, 0)
|
227 |
+
w2c[:3, 3:] = -c2w[:3, :3].permute(1, 0) @ c2w[:3, 3:]
|
228 |
+
w2c[3, 3] = 1.0
|
229 |
+
else:
|
230 |
+
w2c: Float[Tensor, "B 4 4"] = torch.zeros(c2w.shape[0], 4, 4).to(c2w)
|
231 |
+
w2c[:, :3, :3] = c2w[:, :3, :3].permute(0, 2, 1)
|
232 |
+
w2c[:, :3, 3:] = -c2w[:, :3, :3].permute(0, 2, 1) @ c2w[:, :3, 3:]
|
233 |
+
w2c[:, 3, 3] = 1.0
|
234 |
+
# calculate mvp matrix by proj_mtx @ w2c (mv_mtx)
|
235 |
+
mvp_mtx = proj_mtx @ w2c
|
236 |
+
return mvp_mtx
|
237 |
+
|
238 |
+
def get_mvp_matrix(mesh, num_views=4, width=512, height=512, strategy="strategy_test_4_90deg"):
|
239 |
+
"""
|
240 |
+
Get Model-View-Projection (MVP) matrix for rendering views.
|
241 |
+
:param mesh: The mesh object to determine camera positioning.
|
242 |
+
:param num_views: Number of views to generate, default is 4.
|
243 |
+
:param width: Image width for projection matrix calculation.
|
244 |
+
:param height: Image height for projection matrix calculation.
|
245 |
+
:param strategy: Camera positioning strategy, default is "strategy_test_4_90deg".
|
246 |
+
:return: MVP matrix and world-to-camera transformation matrix.
|
247 |
+
"""
|
248 |
+
if strategy == "strategy_test_4_90deg":
|
249 |
+
camera_info = camera_strategy_test_4_90deg(
|
250 |
+
mesh=mesh, # Dummy mesh for camera strategy
|
251 |
+
num_views=num_views,
|
252 |
+
)
|
253 |
+
cond_sup_fovy = camera_info["cond_sup_fovy"]
|
254 |
+
cond_sup_c2w = camera_info["cond_sup_c2w"]
|
255 |
+
cond_sup_w2c = camera_info["cond_sup_w2c"]
|
256 |
+
# cond_sup_azimuth = camera_info["cond_sup_azimuth"]
|
257 |
+
# cond_sup_elevation = camera_info["cond_sup_elevation"]
|
258 |
+
else:
|
259 |
+
raise ValueError(f"Unsupported camera strategy: {strategy}")
|
260 |
+
cond_sup_proj_mtx: Float[Tensor, "B 4 4"] = _get_projection_matrix(
|
261 |
+
cond_sup_fovy, width / height, 0.1, 1000.0
|
262 |
+
)
|
263 |
+
mvp_mtx: Float[Tensor, "B 4 4"] = _get_mvp_matrix(cond_sup_c2w, cond_sup_proj_mtx)
|
264 |
+
return mvp_mtx, cond_sup_w2c
|
265 |
+
|
266 |
+
@torch.cuda.amp.autocast(enabled=False)
|
267 |
+
def _get_depth_noraml_map_with_mask(xyz_map, normal_map, mask, w2c, device="cuda", background_color=(0, 0, 0)):
|
268 |
+
"""
|
269 |
+
Get depth and normal map with mask from position and normal images.
|
270 |
+
:param xyz_map: Position images in world coordinate, shape [B, Nv, H, W, 3]. It is the return value of `render_geo_views`.
|
271 |
+
:param normal_map: Normal images in world coordinate, shape [B, Nv, H, W, 3]. It is the return value of `render_geo_views`.
|
272 |
+
:param mask: Mask for the images, shape [B, Nv, H, W]. It is the return value of `render_geo_views`.
|
273 |
+
:param w2c: World to camera transformation matrix, shape [B, Nv, 4, 4].
|
274 |
+
:param device: Device to run the computation on, default is "cuda".
|
275 |
+
:param background_color: Background color for the depth and normal maps.
|
276 |
+
:return: depth_map, normal_map, mask
|
277 |
+
"""
|
278 |
+
w2c = w2c.to(device)
|
279 |
+
|
280 |
+
# Render world coordinate position map and mask
|
281 |
+
B, Nv, H, W, C = xyz_map.shape # B: batch size, Nv: number of views, H/W: height/width, C: channels
|
282 |
+
assert Nv == 1
|
283 |
+
# Rearrange tensors for batch processing
|
284 |
+
xyz_map = rearrange(xyz_map, "B Nv H W C -> (B Nv) (H W) C")
|
285 |
+
normal_map = rearrange(normal_map, "B Nv H W C -> (B Nv) (H W) C")
|
286 |
+
w2c = rearrange(w2c, "B Nv C1 C2 -> (B Nv) C1 C2")
|
287 |
+
|
288 |
+
# Create homogeneous coordinates and correctly transform to camera coordinate system
|
289 |
+
# Points in world coordinate system need to be multiplied by world-to-camera transformation matrix
|
290 |
+
B_Nv, N, C = xyz_map.shape
|
291 |
+
ones = torch.ones(B_Nv, N, 1, dtype=xyz_map.dtype, device=xyz_map.device)
|
292 |
+
homogeneous_xyz = torch.cat([xyz_map, ones], dim=2) # [x,y,z,1]
|
293 |
+
zeros = torch.zeros(B_Nv, N, 1, dtype=xyz_map.dtype, device=xyz_map.device)
|
294 |
+
homogeneous_normal = torch.cat([normal_map, zeros], dim=2) # [x,y,z,1]
|
295 |
+
|
296 |
+
camera_coords = torch.bmm(homogeneous_xyz, w2c.transpose(1, 2))
|
297 |
+
camera_normals = torch.bmm(homogeneous_normal, w2c.transpose(1, 2))
|
298 |
+
|
299 |
+
depth_map = camera_coords[..., 2:3] # Z-axis is the depth direction in camera coordinate system
|
300 |
+
depth_map = rearrange(depth_map, "(B Nv) (H W) 1 -> B Nv H W", B=B, Nv=Nv, H=H, W=W)
|
301 |
+
normal_map = camera_normals[..., :3] # Keep only x, y, z components
|
302 |
+
normal_map = rearrange(normal_map, "(B Nv) (H W) c -> B Nv H W c", B=B, Nv=Nv, H=H, W=W)
|
303 |
+
assert depth_map.dtype == torch.float32, f"depth_map must be float32, otherwise there will be artifact in controlnet generated pictures, but got {depth_map.dtype}"
|
304 |
+
|
305 |
+
# Calculate min and max values
|
306 |
+
min_depth = depth_map.amin((1,2,3), keepdim=True)
|
307 |
+
max_depth = depth_map.amax((1,2,3), keepdim=True)
|
308 |
+
|
309 |
+
depth_map = (depth_map - min_depth) / (max_depth - min_depth + 1e-6) # Normalize to [0, 1]
|
310 |
+
|
311 |
+
depth_map = depth_map.repeat(1, 3, 1, 1) # Repeat 3 times to get RGB depth map
|
312 |
+
normal_map = normal_map * 0.5 + 0.5 # Normalize to [0, 1], [B, Nv, H, W, 3]
|
313 |
+
normal_map = normal_map[:,0].permute(0, 3, 1, 2) # [B, 3, H, W]
|
314 |
+
|
315 |
+
rgb_background_batched = torch.tensor(background_color, dtype=torch.float32, device=device).view(1, 3, 1, 1)
|
316 |
+
depth_map = torch.lerp(rgb_background_batched, depth_map, mask)
|
317 |
+
normal_map = torch.lerp(rgb_background_batched, normal_map, mask)
|
318 |
+
|
319 |
+
return depth_map, normal_map, mask
|
320 |
+
|
321 |
+
def get_silhouette_image(position_imgs, normal_imgs, mask_imgs, w2c, selected_view="First View") -> tuple[Image.Image, Image.Image]:
|
322 |
+
"""
|
323 |
+
Get the silhouette image based on geometry image.
|
324 |
+
:param position_imgs: Position images from different views, shape [Nv, H, W, 3].
|
325 |
+
:param normal_imgs: Normal images from different views, shape [Nv, H, W, 3].
|
326 |
+
:param mask_imgs: Mask for the images, shape [Nv, H, W]. It is the return value of `render_geo_views`.
|
327 |
+
:param w2c: World to camera transformation matrix, shape [Nv, 4, 4].
|
328 |
+
:param selected_view: The view selected for generating the image condition.
|
329 |
+
:return: silhouettes (including depth and normal, which is in camera coordinate system).
|
330 |
+
"""
|
331 |
+
view_id_map = {
|
332 |
+
"First View": 0,
|
333 |
+
"Second View": 1,
|
334 |
+
"Third View": 2,
|
335 |
+
"Fourth View": 3
|
336 |
+
}
|
337 |
+
view_id = view_id_map[selected_view]
|
338 |
+
position_view = position_imgs[view_id: view_id + 1]
|
339 |
+
normal_view = normal_imgs[view_id: view_id + 1]
|
340 |
+
mask_view = mask_imgs[view_id: view_id + 1]
|
341 |
+
w2c = w2c[view_id: view_id + 1] # Select the corresponding w2c for the view
|
342 |
+
|
343 |
+
depth_img, normal_img, mask = _get_depth_noraml_map_with_mask(
|
344 |
+
position_view.unsqueeze(0), # Add batch dimension
|
345 |
+
normal_view.unsqueeze(0),
|
346 |
+
mask_view.unsqueeze(0),
|
347 |
+
w2c.unsqueeze(0),
|
348 |
+
)
|
349 |
+
|
350 |
+
to_img = ToPILImage()
|
351 |
+
return to_img(depth_img.squeeze(0)), to_img(normal_img.squeeze(0)), to_img(mask.squeeze(0))
|
352 |
+
|
utils/texture_generation.py
ADDED
@@ -0,0 +1,309 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import threading
|
3 |
+
from dataclasses import dataclass
|
4 |
+
from urllib.parse import urlparse
|
5 |
+
|
6 |
+
import gradio as gr
|
7 |
+
import numpy as np
|
8 |
+
import spaces
|
9 |
+
import torch
|
10 |
+
from diffusers.models import AutoencoderKLWan
|
11 |
+
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
|
12 |
+
from einops import rearrange
|
13 |
+
from jaxtyping import Float
|
14 |
+
from peft import LoraConfig
|
15 |
+
from PIL import Image
|
16 |
+
from torch import Tensor
|
17 |
+
|
18 |
+
from wan.pipeline_wan_t2tex_extra import WanT2TexPipeline
|
19 |
+
from wan.wan_t2tex_transformer_3d_extra import WanT2TexTransformer3DModel
|
20 |
+
|
21 |
+
TEX_PIPE = None
|
22 |
+
VAE = None
|
23 |
+
LATENTS_MEAN, LATENTS_STD = None, None
|
24 |
+
TEX_PIPE_LOCK = threading.Lock()
|
25 |
+
|
26 |
+
@dataclass
|
27 |
+
class Config:
|
28 |
+
video_base_name: str = "Wan-AI/Wan2.1-T2V-1.3B-Diffusers"
|
29 |
+
seqtex_path: str = "https://huggingface.co/VAST-AI/SeqTex/resolve/main/.gitattributes/edm2_ema_12176_clean.pth"
|
30 |
+
min_noise_level_index: int = 15 # which is same as paper [WorldMem](https://arxiv.org/pdf/2504.12369v1)
|
31 |
+
|
32 |
+
use_causal_mask: bool = False
|
33 |
+
addtional_qk_geometry: bool = False
|
34 |
+
use_normal: bool = True
|
35 |
+
use_position: bool = True
|
36 |
+
randomly_init: bool = True # we load the weights from a corresponding ckpt
|
37 |
+
|
38 |
+
num_views: int = 4
|
39 |
+
uv_num_views: int = 1
|
40 |
+
mv_height: int = 512
|
41 |
+
mv_width: int = 512
|
42 |
+
uv_height: int = 1024
|
43 |
+
uv_width: int = 1024
|
44 |
+
|
45 |
+
flow_shift: float = 5.0
|
46 |
+
eval_guidance_scale: float = 1.0
|
47 |
+
eval_num_inference_steps: int = 30
|
48 |
+
eval_seed: int = 42
|
49 |
+
|
50 |
+
lora_rank: int = 128
|
51 |
+
lora_alpha: int = 64
|
52 |
+
|
53 |
+
cfg = Config()
|
54 |
+
|
55 |
+
def load_model_weights(model_path: str, map_location="cpu"):
|
56 |
+
"""
|
57 |
+
Load model weights from either a URL or local file path.
|
58 |
+
|
59 |
+
Args:
|
60 |
+
model_path (str): Path to model weights, can be URL or local file path
|
61 |
+
map_location (str): Device to map the model to
|
62 |
+
|
63 |
+
Returns:
|
64 |
+
Dict: Loaded state dictionary
|
65 |
+
"""
|
66 |
+
# Check if the path is a URL
|
67 |
+
parsed_url = urlparse(model_path)
|
68 |
+
if parsed_url.scheme in ('http', 'https'):
|
69 |
+
# Load from URL using torch.hub
|
70 |
+
try:
|
71 |
+
state_dict = torch.hub.load_state_dict_from_url(
|
72 |
+
model_path,
|
73 |
+
map_location=map_location,
|
74 |
+
progress=True
|
75 |
+
)
|
76 |
+
return state_dict
|
77 |
+
except Exception as e:
|
78 |
+
gr.Warning(f"Failed to load from URL: {e}")
|
79 |
+
raise e
|
80 |
+
else:
|
81 |
+
# Load from local file path
|
82 |
+
if not os.path.exists(model_path):
|
83 |
+
raise FileNotFoundError(f"Local model file not found: {model_path}")
|
84 |
+
return torch.load(model_path, map_location=map_location)
|
85 |
+
|
86 |
+
def lazy_get_seqtex_pipe():
|
87 |
+
"""
|
88 |
+
Lazy load the SeqTex pipeline for texture generation.
|
89 |
+
"""
|
90 |
+
global TEX_PIPE, VAE, LATENTS_MEAN, LATENTS_STD
|
91 |
+
if TEX_PIPE is not None:
|
92 |
+
return TEX_PIPE
|
93 |
+
gr.Info("First called, loading SeqTex pipeline... It may take about 1 minute.")
|
94 |
+
with TEX_PIPE_LOCK:
|
95 |
+
if TEX_PIPE is not None:
|
96 |
+
return TEX_PIPE
|
97 |
+
|
98 |
+
# Pipeline
|
99 |
+
TEX_PIPE = WanT2TexPipeline.from_pretrained(cfg.video_base_name)
|
100 |
+
|
101 |
+
# Models
|
102 |
+
transformer = WanT2TexTransformer3DModel(
|
103 |
+
TEX_PIPE.transformer,
|
104 |
+
use_causal_mask=cfg.use_causal_mask,
|
105 |
+
addtional_qk_geo=cfg.addtional_qk_geometry,
|
106 |
+
use_normal=cfg.use_normal,
|
107 |
+
use_position=cfg.use_position,
|
108 |
+
randomly_init=cfg.randomly_init,
|
109 |
+
)
|
110 |
+
transformer.add_adapter(
|
111 |
+
LoraConfig(
|
112 |
+
r=cfg.lora_rank,
|
113 |
+
lora_alpha=cfg.lora_alpha,
|
114 |
+
init_lora_weights=True,
|
115 |
+
target_modules=["attn1.to_q", "attn1.to_k", "attn1.to_v", "attn1.to_out.0", "attn1.to_out.2",
|
116 |
+
"ffn.net.0.proj", "ffn.net.2"],
|
117 |
+
)
|
118 |
+
)
|
119 |
+
# load transformer
|
120 |
+
state_dict = load_model_weights(cfg.seqtex_path, map_location="cpu")
|
121 |
+
transformer.load_state_dict(state_dict, strict=True)
|
122 |
+
TEX_PIPE.transformer = transformer
|
123 |
+
|
124 |
+
VAE = AutoencoderKLWan.from_pretrained(cfg.video_base_name, subfolder="vae", torch_dtype=torch.float32).to("cuda").requires_grad_(False)
|
125 |
+
TEX_PIPE.vae = VAE
|
126 |
+
|
127 |
+
# Some useful parameters
|
128 |
+
LATENTS_MEAN = torch.tensor(VAE.config.latents_mean).view(
|
129 |
+
1, VAE.config.z_dim, 1, 1, 1
|
130 |
+
).to("cuda", dtype=torch.float32)
|
131 |
+
LATENTS_STD = 1.0 / torch.tensor(VAE.config.latents_std).view(
|
132 |
+
1, VAE.config.z_dim, 1, 1, 1
|
133 |
+
).to("cuda", dtype=torch.float32)
|
134 |
+
|
135 |
+
scheduler: FlowMatchEulerDiscreteScheduler = (
|
136 |
+
FlowMatchEulerDiscreteScheduler.from_config(
|
137 |
+
TEX_PIPE.scheduler.config, shift=cfg.flow_shift
|
138 |
+
)
|
139 |
+
)
|
140 |
+
min_noise_level_index = scheduler.config.num_train_timesteps - cfg.min_noise_level_index # in our scheduler, the first time is noise. set to 1000 - 15 typically
|
141 |
+
setattr(TEX_PIPE, "min_noise_level_index", min_noise_level_index)
|
142 |
+
min_noise_level_timestep = scheduler.timesteps[min_noise_level_index]
|
143 |
+
setattr(TEX_PIPE, "min_noise_level_timestep", min_noise_level_timestep)
|
144 |
+
setattr(TEX_PIPE, "min_noise_level_sigma", min_noise_level_timestep / 1000.)
|
145 |
+
|
146 |
+
TEX_PIPE = TEX_PIPE.to("cuda", dtype=torch.float32) # use float32 for inference
|
147 |
+
return TEX_PIPE
|
148 |
+
|
149 |
+
@torch.amp.autocast('cuda', dtype=torch.float32)
|
150 |
+
def encode_images(
|
151 |
+
images: Float[Tensor, "B F H W C"], encode_as_first: bool = False
|
152 |
+
) -> Float[Tensor, "B C' F H/8 W/8"]:
|
153 |
+
"""
|
154 |
+
Encode images to latent space using VAE.
|
155 |
+
Every frame is seen as a separate image, without any awareness of the temporal dimension.
|
156 |
+
:param images: Input images tensor with shape [B, F, H, W, C].
|
157 |
+
:param encode_as_first: Whether to encode all frames as the first frame.
|
158 |
+
:return: Encoded latents with shape [B, C', F, H/8, W/8].
|
159 |
+
"""
|
160 |
+
if images.min() < - 0.1:
|
161 |
+
# images are in [-1, 1] range
|
162 |
+
images = (images + 1.0) / 2.0 # Normalize to [0, 1] range
|
163 |
+
if encode_as_first:
|
164 |
+
# encode all the frame as the first one
|
165 |
+
B = images.shape[0]
|
166 |
+
images = rearrange(images, "B F H W C -> (B F) C 1 H W")
|
167 |
+
latents = (VAE.encode(images).latent_dist.sample() - LATENTS_MEAN) * LATENTS_STD
|
168 |
+
latents = rearrange(latents, "(B F) C 1 H W -> B C F H W", B=B)
|
169 |
+
else:
|
170 |
+
raise NotImplementedError("Currently only support encode as first frame.")
|
171 |
+
|
172 |
+
return latents
|
173 |
+
|
174 |
+
# @torch.no_grad()
|
175 |
+
# @torch.amp.autocast('cuda', dtype=torch.float32)
|
176 |
+
# def decode_images(self, latents: Float[Tensor, "B C F H W"], decode_as_first: bool = False):
|
177 |
+
# if decode_as_first:
|
178 |
+
# F = latents.shape[2]
|
179 |
+
# latents = latents.to(self.vae.dtype)
|
180 |
+
# latents = latents / self.latents_std + self.latents_mean
|
181 |
+
# latents = rearrange(latents, "B C F H W -> (B F) C 1 H W")
|
182 |
+
# images = self.vae.decode(latents, return_dict=False)[0]
|
183 |
+
# images = rearrange(images, "(B F) C Nv H W -> B C (F Nv) H W", F=F, Nv=1)
|
184 |
+
# else:
|
185 |
+
# raise NotImplementedError("Currently only support decode as first frame.")
|
186 |
+
# return images
|
187 |
+
@torch.amp.autocast('cuda', dtype=torch.float32)
|
188 |
+
def decode_images(latents: Float[Tensor, "B C F H W"], decode_as_first: bool = False):
|
189 |
+
"""
|
190 |
+
Decode latents back to images using VAE.
|
191 |
+
:param latents: Input latents with shape [B, C, F, H, W].
|
192 |
+
:param decode_as_first: Whether to decode all frames as the first frame.
|
193 |
+
:return: Decoded images with shape [B, C, F*Nv, H*8, W*8].
|
194 |
+
"""
|
195 |
+
if decode_as_first:
|
196 |
+
F = latents.shape[2]
|
197 |
+
latents = latents.to(VAE.dtype)
|
198 |
+
latents = latents / LATENTS_STD + LATENTS_MEAN
|
199 |
+
latents = rearrange(latents, "B C F H W -> (B F) C 1 H W")
|
200 |
+
images = VAE.decode(latents, return_dict=False)[0]
|
201 |
+
images = rearrange(images, "(B F) C Nv H W -> B C (F Nv) H W", F=F, Nv=1)
|
202 |
+
else:
|
203 |
+
raise NotImplementedError("Currently only support decode as first frame.")
|
204 |
+
return images
|
205 |
+
|
206 |
+
def convert_img_to_tensor(image: Image.Image, device="cuda") -> Float[Tensor, "H W C"]:
|
207 |
+
"""
|
208 |
+
Convert a PIL Image to a tensor. If Image is RGBA, mask it with black background using a-channel mask.
|
209 |
+
:param image: PIL Image to convert. [0, 255]
|
210 |
+
:return: Tensor representation of the image. [0.0, 1.0], still [H, W, C]
|
211 |
+
"""
|
212 |
+
# Convert to RGBA to ensure alpha channel exists
|
213 |
+
image = image.convert("RGBA")
|
214 |
+
np_img = np.array(image)
|
215 |
+
rgb = np_img[..., :3]
|
216 |
+
alpha = np_img[..., 3:4] / 255.0 # Normalize alpha to [0, 1]
|
217 |
+
# Blend with black background using alpha mask
|
218 |
+
rgb = rgb * alpha
|
219 |
+
rgb = rgb.astype(np.float32) / 255.0 # Normalize to [0, 1]
|
220 |
+
tensor = torch.from_numpy(rgb).to(device)
|
221 |
+
return tensor
|
222 |
+
|
223 |
+
@spaces.GPU(duration=120)
|
224 |
+
@torch.cuda.amp.autocast(dtype=torch.float32)
|
225 |
+
@torch.inference_mode
|
226 |
+
@torch.no_grad
|
227 |
+
def generate_texture(position_map, normal_map, position_images, normal_images, condition_image, text_prompt, selected_view, negative_prompt=None, device="cuda", progress=gr.Progress()):
|
228 |
+
"""
|
229 |
+
Use SeqTex to generate texture for the mesh based on the image condition.
|
230 |
+
:param position_images: List of position images from different views.
|
231 |
+
:param normal_images: List of normal images from different views.
|
232 |
+
:param condition_image: Image condition generated from the selected view.
|
233 |
+
:param text_prompt: Text prompt for texture generation.
|
234 |
+
:param selected_view: The view selected for generating the image condition.
|
235 |
+
:return: Generated texture map, and multi-view frames in tensor.
|
236 |
+
"""
|
237 |
+
progress(0, desc="Loading SeqTex pipeline...")
|
238 |
+
tex_pipe = lazy_get_seqtex_pipe()
|
239 |
+
progress(0.2, desc="SeqTex pipeline loaded successfully.")
|
240 |
+
view_id_map = {
|
241 |
+
"First View": 0,
|
242 |
+
"Second View": 1,
|
243 |
+
"Third View": 2,
|
244 |
+
"Fourth View": 3
|
245 |
+
}
|
246 |
+
view_id = view_id_map[selected_view]
|
247 |
+
|
248 |
+
progress(0.3, desc="Encoding position and normal images...")
|
249 |
+
nat_seq = torch.cat([position_images.unsqueeze(0), normal_images.unsqueeze(0)], dim=0) # 1 F H W C
|
250 |
+
uv_seq = torch.cat([position_map.unsqueeze(0), normal_map.unsqueeze(0)], dim=0)
|
251 |
+
nat_latents = encode_images(nat_seq, encode_as_first=True) # B C F H W
|
252 |
+
uv_latents = encode_images(uv_seq, encode_as_first=True) # B C F' H' W'
|
253 |
+
nat_pos_latents, nat_norm_latents = torch.chunk(nat_latents, 2, dim=0)
|
254 |
+
uv_pos_latents, uv_norm_latents = torch.chunk(uv_latents, 2, dim=0)
|
255 |
+
nat_geo_latents = torch.cat([nat_pos_latents, nat_norm_latents], dim=1)
|
256 |
+
uv_geo_latents = torch.cat([uv_pos_latents, uv_norm_latents], dim=1)
|
257 |
+
cond_model_latents = (nat_geo_latents, uv_geo_latents)
|
258 |
+
|
259 |
+
num_frames = cfg.num_views * (2 ** sum(VAE.config.temperal_downsample))
|
260 |
+
uv_num_frames = cfg.uv_num_views * (2 ** sum(VAE.config.temperal_downsample))
|
261 |
+
|
262 |
+
progress(0.4, desc="Encoding condition image...")
|
263 |
+
if isinstance(condition_image, Image.Image):
|
264 |
+
condition_image = condition_image.resize((cfg.mv_width, cfg.mv_height), Image.LANCZOS)
|
265 |
+
# Convert PIL Image to tensor
|
266 |
+
condition_image = convert_img_to_tensor(condition_image, device=device)
|
267 |
+
condition_image = condition_image.unsqueeze(0).unsqueeze(0)
|
268 |
+
gt_latents = (encode_images(condition_image, encode_as_first=True), None)
|
269 |
+
|
270 |
+
progress(0.5, desc="Generating texture with SeqTex...")
|
271 |
+
latents = tex_pipe(
|
272 |
+
prompt=text_prompt,
|
273 |
+
negative_prompt=negative_prompt,
|
274 |
+
num_frames=num_frames,
|
275 |
+
generator=torch.Generator(device=device).manual_seed(cfg.eval_seed),
|
276 |
+
num_inference_steps=cfg.eval_num_inference_steps,
|
277 |
+
guidance_scale=cfg.eval_guidance_scale,
|
278 |
+
height=cfg.mv_height,
|
279 |
+
width=cfg.mv_width,
|
280 |
+
output_type="latent",
|
281 |
+
|
282 |
+
cond_model_latents=cond_model_latents,
|
283 |
+
# mask_indices=test_mask_indices,
|
284 |
+
uv_height=cfg.uv_height,
|
285 |
+
uv_width=cfg.uv_width,
|
286 |
+
uv_num_frames=uv_num_frames,
|
287 |
+
treat_as_first=True,
|
288 |
+
gt_condition=gt_latents,
|
289 |
+
inference_img_cond_frame=view_id,
|
290 |
+
use_qk_geometry=True,
|
291 |
+
task_type="img2tex", # img2tex
|
292 |
+
progress=progress,
|
293 |
+
).frames
|
294 |
+
|
295 |
+
mv_latents, uv_latents = latents
|
296 |
+
|
297 |
+
progress(0.9, desc="Decoding generated latents to images...")
|
298 |
+
mv_frames = decode_images(mv_latents, decode_as_first=True) # B C 4 H W
|
299 |
+
uv_frames = decode_images(uv_latents, decode_as_first=True) # B C 1 H W
|
300 |
+
|
301 |
+
uv_map_pred = uv_frames[:, :, -1, ...]
|
302 |
+
uv_map_pred.squeeze_(0)
|
303 |
+
mv_out = rearrange(mv_frames[:, :, :cfg.num_views, ...], "B C (F N) H W -> N C (B H) (F W)", N=1)[0]
|
304 |
+
|
305 |
+
mv_out = torch.clamp(mv_out, 0.0, 1.0)
|
306 |
+
uv_map_pred = torch.clamp(uv_map_pred, 0.0, 1.0)
|
307 |
+
|
308 |
+
progress(1, desc="Texture generated successfully.")
|
309 |
+
return uv_map_pred.float(), mv_out.float(), "Step 3: Texture generated successfully."
|
wan/__init__.py
ADDED
File without changes
|
wan/pipeline_wan_t2tex_extra.py
ADDED
@@ -0,0 +1,366 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
from typing import Any, Callable, Dict, List, Optional, Union, Tuple
|
3 |
+
|
4 |
+
from einops import rearrange
|
5 |
+
import regex as re
|
6 |
+
import torch
|
7 |
+
from diffusers.pipelines.wan.pipeline_wan import WanPipeline
|
8 |
+
from diffusers.pipelines.wan.pipeline_output import WanPipelineOutput
|
9 |
+
from diffusers.callbacks import PipelineCallback, MultiPipelineCallbacks
|
10 |
+
from diffusers.utils.torch_utils import randn_tensor
|
11 |
+
from torch import Tensor
|
12 |
+
from transformers import AutoTokenizer, UMT5EncoderModel
|
13 |
+
from jaxtyping import Float
|
14 |
+
import gradio as gr
|
15 |
+
|
16 |
+
def get_sigmas(scheduler, timesteps, dtype=torch.float32, device="cuda"):
|
17 |
+
sigmas = scheduler.sigmas.to(device=device, dtype=dtype)
|
18 |
+
schedule_timesteps = scheduler.timesteps.to(device)
|
19 |
+
timesteps = timesteps.to(device)
|
20 |
+
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
|
21 |
+
|
22 |
+
sigma = sigmas[step_indices].flatten()
|
23 |
+
return sigma
|
24 |
+
|
25 |
+
class WanT2TexPipeline(WanPipeline):
|
26 |
+
def __init__(self, tokenizer, text_encoder, transformer, vae, scheduler):
|
27 |
+
super().__init__(tokenizer, text_encoder, transformer, vae, scheduler)
|
28 |
+
self.uv_scheduler = copy.deepcopy(scheduler)
|
29 |
+
|
30 |
+
def prepare_latents(
|
31 |
+
self,
|
32 |
+
batch_size: int,
|
33 |
+
num_channels_latents: int = 16,
|
34 |
+
height: int = 480,
|
35 |
+
width: int = 832,
|
36 |
+
num_frames: int = 81,
|
37 |
+
dtype: Optional[torch.dtype] = None,
|
38 |
+
device: Optional[torch.device] = None,
|
39 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
40 |
+
latents: Optional[torch.Tensor] = None,
|
41 |
+
treat_as_first: Optional[bool] = True,
|
42 |
+
) -> torch.Tensor:
|
43 |
+
if latents is not None:
|
44 |
+
return latents.to(device=device, dtype=dtype)
|
45 |
+
|
46 |
+
####################
|
47 |
+
if treat_as_first:
|
48 |
+
num_latent_frames = num_frames // self.vae_scale_factor_temporal
|
49 |
+
else:
|
50 |
+
num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
|
51 |
+
####################
|
52 |
+
|
53 |
+
shape = (
|
54 |
+
batch_size,
|
55 |
+
num_channels_latents,
|
56 |
+
num_latent_frames,
|
57 |
+
int(height) // self.vae_scale_factor_spatial,
|
58 |
+
int(width) // self.vae_scale_factor_spatial,
|
59 |
+
)
|
60 |
+
if isinstance(generator, list) and len(generator) != batch_size:
|
61 |
+
raise ValueError(
|
62 |
+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
63 |
+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
64 |
+
)
|
65 |
+
|
66 |
+
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
67 |
+
return latents
|
68 |
+
|
69 |
+
@torch.no_grad()
|
70 |
+
def __call__(
|
71 |
+
self,
|
72 |
+
prompt: Union[str, List[str]] = None,
|
73 |
+
negative_prompt: Union[str, List[str]] = None,
|
74 |
+
height: int = 480,
|
75 |
+
width: int = 832,
|
76 |
+
num_frames: int = 81,
|
77 |
+
num_inference_steps: int = 50,
|
78 |
+
guidance_scale: float = 5.0,
|
79 |
+
num_videos_per_prompt: Optional[int] = 1,
|
80 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
81 |
+
latents: Optional[torch.Tensor] = None,
|
82 |
+
prompt_embeds: Optional[torch.Tensor] = None,
|
83 |
+
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
84 |
+
output_type: Optional[str] = "np",
|
85 |
+
return_dict: bool = True,
|
86 |
+
attention_kwargs: Optional[Dict[str, Any]] = None,
|
87 |
+
callback_on_step_end: Optional[
|
88 |
+
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
|
89 |
+
] = None,
|
90 |
+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
91 |
+
max_sequence_length: int = 512,
|
92 |
+
cond_model_latents: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
|
93 |
+
uv_height=None,
|
94 |
+
uv_width=None,
|
95 |
+
uv_num_frames=None,
|
96 |
+
# multi_task_cond=None,
|
97 |
+
treat_as_first=True,
|
98 |
+
gt_condition:Tuple[Optional[Float[Tensor, "B C F H W"]], Optional[Float[Tensor, "B C F H W"]]]=None,
|
99 |
+
inference_img_cond_frame=None,
|
100 |
+
use_qk_geometry=False,
|
101 |
+
task_type="all",
|
102 |
+
progress=gr.Progress()
|
103 |
+
):
|
104 |
+
r"""
|
105 |
+
The call function to the pipeline for generation.
|
106 |
+
|
107 |
+
Args:
|
108 |
+
prompt (`str` or `List[str]`, *optional*):
|
109 |
+
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
110 |
+
instead.
|
111 |
+
height (`int`, defaults to `480`):
|
112 |
+
The height in pixels of the generated image.
|
113 |
+
width (`int`, defaults to `832`):
|
114 |
+
The width in pixels of the generated image.
|
115 |
+
num_frames (`int`, defaults to `81`):
|
116 |
+
The number of frames in the generated video.
|
117 |
+
num_inference_steps (`int`, defaults to `50`):
|
118 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
119 |
+
expense of slower inference.
|
120 |
+
guidance_scale (`float`, defaults to `5.0`):
|
121 |
+
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
122 |
+
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
123 |
+
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
124 |
+
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
125 |
+
usually at the expense of lower image quality.
|
126 |
+
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
127 |
+
The number of images to generate per prompt.
|
128 |
+
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
129 |
+
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
|
130 |
+
generation deterministic.
|
131 |
+
latents (`torch.Tensor`, *optional*):
|
132 |
+
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
|
133 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
134 |
+
tensor is generated by sampling using the supplied random `generator`.
|
135 |
+
prompt_embeds (`torch.Tensor`, *optional*):
|
136 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
|
137 |
+
provided, text embeddings are generated from the `prompt` input argument.
|
138 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
139 |
+
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
|
140 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
141 |
+
Whether or not to return a [`WanPipelineOutput`] instead of a plain tuple.
|
142 |
+
attention_kwargs (`dict`, *optional*):
|
143 |
+
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
144 |
+
`self.processor` in
|
145 |
+
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
146 |
+
callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
|
147 |
+
A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
|
148 |
+
each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
|
149 |
+
DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
|
150 |
+
list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
|
151 |
+
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
152 |
+
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
153 |
+
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
154 |
+
`._callback_tensor_inputs` attribute of your pipeline class.
|
155 |
+
autocast_dtype (`torch.dtype`, *optional*, defaults to `torch.bfloat16`):
|
156 |
+
The dtype to use for the torch.amp.autocast.
|
157 |
+
|
158 |
+
Examples:
|
159 |
+
|
160 |
+
Returns:
|
161 |
+
[`~WanPipelineOutput`] or `tuple`:
|
162 |
+
If `return_dict` is `True`, [`WanPipelineOutput`] is returned, otherwise a `tuple` is returned where
|
163 |
+
the first element is a list with the generated images and the second element is a list of `bool`s
|
164 |
+
indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content.
|
165 |
+
"""
|
166 |
+
|
167 |
+
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
|
168 |
+
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
|
169 |
+
|
170 |
+
# 1. Check inputs. Raise error if not correct
|
171 |
+
self.check_inputs(
|
172 |
+
prompt,
|
173 |
+
negative_prompt,
|
174 |
+
height,
|
175 |
+
width,
|
176 |
+
prompt_embeds,
|
177 |
+
negative_prompt_embeds,
|
178 |
+
callback_on_step_end_tensor_inputs,
|
179 |
+
)
|
180 |
+
|
181 |
+
# ATTENTION: My inputs are images, so the num_frames is 5, without time dimension compression.
|
182 |
+
# if num_frames % self.vae_scale_factor_temporal != 1:
|
183 |
+
# raise ValueError(
|
184 |
+
# f"num_frames should be divisible by {self.vae_scale_factor_temporal} + 1, but got {num_frames}."
|
185 |
+
# )
|
186 |
+
# num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1
|
187 |
+
# num_frames = max(num_frames, 1)
|
188 |
+
|
189 |
+
self._guidance_scale = guidance_scale
|
190 |
+
self._attention_kwargs = attention_kwargs
|
191 |
+
self._current_timestep = None
|
192 |
+
self._interrupt = False
|
193 |
+
|
194 |
+
device = self._execution_device
|
195 |
+
|
196 |
+
# 2. Define call parameters
|
197 |
+
if prompt is not None and isinstance(prompt, str):
|
198 |
+
batch_size = 1
|
199 |
+
elif prompt is not None and isinstance(prompt, list):
|
200 |
+
batch_size = len(prompt)
|
201 |
+
else:
|
202 |
+
batch_size = prompt_embeds.shape[0]
|
203 |
+
|
204 |
+
# 3. Encode input prompt
|
205 |
+
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
|
206 |
+
prompt=prompt,
|
207 |
+
negative_prompt=negative_prompt,
|
208 |
+
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
209 |
+
num_videos_per_prompt=num_videos_per_prompt,
|
210 |
+
prompt_embeds=prompt_embeds,
|
211 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
212 |
+
max_sequence_length=max_sequence_length,
|
213 |
+
device=device,
|
214 |
+
)
|
215 |
+
|
216 |
+
transformer_dtype = self.transformer.dtype
|
217 |
+
prompt_embeds = prompt_embeds.to(transformer_dtype)
|
218 |
+
if self.do_classifier_free_guidance:
|
219 |
+
if negative_prompt_embeds is not None:
|
220 |
+
negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype)
|
221 |
+
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
222 |
+
|
223 |
+
# 4. Prepare timesteps
|
224 |
+
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
225 |
+
self.uv_scheduler.set_timesteps(num_inference_steps, device=device)
|
226 |
+
timesteps = self.scheduler.timesteps
|
227 |
+
|
228 |
+
# 5. Prepare latent variables
|
229 |
+
num_channels_latents = self.transformer.config.in_channels
|
230 |
+
mv_latents = self.prepare_latents(
|
231 |
+
batch_size * num_videos_per_prompt,
|
232 |
+
num_channels_latents,
|
233 |
+
height,
|
234 |
+
width,
|
235 |
+
num_frames,
|
236 |
+
torch.float32,
|
237 |
+
device,
|
238 |
+
generator,
|
239 |
+
treat_as_first=treat_as_first,
|
240 |
+
)
|
241 |
+
uv_latents = self.prepare_latents(
|
242 |
+
batch_size * num_videos_per_prompt,
|
243 |
+
num_channels_latents,
|
244 |
+
uv_height,
|
245 |
+
uv_width,
|
246 |
+
uv_num_frames,
|
247 |
+
torch.float32,
|
248 |
+
device,
|
249 |
+
generator,
|
250 |
+
treat_as_first=True # UV latents are always different from the others, so treat as the first frame
|
251 |
+
)
|
252 |
+
|
253 |
+
# 6. Denoising loop
|
254 |
+
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
255 |
+
self._num_timesteps = len(timesteps)
|
256 |
+
|
257 |
+
# with progress.tqdm(total=num_inference_steps, desc="Diffusing...") as progress_bar:
|
258 |
+
for i, t in progress.tqdm(enumerate(timesteps), desc="Diffusing..."):
|
259 |
+
if self.interrupt:
|
260 |
+
continue
|
261 |
+
|
262 |
+
# set conditions
|
263 |
+
timestep_df = torch.ones((batch_size, num_frames // self.vae_scale_factor_temporal + 1)).to(device) * t
|
264 |
+
sigmas = get_sigmas(self.scheduler, rearrange(timestep_df, "B F -> (B F)"), dtype=transformer_dtype, device=device)
|
265 |
+
sigmas = rearrange(sigmas, "(B F) -> B 1 F 1 1", B=batch_size)
|
266 |
+
match task_type:
|
267 |
+
case "geo+mv2tex":
|
268 |
+
timestep_df[:, :num_frames // self.vae_scale_factor_temporal] = self.min_noise_level_timestep
|
269 |
+
sigmas[:, :, :num_frames // self.vae_scale_factor_temporal, ...] = self.min_noise_level_sigma
|
270 |
+
mv_noise = torch.randn_like(mv_latents) # B C 4 H W
|
271 |
+
mv_latents = (1.0 - sigmas[:, :, :-1, ...]) * gt_condition[0] + sigmas[:, :, :-1, ...] * mv_noise
|
272 |
+
case "img2tex":
|
273 |
+
assert inference_img_cond_frame is not None, "inference_img_cond_frame should be specified for img2tex task"
|
274 |
+
# Use specified frame index as condition instead of just first frame
|
275 |
+
timestep_df[:, inference_img_cond_frame: inference_img_cond_frame + 1] = self.min_noise_level_timestep
|
276 |
+
sigmas[:, :, inference_img_cond_frame: inference_img_cond_frame + 1, ...] = self.min_noise_level_sigma
|
277 |
+
mv_noise = randn_tensor(mv_latents[:, :, inference_img_cond_frame: inference_img_cond_frame + 1].shape, generator=generator, device=device, dtype=self.dtype)
|
278 |
+
# mv_noise = torch.randn_like(mv_latents[:, :, inference_img_cond_frame: inference_img_cond_frame + 1], generator=generator) # B C selected_frames H W
|
279 |
+
mv_latents[:, :, inference_img_cond_frame: inference_img_cond_frame + 1, ...] = (1.0 - sigmas[:, :, inference_img_cond_frame: inference_img_cond_frame + 1, ...]) * gt_condition[0] + sigmas[:, :, inference_img_cond_frame: inference_img_cond_frame + 1, ...] * mv_noise
|
280 |
+
case "soft_render":
|
281 |
+
timestep_df[:, -1:] = self.min_noise_level_timestep
|
282 |
+
sigmas[:, :, -1:, ...] = self.min_noise_level_sigma
|
283 |
+
uv_noise = torch.randn_like(uv_latents) # B C 1 H W
|
284 |
+
uv_latents = (1.0 - sigmas[:, :, -1:, ...]) * gt_condition[1] + sigmas[:, :, -1:, ...] * uv_noise
|
285 |
+
case "geo2mv":
|
286 |
+
timestep_df[:, -1:] = 1000.
|
287 |
+
sigmas[:, :, -1:, ...] = 1.
|
288 |
+
case _:
|
289 |
+
pass
|
290 |
+
|
291 |
+
# add geometry information to channel C
|
292 |
+
mv_latents_input = torch.cat([mv_latents, cond_model_latents[0]], dim=1)
|
293 |
+
uv_latents_input = torch.cat([uv_latents, cond_model_latents[1]], dim=1)
|
294 |
+
if self.do_classifier_free_guidance:
|
295 |
+
mv_latents_input = torch.cat([mv_latents_input, mv_latents_input], dim=0)
|
296 |
+
uv_latents_input = torch.cat([uv_latents_input, uv_latents_input], dim=0)
|
297 |
+
|
298 |
+
self._current_timestep = t
|
299 |
+
latent_model_input = (mv_latents_input.to(transformer_dtype), uv_latents_input.to(transformer_dtype))
|
300 |
+
# timestep = t.expand(mv_latents.shape[0])
|
301 |
+
|
302 |
+
noise_out = self.transformer(
|
303 |
+
hidden_states=latent_model_input,
|
304 |
+
timestep=timestep_df,
|
305 |
+
encoder_hidden_states=prompt_embeds,
|
306 |
+
attention_kwargs=attention_kwargs,
|
307 |
+
# task_cond=multi_task_cond,
|
308 |
+
return_dict=False,
|
309 |
+
use_qk_geometry=use_qk_geometry
|
310 |
+
)[0]
|
311 |
+
mv_noise_out, uv_noise_out = noise_out
|
312 |
+
|
313 |
+
if self.do_classifier_free_guidance:
|
314 |
+
mv_noise_uncond, mv_noise_pred = mv_noise_out.chunk(2)
|
315 |
+
uv_noise_uncond, uv_noise_pred = uv_noise_out.chunk(2)
|
316 |
+
mv_noise_pred = mv_noise_uncond + guidance_scale * (mv_noise_pred - mv_noise_uncond)
|
317 |
+
uv_noise_pred = uv_noise_uncond + guidance_scale * (uv_noise_pred - uv_noise_uncond)
|
318 |
+
else:
|
319 |
+
mv_noise_pred = mv_noise_out
|
320 |
+
uv_noise_pred = uv_noise_out
|
321 |
+
|
322 |
+
# compute the previous noisy sample x_t -> x_t-1
|
323 |
+
# The conditions will be replaced anyway, so perhaps we don't need to step frames seperately
|
324 |
+
mv_latents = self.scheduler.step(mv_noise_pred, t, mv_latents, return_dict=False)[0]
|
325 |
+
uv_latents = self.uv_scheduler.step(uv_noise_pred, t, uv_latents, return_dict=False)[0]
|
326 |
+
|
327 |
+
if callback_on_step_end is not None:
|
328 |
+
raise NotImplementedError()
|
329 |
+
callback_kwargs = {}
|
330 |
+
for k in callback_on_step_end_tensor_inputs:
|
331 |
+
callback_kwargs[k] = locals()[k]
|
332 |
+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
333 |
+
|
334 |
+
latents = callback_outputs.pop("latents", latents)
|
335 |
+
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
336 |
+
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
337 |
+
|
338 |
+
# # call the callback, if provided
|
339 |
+
# if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
340 |
+
# progress_bar.update()
|
341 |
+
|
342 |
+
self._current_timestep = None
|
343 |
+
|
344 |
+
if not output_type == "latent":
|
345 |
+
latents = latents.to(self.vae.dtype)
|
346 |
+
latents_mean = (
|
347 |
+
torch.tensor(self.vae.config.latents_mean)
|
348 |
+
.view(1, self.vae.config.z_dim, 1, 1, 1)
|
349 |
+
.to(latents.device, latents.dtype)
|
350 |
+
)
|
351 |
+
latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
|
352 |
+
latents.device, latents.dtype
|
353 |
+
)
|
354 |
+
latents = latents / latents_std + latents_mean
|
355 |
+
video = self.vae.decode(latents, return_dict=False)[0]
|
356 |
+
# video = self.video_processor.postprocess_video(video, output_type=output_type)
|
357 |
+
else:
|
358 |
+
video = (mv_latents, uv_latents)
|
359 |
+
|
360 |
+
# Offload all models
|
361 |
+
self.maybe_free_model_hooks()
|
362 |
+
|
363 |
+
if not return_dict:
|
364 |
+
return (video,)
|
365 |
+
|
366 |
+
return WanPipelineOutput(frames=video)
|
wan/wan_t2tex_transformer_3d_extra.py
ADDED
@@ -0,0 +1,634 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2025 The Wan Team and The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import copy
|
16 |
+
import math
|
17 |
+
from typing import Any, Dict, Optional, Tuple, Union
|
18 |
+
from functools import cache
|
19 |
+
|
20 |
+
from einops import rearrange, repeat
|
21 |
+
import torch
|
22 |
+
import torch.nn as nn
|
23 |
+
import torch.nn.functional as F
|
24 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
25 |
+
from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin
|
26 |
+
from diffusers.models import WanTransformer3DModel
|
27 |
+
from diffusers.models.attention import FeedForward
|
28 |
+
from diffusers.models.attention_processor import Attention
|
29 |
+
from diffusers.models.cache_utils import CacheMixin
|
30 |
+
from diffusers.models.embeddings import (PixArtAlphaTextProjection,
|
31 |
+
TimestepEmbedding, Timesteps,
|
32 |
+
get_1d_rotary_pos_embed)
|
33 |
+
from diffusers.models.modeling_outputs import Transformer2DModelOutput
|
34 |
+
from diffusers.models.modeling_utils import ModelMixin
|
35 |
+
from diffusers.models.normalization import FP32LayerNorm
|
36 |
+
from diffusers.models.transformers.transformer_wan import \
|
37 |
+
WanTimeTextImageEmbedding
|
38 |
+
from diffusers.utils import (USE_PEFT_BACKEND, logging, scale_lora_layers,
|
39 |
+
unscale_lora_layers)
|
40 |
+
|
41 |
+
|
42 |
+
class WanT2TexAttnProcessor2_0:
|
43 |
+
def __init__(self):
|
44 |
+
if not hasattr(F, "scaled_dot_product_attention"):
|
45 |
+
raise ImportError("WanAttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0.")
|
46 |
+
|
47 |
+
def __call__(
|
48 |
+
self,
|
49 |
+
attn: Attention,
|
50 |
+
hidden_states: torch.Tensor,
|
51 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
52 |
+
attention_mask: Optional[torch.Tensor] = None,
|
53 |
+
rotary_emb: Optional[torch.Tensor] = None,
|
54 |
+
geometry_embedding: Optional[torch.Tensor] = None,
|
55 |
+
) -> torch.Tensor:
|
56 |
+
encoder_hidden_states_img = None
|
57 |
+
if attn.add_k_proj is not None:
|
58 |
+
encoder_hidden_states_img = encoder_hidden_states[:, :257]
|
59 |
+
encoder_hidden_states = encoder_hidden_states[:, 257:]
|
60 |
+
if encoder_hidden_states is None:
|
61 |
+
encoder_hidden_states = hidden_states
|
62 |
+
|
63 |
+
query = attn.to_q(hidden_states)
|
64 |
+
key = attn.to_k(encoder_hidden_states)
|
65 |
+
value = attn.to_v(encoder_hidden_states)
|
66 |
+
|
67 |
+
if attn.norm_q is not None:
|
68 |
+
query = attn.norm_q(query)
|
69 |
+
if attn.norm_k is not None:
|
70 |
+
key = attn.norm_k(key)
|
71 |
+
|
72 |
+
if geometry_embedding is not None:
|
73 |
+
# add-type geometry embedding
|
74 |
+
if True:
|
75 |
+
if isinstance(geometry_embedding, Tuple):
|
76 |
+
query = query + geometry_embedding[0]
|
77 |
+
key = key + geometry_embedding[1]
|
78 |
+
else:
|
79 |
+
query = query + geometry_embedding
|
80 |
+
key = key + geometry_embedding
|
81 |
+
else:
|
82 |
+
# mul-type geometry embedding
|
83 |
+
if isinstance(geometry_embedding, Tuple):
|
84 |
+
query = query * (1 + geometry_embedding[0])
|
85 |
+
key = key * (1 + geometry_embedding[1])
|
86 |
+
else:
|
87 |
+
query = query * (1 + geometry_embedding)
|
88 |
+
key = key * (1 + geometry_embedding)
|
89 |
+
|
90 |
+
query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2) # [B, F*H*W, 2C] -> [B, H, F*H*W, 2C//H]
|
91 |
+
key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
|
92 |
+
value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2)
|
93 |
+
|
94 |
+
if rotary_emb is not None:
|
95 |
+
|
96 |
+
def apply_rotary_emb(hidden_states: torch.Tensor, freqs: torch.Tensor):
|
97 |
+
x_rotated = torch.view_as_complex(hidden_states.to(torch.float64).unflatten(3, (-1, 2)))
|
98 |
+
x_out = torch.view_as_real(x_rotated * freqs).flatten(3, 4)
|
99 |
+
return x_out.type_as(hidden_states)
|
100 |
+
|
101 |
+
if isinstance(rotary_emb, Tuple):
|
102 |
+
query = apply_rotary_emb(query, rotary_emb[0])
|
103 |
+
key = apply_rotary_emb(key, rotary_emb[1])
|
104 |
+
else:
|
105 |
+
query = apply_rotary_emb(query, rotary_emb)
|
106 |
+
key = apply_rotary_emb(key, rotary_emb)
|
107 |
+
|
108 |
+
# I2V task
|
109 |
+
hidden_states_img = None
|
110 |
+
if encoder_hidden_states_img is not None:
|
111 |
+
key_img = attn.add_k_proj(encoder_hidden_states_img)
|
112 |
+
key_img = attn.norm_added_k(key_img)
|
113 |
+
value_img = attn.add_v_proj(encoder_hidden_states_img)
|
114 |
+
|
115 |
+
key_img = key_img.unflatten(2, (attn.heads, -1)).transpose(1, 2)
|
116 |
+
value_img = value_img.unflatten(2, (attn.heads, -1)).transpose(1, 2)
|
117 |
+
|
118 |
+
hidden_states_img = F.scaled_dot_product_attention(
|
119 |
+
query, key_img, value_img, attn_mask=None, dropout_p=0.0, is_causal=False
|
120 |
+
)
|
121 |
+
hidden_states_img = hidden_states_img.transpose(1, 2).flatten(2, 3)
|
122 |
+
hidden_states_img = hidden_states_img.type_as(query)
|
123 |
+
|
124 |
+
hidden_states = F.scaled_dot_product_attention(
|
125 |
+
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
126 |
+
)
|
127 |
+
hidden_states = hidden_states.transpose(1, 2).flatten(2, 3)
|
128 |
+
hidden_states = hidden_states.type_as(query)
|
129 |
+
|
130 |
+
if hidden_states_img is not None:
|
131 |
+
hidden_states = hidden_states + hidden_states_img
|
132 |
+
|
133 |
+
hidden_states = attn.to_out[0](hidden_states)
|
134 |
+
hidden_states = attn.to_out[1](hidden_states)
|
135 |
+
return hidden_states
|
136 |
+
|
137 |
+
|
138 |
+
class WanTimeTaskTextImageEmbedding(WanTimeTextImageEmbedding):
|
139 |
+
def __init__(
|
140 |
+
self,
|
141 |
+
original_model,
|
142 |
+
dim: int,
|
143 |
+
time_freq_dim: int,
|
144 |
+
time_proj_dim: int,
|
145 |
+
text_embed_dim: int,
|
146 |
+
image_embed_dim: Optional[int] = None,
|
147 |
+
randomly_init: bool = False,
|
148 |
+
):
|
149 |
+
super(WanTimeTaskTextImageEmbedding, self).__init__(dim, time_freq_dim, time_proj_dim, text_embed_dim, image_embed_dim)
|
150 |
+
if not randomly_init:
|
151 |
+
self.load_state_dict(original_model.state_dict(), strict=True)
|
152 |
+
# cond_proj = nn.Linear(512, original_model.timesteps_proj.num_channels, bias=False)
|
153 |
+
# setattr(self.time_embedder, "cond_proj", cond_proj)
|
154 |
+
|
155 |
+
def forward(
|
156 |
+
self,
|
157 |
+
timestep: torch.Tensor,
|
158 |
+
encoder_hidden_states: torch.Tensor,
|
159 |
+
encoder_hidden_states_image: Optional[torch.Tensor] = None,
|
160 |
+
# time_cond: Optional[torch.Tensor] = None,
|
161 |
+
):
|
162 |
+
B = timestep.shape[0]
|
163 |
+
timestep = rearrange(timestep, "B F -> (B F)")
|
164 |
+
timestep = self.timesteps_proj(timestep)
|
165 |
+
timestep = rearrange(timestep, "(B F) D -> B F D", B=B)
|
166 |
+
|
167 |
+
time_embedder_dtype = next(iter(self.time_embedder.parameters())).dtype
|
168 |
+
if timestep.dtype != time_embedder_dtype and time_embedder_dtype != torch.int8:
|
169 |
+
timestep = timestep.to(time_embedder_dtype)
|
170 |
+
temb = self.time_embedder(timestep).type_as(encoder_hidden_states)
|
171 |
+
timestep_proj = self.time_proj(self.act_fn(temb))
|
172 |
+
|
173 |
+
encoder_hidden_states = self.text_embedder(encoder_hidden_states)
|
174 |
+
if encoder_hidden_states_image is not None:
|
175 |
+
encoder_hidden_states_image = self.image_embedder(encoder_hidden_states_image)
|
176 |
+
|
177 |
+
return temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image
|
178 |
+
|
179 |
+
|
180 |
+
class WanRotaryPosEmbed(nn.Module):
|
181 |
+
def __init__(
|
182 |
+
self, attention_head_dim: int, patch_size: Tuple[int, int, int], max_seq_len: int, theta: float = 10000.0, addtional_qk_geo: bool = False
|
183 |
+
):
|
184 |
+
super().__init__()
|
185 |
+
|
186 |
+
if addtional_qk_geo: # to add PE to geometry embedding
|
187 |
+
attention_head_dim = attention_head_dim * 2
|
188 |
+
self.attention_head_dim = attention_head_dim
|
189 |
+
self.patch_size = patch_size
|
190 |
+
self.max_seq_len = max_seq_len
|
191 |
+
|
192 |
+
h_dim = w_dim = 2 * (attention_head_dim // 6)
|
193 |
+
t_dim = attention_head_dim - h_dim - w_dim
|
194 |
+
|
195 |
+
freqs = []
|
196 |
+
for dim in [t_dim, h_dim, w_dim]:
|
197 |
+
freq = get_1d_rotary_pos_embed(
|
198 |
+
dim, max_seq_len, theta, use_real=False, repeat_interleave_real=False, freqs_dtype=torch.float64
|
199 |
+
)
|
200 |
+
freqs.append(freq)
|
201 |
+
self.freqs = torch.cat(freqs, dim=1)
|
202 |
+
|
203 |
+
def forward(self, hidden_states: torch.Tensor, uv_hidden_states: torch.Tensor) -> torch.Tensor:
|
204 |
+
batch_size, num_channels, num_frames, height, width = hidden_states.shape
|
205 |
+
_, _, uv_num_frames, uv_height, uv_width = uv_hidden_states.shape
|
206 |
+
p_t, p_h, p_w = self.patch_size
|
207 |
+
ppf, pph, ppw = num_frames // p_t, height // p_h, width // p_w
|
208 |
+
uppf, upph, uppw = uv_num_frames // p_t, uv_height // p_h, uv_width // p_w
|
209 |
+
|
210 |
+
self.freqs = self.freqs.to(hidden_states.device)
|
211 |
+
freqs = self.freqs.split_with_sizes(
|
212 |
+
[
|
213 |
+
self.attention_head_dim // 2 - 2 * (self.attention_head_dim // 6),
|
214 |
+
self.attention_head_dim // 6,
|
215 |
+
self.attention_head_dim // 6,
|
216 |
+
],
|
217 |
+
dim=1,
|
218 |
+
)
|
219 |
+
|
220 |
+
freqs_f = freqs[0][:ppf].view(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1)
|
221 |
+
freqs_h = freqs[1][:pph].view(1, pph, 1, -1).expand(ppf, pph, ppw, -1)
|
222 |
+
freqs_w = freqs[2][:ppw].view(1, 1, ppw, -1).expand(ppf, pph, ppw, -1)
|
223 |
+
|
224 |
+
uv_freqs_f = freqs[0][ppf:ppf+uppf].view(uppf, 1, 1, -1).expand(uppf, upph, uppw, -1)
|
225 |
+
uv_freqs_h = freqs[1][:upph].view(1, upph, 1, -1).expand(uppf, upph, uppw, -1)
|
226 |
+
uv_freqs_w = freqs[2][:uppw].view(1, 1, uppw, -1).expand(uppf, upph, uppw, -1)
|
227 |
+
freqs = torch.cat([freqs_f, freqs_h, freqs_w], dim=-1).reshape(1, 1, ppf * pph * ppw, -1)
|
228 |
+
uv_freqs = torch.cat([uv_freqs_f, uv_freqs_h, uv_freqs_w], dim=-1).reshape(1, 1, uppf * upph * uppw, -1)
|
229 |
+
return torch.cat([freqs, uv_freqs], dim=-2)
|
230 |
+
|
231 |
+
# def pseudo_code(freqs, mv_tokens_shape, uv_tokens_shape, dimmension):
|
232 |
+
# """
|
233 |
+
# Input:
|
234 |
+
# freqs: [S, D/2], S is the number of tokens, D is the dimension of tokens, 2 indicates Cos and Sin in original RoPE.
|
235 |
+
# mv_tokens_shape: (mv_num_frames, mv_height, mv_width)
|
236 |
+
# uv_tokens_shape: (uv_num_frames, uv_height, uv_width)
|
237 |
+
# dimension: the dimension of tokens
|
238 |
+
# Output:
|
239 |
+
# """
|
240 |
+
# mpf, mph, mpw = mv_tokens_shape # mv_num_frames, mv_height, mv_width
|
241 |
+
# upf, uph, upw = uv_tokens_shape # uv_num_frames, uv_height, uv_width
|
242 |
+
|
243 |
+
# # 1. To evenly split the freqs into 3 parts
|
244 |
+
# freqs = freqs.split_with_sizes(
|
245 |
+
# [
|
246 |
+
# dimmension // 2 - 2 * (dimmension // 6),
|
247 |
+
# dimmension // 6,
|
248 |
+
# dimmension // 6,
|
249 |
+
# ],
|
250 |
+
# dim=1,
|
251 |
+
# )
|
252 |
+
|
253 |
+
# # 2. In time dimension, the freqs for UV are subsequent to the freqs for MV
|
254 |
+
# freqs_f = freqs[0][:mpf].view(mpf, 1, 1, -1).expand(mpf, mph, mpw, -1)
|
255 |
+
# uv_freqs_f = freqs[0][mpf:mpf+upf].view(upf, 1, 1, -1).expand(upf, uph, upw, -1)
|
256 |
+
|
257 |
+
# # 3. The freqs in height and width dimension are the same for mv and uv
|
258 |
+
# freqs_h = freqs[1][:mph].view(1, mph, 1, -1).expand(mpf, mph, mpw, -1)
|
259 |
+
# uv_freqs_h = freqs[1][:uph].view(1, uph, 1, -1).expand(upf, uph, upw, -1)
|
260 |
+
# freqs_w = freqs[2][:mpw].view(1, 1, mpw, -1).expand(mpf, mph, mpw, -1)
|
261 |
+
# uv_freqs_w = freqs[2][:upw].view(1, 1, upw, -1).expand(upf, uph, upw, -1)
|
262 |
+
|
263 |
+
# # 4. rearrange three 1D RoPEs into 3D RoPE in channel dimension
|
264 |
+
# mv_rope = torch.cat([freqs_f, freqs_h, freqs_w], dim=-1).reshape(mpf * mph * mpw, -1)
|
265 |
+
# uv_rope = torch.cat([uv_freqs_f, uv_freqs_h, uv_freqs_w], dim=-1).reshape(upf * uph * upw, -1)
|
266 |
+
# return torch.cat([mv_rope, uv_rope], dim=-2)
|
267 |
+
|
268 |
+
class WanT2TexTransformerBlock(nn.Module):
|
269 |
+
def __init__(
|
270 |
+
self,
|
271 |
+
dim: int,
|
272 |
+
ffn_dim: int,
|
273 |
+
num_heads: int,
|
274 |
+
qk_norm: str = "rms_norm_across_heads",
|
275 |
+
cross_attn_norm: bool = False,
|
276 |
+
eps: float = 1e-6,
|
277 |
+
added_kv_proj_dim: Optional[int] = None,
|
278 |
+
addtional_qk_geo: bool = False,
|
279 |
+
):
|
280 |
+
super().__init__()
|
281 |
+
|
282 |
+
# 1. Self-attention
|
283 |
+
self.norm1 = FP32LayerNorm(dim, eps, elementwise_affine=False)
|
284 |
+
self.attn1 = Attention(
|
285 |
+
query_dim=dim,
|
286 |
+
heads=num_heads,
|
287 |
+
kv_heads=num_heads,
|
288 |
+
dim_head=dim // num_heads,
|
289 |
+
qk_norm=qk_norm,
|
290 |
+
eps=eps,
|
291 |
+
bias=True,
|
292 |
+
cross_attention_dim=None,
|
293 |
+
out_bias=True,
|
294 |
+
processor=WanT2TexAttnProcessor2_0(),
|
295 |
+
)
|
296 |
+
|
297 |
+
# 2. Cross-attention
|
298 |
+
self.attn2 = Attention(
|
299 |
+
query_dim=dim,
|
300 |
+
heads=num_heads,
|
301 |
+
kv_heads=num_heads,
|
302 |
+
dim_head=dim // num_heads,
|
303 |
+
qk_norm=qk_norm,
|
304 |
+
eps=eps,
|
305 |
+
bias=True,
|
306 |
+
cross_attention_dim=None,
|
307 |
+
out_bias=True,
|
308 |
+
added_kv_proj_dim=added_kv_proj_dim,
|
309 |
+
added_proj_bias=True,
|
310 |
+
processor=WanT2TexAttnProcessor2_0(),
|
311 |
+
)
|
312 |
+
self.norm2 = FP32LayerNorm(dim, eps, elementwise_affine=True) if cross_attn_norm else nn.Identity()
|
313 |
+
|
314 |
+
# 3. Feed-forward
|
315 |
+
self.ffn = FeedForward(dim, inner_dim=ffn_dim, activation_fn="gelu-approximate")
|
316 |
+
self.norm3 = FP32LayerNorm(dim, eps, elementwise_affine=False)
|
317 |
+
|
318 |
+
self.scale_shift_table = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5)
|
319 |
+
|
320 |
+
self.geometry_caster = nn.Linear(dim, dim)
|
321 |
+
nn.init.zeros_(self.geometry_caster.weight.data)
|
322 |
+
nn.init.zeros_(self.geometry_caster.bias.data)
|
323 |
+
|
324 |
+
self.attnuv = Attention(
|
325 |
+
query_dim=dim,
|
326 |
+
heads=num_heads,
|
327 |
+
kv_heads=num_heads,
|
328 |
+
dim_head=dim // num_heads,
|
329 |
+
qk_norm=qk_norm,
|
330 |
+
eps=eps,
|
331 |
+
bias=True,
|
332 |
+
cross_attention_dim=None,
|
333 |
+
out_bias=True,
|
334 |
+
processor=WanT2TexAttnProcessor2_0(),
|
335 |
+
)
|
336 |
+
self.normuv2 = FP32LayerNorm(dim, eps, elementwise_affine=True)
|
337 |
+
self.scale_shift_table_uv = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5)
|
338 |
+
self.ffnuv = FeedForward(dim, inner_dim=ffn_dim, activation_fn="gelu-approximate")
|
339 |
+
|
340 |
+
def forward(
|
341 |
+
self,
|
342 |
+
hidden_states: torch.Tensor,
|
343 |
+
encoder_hidden_states: torch.Tensor,
|
344 |
+
temb: torch.Tensor,
|
345 |
+
rotary_emb: torch.Tensor,
|
346 |
+
attn_bias: Optional[torch.Tensor] = None,
|
347 |
+
geometry_embedding: Optional[torch.Tensor] = None,
|
348 |
+
token_shape: Optional[Tuple[int, int, int, int, int, int]] = None,
|
349 |
+
) -> torch.Tensor:
|
350 |
+
post_patch_num_frames, post_patch_height, post_patch_width, post_uv_num_frames, post_uv_height, post_uv_width = token_shape
|
351 |
+
mv_temb, uv_temb = temb[:, :post_patch_num_frames], temb[:, post_patch_num_frames:]
|
352 |
+
mv_temb = repeat(mv_temb, "B F N D -> B N (F H W) D", H=post_patch_height, W=post_patch_width)
|
353 |
+
uv_temb = repeat(uv_temb, "B F N D -> B N (F H W) D", H=post_uv_height, W=post_uv_width)
|
354 |
+
dit_ssg = rearrange(self.scale_shift_table, "1 N D -> 1 N 1 D") + mv_temb.float()
|
355 |
+
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = torch.unbind(dit_ssg, dim=1)
|
356 |
+
dit_ssg_uv = rearrange(self.scale_shift_table_uv, "1 N D -> 1 N 1 D") + uv_temb.float()
|
357 |
+
shift_msa_uv, scale_msa_uv, gate_msa_uv, c_shift_msa_uv, c_scale_msa_uv, c_gate_msa_uv = torch.unbind(dit_ssg_uv, dim=1)
|
358 |
+
|
359 |
+
geometry_embedding = self.geometry_caster(geometry_embedding)
|
360 |
+
|
361 |
+
n_mv, n_uv = post_patch_num_frames * post_patch_height * post_patch_width, post_uv_num_frames * post_uv_height * post_uv_width
|
362 |
+
assert hidden_states.shape[1] == n_mv + n_uv, f"hidden_states shape {hidden_states.shape} is not equal to {n_mv + n_uv}"
|
363 |
+
mv_hidden_states, uv_hidden_states = hidden_states[:, :n_mv], hidden_states[:, n_mv:]
|
364 |
+
|
365 |
+
# 1. Self-attention
|
366 |
+
mv_norm_hidden_states = (self.norm1(mv_hidden_states.float()) * (1 + scale_msa) + shift_msa).type_as(mv_hidden_states)
|
367 |
+
uv_norm_hidden_states = (self.norm1(uv_hidden_states.float()) * (1 + scale_msa_uv) + shift_msa_uv).type_as(uv_hidden_states)
|
368 |
+
|
369 |
+
mv_attn_output = self.attn1(hidden_states=mv_norm_hidden_states, rotary_emb=rotary_emb[:, :, :n_mv], attention_mask=attn_bias, geometry_embedding=geometry_embedding[:, :n_mv])
|
370 |
+
mv_hidden_states = (mv_hidden_states.float() + mv_attn_output * gate_msa).type_as(mv_hidden_states)
|
371 |
+
uv_attn_output = self.attnuv(hidden_states=uv_norm_hidden_states, encoder_hidden_states=torch.cat([mv_hidden_states, uv_norm_hidden_states], dim=1),
|
372 |
+
rotary_emb=(rotary_emb[:, :, n_mv:], rotary_emb), geometry_embedding=(geometry_embedding[:, n_mv:], geometry_embedding))
|
373 |
+
uv_hidden_states = (uv_hidden_states.float() + uv_attn_output * gate_msa_uv).type_as(uv_hidden_states)
|
374 |
+
|
375 |
+
# 2. Cross-attention
|
376 |
+
mv_norm_hidden_states = self.norm2(mv_hidden_states.float()).type_as(mv_hidden_states)
|
377 |
+
uv_norm_hidden_states = self.normuv2(uv_hidden_states.float()).type_as(uv_hidden_states)
|
378 |
+
attn_output = self.attn2(hidden_states=torch.cat([mv_norm_hidden_states, uv_norm_hidden_states], dim=1), encoder_hidden_states=encoder_hidden_states)
|
379 |
+
mv_attn_output, uv_attn_output = attn_output[:, :n_mv], attn_output[:, n_mv:]
|
380 |
+
mv_hidden_states.add_(mv_attn_output)
|
381 |
+
uv_hidden_states.add_(uv_attn_output)
|
382 |
+
|
383 |
+
# 3. Feed-forward
|
384 |
+
mv_norm_hidden_states = (self.norm3(mv_hidden_states.float()) * (1 + c_scale_msa) + c_shift_msa).type_as(
|
385 |
+
mv_hidden_states
|
386 |
+
)
|
387 |
+
uv_norm_hidden_states = (self.norm3(uv_hidden_states.float()) * (1 + c_scale_msa_uv) + c_shift_msa_uv).type_as(
|
388 |
+
uv_hidden_states
|
389 |
+
)
|
390 |
+
ff_output = self.ffn(mv_norm_hidden_states)
|
391 |
+
mv_hidden_states = (mv_hidden_states.float() + ff_output.float() * c_gate_msa).type_as(mv_hidden_states)
|
392 |
+
ff_output_uv = self.ffnuv(uv_norm_hidden_states)
|
393 |
+
uv_hidden_states = (uv_hidden_states.float() + ff_output_uv.float() * c_gate_msa_uv).type_as(uv_hidden_states)
|
394 |
+
hidden_states = torch.cat([mv_hidden_states, uv_hidden_states], dim=1)
|
395 |
+
|
396 |
+
return hidden_states
|
397 |
+
|
398 |
+
|
399 |
+
class WanT2TexTransformer3DModel(WanTransformer3DModel):
|
400 |
+
"""
|
401 |
+
3D Transformer model for T2Tex.
|
402 |
+
"""
|
403 |
+
def __init__(self, original_model, use_causal_mask=False, addtional_qk_geo=False, randomly_init=False, **kwargs):
|
404 |
+
super(WanT2TexTransformer3DModel, self).__init__(**original_model.config)
|
405 |
+
if not randomly_init:
|
406 |
+
self.load_state_dict(original_model.state_dict(), strict=True)
|
407 |
+
self.addtional_qk_geo = addtional_qk_geo
|
408 |
+
if addtional_qk_geo:
|
409 |
+
raise ValueError("addtional_qk_geo did not work")
|
410 |
+
warn("addtional_qk_geo is set to True, this will drastically increase the memory usage and slow down the training, without significant performance gain.")
|
411 |
+
|
412 |
+
# 1. Patch & position embedding
|
413 |
+
self.rope = WanRotaryPosEmbed(self.rope.attention_head_dim, self.rope.patch_size, self.rope.max_seq_len, addtional_qk_geo=addtional_qk_geo)
|
414 |
+
self.use_normal, self.use_position = kwargs.get("use_normal", True), kwargs.get("use_position", True)
|
415 |
+
if self.use_normal:
|
416 |
+
self.norm_patch_embedding = copy.deepcopy(self.patch_embedding)
|
417 |
+
# torch.nn.init.zeros_(self.norm_patch_embedding.weight.data)
|
418 |
+
# torch.nn.init.zeros_(self.norm_patch_embedding.bias.data)
|
419 |
+
if self.use_position:
|
420 |
+
self.pos_patch_embedding = copy.deepcopy(self.patch_embedding)
|
421 |
+
# torch.nn.init.zeros_(self.pos_patch_embedding.weight.data)
|
422 |
+
# torch.nn.init.zeros_(self.pos_patch_embedding.bias.data)
|
423 |
+
|
424 |
+
# 2. Condition embeddings
|
425 |
+
inner_dim = original_model.config.num_attention_heads * original_model.config.attention_head_dim
|
426 |
+
self.condition_embedder = WanTimeTaskTextImageEmbedding(
|
427 |
+
original_model=self.condition_embedder,
|
428 |
+
dim=inner_dim,
|
429 |
+
time_freq_dim=original_model.config.freq_dim,
|
430 |
+
time_proj_dim=inner_dim * 6,
|
431 |
+
text_embed_dim=original_model.config.text_dim,
|
432 |
+
image_embed_dim=original_model.config.image_dim,
|
433 |
+
randomly_init=randomly_init,
|
434 |
+
)
|
435 |
+
|
436 |
+
# 3. Transformer blocks
|
437 |
+
self.use_causal_mask = use_causal_mask
|
438 |
+
self.num_attention_heads = original_model.config.num_attention_heads
|
439 |
+
|
440 |
+
block = WanT2TexTransformerBlock(
|
441 |
+
inner_dim,
|
442 |
+
original_model.config.ffn_dim,
|
443 |
+
original_model.config.num_attention_heads,
|
444 |
+
original_model.config.qk_norm,
|
445 |
+
original_model.config.cross_attn_norm,
|
446 |
+
original_model.config.eps,
|
447 |
+
original_model.config.added_kv_proj_dim,
|
448 |
+
)
|
449 |
+
self.blocks = None
|
450 |
+
self.blocks = nn.ModuleList(
|
451 |
+
[
|
452 |
+
copy.deepcopy(block)
|
453 |
+
for _ in range(original_model.config.num_layers)
|
454 |
+
]
|
455 |
+
)
|
456 |
+
self.scale_shift_table_uv = nn.Parameter(torch.randn(1, 2, inner_dim) / inner_dim**0.5)
|
457 |
+
if not randomly_init:
|
458 |
+
self.scale_shift_table_uv.data.copy_(self.scale_shift_table.data)
|
459 |
+
self.blocks.load_state_dict(original_model.blocks.state_dict(), strict=False)
|
460 |
+
for block in self.blocks:
|
461 |
+
block.attnuv.load_state_dict(block.attn1.state_dict())
|
462 |
+
block.scale_shift_table_uv.data.copy_(block.scale_shift_table.data)
|
463 |
+
block.normuv2.load_state_dict(block.norm2.state_dict())
|
464 |
+
block.ffnuv.load_state_dict(block.ffn.state_dict())
|
465 |
+
|
466 |
+
# 4. Output norm & projection
|
467 |
+
pass
|
468 |
+
|
469 |
+
@cache
|
470 |
+
def get_attention_bias(self, mv_length, uv_length):
|
471 |
+
total_len = mv_length + uv_length
|
472 |
+
attention_mask = torch.ones((total_len, total_len), dtype=torch.bool)
|
473 |
+
uv_start = mv_length
|
474 |
+
attention_mask[:uv_start, uv_start:] = False
|
475 |
+
|
476 |
+
attention_mask = repeat(attention_mask, "s l -> 1 h s l", h=self.num_attention_heads)
|
477 |
+
attention_bias = torch.ones_like(attention_mask)
|
478 |
+
attention_bias.masked_fill_(attention_mask.logical_not(), float("-inf"))
|
479 |
+
attention_bias = attention_bias.to("cuda").contiguous()
|
480 |
+
return attention_bias
|
481 |
+
|
482 |
+
def forward(
|
483 |
+
self,
|
484 |
+
hidden_states: Tuple[torch.Tensor, torch.Tensor],
|
485 |
+
timestep: torch.LongTensor,
|
486 |
+
encoder_hidden_states: torch.Tensor,
|
487 |
+
encoder_hidden_states_image: Optional[torch.Tensor] = None,
|
488 |
+
# task_cond: Optional[torch.Tensor] = None,
|
489 |
+
return_dict: bool = True,
|
490 |
+
attention_kwargs: Optional[Dict[str, Any]] = None,
|
491 |
+
use_qk_geometry: Optional[bool] = False,
|
492 |
+
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
|
493 |
+
if attention_kwargs is not None:
|
494 |
+
attention_kwargs = attention_kwargs.copy()
|
495 |
+
lora_scale = attention_kwargs.pop("scale", 1.0)
|
496 |
+
else:
|
497 |
+
lora_scale = 1.0
|
498 |
+
|
499 |
+
if USE_PEFT_BACKEND:
|
500 |
+
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
501 |
+
scale_lora_layers(self, lora_scale)
|
502 |
+
else:
|
503 |
+
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
|
504 |
+
raise NotImplementedError()
|
505 |
+
|
506 |
+
assert timestep.ndim == 2, "Use Diffusion Forcing to set seperate timestep for each frame."
|
507 |
+
|
508 |
+
mv_hidden_states, uv_hidden_states = hidden_states
|
509 |
+
|
510 |
+
batch_size, num_channels, num_frames, height, width = mv_hidden_states.shape
|
511 |
+
_, _, uv_num_frames, uv_height, uv_width = uv_hidden_states.shape
|
512 |
+
|
513 |
+
p_t, p_h, p_w = self.config.patch_size
|
514 |
+
post_patch_num_frames = num_frames // p_t
|
515 |
+
post_patch_height = height // p_h
|
516 |
+
post_patch_width = width // p_w
|
517 |
+
post_uv_num_frames = uv_num_frames // p_t
|
518 |
+
post_uv_height = uv_height // p_h
|
519 |
+
post_uv_width = uv_width // p_w
|
520 |
+
|
521 |
+
rotary_emb = self.rope(mv_hidden_states, uv_hidden_states)
|
522 |
+
|
523 |
+
# Patchify
|
524 |
+
if self.use_normal and self.use_position:
|
525 |
+
mv_rgb_hidden_states, mv_pos_hidden_states, mv_norm_hidden_states = torch.chunk(mv_hidden_states, 3, dim=1)
|
526 |
+
uv_rgb_hidden_states, uv_pos_hidden_states, uv_norm_hidden_states = torch.chunk(uv_hidden_states, 3, dim=1)
|
527 |
+
mv_geometry_embedding = self.pos_patch_embedding(mv_pos_hidden_states) + self.norm_patch_embedding(mv_norm_hidden_states)
|
528 |
+
uv_geometry_embedding = self.pos_patch_embedding(uv_pos_hidden_states) + self.norm_patch_embedding(uv_norm_hidden_states)
|
529 |
+
elif self.use_normal:
|
530 |
+
mv_rgb_hidden_states, mv_norm_hidden_states = torch.chunk(mv_hidden_states, 2, dim=1)
|
531 |
+
uv_rgb_hidden_states, uv_norm_hidden_states = torch.chunk(uv_hidden_states, 2, dim=1)
|
532 |
+
mv_geometry_embedding = self.norm_patch_embedding(mv_norm_hidden_states)
|
533 |
+
uv_geometry_embedding = self.norm_patch_embedding(uv_norm_hidden_states)
|
534 |
+
elif self.use_position:
|
535 |
+
mv_rgb_hidden_states, mv_pos_hidden_states = torch.chunk(mv_hidden_states, 2, dim=1)
|
536 |
+
uv_rgb_hidden_states, uv_pos_hidden_states = torch.chunk(uv_hidden_states, 2, dim=1)
|
537 |
+
mv_geometry_embedding = self.pos_patch_embedding(mv_pos_hidden_states)
|
538 |
+
uv_geometry_embedding = self.pos_patch_embedding(uv_pos_hidden_states)
|
539 |
+
else:
|
540 |
+
raise ValueError("use_normal and use_position are both False, please set at least one of them to True.")
|
541 |
+
|
542 |
+
mv_hidden_states = self.patch_embedding(mv_rgb_hidden_states)
|
543 |
+
uv_hidden_states = self.patch_embedding(uv_rgb_hidden_states)
|
544 |
+
if use_qk_geometry:
|
545 |
+
mv_geometry_embedding = mv_geometry_embedding.flatten(2).transpose(1, 2)
|
546 |
+
uv_geometry_embedding = uv_geometry_embedding.flatten(2).transpose(1, 2) # [B, F*H*W, C]
|
547 |
+
geometry_embedding = torch.cat([mv_geometry_embedding, uv_geometry_embedding], dim=1)
|
548 |
+
else:
|
549 |
+
raise NotImplementedError("please set use_qk_geometry to True")
|
550 |
+
# geometry_embedding = None
|
551 |
+
# mv_hidden_states = mv_hidden_states + mv_geometry_embedding
|
552 |
+
# uv_hidden_states = uv_hidden_states + uv_geometry_embedding
|
553 |
+
|
554 |
+
mv_hidden_states = mv_hidden_states.flatten(2).transpose(1, 2)
|
555 |
+
uv_hidden_states = uv_hidden_states.flatten(2).transpose(1, 2) # [B, F*H*W, C]
|
556 |
+
hidden_states = torch.cat([mv_hidden_states, uv_hidden_states], dim=1) # [B, F*H*W, C]
|
557 |
+
|
558 |
+
temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = self.condition_embedder(
|
559 |
+
timestep, encoder_hidden_states, encoder_hidden_states_image
|
560 |
+
)
|
561 |
+
# temb [B, F, 6*D], timestep_proj [B, F, 6*D], used to be [B, 6*D]
|
562 |
+
timestep_proj = timestep_proj.unflatten(-1, (6, -1)) # [B, F, 6*D] -> [B, F, 6, D]
|
563 |
+
|
564 |
+
if encoder_hidden_states_image is not None:
|
565 |
+
encoder_hidden_states = torch.concat([encoder_hidden_states_image, encoder_hidden_states], dim=1)
|
566 |
+
|
567 |
+
# # Get attention bias
|
568 |
+
# if self.use_causal_mask:
|
569 |
+
# # This may be gainless, because the patch embedding is not causal, which will leak information to MV
|
570 |
+
# attn_bias = self.get_attention_bias(post_patch_num_frames * post_patch_height * post_patch_width,
|
571 |
+
# post_uv_num_frames * post_uv_height * post_uv_width)
|
572 |
+
# else:
|
573 |
+
attn_bias = None
|
574 |
+
|
575 |
+
# 4. Transformer blocks
|
576 |
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
577 |
+
for block in self.blocks:
|
578 |
+
hidden_states = self._gradient_checkpointing_func(
|
579 |
+
block, hidden_states, encoder_hidden_states, timestep_proj, rotary_emb,
|
580 |
+
attn_bias, geometry_embedding, (post_patch_num_frames, post_patch_height, post_patch_width, post_uv_num_frames, post_uv_height, post_uv_width)
|
581 |
+
)
|
582 |
+
else:
|
583 |
+
for block in self.blocks:
|
584 |
+
hidden_states = block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb,
|
585 |
+
attn_bias=attn_bias, geometry_embedding=geometry_embedding,
|
586 |
+
token_shape=(post_patch_num_frames, post_patch_height, post_patch_width, post_uv_num_frames, post_uv_height, post_uv_width))
|
587 |
+
|
588 |
+
# 5. Output norm, projection & unpatchify
|
589 |
+
# [B, 2, D] chunk into [B, 1, D] and [B, 1, D], D is 1536
|
590 |
+
inner_dim = self.config.num_attention_heads * self.config.attention_head_dim
|
591 |
+
mv_temb, uv_temb = temb[:, :post_patch_num_frames], temb[:, post_patch_num_frames:]
|
592 |
+
mv_temb = repeat(mv_temb, "B F D -> B 1 (F H W) D", H=post_patch_height, W=post_patch_width)
|
593 |
+
uv_temb = repeat(uv_temb, "B F D -> B 1 (F H W) D", H=post_uv_height, W=post_uv_width)
|
594 |
+
shift, scale = (self.scale_shift_table.view(1, 2, 1, inner_dim) + mv_temb).chunk(2, dim=1)
|
595 |
+
shift_uv, scale_uv = (self.scale_shift_table_uv.view(1, 2, 1, inner_dim) + uv_temb).chunk(2, dim=1)
|
596 |
+
|
597 |
+
# Move the shift and scale tensors to the same device as hidden_states.
|
598 |
+
# When using multi-GPU inference via accelerate these will be on the
|
599 |
+
# first device rather than the last device, which hidden_states ends up
|
600 |
+
# on.
|
601 |
+
shift = shift.squeeze(1).to(hidden_states.device)
|
602 |
+
scale = scale.squeeze(1).to(hidden_states.device)
|
603 |
+
shift_uv = shift_uv.squeeze(1).to(hidden_states.device)
|
604 |
+
scale_uv = scale_uv.squeeze(1).to(hidden_states.device)
|
605 |
+
|
606 |
+
# Unpatchify
|
607 |
+
uv_token_length = post_uv_num_frames * post_uv_height * post_uv_width
|
608 |
+
mv_token_length = post_patch_num_frames * post_patch_height * post_patch_width
|
609 |
+
assert uv_token_length + mv_token_length == hidden_states.shape[1]
|
610 |
+
uv_hidden_states = hidden_states[:, mv_token_length:]
|
611 |
+
mv_hidden_states = hidden_states[:, :mv_token_length]
|
612 |
+
|
613 |
+
mv_hidden_states = (self.norm_out(mv_hidden_states.float()) * (1 + scale) + shift).type_as(mv_hidden_states)
|
614 |
+
uv_hidden_states = (self.norm_out(uv_hidden_states.float()) * (1 + scale_uv) + shift_uv).type_as(uv_hidden_states)
|
615 |
+
mv_hidden_states = self.proj_out(mv_hidden_states)
|
616 |
+
uv_hidden_states = self.proj_out(uv_hidden_states)
|
617 |
+
|
618 |
+
mv_hidden_states = mv_hidden_states.reshape(
|
619 |
+
batch_size, post_patch_num_frames, post_patch_height, post_patch_width, p_t, p_h, p_w, -1
|
620 |
+
)
|
621 |
+
mv_hidden_states = mv_hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6)
|
622 |
+
mv_output = mv_hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3)
|
623 |
+
uv_hidden_states = uv_hidden_states.reshape(
|
624 |
+
batch_size, post_uv_num_frames, post_uv_height, post_uv_width, p_t, p_h, p_w, -1
|
625 |
+
)
|
626 |
+
uv_hidden_states = uv_hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6)
|
627 |
+
uv_output = uv_hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3)
|
628 |
+
|
629 |
+
if USE_PEFT_BACKEND:
|
630 |
+
# remove `lora_scale` from each PEFT layer
|
631 |
+
unscale_lora_layers(self, lora_scale)
|
632 |
+
|
633 |
+
return ((mv_output, uv_output),)
|
634 |
+
|