Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
6d4bcdf
1
Parent(s):
1d5bb62
init space 2
Browse files- .gitignore +1 -0
- README.md +23 -0
- THIRD_PARTY_LICENSES.md +127 -0
- app.py +183 -222
- examples/shoe.glb +3 -0
- packages.txt +1 -0
- requirements.txt +24 -0
- utils/__init__.py +67 -0
- utils/file_utils.py +15 -0
- utils/image_generation.py +38 -13
- utils/mesh_utils.py +26 -6
- utils/rasterize.py +13 -5
- utils/render_utils.py +36 -4
- utils/texture_generation.py +55 -98
- wan/pipeline_wan_t2tex_extra.py +30 -7
- wan/wan_t2tex_transformer_3d_extra.py +84 -114
.gitignore
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
__pycache__/
|
README.md
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: SeqTex
|
3 |
+
emoji: 🗺️
|
4 |
+
colorFrom: yellow
|
5 |
+
colorTo: green
|
6 |
+
sdk: gradio
|
7 |
+
sdk_version: 5.34.2
|
8 |
+
python_version: 3.12
|
9 |
+
models:
|
10 |
+
- Wan-AI/Wan2.1-T2V-1.3B-Diffusers
|
11 |
+
- VAST-AI/SeqTex-Transformer
|
12 |
+
- black-forest-labs/FLUX.1-dev
|
13 |
+
- Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro-2.0
|
14 |
+
- madebyollin/sdxl-vae-fp16-fix
|
15 |
+
- stabilityai/stable-diffusion-xl-base-1.0
|
16 |
+
- xinsir/controlnet-union-sdxl-1.0
|
17 |
+
app_file: app.py
|
18 |
+
pinned: false
|
19 |
+
license: mit
|
20 |
+
short_description: SeqTex generates texture based on textual conditions
|
21 |
+
---
|
22 |
+
|
23 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
THIRD_PARTY_LICENSES.md
ADDED
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
# Third Party Licenses
|
3 |
+
|
4 |
+
This project uses third-party libraries that are subject to their own licenses.
|
5 |
+
|
6 |
+
## nvdiffrast
|
7 |
+
|
8 |
+
**Project:** https://github.com/NVlabs/nvdiffrast
|
9 |
+
**License:** NVIDIA Source Code License (1-Way Commercial)
|
10 |
+
**Usage:** Non-commercial use (research or evaluation purposes only)
|
11 |
+
|
12 |
+
```text
|
13 |
+
Copyright (c) 2020, NVIDIA Corporation. All rights reserved.
|
14 |
+
|
15 |
+
This work is made available under the Nvidia Source Code License (1-Way Commercial).
|
16 |
+
The Work and any derivative works thereof only may be used or intended for use
|
17 |
+
non-commercially. "Non-commercially" means for research or evaluation purposes only
|
18 |
+
and not for any direct or indirect monetary gain.
|
19 |
+
|
20 |
+
Full license: https://github.com/NVlabs/nvdiffrast/blob/main/LICENSE.txt
|
21 |
+
```
|
22 |
+
|
23 |
+
**Key Points:**
|
24 |
+
- ✅ Research/Academic Use: Permitted
|
25 |
+
- ❌ Commercial Use: Requires separate licensing from NVIDIA
|
26 |
+
- 📞 Commercial Licensing: https://www.nvidia.com/en-us/research/inquiries/
|
27 |
+
|
28 |
+
## Wan Team Libraries
|
29 |
+
|
30 |
+
**Project:** Various components in `wan/` directory
|
31 |
+
**License:** Apache License 2.0
|
32 |
+
**Copyright:** Copyright (c) 2024 Wan Team
|
33 |
+
|
34 |
+
```text
|
35 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
36 |
+
you may not use this file except in compliance with the License.
|
37 |
+
You may obtain a copy of the License at
|
38 |
+
|
39 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
40 |
+
|
41 |
+
Unless required by applicable law or agreed to in writing, software
|
42 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
43 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
44 |
+
See the License for the specific language governing permissions and
|
45 |
+
limitations under the License.
|
46 |
+
```
|
47 |
+
|
48 |
+
## Hugging Face Diffusers
|
49 |
+
|
50 |
+
**Project:** https://github.com/huggingface/diffusers
|
51 |
+
**License:** Apache License 2.0
|
52 |
+
**Copyright:** Copyright 2024 The HuggingFace Team
|
53 |
+
|
54 |
+
```text
|
55 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
56 |
+
you may not use this file except in compliance with the License.
|
57 |
+
You may obtain a copy of the License at
|
58 |
+
|
59 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
60 |
+
|
61 |
+
Unless required by applicable law or agreed to in writing, software
|
62 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
63 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
64 |
+
See the License for the specific language governing permissions and
|
65 |
+
limitations under the License.
|
66 |
+
```
|
67 |
+
|
68 |
+
## Hugging Face Transformers
|
69 |
+
|
70 |
+
**Project:** https://github.com/huggingface/transformers
|
71 |
+
**License:** Apache License 2.0
|
72 |
+
**Copyright:** Copyright 2024 The HuggingFace Team
|
73 |
+
|
74 |
+
```text
|
75 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
76 |
+
you may not use this file except in compliance with the License.
|
77 |
+
You may obtain a copy of the License at
|
78 |
+
|
79 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
80 |
+
|
81 |
+
Unless required by applicable law or agreed to in writing, software
|
82 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
83 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
84 |
+
See the License for the specific language governing permissions and
|
85 |
+
limitations under the License.
|
86 |
+
```
|
87 |
+
|
88 |
+
## PEFT (Parameter-Efficient Fine-Tuning)
|
89 |
+
|
90 |
+
**Project:** https://github.com/huggingface/peft
|
91 |
+
**License:** Apache License 2.0
|
92 |
+
**Copyright:** Copyright 2024 The HuggingFace Team
|
93 |
+
|
94 |
+
```text
|
95 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
96 |
+
you may not use this file except in compliance with the License.
|
97 |
+
You may obtain a copy of the License at
|
98 |
+
|
99 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
100 |
+
|
101 |
+
Unless required by applicable law or agreed to in writing, software
|
102 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
103 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
104 |
+
See the License for the specific language governing permissions and
|
105 |
+
limitations under the License.
|
106 |
+
```
|
107 |
+
|
108 |
+
## Other Dependencies
|
109 |
+
|
110 |
+
The following dependencies are used under their respective licenses:
|
111 |
+
|
112 |
+
- **PyTorch & TorchVision**: BSD-3-Clause License
|
113 |
+
- **Einops**: MIT License
|
114 |
+
- **OmegaConf**: BSD-3-Clause License
|
115 |
+
- **Trimesh**: MIT License
|
116 |
+
- **Gradio**: Apache License 2.0
|
117 |
+
- **OpenCV**: BSD-3-Clause License
|
118 |
+
- **NumPy**: BSD-3-Clause License
|
119 |
+
- **ImageIO**: BSD-2-Clause License
|
120 |
+
|
121 |
+
## Individual Contributors
|
122 |
+
|
123 |
+
**Qi Xin** - Condition Transformer implementation in `utils/controlnet_union.py`
|
124 |
+
- Copyright by Qi Xin (2024/07/06)
|
125 |
+
- Condition Transformer component for fusing single/multi conditions with input image
|
126 |
+
|
127 |
+
For the complete list of dependencies and their licenses, please refer to the respective package repositories.
|
app.py
CHANGED
@@ -1,231 +1,192 @@
|
|
1 |
-
import
|
2 |
-
import
|
3 |
-
|
4 |
-
from
|
5 |
-
from utils.image_generation import generate_image_condition
|
6 |
from utils.mesh_utils import Mesh
|
7 |
from utils.render_utils import render_views
|
8 |
-
from utils.texture_generation import generate_texture
|
9 |
-
|
10 |
-
import gradio as gr
|
11 |
-
from gradio_litmodel3d import LitModel3D
|
12 |
|
13 |
EXAMPLES = [
|
14 |
["examples/birdhouse.glb", True, False, False, False, 42, "First View", "SDXL", False, "A rustic birdhouse featuring a snow-covered roof, wood textures, and two decorative cardinal birds. It has a circular entryway and conveys a winter-themed aesthetic."],
|
15 |
-
["examples/
|
|
|
16 |
]
|
|
|
|
|
17 |
|
18 |
-
|
19 |
-
""
|
20 |
-
Convert tensor to PIL Image.
|
21 |
-
:param tensor: torch.Tensor, shape can be (Nv, H, W, C), (Nv, C, H, W), (H, W, C), (C, H, W)
|
22 |
-
:param mask: torch.Tensor, shape same as tensor, effective when C=3
|
23 |
-
:return: PIL.Image
|
24 |
-
"""
|
25 |
-
# Move to cpu
|
26 |
-
tensor = tensor.detach()
|
27 |
-
if tensor.is_cuda:
|
28 |
-
tensor = tensor.cpu()
|
29 |
-
if mask is not None and mask.is_cuda:
|
30 |
-
mask = mask.cpu()
|
31 |
-
|
32 |
-
# Convert to float32
|
33 |
-
tensor = tensor.float()
|
34 |
-
if mask is not None:
|
35 |
-
mask = mask.float()
|
36 |
-
|
37 |
-
if normalize:
|
38 |
-
tensor = (tensor + 1.0) / 2.0
|
39 |
-
tensor = torch.clamp(tensor, 0.0, 1.0)
|
40 |
-
if mask is not None:
|
41 |
-
if mask.shape[-1] not in [1, 3]:
|
42 |
-
mask = mask.unsqueeze(-1)
|
43 |
-
tensor = torch.cat([tensor, mask], dim=-1)
|
44 |
-
|
45 |
-
shape = tensor.shape
|
46 |
-
# 4D: (Nv, H, W, C) or (Nv, C, H, W)
|
47 |
-
if len(shape) == 4:
|
48 |
-
Nv = shape[0]
|
49 |
-
if shape[-1] in [3, 4]: # (Nv, H, W, C)
|
50 |
-
tensor = rearrange(tensor, 'nv h w c -> h (nv w) c')
|
51 |
-
else: # (Nv, C, H, W)
|
52 |
-
tensor = rearrange(tensor, 'nv c h w -> h (nv w) c')
|
53 |
-
# 3D: (H, W, C) or (C, H, W)
|
54 |
-
elif len(shape) == 3:
|
55 |
-
if shape[-1] in [3, 4]: # (H, W, C)
|
56 |
-
tensor = rearrange(tensor, 'h w c -> h w c')
|
57 |
-
else: # (C, H, W)
|
58 |
-
tensor = rearrange(tensor, 'c h w -> h w c')
|
59 |
-
else:
|
60 |
-
raise ValueError(f"Unsupported tensor shape: {shape}")
|
61 |
-
|
62 |
-
# Convert to numpy
|
63 |
-
np_img = (tensor.numpy() * 255).round().astype(np.uint8)
|
64 |
-
|
65 |
-
# Create PIL Image
|
66 |
-
if np_img.shape[2] == 3:
|
67 |
-
return Image.fromarray(np_img, mode="RGB")
|
68 |
-
elif np_img.shape[2] == 4:
|
69 |
-
return Image.fromarray(np_img, mode="RGBA")
|
70 |
-
else:
|
71 |
-
raise ValueError("Only support 3 or 4 channel images.")
|
72 |
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
gr.Markdown("""
|
78 |
-
## 🚀 Welcome to SeqTex!
|
79 |
-
**SeqTex** is a cutting-edge AI system that generates high-quality textures for 3D meshes using image prompts (here we use image generator to get them from textual prompts).
|
80 |
-
|
81 |
-
Choose to either **try our example models** below or **upload your own 3D mesh** to create stunning textures.
|
82 |
-
""")
|
83 |
-
|
84 |
-
gr.Markdown("---")
|
85 |
-
|
86 |
-
gr.Markdown("## 🔧 Step 1: Upload & Process 3D Mesh")
|
87 |
-
gr.Markdown("""
|
88 |
-
**📋 How to prepare your 3D mesh:**
|
89 |
-
- Upload your 3D mesh in **.obj** or **.glb** format
|
90 |
-
- **💡 Pro Tip**:
|
91 |
-
- For optimal results, ensure your mesh includes only one part with <span style="color:#e74c3c; font-weight:bold;">UV parameterization</span>
|
92 |
-
- Otherwise, we'll combine all parts and generate UV parameterization using *xAtlas* (may take longer for high-poly meshes; may also fail for certain meshes)
|
93 |
-
- **⚠️ Important**: We recommend adjusting your model using *Mesh Orientation Adjustments* to be **Z-UP oriented** for best results
|
94 |
-
""")
|
95 |
-
position_map_tensor, normal_map_tensor, position_images_tensor, normal_images_tensor, mask_images_tensor, w2cs, mesh, mvp_matrix = gr.State(), gr.State(), gr.State(), gr.State(), gr.State(), gr.State(), gr.State(), gr.State()
|
96 |
-
|
97 |
-
# fixed_texture_map = Image.open("image.webp").convert("RGB")
|
98 |
-
# Step 1
|
99 |
-
with gr.Row():
|
100 |
-
with gr.Column():
|
101 |
-
mesh_upload = gr.File(label="📁 Upload 3D Mesh", file_types=[".obj", ".glb"])
|
102 |
-
# uv_tool = gr.Radio(["xAtlas", "UVAtlas"], label="UV parameterizer", value="xAtlas")
|
103 |
-
|
104 |
-
gr.Markdown("**🔄 Mesh Orientation Adjustments** (if needed):")
|
105 |
-
y2z = gr.Checkbox(label="Y → Z Transform", value=False, info="Rotate: Y becomes Z, -Z becomes Y")
|
106 |
-
y2x = gr.Checkbox(label="Y → X Transform", value=False, info="Rotate: Y becomes X, -X becomes Y")
|
107 |
-
z2x = gr.Checkbox(label="Z → X Transform", value=False, info="Rotate: Z becomes X, -X becomes Z")
|
108 |
-
upside_down = gr.Checkbox(label="🔃 Flip Vertically", value=False, info="Fix upside-down mesh orientation")
|
109 |
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
Mesh.
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
""")
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
-
|
218 |
-
|
219 |
-
|
220 |
-
|
221 |
-
|
222 |
-
|
223 |
-
|
224 |
-
|
225 |
-
gr.
|
226 |
-
|
227 |
-
|
228 |
-
|
229 |
-
|
230 |
-
|
231 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import gradio as gr
|
3 |
+
|
4 |
+
from utils import tensor_to_pil
|
5 |
+
from utils.image_generation import generate_image_condition, get_flux_pipe, get_sdxl_pipe
|
6 |
from utils.mesh_utils import Mesh
|
7 |
from utils.render_utils import render_views
|
8 |
+
from utils.texture_generation import generate_texture, get_seqtex_pipe
|
|
|
|
|
|
|
9 |
|
10 |
EXAMPLES = [
|
11 |
["examples/birdhouse.glb", True, False, False, False, 42, "First View", "SDXL", False, "A rustic birdhouse featuring a snow-covered roof, wood textures, and two decorative cardinal birds. It has a circular entryway and conveys a winter-themed aesthetic."],
|
12 |
+
["examples/shoe.glb", True, False, False, False, 42, "Second View", "SDXL", False, "Modern sneaker exhibiting a mesh upper and wavy rubber outsole. Features include lacing for adjustability and padded components for comfort. Normal maps emphasize geometric detail."],
|
13 |
+
# ["examples/mario.glb", False, False, False, True, 6666, "Third View", "FLUX", True, "Mario, a cartoon character wearing a red cap and blue overalls, with brown hair and a mustache, and white gloves, in a fighting pose. The clothes he wears are not in a reflection mode."],
|
14 |
]
|
15 |
+
LOAD_FIRST = True
|
16 |
+
|
17 |
|
18 |
+
with gr.Blocks(delete_cache=(600, 600)) as demo:
|
19 |
+
gr.Markdown("# 🎨 SeqTex: Generate Mesh Textures in Video Sequence")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
|
21 |
+
gr.Markdown("""
|
22 |
+
## 🚀 Welcome to SeqTex!
|
23 |
+
**SeqTex** is a cutting-edge AI system that generates high-quality textures for 3D meshes using image prompts (here we use image generator to get them from textual prompts).
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
|
25 |
+
Choose to either **try our example models** below or **upload your own 3D mesh** to create stunning textures.
|
26 |
+
""")
|
27 |
+
|
28 |
+
gr.Markdown("---")
|
29 |
+
|
30 |
+
gr.Markdown("## 🔧 Step 1: Upload & Process 3D Mesh")
|
31 |
+
gr.Markdown("""
|
32 |
+
**📋 How to prepare your 3D mesh:**
|
33 |
+
- Upload your 3D mesh in **.obj** or **.glb** format
|
34 |
+
- **💡 Pro Tip**:
|
35 |
+
- For optimal results, ensure your mesh includes only one part with <span style="color:#e74c3c; font-weight:bold;">UV parameterization</span>
|
36 |
+
- Otherwise, we'll combine all parts and generate UV parameterization using *xAtlas* (may take longer for high-poly meshes; may also fail for certain meshes)
|
37 |
+
- **⚠️ Important**: We recommend adjusting your model using *Mesh Orientation Adjustments* to be **Z-UP oriented** for best results
|
38 |
+
""")
|
39 |
+
position_map_tensor_path = gr.State()
|
40 |
+
normal_map_tensor_path = gr.State()
|
41 |
+
position_images_tensor_path = gr.State()
|
42 |
+
normal_images_tensor_path = gr.State()
|
43 |
+
mask_images_tensor_path = gr.State()
|
44 |
+
w2c_tensor_path = gr.State()
|
45 |
+
mesh = gr.State()
|
46 |
+
mvp_matrix_tensor_path = gr.State()
|
47 |
+
|
48 |
+
# fixed_texture_map = Image.open("image.webp").convert("RGB")
|
49 |
+
# Step 1
|
50 |
+
with gr.Row():
|
51 |
+
with gr.Column():
|
52 |
+
mesh_upload = gr.File(label="📁 Upload 3D Mesh", file_types=[".obj", ".glb"])
|
53 |
+
# uv_tool = gr.Radio(["xAtlas", "UVAtlas"], label="UV parameterizer", value="xAtlas")
|
54 |
+
|
55 |
+
gr.Markdown("**🔄 Mesh Orientation Adjustments** (if needed):")
|
56 |
+
y2z = gr.Checkbox(label="Y → Z Transform", value=False, info="Rotate: Y becomes Z, -Z becomes Y")
|
57 |
+
y2x = gr.Checkbox(label="Y → X Transform", value=False, info="Rotate: Y becomes X, -X becomes Y")
|
58 |
+
z2x = gr.Checkbox(label="Z → X Transform", value=False, info="Rotate: Z becomes X, -X becomes Z")
|
59 |
+
upside_down = gr.Checkbox(label="🔃 Flip Vertically", value=False, info="Fix upside-down mesh orientation")
|
60 |
+
step1_button = gr.Button("🔄 Process Mesh & Generate Views", variant="primary")
|
61 |
+
step1_progress = gr.Textbox(label="📊 Processing Status", interactive=False)
|
62 |
+
|
63 |
+
with gr.Column():
|
64 |
+
model_input = gr.Model3D(label="📐 Processed 3D Model", height=500)
|
65 |
+
|
66 |
+
with gr.Row(equal_height=True):
|
67 |
+
rgb_views = gr.Image(label="📷 Generated Views", type="pil", scale=3)
|
68 |
+
position_map = gr.Image(label="🗺️ Position Map", type="pil", scale=1)
|
69 |
+
normal_map = gr.Image(label="🧭 Normal Map", type="pil", scale=1)
|
70 |
+
|
71 |
+
step1_button.click(
|
72 |
+
Mesh.process,
|
73 |
+
inputs=[mesh_upload, gr.State("xAtlas"), y2z, y2x, z2x, upside_down],
|
74 |
+
outputs=[position_map_tensor_path, normal_map_tensor_path, position_images_tensor_path, normal_images_tensor_path, mask_images_tensor_path, w2c_tensor_path, mesh, mvp_matrix_tensor_path, step1_progress]
|
75 |
+
).success(
|
76 |
+
tensor_to_pil,
|
77 |
+
inputs=[normal_images_tensor_path, mask_images_tensor_path],
|
78 |
+
outputs=[rgb_views]
|
79 |
+
).success(
|
80 |
+
tensor_to_pil,
|
81 |
+
inputs=[position_map_tensor_path],
|
82 |
+
outputs=[position_map]
|
83 |
+
).success(
|
84 |
+
tensor_to_pil,
|
85 |
+
inputs=[normal_map_tensor_path],
|
86 |
+
outputs=[normal_map]
|
87 |
+
).success(
|
88 |
+
Mesh.export,
|
89 |
+
inputs=[mesh, gr.State(None), gr.State(None)],
|
90 |
+
outputs=[model_input]
|
91 |
+
)
|
92 |
+
|
93 |
+
# Step 2
|
94 |
+
gr.Markdown("---")
|
95 |
+
gr.Markdown("## 👁️ Step 2: Select View & Generate Image Condition")
|
96 |
+
gr.Markdown("""
|
97 |
+
**📋 How to generate image condition:**
|
98 |
+
- Your mesh will be rendered from **four viewpoints** (front, back, left, right)
|
99 |
+
- Choose **one view** as your image condition
|
100 |
+
- Enter a **descriptive text prompt** for the desired texture
|
101 |
+
- Select your preferred AI model:
|
102 |
+
- <span style="color:#27ae60; font-weight:bold;">🎯 SDXL</span>: Fast generation with depth + normal control, better details (often suffer from wrong highlights)
|
103 |
+
- <span style="color:#3498db; font-weight:bold;">⚡ FLUX</span>: ~~High-quality generation with depth control (slower due to CPU offloading). Better work with **Edge Refinement**~~ (Not supported due to the memory limit of HF Space. You can try it locally)
|
104 |
+
""")
|
105 |
+
with gr.Row():
|
106 |
+
with gr.Column():
|
107 |
+
img_condition_seed = gr.Number(label="🎲 Random Seed", minimum=0, maximum=9999, step=1, value=42, info="Change for different results")
|
108 |
+
selected_view = gr.Radio(["First View", "Second View", "Third View", "Fourth View"], label="📐 Camera View", value="First View", info="Choose which viewpoint to use as reference")
|
109 |
+
with gr.Row():
|
110 |
+
# model_choice = gr.Radio(["SDXL", "FLUX"], label="🤖 AI Model", value="SDXL", info="SDXL: Fast, depth+normal control | FLUX: High-quality, slower processing")
|
111 |
+
model_choice = gr.Radio(["SDXL"], label="🤖 AI Model", value="SDXL", info="SDXL: Fast, depth+normal control | FLUX: High-quality, slower processing (Not supported due to the memory limit of HF Space)")
|
112 |
+
edge_refinement = gr.Checkbox(label="✨ Edge Refinement", value=True, info="Smooth boundary artifacts (recommended for delightning highlights in the boundary)")
|
113 |
+
text_prompt = gr.Textbox(label="💬 Texture Description", placeholder="Describe the desired texture appearance (e.g., 'rustic wooden surface with weathered paint')", lines=2)
|
114 |
+
step2_button = gr.Button("🎯 Generate Image Condition", variant="primary")
|
115 |
+
step2_progress = gr.Textbox(label="📊 Generation Status", interactive=False)
|
116 |
+
|
117 |
+
with gr.Column():
|
118 |
+
condition_image = gr.Image(label="🖼️ Generated Image Condition", type="pil") # , interactive=False
|
119 |
+
|
120 |
+
step2_button.click(
|
121 |
+
generate_image_condition,
|
122 |
+
inputs=[position_images_tensor_path, normal_images_tensor_path, mask_images_tensor_path, w2c_tensor_path, text_prompt, selected_view, img_condition_seed, model_choice, edge_refinement],
|
123 |
+
outputs=[condition_image, step2_progress],
|
124 |
+
)
|
125 |
+
|
126 |
+
# Step 3
|
127 |
+
gr.Markdown("---")
|
128 |
+
gr.Markdown("## 🎨 Step 3: Generate Final Texture")
|
129 |
+
gr.Markdown("""
|
130 |
+
**📋 How to generate final texture:**
|
131 |
+
- The **SeqTex pipeline** will create a complete texture map for your model
|
132 |
+
- View the results from multiple angles and download your textured 3D model (the viewport is a little bit dark)
|
133 |
+
""")
|
134 |
+
texture_map_tensor_path = gr.State()
|
135 |
+
with gr.Row():
|
136 |
+
with gr.Column(scale=1):
|
137 |
+
step3_button = gr.Button("🎨 Generate Final Texture", variant="primary")
|
138 |
+
step3_progress = gr.Textbox(label="📊 Texture Generation Status", interactive=False)
|
139 |
+
texture_map = gr.Image(label="🏆 Generated Texture Map", interactive=False)
|
140 |
+
with gr.Column(scale=2):
|
141 |
+
rendered_imgs = gr.Image(label="🖼️ Final Rendered Views")
|
142 |
+
mv_branch_imgs = gr.Image(label="🖼️ SeqTex Direct Output")
|
143 |
+
with gr.Column(scale=1.5):
|
144 |
+
model_display = gr.Model3D(label="🏆 Final Textured Model", height=500)
|
145 |
+
# model_display = LitModel3D(label="Model with Texture",
|
146 |
+
# exposure=30.0,
|
147 |
+
# height=500)
|
148 |
+
|
149 |
+
step3_button.click(
|
150 |
+
generate_texture,
|
151 |
+
inputs=[position_map_tensor_path, normal_map_tensor_path, position_images_tensor_path, normal_images_tensor_path, condition_image, text_prompt, selected_view],
|
152 |
+
outputs=[texture_map_tensor_path, texture_map, mv_branch_imgs, step3_progress],
|
153 |
+
).success(
|
154 |
+
render_views,
|
155 |
+
inputs=[mesh, texture_map_tensor_path, mvp_matrix_tensor_path],
|
156 |
+
outputs=[rendered_imgs]
|
157 |
+
).success(
|
158 |
+
Mesh.export,
|
159 |
+
inputs=[mesh, gr.State(None), texture_map],
|
160 |
+
outputs=[model_display]
|
161 |
+
)
|
162 |
+
|
163 |
+
# Add example inputs for user convenience
|
164 |
+
gr.Markdown("---")
|
165 |
+
gr.Markdown("## 🚀 Try Our Examples")
|
166 |
+
gr.Markdown("**Quick Start**: Click on any example below to see SeqTex in action with pre-configured settings!")
|
167 |
+
gr.Examples(
|
168 |
+
examples=EXAMPLES,
|
169 |
+
inputs=[mesh_upload, y2z, y2x, z2x, upside_down, img_condition_seed, selected_view, model_choice, edge_refinement, text_prompt],
|
170 |
+
cache_examples=False
|
171 |
+
)
|
172 |
+
|
173 |
+
# Acknowledgments
|
174 |
+
gr.Markdown("---")
|
175 |
+
gr.Markdown("## 🙏 Acknowledgments")
|
176 |
+
gr.Markdown("""
|
177 |
+
**Special thanks to [Toshihiro Hayashi](mailto:[email protected])** for his valuable support and assistance in fixing bugs for this demo.
|
178 |
+
""")
|
179 |
+
|
180 |
+
if LOAD_FIRST is True:
|
181 |
+
import gc
|
182 |
+
get_seqtex_pipe()
|
183 |
+
print("SeqTex pipeline loaded successfully.")
|
184 |
+
get_sdxl_pipe()
|
185 |
+
print("SDXL pipeline loaded successfully.")
|
186 |
+
# get_flux_pipe()
|
187 |
+
# Note: FLUX pipeline is available in code but not loaded due to GPU memory constraints on HF Space
|
188 |
+
print("Note: FLUX and other models are available for local deployment.")
|
189 |
+
gc.collect()
|
190 |
+
|
191 |
+
assert os.environ["OPENCV_IO_ENABLE_OPENEXR"] == "1", "OpenEXR support is required for this demo."
|
192 |
+
demo.launch(server_name="0.0.0.0")
|
examples/shoe.glb
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:3945c77a6f98eb18aae9c97f253e4c6b06daf83194c21f91ea4b955756bace7e
|
3 |
+
size 8842904
|
packages.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
ninja-build
|
requirements.txt
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torchvision==0.21.0
|
2 |
+
|
3 |
+
einops
|
4 |
+
omegaconf==2.3.0
|
5 |
+
jaxtyping
|
6 |
+
typeguard
|
7 |
+
imageio
|
8 |
+
trimesh==4.6.4
|
9 |
+
peft==0.14.0
|
10 |
+
diffusers==0.33.1
|
11 |
+
bitsandbytes
|
12 |
+
transformers==4.52.4
|
13 |
+
ftfy
|
14 |
+
accelerate
|
15 |
+
sentencepiece
|
16 |
+
ipdb
|
17 |
+
clean-fid
|
18 |
+
apex==0.9.10.dev0
|
19 |
+
xatlas
|
20 |
+
gradio_litmodel3d
|
21 |
+
spaces
|
22 |
+
numpy
|
23 |
+
opencv-python
|
24 |
+
https://huggingface.co/spaces/VAST-AI/MV-Adapter-Img2Texture/resolve/main/wheels/nvdiffrast-0.3.3-cp310-cp310-linux_x86_64.whl?download=true
|
utils/__init__.py
CHANGED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
from einops import rearrange
|
4 |
+
from PIL import Image
|
5 |
+
|
6 |
+
|
7 |
+
def tensor_to_pil(tensor, mask=None, normalize: bool = True):
|
8 |
+
"""
|
9 |
+
Convert tensor to PIL Image.
|
10 |
+
:param tensor: torch.Tensor or str (file path to tensor), shape can be (Nv, H, W, C), (Nv, C, H, W), (H, W, C), (C, H, W)
|
11 |
+
:param mask: torch.Tensor or str (file path to tensor), shape same as tensor, effective when C=3
|
12 |
+
:return: PIL.Image
|
13 |
+
"""
|
14 |
+
# If input is a file path, load the tensor
|
15 |
+
if isinstance(tensor, str):
|
16 |
+
from utils.file_utils import load_tensor_from_file
|
17 |
+
tensor = load_tensor_from_file(tensor, map_location="cpu")
|
18 |
+
if mask is not None and isinstance(mask, str):
|
19 |
+
from utils.file_utils import load_tensor_from_file
|
20 |
+
mask = load_tensor_from_file(mask, map_location="cpu")
|
21 |
+
# Move to cpu
|
22 |
+
tensor = tensor.detach()
|
23 |
+
if tensor.is_cuda:
|
24 |
+
tensor = tensor.cpu()
|
25 |
+
if mask is not None and mask.is_cuda:
|
26 |
+
mask = mask.cpu()
|
27 |
+
|
28 |
+
# Convert to float32
|
29 |
+
tensor = tensor.float()
|
30 |
+
if mask is not None:
|
31 |
+
mask = mask.float()
|
32 |
+
|
33 |
+
if normalize:
|
34 |
+
tensor = (tensor + 1.0) / 2.0
|
35 |
+
tensor = torch.clamp(tensor, 0.0, 1.0)
|
36 |
+
if mask is not None:
|
37 |
+
if mask.shape[-1] not in [1, 3]:
|
38 |
+
mask = mask.unsqueeze(-1)
|
39 |
+
tensor = torch.cat([tensor, mask], dim=-1)
|
40 |
+
|
41 |
+
shape = tensor.shape
|
42 |
+
# 4D: (Nv, H, W, C) or (Nv, C, H, W)
|
43 |
+
if len(shape) == 4:
|
44 |
+
Nv = shape[0]
|
45 |
+
if shape[-1] in [3, 4]: # (Nv, H, W, C)
|
46 |
+
tensor = rearrange(tensor, 'nv h w c -> h (nv w) c')
|
47 |
+
else: # (Nv, C, H, W)
|
48 |
+
tensor = rearrange(tensor, 'nv c h w -> h (nv w) c')
|
49 |
+
# 3D: (H, W, C) or (C, H, W)
|
50 |
+
elif len(shape) == 3:
|
51 |
+
if shape[-1] in [3, 4]: # (H, W, C)
|
52 |
+
tensor = rearrange(tensor, 'h w c -> h w c')
|
53 |
+
else: # (C, H, W)
|
54 |
+
tensor = rearrange(tensor, 'c h w -> h w c')
|
55 |
+
else:
|
56 |
+
raise ValueError(f"Unsupported tensor shape: {shape}")
|
57 |
+
|
58 |
+
# Convert to numpy
|
59 |
+
np_img = (tensor.numpy() * 255).round().astype(np.uint8)
|
60 |
+
|
61 |
+
# Create PIL Image
|
62 |
+
if np_img.shape[2] == 3:
|
63 |
+
return Image.fromarray(np_img, mode="RGB")
|
64 |
+
elif np_img.shape[2] == 4:
|
65 |
+
return Image.fromarray(np_img, mode="RGBA")
|
66 |
+
else:
|
67 |
+
raise ValueError("Only support 3 or 4 channel images.")
|
utils/file_utils.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import uuid
|
3 |
+
import torch
|
4 |
+
from gradio.utils import get_upload_folder
|
5 |
+
|
6 |
+
def save_tensor_to_file(tensor, prefix="tensor"):
|
7 |
+
upload_dir = get_upload_folder()
|
8 |
+
os.makedirs(upload_dir, exist_ok=True)
|
9 |
+
path = os.path.join(upload_dir, f"{prefix}_{uuid.uuid4().hex}.pt")
|
10 |
+
torch.save(tensor, path)
|
11 |
+
return path
|
12 |
+
|
13 |
+
def load_tensor_from_file(path, map_location=None):
|
14 |
+
# Use weights_only=True for security and to suppress FutureWarning (only tensors are loaded in this app)
|
15 |
+
return torch.load(path, map_location=map_location, weights_only=True)
|
utils/image_generation.py
CHANGED
@@ -1,6 +1,8 @@
|
|
|
|
1 |
import threading
|
2 |
|
3 |
import cv2
|
|
|
4 |
import numpy as np
|
5 |
import spaces
|
6 |
import torch
|
@@ -12,12 +14,11 @@ from einops import rearrange
|
|
12 |
from PIL import Image
|
13 |
from torchvision.transforms import ToPILImage
|
14 |
|
15 |
-
import gradio as gr
|
16 |
-
|
17 |
from .controlnet_union import ControlNetModel_Union
|
18 |
from .pipeline_controlnet_union_sd_xl import \
|
19 |
StableDiffusionXLControlNetUnionPipeline
|
20 |
from .render_utils import get_silhouette_image
|
|
|
21 |
|
22 |
IMG_PIPE = None
|
23 |
IMG_PIPE_LOCK = threading.Lock()
|
@@ -26,8 +27,9 @@ FLUX_PIPE = None
|
|
26 |
FLUX_PIPE_LOCK = threading.Lock()
|
27 |
FLUX_SUFFIX = None
|
28 |
FLUX_NEGATIVE = None
|
|
|
29 |
|
30 |
-
def
|
31 |
"""
|
32 |
Lazy load the FLUX pipeline with ControlNet for image generation.
|
33 |
"""
|
@@ -44,16 +46,21 @@ def lazy_get_flux_pipe():
|
|
44 |
controlnet_model_union = 'Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro-2.0'
|
45 |
|
46 |
controlnet = FluxControlNetModel.from_pretrained(controlnet_model_union, torch_dtype=torch.bfloat16)
|
|
|
47 |
FLUX_PIPE = FluxControlNetPipeline.from_pretrained(
|
48 |
base_model,
|
49 |
controlnet=controlnet,
|
50 |
-
torch_dtype=torch.bfloat16
|
|
|
51 |
)
|
52 |
# Use model CPU offload for better GPU utilization during inference
|
53 |
-
|
|
|
|
|
|
|
54 |
return FLUX_PIPE
|
55 |
|
56 |
-
def
|
57 |
"""
|
58 |
Lazy load the SDXL pipeline with ControlNet for image generation.
|
59 |
"""
|
@@ -74,8 +81,11 @@ def lazy_get_sdxl_pipe():
|
|
74 |
torch_dtype=torch.float16,
|
75 |
scheduler=eulera_scheduler,
|
76 |
)
|
77 |
-
#
|
78 |
-
|
|
|
|
|
|
|
79 |
return IMG_PIPE
|
80 |
|
81 |
|
@@ -94,11 +104,11 @@ def generate_sdxl_condition(depth_img, normal_img, text_prompt, mask, seed=42, e
|
|
94 |
:return: Generated image condition (e.g., PIL Image).
|
95 |
"""
|
96 |
progress(0.1, desc="Loading SDXL pipeline...")
|
97 |
-
pipeline =
|
98 |
progress(0.3, desc="SDXL pipeline loaded successfully.")
|
99 |
|
100 |
-
positive_prompt = text_prompt + ", photo-realistic style, high quality, 8K, highly detailed texture, soft
|
101 |
-
negative_prompt = 'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality'
|
102 |
|
103 |
img_generation_resolution = 1024 # SDXL performs better at 1024x1024
|
104 |
image = pipeline(prompt=[positive_prompt]*1,
|
@@ -154,7 +164,7 @@ def generate_flux_condition(depth_img, text_prompt, mask, seed=42, edge_refineme
|
|
154 |
:return: Generated image condition (PIL Image).
|
155 |
"""
|
156 |
progress(0.1, desc="Loading FLUX pipeline...")
|
157 |
-
pipeline =
|
158 |
progress(0.3, desc="FLUX pipeline loaded successfully.")
|
159 |
|
160 |
# Enhanced prompt for better results
|
@@ -262,7 +272,7 @@ def refine_image_edges(rgb_tensor, mask_tensor):
|
|
262 |
|
263 |
return refined_rgb_tensor
|
264 |
|
265 |
-
@spaces.GPU(
|
266 |
def generate_image_condition(position_imgs, normal_imgs, mask_imgs, w2c, text_prompt, selected_view="First View", seed=42, model="SDXL", edge_refinement=True, progress=gr.Progress()):
|
267 |
"""
|
268 |
Generate the image condition based on the selected view's silhouette and text prompt.
|
@@ -278,6 +288,20 @@ def generate_image_condition(position_imgs, normal_imgs, mask_imgs, w2c, text_pr
|
|
278 |
:param edge_refinement: Whether to apply edge refinement to smooth mask boundaries (default: True).
|
279 |
:return: Generated condition image and status message.
|
280 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
281 |
|
282 |
progress(0, desc="Handling geometry information...")
|
283 |
silhouette = get_silhouette_image(position_imgs, normal_imgs, mask_imgs=mask_imgs, w2c=w2c, selected_view=selected_view)
|
@@ -291,6 +315,7 @@ def generate_image_condition(position_imgs, normal_imgs, mask_imgs, w2c, text_pr
|
|
291 |
return condition, "SDXL condition generated successfully."
|
292 |
elif model == "FLUX":
|
293 |
# FLUX only supports depth control, not normal
|
|
|
294 |
condition = generate_flux_condition(depth_img, text_prompt, mask, seed, edge_refinement=edge_refinement, progress=progress)
|
295 |
return condition, "FLUX condition generated successfully (depth-only control)."
|
296 |
else:
|
|
|
1 |
+
import os
|
2 |
import threading
|
3 |
|
4 |
import cv2
|
5 |
+
import gradio as gr
|
6 |
import numpy as np
|
7 |
import spaces
|
8 |
import torch
|
|
|
14 |
from PIL import Image
|
15 |
from torchvision.transforms import ToPILImage
|
16 |
|
|
|
|
|
17 |
from .controlnet_union import ControlNetModel_Union
|
18 |
from .pipeline_controlnet_union_sd_xl import \
|
19 |
StableDiffusionXLControlNetUnionPipeline
|
20 |
from .render_utils import get_silhouette_image
|
21 |
+
from utils.file_utils import load_tensor_from_file
|
22 |
|
23 |
IMG_PIPE = None
|
24 |
IMG_PIPE_LOCK = threading.Lock()
|
|
|
27 |
FLUX_PIPE_LOCK = threading.Lock()
|
28 |
FLUX_SUFFIX = None
|
29 |
FLUX_NEGATIVE = None
|
30 |
+
CPU_OFFLOAD = False
|
31 |
|
32 |
+
def get_flux_pipe():
|
33 |
"""
|
34 |
Lazy load the FLUX pipeline with ControlNet for image generation.
|
35 |
"""
|
|
|
46 |
controlnet_model_union = 'Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro-2.0'
|
47 |
|
48 |
controlnet = FluxControlNetModel.from_pretrained(controlnet_model_union, torch_dtype=torch.bfloat16)
|
49 |
+
assert os.environ["SEQTEX_SPACE_TOKEN"] != "", "Please set the SEQTEX_SPACE_TOKEN environment variable with your Hugging Face token, which has access to black-forest-labs/FLUX.1-dev."
|
50 |
FLUX_PIPE = FluxControlNetPipeline.from_pretrained(
|
51 |
base_model,
|
52 |
controlnet=controlnet,
|
53 |
+
torch_dtype=torch.bfloat16,
|
54 |
+
token=os.environ["SEQTEX_SPACE_TOKEN"]
|
55 |
)
|
56 |
# Use model CPU offload for better GPU utilization during inference
|
57 |
+
if CPU_OFFLOAD:
|
58 |
+
FLUX_PIPE.enable_model_cpu_offload()
|
59 |
+
else:
|
60 |
+
FLUX_PIPE.to("cuda")
|
61 |
return FLUX_PIPE
|
62 |
|
63 |
+
def get_sdxl_pipe():
|
64 |
"""
|
65 |
Lazy load the SDXL pipeline with ControlNet for image generation.
|
66 |
"""
|
|
|
81 |
torch_dtype=torch.float16,
|
82 |
scheduler=eulera_scheduler,
|
83 |
)
|
84 |
+
# Use model CPU offload for better GPU utilization during inference
|
85 |
+
if CPU_OFFLOAD:
|
86 |
+
IMG_PIPE.enable_model_cpu_offload()
|
87 |
+
else:
|
88 |
+
IMG_PIPE.to("cuda")
|
89 |
return IMG_PIPE
|
90 |
|
91 |
|
|
|
104 |
:return: Generated image condition (e.g., PIL Image).
|
105 |
"""
|
106 |
progress(0.1, desc="Loading SDXL pipeline...")
|
107 |
+
pipeline = get_sdxl_pipe()
|
108 |
progress(0.3, desc="SDXL pipeline loaded successfully.")
|
109 |
|
110 |
+
positive_prompt = text_prompt + ", photo-realistic style, high quality, 8K, highly detailed texture, soft diffuse lighting, uniform lighting, flat lighting, even illumination, matte surface, low contrast, uniform color, foreground"
|
111 |
+
negative_prompt = 'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, harsh lighting, high contrast, bright highlights, specular reflections, shiny surface, glossy, reflective, strong shadows, dramatic lighting, spotlight, direct sunlight, glare, bloom, lens flare'
|
112 |
|
113 |
img_generation_resolution = 1024 # SDXL performs better at 1024x1024
|
114 |
image = pipeline(prompt=[positive_prompt]*1,
|
|
|
164 |
:return: Generated image condition (PIL Image).
|
165 |
"""
|
166 |
progress(0.1, desc="Loading FLUX pipeline...")
|
167 |
+
pipeline = get_flux_pipe()
|
168 |
progress(0.3, desc="FLUX pipeline loaded successfully.")
|
169 |
|
170 |
# Enhanced prompt for better results
|
|
|
272 |
|
273 |
return refined_rgb_tensor
|
274 |
|
275 |
+
@spaces.GPU()
|
276 |
def generate_image_condition(position_imgs, normal_imgs, mask_imgs, w2c, text_prompt, selected_view="First View", seed=42, model="SDXL", edge_refinement=True, progress=gr.Progress()):
|
277 |
"""
|
278 |
Generate the image condition based on the selected view's silhouette and text prompt.
|
|
|
288 |
:param edge_refinement: Whether to apply edge refinement to smooth mask boundaries (default: True).
|
289 |
:return: Generated condition image and status message.
|
290 |
"""
|
291 |
+
# If any input is a file path, load the tensor from file
|
292 |
+
if isinstance(position_imgs, str):
|
293 |
+
position_imgs = load_tensor_from_file(position_imgs, map_location="cuda")
|
294 |
+
if isinstance(normal_imgs, str):
|
295 |
+
normal_imgs = load_tensor_from_file(normal_imgs, map_location="cuda")
|
296 |
+
if isinstance(mask_imgs, str):
|
297 |
+
mask_imgs = load_tensor_from_file(mask_imgs, map_location="cuda")
|
298 |
+
if isinstance(w2c, str):
|
299 |
+
w2c = load_tensor_from_file(w2c, map_location="cuda")
|
300 |
+
|
301 |
+
position_imgs = position_imgs.to("cuda")
|
302 |
+
normal_imgs = normal_imgs.to("cuda")
|
303 |
+
mask_imgs = mask_imgs.to("cuda")
|
304 |
+
w2c = w2c.to("cuda")
|
305 |
|
306 |
progress(0, desc="Handling geometry information...")
|
307 |
silhouette = get_silhouette_image(position_imgs, normal_imgs, mask_imgs=mask_imgs, w2c=w2c, selected_view=selected_view)
|
|
|
315 |
return condition, "SDXL condition generated successfully."
|
316 |
elif model == "FLUX":
|
317 |
# FLUX only supports depth control, not normal
|
318 |
+
raise NotImplementedError("FLUX model not supported in HF space, please delete it and use it locally")
|
319 |
condition = generate_flux_condition(depth_img, text_prompt, mask, seed, edge_refinement=edge_refinement, progress=progress)
|
320 |
return condition, "FLUX condition generated successfully (depth-only control)."
|
321 |
else:
|
utils/mesh_utils.py
CHANGED
@@ -1,16 +1,16 @@
|
|
1 |
-
import os
|
2 |
import tempfile
|
3 |
|
|
|
4 |
import numpy as np
|
|
|
5 |
import torch
|
6 |
import trimesh
|
7 |
import xatlas
|
8 |
from PIL import Image
|
9 |
|
10 |
-
import gradio as gr
|
11 |
-
|
12 |
from .render_utils import (get_mvp_matrix, get_pure_texture, render_geo_map,
|
13 |
render_geo_views_tensor, render_views, setup_lights)
|
|
|
14 |
|
15 |
|
16 |
class Mesh:
|
@@ -19,7 +19,7 @@ class Mesh:
|
|
19 |
Initialize the Mesh object with a mesh file path.
|
20 |
:param mesh_path: Path to the mesh file (e.g., .obj or .glb).
|
21 |
"""
|
22 |
-
self.
|
23 |
if mesh_path is not None:
|
24 |
# Initialize _parts dictionary to store all parts
|
25 |
self._parts = {}
|
@@ -91,11 +91,16 @@ class Mesh:
|
|
91 |
progress(0.4, f"The model has SINGLE UV parameterization, no need to reparameterize.")
|
92 |
self._vmapping = None # No vmapping needed when not reparameterizing
|
93 |
|
|
|
|
|
|
|
|
|
94 |
def to(self, device):
|
95 |
"""
|
96 |
Move the mesh data to the specified device.
|
97 |
:param device: The target device (e.g., 'cuda' or 'cpu').
|
98 |
"""
|
|
|
99 |
self._v_pos = self._v_pos.to(device)
|
100 |
self._t_pos_idx = self._t_pos_idx.to(device)
|
101 |
if self._v_tex is not None:
|
@@ -104,6 +109,7 @@ class Mesh:
|
|
104 |
if hasattr(self, '_vmapping') and self._vmapping is not None:
|
105 |
self._vmapping = self._vmapping.to(device)
|
106 |
self._v_normal = self._v_normal.to(device)
|
|
|
107 |
|
108 |
@property
|
109 |
def has_multi_parts(self):
|
@@ -404,7 +410,12 @@ class Mesh:
|
|
404 |
texture_map = Image.fromarray(texture_map)
|
405 |
assert type(texture_map) is Image.Image, "texture_map should be a PIL.Image"
|
406 |
texture_map = texture_map.transpose(Image.FLIP_TOP_BOTTOM).convert("RGB")
|
407 |
-
material = trimesh.visual.
|
|
|
|
|
|
|
|
|
|
|
408 |
else:
|
409 |
default_texture = Image.new("RGB", (1024, 1024), (200, 200, 200))
|
410 |
material = trimesh.visual.texture.SimpleMaterial(image=default_texture)
|
@@ -445,6 +456,7 @@ class Mesh:
|
|
445 |
return save_path
|
446 |
|
447 |
@classmethod
|
|
|
448 |
def process(cls, mesh_file, uv_tool="xAtlas", y2z=True, y2x=False, z2x=False, upside_down=False, img_size=(512, 512), uv_size=(1024, 1024), device='cuda', progress=gr.Progress()):
|
449 |
"""
|
450 |
Handle the mesh processing, which includes normalization, parts merging, and UV mapping.
|
@@ -486,7 +498,15 @@ class Mesh:
|
|
486 |
position_map, normal_map = render_geo_map(mesh)
|
487 |
|
488 |
progress(1, f"Mesh processing completed.")
|
489 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
490 |
|
491 |
|
492 |
if __name__ == '__main__':
|
|
|
|
|
1 |
import tempfile
|
2 |
|
3 |
+
import gradio as gr
|
4 |
import numpy as np
|
5 |
+
import spaces
|
6 |
import torch
|
7 |
import trimesh
|
8 |
import xatlas
|
9 |
from PIL import Image
|
10 |
|
|
|
|
|
11 |
from .render_utils import (get_mvp_matrix, get_pure_texture, render_geo_map,
|
12 |
render_geo_views_tensor, render_views, setup_lights)
|
13 |
+
from utils.file_utils import save_tensor_to_file
|
14 |
|
15 |
|
16 |
class Mesh:
|
|
|
19 |
Initialize the Mesh object with a mesh file path.
|
20 |
:param mesh_path: Path to the mesh file (e.g., .obj or .glb).
|
21 |
"""
|
22 |
+
self._device = device
|
23 |
if mesh_path is not None:
|
24 |
# Initialize _parts dictionary to store all parts
|
25 |
self._parts = {}
|
|
|
91 |
progress(0.4, f"The model has SINGLE UV parameterization, no need to reparameterize.")
|
92 |
self._vmapping = None # No vmapping needed when not reparameterizing
|
93 |
|
94 |
+
@property
|
95 |
+
def device(self):
|
96 |
+
return self._device
|
97 |
+
|
98 |
def to(self, device):
|
99 |
"""
|
100 |
Move the mesh data to the specified device.
|
101 |
:param device: The target device (e.g., 'cuda' or 'cpu').
|
102 |
"""
|
103 |
+
self._device = device
|
104 |
self._v_pos = self._v_pos.to(device)
|
105 |
self._t_pos_idx = self._t_pos_idx.to(device)
|
106 |
if self._v_tex is not None:
|
|
|
109 |
if hasattr(self, '_vmapping') and self._vmapping is not None:
|
110 |
self._vmapping = self._vmapping.to(device)
|
111 |
self._v_normal = self._v_normal.to(device)
|
112 |
+
return self
|
113 |
|
114 |
@property
|
115 |
def has_multi_parts(self):
|
|
|
410 |
texture_map = Image.fromarray(texture_map)
|
411 |
assert type(texture_map) is Image.Image, "texture_map should be a PIL.Image"
|
412 |
texture_map = texture_map.transpose(Image.FLIP_TOP_BOTTOM).convert("RGB")
|
413 |
+
material = trimesh.visual.material.PBRMaterial(
|
414 |
+
baseColorTexture=texture_map,
|
415 |
+
baseColorFactor=[255, 255, 255, 255], # 设置为白色以避免颜色混合
|
416 |
+
metallicFactor=0.0,
|
417 |
+
roughnessFactor=1.0
|
418 |
+
)
|
419 |
else:
|
420 |
default_texture = Image.new("RGB", (1024, 1024), (200, 200, 200))
|
421 |
material = trimesh.visual.texture.SimpleMaterial(image=default_texture)
|
|
|
456 |
return save_path
|
457 |
|
458 |
@classmethod
|
459 |
+
@spaces.GPU(duration=30)
|
460 |
def process(cls, mesh_file, uv_tool="xAtlas", y2z=True, y2x=False, z2x=False, upside_down=False, img_size=(512, 512), uv_size=(1024, 1024), device='cuda', progress=gr.Progress()):
|
461 |
"""
|
462 |
Handle the mesh processing, which includes normalization, parts merging, and UV mapping.
|
|
|
498 |
position_map, normal_map = render_geo_map(mesh)
|
499 |
|
500 |
progress(1, f"Mesh processing completed.")
|
501 |
+
position_map_path = save_tensor_to_file(position_map, prefix="position_map")
|
502 |
+
normal_map_path = save_tensor_to_file(normal_map, prefix="normal_map")
|
503 |
+
position_images_path = save_tensor_to_file(position_images, prefix="position_images")
|
504 |
+
normal_images_path = save_tensor_to_file(normal_images, prefix="normal_images")
|
505 |
+
mask_images_path = save_tensor_to_file(mask_images.squeeze(-1), prefix="mask_images")
|
506 |
+
w2c_path = save_tensor_to_file(w2c, prefix="w2c")
|
507 |
+
mvp_matrix_path = save_tensor_to_file(mvp_matrix, prefix="mvp_matrix")
|
508 |
+
# Return mesh instance as is
|
509 |
+
return position_map_path, normal_map_path, position_images_path, normal_images_path, mask_images_path, w2c_path, mesh.to("cpu"), mvp_matrix_path, "Mesh processing completed."
|
510 |
|
511 |
|
512 |
if __name__ == '__main__':
|
utils/rasterize.py
CHANGED
@@ -1,17 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import nvdiffrast.torch as dr
|
2 |
import torch
|
3 |
-
|
4 |
-
from torch import Tensor
|
5 |
from jaxtyping import Float, Integer
|
6 |
-
from
|
|
|
7 |
|
8 |
class NVDiffRasterizerContext:
|
9 |
-
def __init__(self, context_type: str, device
|
10 |
self.device = device
|
11 |
self.ctx = self.initialize_context(context_type, device)
|
12 |
|
13 |
def initialize_context(
|
14 |
-
self, context_type: str, device
|
15 |
) -> Union[dr.RasterizeGLContext, dr.RasterizeCudaContext]:
|
16 |
if context_type == "gl":
|
17 |
return dr.RasterizeGLContext(device=device)
|
|
|
1 |
+
# This file uses nvdiffrast library, which is licensed under the NVIDIA Source Code License (1-Way Commercial).
|
2 |
+
# nvdiffrast is available for non-commercial use (research or evaluation purposes only).
|
3 |
+
# For commercial use, please contact NVIDIA for licensing: https://www.nvidia.com/en-us/research/inquiries/
|
4 |
+
#
|
5 |
+
# nvdiffrast copyright: Copyright (c) 2020, NVIDIA Corporation. All rights reserved.
|
6 |
+
# Full license: https://github.com/NVlabs/nvdiffrast/blob/main/LICENSE.txt
|
7 |
+
|
8 |
+
from typing import Tuple, Union
|
9 |
+
|
10 |
import nvdiffrast.torch as dr
|
11 |
import torch
|
|
|
|
|
12 |
from jaxtyping import Float, Integer
|
13 |
+
from torch import Tensor
|
14 |
+
|
15 |
|
16 |
class NVDiffRasterizerContext:
|
17 |
+
def __init__(self, context_type: str, device) -> None:
|
18 |
self.device = device
|
19 |
self.ctx = self.initialize_context(context_type, device)
|
20 |
|
21 |
def initialize_context(
|
22 |
+
self, context_type: str, device
|
23 |
) -> Union[dr.RasterizeGLContext, dr.RasterizeCudaContext]:
|
24 |
if context_type == "gl":
|
25 |
return dr.RasterizeGLContext(device=device)
|
utils/render_utils.py
CHANGED
@@ -3,6 +3,7 @@ from functools import cache
|
|
3 |
from typing import Dict, Union
|
4 |
|
5 |
import numpy as np
|
|
|
6 |
import torch
|
7 |
import torch.nn.functional as F
|
8 |
from einops import rearrange
|
@@ -15,8 +16,22 @@ from .rasterize import (NVDiffRasterizerContext,
|
|
15 |
rasterize_position_and_normal_maps,
|
16 |
render_geo_from_mesh,
|
17 |
render_rgb_from_texture_mesh_with_mask)
|
|
|
18 |
|
19 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
|
21 |
def setup_lights():
|
22 |
"""
|
@@ -24,6 +39,7 @@ def setup_lights():
|
|
24 |
"""
|
25 |
raise NotImplementedError("setup_lights function is not implemented yet.")
|
26 |
|
|
|
27 |
def render_views(mesh, texture, mvp_matrix, lights=None, img_size=(512, 512)) -> Image.Image:
|
28 |
"""
|
29 |
Render the RGB color images of the mesh. The background will be transparent.
|
@@ -34,11 +50,22 @@ def render_views(mesh, texture, mvp_matrix, lights=None, img_size=(512, 512)) ->
|
|
34 |
:param img_size: The size of the output image, a tuple (height, width).
|
35 |
:return: A concatenated PIL Image.
|
36 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
37 |
if texture.shape[-1] != 3:
|
38 |
texture = texture.permute(1, 2, 0)
|
39 |
image_height, image_width = img_size
|
40 |
rgb_cond, mask = render_rgb_from_texture_mesh_with_mask(
|
41 |
-
|
42 |
|
43 |
if mvp_matrix.shape[0] == 0:
|
44 |
return None
|
@@ -65,20 +92,24 @@ def render_views(mesh, texture, mvp_matrix, lights=None, img_size=(512, 512)) ->
|
|
65 |
|
66 |
return concatenated_image
|
67 |
|
|
|
68 |
def render_geo_views_tensor(mesh, mvp_matrix, img_size=(512, 512)) -> tuple[torch.Tensor, torch.Tensor]:
|
69 |
"""
|
70 |
render the geometry information including position and normal from views that mvp matrix implies.
|
71 |
"""
|
|
|
72 |
image_height, image_width = img_size
|
73 |
-
position_images, normal_images, mask_images = render_geo_from_mesh(
|
74 |
return position_images, normal_images, mask_images
|
75 |
|
|
|
76 |
def render_geo_map(mesh, map_size=(1024, 1024)) -> tuple[torch.Tensor, torch.Tensor]:
|
77 |
"""
|
78 |
Render the geometry information including position and normal from UV parameterization.
|
79 |
"""
|
|
|
80 |
map_height, map_width = map_size
|
81 |
-
position_images, normal_images, mask = rasterize_position_and_normal_maps(
|
82 |
# out_imgs = []
|
83 |
# if mask.ndim == 4:
|
84 |
# mask = mask[0]
|
@@ -318,6 +349,7 @@ def _get_depth_noraml_map_with_mask(xyz_map, normal_map, mask, w2c, device="cuda
|
|
318 |
|
319 |
return depth_map, normal_map, mask
|
320 |
|
|
|
321 |
def get_silhouette_image(position_imgs, normal_imgs, mask_imgs, w2c, selected_view="First View") -> tuple[Image.Image, Image.Image]:
|
322 |
"""
|
323 |
Get the silhouette image based on geometry image.
|
|
|
3 |
from typing import Dict, Union
|
4 |
|
5 |
import numpy as np
|
6 |
+
import spaces
|
7 |
import torch
|
8 |
import torch.nn.functional as F
|
9 |
from einops import rearrange
|
|
|
16 |
rasterize_position_and_normal_maps,
|
17 |
render_geo_from_mesh,
|
18 |
render_rgb_from_texture_mesh_with_mask)
|
19 |
+
from utils.file_utils import load_tensor_from_file
|
20 |
|
21 |
+
# Global variable to store the singleton context
|
22 |
+
_CTX_INSTANCE = None
|
23 |
+
|
24 |
+
@spaces.GPU
|
25 |
+
def get_rasterizer_context():
|
26 |
+
"""
|
27 |
+
Get the NVDiffRasterizer context using singleton pattern.
|
28 |
+
This ensures only one context is created and reused across the application.
|
29 |
+
"""
|
30 |
+
global _CTX_INSTANCE
|
31 |
+
if _CTX_INSTANCE is None:
|
32 |
+
# Use string 'cuda' instead of torch.device to avoid early CUDA initialization
|
33 |
+
_CTX_INSTANCE = NVDiffRasterizerContext('cuda', 'cuda')
|
34 |
+
return _CTX_INSTANCE
|
35 |
|
36 |
def setup_lights():
|
37 |
"""
|
|
|
39 |
"""
|
40 |
raise NotImplementedError("setup_lights function is not implemented yet.")
|
41 |
|
42 |
+
@spaces.GPU
|
43 |
def render_views(mesh, texture, mvp_matrix, lights=None, img_size=(512, 512)) -> Image.Image:
|
44 |
"""
|
45 |
Render the RGB color images of the mesh. The background will be transparent.
|
|
|
50 |
:param img_size: The size of the output image, a tuple (height, width).
|
51 |
:return: A concatenated PIL Image.
|
52 |
"""
|
53 |
+
# If texture or mvp_matrix is a file path, load the tensor from file
|
54 |
+
if isinstance(texture, str):
|
55 |
+
texture = load_tensor_from_file(texture, map_location="cuda")
|
56 |
+
if isinstance(mvp_matrix, str):
|
57 |
+
mvp_matrix = load_tensor_from_file(mvp_matrix, map_location="cuda")
|
58 |
+
mesh = mesh.to("cuda")
|
59 |
+
texture = texture.to("cuda")
|
60 |
+
mvp_matrix = mvp_matrix.to("cuda")
|
61 |
+
|
62 |
+
print("Trying to render views...")
|
63 |
+
ctx = get_rasterizer_context()
|
64 |
if texture.shape[-1] != 3:
|
65 |
texture = texture.permute(1, 2, 0)
|
66 |
image_height, image_width = img_size
|
67 |
rgb_cond, mask = render_rgb_from_texture_mesh_with_mask(
|
68 |
+
ctx, mesh, texture, mvp_matrix, image_height, image_width, torch.tensor([0.0, 0.0, 0.0], device=texture.device))
|
69 |
|
70 |
if mvp_matrix.shape[0] == 0:
|
71 |
return None
|
|
|
92 |
|
93 |
return concatenated_image
|
94 |
|
95 |
+
@spaces.GPU
|
96 |
def render_geo_views_tensor(mesh, mvp_matrix, img_size=(512, 512)) -> tuple[torch.Tensor, torch.Tensor]:
|
97 |
"""
|
98 |
render the geometry information including position and normal from views that mvp matrix implies.
|
99 |
"""
|
100 |
+
ctx = get_rasterizer_context()
|
101 |
image_height, image_width = img_size
|
102 |
+
position_images, normal_images, mask_images = render_geo_from_mesh(ctx, mesh, mvp_matrix, image_height, image_width)
|
103 |
return position_images, normal_images, mask_images
|
104 |
|
105 |
+
@spaces.GPU
|
106 |
def render_geo_map(mesh, map_size=(1024, 1024)) -> tuple[torch.Tensor, torch.Tensor]:
|
107 |
"""
|
108 |
Render the geometry information including position and normal from UV parameterization.
|
109 |
"""
|
110 |
+
ctx = get_rasterizer_context()
|
111 |
map_height, map_width = map_size
|
112 |
+
position_images, normal_images, mask = rasterize_position_and_normal_maps(ctx, mesh, map_height, map_width)
|
113 |
# out_imgs = []
|
114 |
# if mask.ndim == 4:
|
115 |
# mask = mask[0]
|
|
|
349 |
|
350 |
return depth_map, normal_map, mask
|
351 |
|
352 |
+
@spaces.GPU
|
353 |
def get_silhouette_image(position_imgs, normal_imgs, mask_imgs, w2c, selected_view="First View") -> tuple[Image.Image, Image.Image]:
|
354 |
"""
|
355 |
Get the silhouette image based on geometry image.
|
utils/texture_generation.py
CHANGED
@@ -11,13 +11,15 @@ from diffusers.models import AutoencoderKLWan
|
|
11 |
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
|
12 |
from einops import rearrange
|
13 |
from jaxtyping import Float
|
14 |
-
from peft import LoraConfig
|
15 |
from PIL import Image
|
16 |
from torch import Tensor
|
17 |
|
18 |
from wan.pipeline_wan_t2tex_extra import WanT2TexPipeline
|
19 |
from wan.wan_t2tex_transformer_3d_extra import WanT2TexTransformer3DModel
|
20 |
|
|
|
|
|
|
|
21 |
TEX_PIPE = None
|
22 |
VAE = None
|
23 |
LATENTS_MEAN, LATENTS_STD = None, None
|
@@ -26,14 +28,8 @@ TEX_PIPE_LOCK = threading.Lock()
|
|
26 |
@dataclass
|
27 |
class Config:
|
28 |
video_base_name: str = "Wan-AI/Wan2.1-T2V-1.3B-Diffusers"
|
29 |
-
|
30 |
-
min_noise_level_index: int = 15 #
|
31 |
-
|
32 |
-
use_causal_mask: bool = False
|
33 |
-
addtional_qk_geometry: bool = False
|
34 |
-
use_normal: bool = True
|
35 |
-
use_position: bool = True
|
36 |
-
randomly_init: bool = True # we load the weights from a corresponding ckpt
|
37 |
|
38 |
num_views: int = 4
|
39 |
uv_num_views: int = 1
|
@@ -47,45 +43,12 @@ class Config:
|
|
47 |
eval_num_inference_steps: int = 30
|
48 |
eval_seed: int = 42
|
49 |
|
50 |
-
lora_rank: int = 128
|
51 |
-
lora_alpha: int = 64
|
52 |
-
|
53 |
cfg = Config()
|
54 |
|
55 |
-
def
|
56 |
-
"""
|
57 |
-
Load model weights from either a URL or local file path.
|
58 |
-
|
59 |
-
Args:
|
60 |
-
model_path (str): Path to model weights, can be URL or local file path
|
61 |
-
map_location (str): Device to map the model to
|
62 |
-
|
63 |
-
Returns:
|
64 |
-
Dict: Loaded state dictionary
|
65 |
-
"""
|
66 |
-
# Check if the path is a URL
|
67 |
-
parsed_url = urlparse(model_path)
|
68 |
-
if parsed_url.scheme in ('http', 'https'):
|
69 |
-
# Load from URL using torch.hub
|
70 |
-
try:
|
71 |
-
state_dict = torch.hub.load_state_dict_from_url(
|
72 |
-
model_path,
|
73 |
-
map_location=map_location,
|
74 |
-
progress=True
|
75 |
-
)
|
76 |
-
return state_dict
|
77 |
-
except Exception as e:
|
78 |
-
gr.Warning(f"Failed to load from URL: {e}")
|
79 |
-
raise e
|
80 |
-
else:
|
81 |
-
# Load from local file path
|
82 |
-
if not os.path.exists(model_path):
|
83 |
-
raise FileNotFoundError(f"Local model file not found: {model_path}")
|
84 |
-
return torch.load(model_path, map_location=map_location)
|
85 |
-
|
86 |
-
def lazy_get_seqtex_pipe():
|
87 |
"""
|
88 |
Lazy load the SeqTex pipeline for texture generation.
|
|
|
89 |
"""
|
90 |
global TEX_PIPE, VAE, LATENTS_MEAN, LATENTS_STD
|
91 |
if TEX_PIPE is not None:
|
@@ -95,42 +58,31 @@ def lazy_get_seqtex_pipe():
|
|
95 |
if TEX_PIPE is not None:
|
96 |
return TEX_PIPE
|
97 |
|
98 |
-
#
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
transformer = WanT2TexTransformer3DModel(
|
103 |
-
TEX_PIPE.transformer,
|
104 |
-
use_causal_mask=cfg.use_causal_mask,
|
105 |
-
addtional_qk_geo=cfg.addtional_qk_geometry,
|
106 |
-
use_normal=cfg.use_normal,
|
107 |
-
use_position=cfg.use_position,
|
108 |
-
randomly_init=cfg.randomly_init,
|
109 |
)
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
)
|
118 |
)
|
119 |
-
|
120 |
-
state_dict = load_model_weights(cfg.seqtex_path, map_location="cpu")
|
121 |
-
transformer.load_state_dict(state_dict, strict=True)
|
122 |
-
TEX_PIPE.transformer = transformer
|
123 |
|
124 |
-
VAE = AutoencoderKLWan.from_pretrained(cfg.video_base_name, subfolder="vae", torch_dtype=torch.float32)
|
125 |
TEX_PIPE.vae = VAE
|
126 |
|
127 |
-
# Some useful parameters
|
128 |
LATENTS_MEAN = torch.tensor(VAE.config.latents_mean).view(
|
129 |
1, VAE.config.z_dim, 1, 1, 1
|
130 |
-
).to(
|
131 |
LATENTS_STD = 1.0 / torch.tensor(VAE.config.latents_std).view(
|
132 |
1, VAE.config.z_dim, 1, 1, 1
|
133 |
-
).to(
|
134 |
|
135 |
scheduler: FlowMatchEulerDiscreteScheduler = (
|
136 |
FlowMatchEulerDiscreteScheduler.from_config(
|
@@ -141,10 +93,8 @@ def lazy_get_seqtex_pipe():
|
|
141 |
setattr(TEX_PIPE, "min_noise_level_index", min_noise_level_index)
|
142 |
min_noise_level_timestep = scheduler.timesteps[min_noise_level_index]
|
143 |
setattr(TEX_PIPE, "min_noise_level_timestep", min_noise_level_timestep)
|
144 |
-
setattr(TEX_PIPE, "min_noise_level_sigma", min_noise_level_timestep / 1000.)
|
145 |
-
|
146 |
-
TEX_PIPE = TEX_PIPE.to("cuda", dtype=torch.float32) # use float32 for inference
|
147 |
-
return TEX_PIPE
|
148 |
|
149 |
@torch.amp.autocast('cuda', dtype=torch.float32)
|
150 |
def encode_images(
|
@@ -157,6 +107,11 @@ def encode_images(
|
|
157 |
:param encode_as_first: Whether to encode all frames as the first frame.
|
158 |
:return: Encoded latents with shape [B, C', F, H/8, W/8].
|
159 |
"""
|
|
|
|
|
|
|
|
|
|
|
160 |
if images.min() < - 0.1:
|
161 |
# images are in [-1, 1] range
|
162 |
images = (images + 1.0) / 2.0 # Normalize to [0, 1] range
|
@@ -171,19 +126,6 @@ def encode_images(
|
|
171 |
|
172 |
return latents
|
173 |
|
174 |
-
# @torch.no_grad()
|
175 |
-
# @torch.amp.autocast('cuda', dtype=torch.float32)
|
176 |
-
# def decode_images(self, latents: Float[Tensor, "B C F H W"], decode_as_first: bool = False):
|
177 |
-
# if decode_as_first:
|
178 |
-
# F = latents.shape[2]
|
179 |
-
# latents = latents.to(self.vae.dtype)
|
180 |
-
# latents = latents / self.latents_std + self.latents_mean
|
181 |
-
# latents = rearrange(latents, "B C F H W -> (B F) C 1 H W")
|
182 |
-
# images = self.vae.decode(latents, return_dict=False)[0]
|
183 |
-
# images = rearrange(images, "(B F) C Nv H W -> B C (F Nv) H W", F=F, Nv=1)
|
184 |
-
# else:
|
185 |
-
# raise NotImplementedError("Currently only support decode as first frame.")
|
186 |
-
# return images
|
187 |
@torch.amp.autocast('cuda', dtype=torch.float32)
|
188 |
def decode_images(latents: Float[Tensor, "B C F H W"], decode_as_first: bool = False):
|
189 |
"""
|
@@ -192,6 +134,11 @@ def decode_images(latents: Float[Tensor, "B C F H W"], decode_as_first: bool = F
|
|
192 |
:param decode_as_first: Whether to decode all frames as the first frame.
|
193 |
:return: Decoded images with shape [B, C, F*Nv, H*8, W*8].
|
194 |
"""
|
|
|
|
|
|
|
|
|
|
|
195 |
if decode_as_first:
|
196 |
F = latents.shape[2]
|
197 |
latents = latents.to(VAE.dtype)
|
@@ -207,6 +154,7 @@ def convert_img_to_tensor(image: Image.Image, device="cuda") -> Float[Tensor, "H
|
|
207 |
"""
|
208 |
Convert a PIL Image to a tensor. If Image is RGBA, mask it with black background using a-channel mask.
|
209 |
:param image: PIL Image to convert. [0, 255]
|
|
|
210 |
:return: Tensor representation of the image. [0.0, 1.0], still [H, W, C]
|
211 |
"""
|
212 |
# Convert to RGBA to ensure alpha channel exists
|
@@ -217,25 +165,33 @@ def convert_img_to_tensor(image: Image.Image, device="cuda") -> Float[Tensor, "H
|
|
217 |
# Blend with black background using alpha mask
|
218 |
rgb = rgb * alpha
|
219 |
rgb = rgb.astype(np.float32) / 255.0 # Normalize to [0, 1]
|
220 |
-
tensor = torch.from_numpy(rgb)
|
|
|
|
|
221 |
return tensor
|
222 |
|
223 |
-
@spaces.GPU(duration=
|
224 |
-
@torch.cuda.amp.autocast(dtype=torch.float32)
|
225 |
-
@torch.inference_mode
|
226 |
@torch.no_grad
|
227 |
-
|
|
|
228 |
"""
|
229 |
Use SeqTex to generate texture for the mesh based on the image condition.
|
230 |
-
:param
|
231 |
-
:param
|
232 |
:param condition_image: Image condition generated from the selected view.
|
233 |
:param text_prompt: Text prompt for texture generation.
|
234 |
:param selected_view: The view selected for generating the image condition.
|
235 |
-
:return:
|
236 |
"""
|
|
|
|
|
|
|
|
|
|
|
237 |
progress(0, desc="Loading SeqTex pipeline...")
|
238 |
-
tex_pipe =
|
|
|
|
|
239 |
progress(0.2, desc="SeqTex pipeline loaded successfully.")
|
240 |
view_id_map = {
|
241 |
"First View": 0,
|
@@ -306,4 +262,5 @@ def generate_texture(position_map, normal_map, position_images, normal_images, c
|
|
306 |
uv_map_pred = torch.clamp(uv_map_pred, 0.0, 1.0)
|
307 |
|
308 |
progress(1, desc="Texture generated successfully.")
|
309 |
-
|
|
|
|
11 |
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
|
12 |
from einops import rearrange
|
13 |
from jaxtyping import Float
|
|
|
14 |
from PIL import Image
|
15 |
from torch import Tensor
|
16 |
|
17 |
from wan.pipeline_wan_t2tex_extra import WanT2TexPipeline
|
18 |
from wan.wan_t2tex_transformer_3d_extra import WanT2TexTransformer3DModel
|
19 |
|
20 |
+
from . import tensor_to_pil
|
21 |
+
from utils.file_utils import save_tensor_to_file, load_tensor_from_file
|
22 |
+
|
23 |
TEX_PIPE = None
|
24 |
VAE = None
|
25 |
LATENTS_MEAN, LATENTS_STD = None, None
|
|
|
28 |
@dataclass
|
29 |
class Config:
|
30 |
video_base_name: str = "Wan-AI/Wan2.1-T2V-1.3B-Diffusers"
|
31 |
+
seqtex_transformer_path: str = "VAST-AI/SeqTex-Transformer"
|
32 |
+
min_noise_level_index: int = 15 # refer to paper [WorldMem](https://arxiv.org/pdf/2504.12369v1)
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
|
34 |
num_views: int = 4
|
35 |
uv_num_views: int = 1
|
|
|
43 |
eval_num_inference_steps: int = 30
|
44 |
eval_seed: int = 42
|
45 |
|
|
|
|
|
|
|
46 |
cfg = Config()
|
47 |
|
48 |
+
def get_seqtex_pipe():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
49 |
"""
|
50 |
Lazy load the SeqTex pipeline for texture generation.
|
51 |
+
Must be called within @spaces.GPU context.
|
52 |
"""
|
53 |
global TEX_PIPE, VAE, LATENTS_MEAN, LATENTS_STD
|
54 |
if TEX_PIPE is not None:
|
|
|
58 |
if TEX_PIPE is not None:
|
59 |
return TEX_PIPE
|
60 |
|
61 |
+
# Load transformer with auto-configured LoRA adapter first
|
62 |
+
transformer = WanT2TexTransformer3DModel.from_pretrained(
|
63 |
+
cfg.seqtex_transformer_path,
|
64 |
+
token=os.environ["SEQTEX_SPACE_TOKEN"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
65 |
)
|
66 |
+
|
67 |
+
assert os.environ["SEQTEX_SPACE_TOKEN"] != "", "Please set the SEQTEX_SPACE_TOKEN environment variable with your Hugging Face token, which has access to VAST-AI/SeqTex-Transformer."
|
68 |
+
# Pipeline - pass the pre-loaded transformer to avoid re-loading
|
69 |
+
TEX_PIPE = WanT2TexPipeline.from_pretrained(
|
70 |
+
cfg.video_base_name,
|
71 |
+
transformer=transformer,
|
72 |
+
torch_dtype=torch.bfloat16
|
|
|
73 |
)
|
74 |
+
del(transformer)
|
|
|
|
|
|
|
75 |
|
76 |
+
VAE = AutoencoderKLWan.from_pretrained(cfg.video_base_name, subfolder="vae", torch_dtype=torch.float32)
|
77 |
TEX_PIPE.vae = VAE
|
78 |
|
79 |
+
# Some useful parameters - delay CUDA initialization until GPU context
|
80 |
LATENTS_MEAN = torch.tensor(VAE.config.latents_mean).view(
|
81 |
1, VAE.config.z_dim, 1, 1, 1
|
82 |
+
).to(torch.float32)
|
83 |
LATENTS_STD = 1.0 / torch.tensor(VAE.config.latents_std).view(
|
84 |
1, VAE.config.z_dim, 1, 1, 1
|
85 |
+
).to(torch.float32)
|
86 |
|
87 |
scheduler: FlowMatchEulerDiscreteScheduler = (
|
88 |
FlowMatchEulerDiscreteScheduler.from_config(
|
|
|
93 |
setattr(TEX_PIPE, "min_noise_level_index", min_noise_level_index)
|
94 |
min_noise_level_timestep = scheduler.timesteps[min_noise_level_index]
|
95 |
setattr(TEX_PIPE, "min_noise_level_timestep", min_noise_level_timestep)
|
96 |
+
setattr(TEX_PIPE, "min_noise_level_sigma", min_noise_level_timestep / 1000.)
|
97 |
+
return TEX_PIPE.to("cuda")
|
|
|
|
|
98 |
|
99 |
@torch.amp.autocast('cuda', dtype=torch.float32)
|
100 |
def encode_images(
|
|
|
107 |
:param encode_as_first: Whether to encode all frames as the first frame.
|
108 |
:return: Encoded latents with shape [B, C', F, H/8, W/8].
|
109 |
"""
|
110 |
+
global VAE, LATENTS_MEAN, LATENTS_STD
|
111 |
+
VAE = VAE.to("cuda").requires_grad_(False)
|
112 |
+
LATENTS_MEAN = LATENTS_MEAN.to("cuda")
|
113 |
+
LATENTS_STD = LATENTS_STD.to("cuda")
|
114 |
+
|
115 |
if images.min() < - 0.1:
|
116 |
# images are in [-1, 1] range
|
117 |
images = (images + 1.0) / 2.0 # Normalize to [0, 1] range
|
|
|
126 |
|
127 |
return latents
|
128 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
129 |
@torch.amp.autocast('cuda', dtype=torch.float32)
|
130 |
def decode_images(latents: Float[Tensor, "B C F H W"], decode_as_first: bool = False):
|
131 |
"""
|
|
|
134 |
:param decode_as_first: Whether to decode all frames as the first frame.
|
135 |
:return: Decoded images with shape [B, C, F*Nv, H*8, W*8].
|
136 |
"""
|
137 |
+
global VAE, LATENTS_MEAN, LATENTS_STD
|
138 |
+
VAE = VAE.to("cuda").requires_grad_(False)
|
139 |
+
LATENTS_MEAN = LATENTS_MEAN.to("cuda")
|
140 |
+
LATENTS_STD = LATENTS_STD.to("cuda")
|
141 |
+
|
142 |
if decode_as_first:
|
143 |
F = latents.shape[2]
|
144 |
latents = latents.to(VAE.dtype)
|
|
|
154 |
"""
|
155 |
Convert a PIL Image to a tensor. If Image is RGBA, mask it with black background using a-channel mask.
|
156 |
:param image: PIL Image to convert. [0, 255]
|
157 |
+
:param device: Target device for the tensor.
|
158 |
:return: Tensor representation of the image. [0.0, 1.0], still [H, W, C]
|
159 |
"""
|
160 |
# Convert to RGBA to ensure alpha channel exists
|
|
|
165 |
# Blend with black background using alpha mask
|
166 |
rgb = rgb * alpha
|
167 |
rgb = rgb.astype(np.float32) / 255.0 # Normalize to [0, 1]
|
168 |
+
tensor = torch.from_numpy(rgb)
|
169 |
+
if device != "cpu":
|
170 |
+
tensor = tensor.to(device)
|
171 |
return tensor
|
172 |
|
173 |
+
@spaces.GPU(duration=90)
|
|
|
|
|
174 |
@torch.no_grad
|
175 |
+
@torch.inference_mode
|
176 |
+
def generate_texture(position_map_path, normal_map_path, position_images_path, normal_images_path, condition_image, text_prompt, selected_view, negative_prompt=None, device="cuda", progress=gr.Progress()):
|
177 |
"""
|
178 |
Use SeqTex to generate texture for the mesh based on the image condition.
|
179 |
+
:param position_images_path: File path to position images tensor
|
180 |
+
:param normal_images_path: File path to normal images tensor
|
181 |
:param condition_image: Image condition generated from the selected view.
|
182 |
:param text_prompt: Text prompt for texture generation.
|
183 |
:param selected_view: The view selected for generating the image condition.
|
184 |
+
:return: File paths of generated texture map and multi-view frames, and PIL images
|
185 |
"""
|
186 |
+
position_map = load_tensor_from_file(position_map_path, map_location=device)
|
187 |
+
normal_map = load_tensor_from_file(normal_map_path, map_location=device)
|
188 |
+
position_images = load_tensor_from_file(position_images_path, map_location=device)
|
189 |
+
normal_images = load_tensor_from_file(normal_images_path, map_location=device)
|
190 |
+
|
191 |
progress(0, desc="Loading SeqTex pipeline...")
|
192 |
+
tex_pipe = get_seqtex_pipe()
|
193 |
+
# assert tex_pipe is in gpu
|
194 |
+
assert tex_pipe.device.type == "cuda", "SeqTex pipeline must be loaded in GPU context."
|
195 |
progress(0.2, desc="SeqTex pipeline loaded successfully.")
|
196 |
view_id_map = {
|
197 |
"First View": 0,
|
|
|
262 |
uv_map_pred = torch.clamp(uv_map_pred, 0.0, 1.0)
|
263 |
|
264 |
progress(1, desc="Texture generated successfully.")
|
265 |
+
uv_map_pred_path = save_tensor_to_file(uv_map_pred, prefix="uv_map_pred")
|
266 |
+
return uv_map_pred_path, tensor_to_pil(uv_map_pred, normalize=False), tensor_to_pil(mv_out, normalize=False), "Step 3: Texture generated successfully."
|
wan/pipeline_wan_t2tex_extra.py
CHANGED
@@ -1,19 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import copy
|
2 |
-
from typing import Any, Callable, Dict, List, Optional,
|
3 |
|
4 |
-
|
5 |
import regex as re
|
6 |
import torch
|
7 |
-
from diffusers.
|
8 |
from diffusers.pipelines.wan.pipeline_output import WanPipelineOutput
|
9 |
-
from diffusers.
|
10 |
from diffusers.utils.torch_utils import randn_tensor
|
|
|
|
|
11 |
from torch import Tensor
|
12 |
from transformers import AutoTokenizer, UMT5EncoderModel
|
13 |
-
|
14 |
-
import gradio as gr
|
15 |
|
16 |
def get_sigmas(scheduler, timesteps, dtype=torch.float32, device="cuda"):
|
|
|
|
|
|
|
|
|
|
|
17 |
sigmas = scheduler.sigmas.to(device=device, dtype=dtype)
|
18 |
schedule_timesteps = scheduler.timesteps.to(device)
|
19 |
timesteps = timesteps.to(device)
|
@@ -83,6 +103,8 @@ class WanT2TexPipeline(WanPipeline):
|
|
83 |
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
84 |
output_type: Optional[str] = "np",
|
85 |
return_dict: bool = True,
|
|
|
|
|
86 |
attention_kwargs: Optional[Dict[str, Any]] = None,
|
87 |
callback_on_step_end: Optional[
|
88 |
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
|
@@ -191,7 +213,8 @@ class WanT2TexPipeline(WanPipeline):
|
|
191 |
self._current_timestep = None
|
192 |
self._interrupt = False
|
193 |
|
194 |
-
device =
|
|
|
195 |
|
196 |
# 2. Define call parameters
|
197 |
if prompt is not None and isinstance(prompt, str):
|
|
|
1 |
+
# Copyright 2025 The Wan Team and The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
import copy
|
16 |
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
17 |
|
18 |
+
import gradio as gr
|
19 |
import regex as re
|
20 |
import torch
|
21 |
+
from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
|
22 |
from diffusers.pipelines.wan.pipeline_output import WanPipelineOutput
|
23 |
+
from diffusers.pipelines.wan.pipeline_wan import WanPipeline
|
24 |
from diffusers.utils.torch_utils import randn_tensor
|
25 |
+
from einops import rearrange
|
26 |
+
from jaxtyping import Float
|
27 |
from torch import Tensor
|
28 |
from transformers import AutoTokenizer, UMT5EncoderModel
|
29 |
+
|
|
|
30 |
|
31 |
def get_sigmas(scheduler, timesteps, dtype=torch.float32, device="cuda"):
|
32 |
+
# Ensure device is available before using it
|
33 |
+
if isinstance(device, str) and device.startswith("cuda"):
|
34 |
+
if not torch.cuda.is_available():
|
35 |
+
device = "cpu"
|
36 |
+
|
37 |
sigmas = scheduler.sigmas.to(device=device, dtype=dtype)
|
38 |
schedule_timesteps = scheduler.timesteps.to(device)
|
39 |
timesteps = timesteps.to(device)
|
|
|
103 |
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
104 |
output_type: Optional[str] = "np",
|
105 |
return_dict: bool = True,
|
106 |
+
device: Optional[str] = "cuda",
|
107 |
+
|
108 |
attention_kwargs: Optional[Dict[str, Any]] = None,
|
109 |
callback_on_step_end: Optional[
|
110 |
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
|
|
|
213 |
self._current_timestep = None
|
214 |
self._interrupt = False
|
215 |
|
216 |
+
device = torch.device(device) if isinstance(device, str) else device
|
217 |
+
self.to(device)
|
218 |
|
219 |
# 2. Define call parameters
|
220 |
if prompt is not None and isinstance(prompt, str):
|
wan/wan_t2tex_transformer_3d_extra.py
CHANGED
@@ -13,30 +13,23 @@
|
|
13 |
# limitations under the License.
|
14 |
|
15 |
import copy
|
16 |
-
import math
|
17 |
-
from typing import Any, Dict, Optional, Tuple, Union
|
18 |
from functools import cache
|
|
|
19 |
|
20 |
-
from einops import rearrange, repeat
|
21 |
import torch
|
22 |
import torch.nn as nn
|
23 |
import torch.nn.functional as F
|
24 |
-
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
25 |
-
from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin
|
26 |
from diffusers.models import WanTransformer3DModel
|
27 |
from diffusers.models.attention import FeedForward
|
28 |
from diffusers.models.attention_processor import Attention
|
29 |
-
from diffusers.models.
|
30 |
-
from diffusers.models.embeddings import (PixArtAlphaTextProjection,
|
31 |
-
TimestepEmbedding, Timesteps,
|
32 |
-
get_1d_rotary_pos_embed)
|
33 |
-
from diffusers.models.modeling_outputs import Transformer2DModelOutput
|
34 |
-
from diffusers.models.modeling_utils import ModelMixin
|
35 |
from diffusers.models.normalization import FP32LayerNorm
|
36 |
from diffusers.models.transformers.transformer_wan import \
|
37 |
WanTimeTextImageEmbedding
|
38 |
from diffusers.utils import (USE_PEFT_BACKEND, logging, scale_lora_layers,
|
39 |
unscale_lora_layers)
|
|
|
|
|
40 |
|
41 |
|
42 |
class WanT2TexAttnProcessor2_0:
|
@@ -228,43 +221,6 @@ class WanRotaryPosEmbed(nn.Module):
|
|
228 |
uv_freqs = torch.cat([uv_freqs_f, uv_freqs_h, uv_freqs_w], dim=-1).reshape(1, 1, uppf * upph * uppw, -1)
|
229 |
return torch.cat([freqs, uv_freqs], dim=-2)
|
230 |
|
231 |
-
# def pseudo_code(freqs, mv_tokens_shape, uv_tokens_shape, dimmension):
|
232 |
-
# """
|
233 |
-
# Input:
|
234 |
-
# freqs: [S, D/2], S is the number of tokens, D is the dimension of tokens, 2 indicates Cos and Sin in original RoPE.
|
235 |
-
# mv_tokens_shape: (mv_num_frames, mv_height, mv_width)
|
236 |
-
# uv_tokens_shape: (uv_num_frames, uv_height, uv_width)
|
237 |
-
# dimension: the dimension of tokens
|
238 |
-
# Output:
|
239 |
-
# """
|
240 |
-
# mpf, mph, mpw = mv_tokens_shape # mv_num_frames, mv_height, mv_width
|
241 |
-
# upf, uph, upw = uv_tokens_shape # uv_num_frames, uv_height, uv_width
|
242 |
-
|
243 |
-
# # 1. To evenly split the freqs into 3 parts
|
244 |
-
# freqs = freqs.split_with_sizes(
|
245 |
-
# [
|
246 |
-
# dimmension // 2 - 2 * (dimmension // 6),
|
247 |
-
# dimmension // 6,
|
248 |
-
# dimmension // 6,
|
249 |
-
# ],
|
250 |
-
# dim=1,
|
251 |
-
# )
|
252 |
-
|
253 |
-
# # 2. In time dimension, the freqs for UV are subsequent to the freqs for MV
|
254 |
-
# freqs_f = freqs[0][:mpf].view(mpf, 1, 1, -1).expand(mpf, mph, mpw, -1)
|
255 |
-
# uv_freqs_f = freqs[0][mpf:mpf+upf].view(upf, 1, 1, -1).expand(upf, uph, upw, -1)
|
256 |
-
|
257 |
-
# # 3. The freqs in height and width dimension are the same for mv and uv
|
258 |
-
# freqs_h = freqs[1][:mph].view(1, mph, 1, -1).expand(mpf, mph, mpw, -1)
|
259 |
-
# uv_freqs_h = freqs[1][:uph].view(1, uph, 1, -1).expand(upf, uph, upw, -1)
|
260 |
-
# freqs_w = freqs[2][:mpw].view(1, 1, mpw, -1).expand(mpf, mph, mpw, -1)
|
261 |
-
# uv_freqs_w = freqs[2][:upw].view(1, 1, upw, -1).expand(upf, uph, upw, -1)
|
262 |
-
|
263 |
-
# # 4. rearrange three 1D RoPEs into 3D RoPE in channel dimension
|
264 |
-
# mv_rope = torch.cat([freqs_f, freqs_h, freqs_w], dim=-1).reshape(mpf * mph * mpw, -1)
|
265 |
-
# uv_rope = torch.cat([uv_freqs_f, uv_freqs_h, uv_freqs_w], dim=-1).reshape(upf * uph * upw, -1)
|
266 |
-
# return torch.cat([mv_rope, uv_rope], dim=-2)
|
267 |
-
|
268 |
class WanT2TexTransformerBlock(nn.Module):
|
269 |
def __init__(
|
270 |
self,
|
@@ -400,71 +356,104 @@ class WanT2TexTransformer3DModel(WanTransformer3DModel):
|
|
400 |
"""
|
401 |
3D Transformer model for T2Tex.
|
402 |
"""
|
403 |
-
def __init__(self,
|
404 |
-
|
405 |
-
|
406 |
-
|
407 |
-
|
408 |
-
|
409 |
-
|
410 |
-
|
411 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
412 |
# 1. Patch & position embedding
|
413 |
-
self.rope = WanRotaryPosEmbed(self.rope.attention_head_dim, self.rope.patch_size, self.rope.max_seq_len
|
414 |
-
self.
|
415 |
-
|
416 |
-
self.norm_patch_embedding = copy.deepcopy(self.patch_embedding)
|
417 |
-
# torch.nn.init.zeros_(self.norm_patch_embedding.weight.data)
|
418 |
-
# torch.nn.init.zeros_(self.norm_patch_embedding.bias.data)
|
419 |
-
if self.use_position:
|
420 |
-
self.pos_patch_embedding = copy.deepcopy(self.patch_embedding)
|
421 |
-
# torch.nn.init.zeros_(self.pos_patch_embedding.weight.data)
|
422 |
-
# torch.nn.init.zeros_(self.pos_patch_embedding.bias.data)
|
423 |
|
424 |
# 2. Condition embeddings
|
425 |
-
inner_dim =
|
426 |
self.condition_embedder = WanTimeTaskTextImageEmbedding(
|
427 |
original_model=self.condition_embedder,
|
428 |
dim=inner_dim,
|
429 |
-
time_freq_dim=
|
430 |
time_proj_dim=inner_dim * 6,
|
431 |
-
text_embed_dim=
|
432 |
-
image_embed_dim=
|
433 |
-
randomly_init=randomly_init,
|
434 |
)
|
435 |
|
436 |
# 3. Transformer blocks
|
437 |
-
self.
|
438 |
-
self.num_attention_heads = original_model.config.num_attention_heads
|
439 |
|
440 |
block = WanT2TexTransformerBlock(
|
441 |
inner_dim,
|
442 |
-
|
443 |
-
|
444 |
-
|
445 |
-
|
446 |
-
|
447 |
-
|
448 |
)
|
449 |
self.blocks = None
|
450 |
self.blocks = nn.ModuleList(
|
451 |
[
|
452 |
copy.deepcopy(block)
|
453 |
-
for _ in range(
|
454 |
]
|
455 |
)
|
456 |
self.scale_shift_table_uv = nn.Parameter(torch.randn(1, 2, inner_dim) / inner_dim**0.5)
|
457 |
-
|
458 |
-
|
459 |
-
|
460 |
-
for block in self.blocks:
|
461 |
-
block.attnuv.load_state_dict(block.attn1.state_dict())
|
462 |
-
block.scale_shift_table_uv.data.copy_(block.scale_shift_table.data)
|
463 |
-
block.normuv2.load_state_dict(block.norm2.state_dict())
|
464 |
-
block.ffnuv.load_state_dict(block.ffn.state_dict())
|
465 |
|
466 |
-
|
467 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
468 |
|
469 |
@cache
|
470 |
def get_attention_bias(self, mv_length, uv_length):
|
@@ -521,23 +510,10 @@ class WanT2TexTransformer3DModel(WanTransformer3DModel):
|
|
521 |
rotary_emb = self.rope(mv_hidden_states, uv_hidden_states)
|
522 |
|
523 |
# Patchify
|
524 |
-
|
525 |
-
|
526 |
-
|
527 |
-
|
528 |
-
uv_geometry_embedding = self.pos_patch_embedding(uv_pos_hidden_states) + self.norm_patch_embedding(uv_norm_hidden_states)
|
529 |
-
elif self.use_normal:
|
530 |
-
mv_rgb_hidden_states, mv_norm_hidden_states = torch.chunk(mv_hidden_states, 2, dim=1)
|
531 |
-
uv_rgb_hidden_states, uv_norm_hidden_states = torch.chunk(uv_hidden_states, 2, dim=1)
|
532 |
-
mv_geometry_embedding = self.norm_patch_embedding(mv_norm_hidden_states)
|
533 |
-
uv_geometry_embedding = self.norm_patch_embedding(uv_norm_hidden_states)
|
534 |
-
elif self.use_position:
|
535 |
-
mv_rgb_hidden_states, mv_pos_hidden_states = torch.chunk(mv_hidden_states, 2, dim=1)
|
536 |
-
uv_rgb_hidden_states, uv_pos_hidden_states = torch.chunk(uv_hidden_states, 2, dim=1)
|
537 |
-
mv_geometry_embedding = self.pos_patch_embedding(mv_pos_hidden_states)
|
538 |
-
uv_geometry_embedding = self.pos_patch_embedding(uv_pos_hidden_states)
|
539 |
-
else:
|
540 |
-
raise ValueError("use_normal and use_position are both False, please set at least one of them to True.")
|
541 |
|
542 |
mv_hidden_states = self.patch_embedding(mv_rgb_hidden_states)
|
543 |
uv_hidden_states = self.patch_embedding(uv_rgb_hidden_states)
|
@@ -564,12 +540,6 @@ class WanT2TexTransformer3DModel(WanTransformer3DModel):
|
|
564 |
if encoder_hidden_states_image is not None:
|
565 |
encoder_hidden_states = torch.concat([encoder_hidden_states_image, encoder_hidden_states], dim=1)
|
566 |
|
567 |
-
# # Get attention bias
|
568 |
-
# if self.use_causal_mask:
|
569 |
-
# # This may be gainless, because the patch embedding is not causal, which will leak information to MV
|
570 |
-
# attn_bias = self.get_attention_bias(post_patch_num_frames * post_patch_height * post_patch_width,
|
571 |
-
# post_uv_num_frames * post_uv_height * post_uv_width)
|
572 |
-
# else:
|
573 |
attn_bias = None
|
574 |
|
575 |
# 4. Transformer blocks
|
|
|
13 |
# limitations under the License.
|
14 |
|
15 |
import copy
|
|
|
|
|
16 |
from functools import cache
|
17 |
+
from typing import Any, Dict, Optional, Tuple, Union
|
18 |
|
|
|
19 |
import torch
|
20 |
import torch.nn as nn
|
21 |
import torch.nn.functional as F
|
|
|
|
|
22 |
from diffusers.models import WanTransformer3DModel
|
23 |
from diffusers.models.attention import FeedForward
|
24 |
from diffusers.models.attention_processor import Attention
|
25 |
+
from diffusers.models.embeddings import get_1d_rotary_pos_embed
|
|
|
|
|
|
|
|
|
|
|
26 |
from diffusers.models.normalization import FP32LayerNorm
|
27 |
from diffusers.models.transformers.transformer_wan import \
|
28 |
WanTimeTextImageEmbedding
|
29 |
from diffusers.utils import (USE_PEFT_BACKEND, logging, scale_lora_layers,
|
30 |
unscale_lora_layers)
|
31 |
+
from einops import rearrange, repeat
|
32 |
+
from peft import LoraConfig
|
33 |
|
34 |
|
35 |
class WanT2TexAttnProcessor2_0:
|
|
|
221 |
uv_freqs = torch.cat([uv_freqs_f, uv_freqs_h, uv_freqs_w], dim=-1).reshape(1, 1, uppf * upph * uppw, -1)
|
222 |
return torch.cat([freqs, uv_freqs], dim=-2)
|
223 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
224 |
class WanT2TexTransformerBlock(nn.Module):
|
225 |
def __init__(
|
226 |
self,
|
|
|
356 |
"""
|
357 |
3D Transformer model for T2Tex.
|
358 |
"""
|
359 |
+
def __init__(self,
|
360 |
+
patch_size: Tuple[int] = (1, 2, 2),
|
361 |
+
num_attention_heads: int = 40,
|
362 |
+
attention_head_dim: int = 128,
|
363 |
+
in_channels: int = 16,
|
364 |
+
out_channels: int = 16,
|
365 |
+
text_dim: int = 4096,
|
366 |
+
freq_dim: int = 256,
|
367 |
+
ffn_dim: int = 13824,
|
368 |
+
num_layers: int = 40,
|
369 |
+
cross_attn_norm: bool = True,
|
370 |
+
qk_norm: Optional[str] = "rms_norm_across_heads",
|
371 |
+
eps: float = 1e-6,
|
372 |
+
image_dim: Optional[int] = None,
|
373 |
+
added_kv_proj_dim: Optional[int] = None,
|
374 |
+
rope_max_seq_len: int = 1024,
|
375 |
+
**kwargs
|
376 |
+
):
|
377 |
+
super(WanT2TexTransformer3DModel, self).__init__(
|
378 |
+
patch_size=patch_size,
|
379 |
+
num_attention_heads=num_attention_heads,
|
380 |
+
attention_head_dim=attention_head_dim,
|
381 |
+
in_channels=in_channels,
|
382 |
+
out_channels=out_channels,
|
383 |
+
text_dim=text_dim,
|
384 |
+
freq_dim=freq_dim,
|
385 |
+
ffn_dim=ffn_dim,
|
386 |
+
num_layers=num_layers,
|
387 |
+
cross_attn_norm=cross_attn_norm,
|
388 |
+
qk_norm=qk_norm,
|
389 |
+
eps=eps,
|
390 |
+
image_dim=image_dim,
|
391 |
+
added_kv_proj_dim=added_kv_proj_dim,
|
392 |
+
rope_max_seq_len=rope_max_seq_len
|
393 |
+
)
|
394 |
# 1. Patch & position embedding
|
395 |
+
self.rope = WanRotaryPosEmbed(self.rope.attention_head_dim, self.rope.patch_size, self.rope.max_seq_len)
|
396 |
+
self.norm_patch_embedding = copy.deepcopy(self.patch_embedding)
|
397 |
+
self.pos_patch_embedding = copy.deepcopy(self.patch_embedding)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
398 |
|
399 |
# 2. Condition embeddings
|
400 |
+
inner_dim = num_attention_heads * attention_head_dim
|
401 |
self.condition_embedder = WanTimeTaskTextImageEmbedding(
|
402 |
original_model=self.condition_embedder,
|
403 |
dim=inner_dim,
|
404 |
+
time_freq_dim=freq_dim,
|
405 |
time_proj_dim=inner_dim * 6,
|
406 |
+
text_embed_dim=text_dim,
|
407 |
+
image_embed_dim=image_dim,
|
|
|
408 |
)
|
409 |
|
410 |
# 3. Transformer blocks
|
411 |
+
self.num_attention_heads = num_attention_heads
|
|
|
412 |
|
413 |
block = WanT2TexTransformerBlock(
|
414 |
inner_dim,
|
415 |
+
ffn_dim,
|
416 |
+
num_attention_heads,
|
417 |
+
qk_norm,
|
418 |
+
cross_attn_norm,
|
419 |
+
eps,
|
420 |
+
added_kv_proj_dim,
|
421 |
)
|
422 |
self.blocks = None
|
423 |
self.blocks = nn.ModuleList(
|
424 |
[
|
425 |
copy.deepcopy(block)
|
426 |
+
for _ in range(num_layers)
|
427 |
]
|
428 |
)
|
429 |
self.scale_shift_table_uv = nn.Parameter(torch.randn(1, 2, inner_dim) / inner_dim**0.5)
|
430 |
+
|
431 |
+
# 4. Auto-configure LoRA adapter for SeqTex
|
432 |
+
self.configure_lora_adapter()
|
|
|
|
|
|
|
|
|
|
|
433 |
|
434 |
+
def configure_lora_adapter(self, lora_rank: int = 128, lora_alpha: int = 64):
|
435 |
+
"""
|
436 |
+
Configure LoRA adapter with custom settings or auto-configuration.
|
437 |
+
|
438 |
+
Args:
|
439 |
+
lora_rank (int, optional): LoRA rank parameter, default (128)
|
440 |
+
lora_alpha (int, optional): LoRA alpha parameter, default (64)
|
441 |
+
"""
|
442 |
+
# Get parameters from args, environment variables, or defaults
|
443 |
+
target_modules = [
|
444 |
+
"attn1.to_q", "attn1.to_k", "attn1.to_v",
|
445 |
+
"attn1.to_out.0", "attn1.to_out.2",
|
446 |
+
"ffn.net.0.proj", "ffn.net.2"
|
447 |
+
]
|
448 |
+
|
449 |
+
lora_config = LoraConfig(
|
450 |
+
r=lora_rank,
|
451 |
+
lora_alpha=lora_alpha,
|
452 |
+
init_lora_weights=True,
|
453 |
+
target_modules=target_modules,
|
454 |
+
)
|
455 |
+
|
456 |
+
self.add_adapter(lora_config)
|
457 |
|
458 |
@cache
|
459 |
def get_attention_bias(self, mv_length, uv_length):
|
|
|
510 |
rotary_emb = self.rope(mv_hidden_states, uv_hidden_states)
|
511 |
|
512 |
# Patchify
|
513 |
+
mv_rgb_hidden_states, mv_pos_hidden_states, mv_norm_hidden_states = torch.chunk(mv_hidden_states, 3, dim=1)
|
514 |
+
uv_rgb_hidden_states, uv_pos_hidden_states, uv_norm_hidden_states = torch.chunk(uv_hidden_states, 3, dim=1)
|
515 |
+
mv_geometry_embedding = self.pos_patch_embedding(mv_pos_hidden_states) + self.norm_patch_embedding(mv_norm_hidden_states)
|
516 |
+
uv_geometry_embedding = self.pos_patch_embedding(uv_pos_hidden_states) + self.norm_patch_embedding(uv_norm_hidden_states)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
517 |
|
518 |
mv_hidden_states = self.patch_embedding(mv_rgb_hidden_states)
|
519 |
uv_hidden_states = self.patch_embedding(uv_rgb_hidden_states)
|
|
|
540 |
if encoder_hidden_states_image is not None:
|
541 |
encoder_hidden_states = torch.concat([encoder_hidden_states_image, encoder_hidden_states], dim=1)
|
542 |
|
|
|
|
|
|
|
|
|
|
|
|
|
543 |
attn_bias = None
|
544 |
|
545 |
# 4. Transformer blocks
|