yuanze1024 commited on
Commit
6d4bcdf
·
1 Parent(s): 1d5bb62

init space 2

Browse files
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ __pycache__/
README.md ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: SeqTex
3
+ emoji: 🗺️
4
+ colorFrom: yellow
5
+ colorTo: green
6
+ sdk: gradio
7
+ sdk_version: 5.34.2
8
+ python_version: 3.12
9
+ models:
10
+ - Wan-AI/Wan2.1-T2V-1.3B-Diffusers
11
+ - VAST-AI/SeqTex-Transformer
12
+ - black-forest-labs/FLUX.1-dev
13
+ - Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro-2.0
14
+ - madebyollin/sdxl-vae-fp16-fix
15
+ - stabilityai/stable-diffusion-xl-base-1.0
16
+ - xinsir/controlnet-union-sdxl-1.0
17
+ app_file: app.py
18
+ pinned: false
19
+ license: mit
20
+ short_description: SeqTex generates texture based on textual conditions
21
+ ---
22
+
23
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
THIRD_PARTY_LICENSES.md ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # Third Party Licenses
3
+
4
+ This project uses third-party libraries that are subject to their own licenses.
5
+
6
+ ## nvdiffrast
7
+
8
+ **Project:** https://github.com/NVlabs/nvdiffrast
9
+ **License:** NVIDIA Source Code License (1-Way Commercial)
10
+ **Usage:** Non-commercial use (research or evaluation purposes only)
11
+
12
+ ```text
13
+ Copyright (c) 2020, NVIDIA Corporation. All rights reserved.
14
+
15
+ This work is made available under the Nvidia Source Code License (1-Way Commercial).
16
+ The Work and any derivative works thereof only may be used or intended for use
17
+ non-commercially. "Non-commercially" means for research or evaluation purposes only
18
+ and not for any direct or indirect monetary gain.
19
+
20
+ Full license: https://github.com/NVlabs/nvdiffrast/blob/main/LICENSE.txt
21
+ ```
22
+
23
+ **Key Points:**
24
+ - ✅ Research/Academic Use: Permitted
25
+ - ❌ Commercial Use: Requires separate licensing from NVIDIA
26
+ - 📞 Commercial Licensing: https://www.nvidia.com/en-us/research/inquiries/
27
+
28
+ ## Wan Team Libraries
29
+
30
+ **Project:** Various components in `wan/` directory
31
+ **License:** Apache License 2.0
32
+ **Copyright:** Copyright (c) 2024 Wan Team
33
+
34
+ ```text
35
+ Licensed under the Apache License, Version 2.0 (the "License");
36
+ you may not use this file except in compliance with the License.
37
+ You may obtain a copy of the License at
38
+
39
+ http://www.apache.org/licenses/LICENSE-2.0
40
+
41
+ Unless required by applicable law or agreed to in writing, software
42
+ distributed under the License is distributed on an "AS IS" BASIS,
43
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
44
+ See the License for the specific language governing permissions and
45
+ limitations under the License.
46
+ ```
47
+
48
+ ## Hugging Face Diffusers
49
+
50
+ **Project:** https://github.com/huggingface/diffusers
51
+ **License:** Apache License 2.0
52
+ **Copyright:** Copyright 2024 The HuggingFace Team
53
+
54
+ ```text
55
+ Licensed under the Apache License, Version 2.0 (the "License");
56
+ you may not use this file except in compliance with the License.
57
+ You may obtain a copy of the License at
58
+
59
+ http://www.apache.org/licenses/LICENSE-2.0
60
+
61
+ Unless required by applicable law or agreed to in writing, software
62
+ distributed under the License is distributed on an "AS IS" BASIS,
63
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
64
+ See the License for the specific language governing permissions and
65
+ limitations under the License.
66
+ ```
67
+
68
+ ## Hugging Face Transformers
69
+
70
+ **Project:** https://github.com/huggingface/transformers
71
+ **License:** Apache License 2.0
72
+ **Copyright:** Copyright 2024 The HuggingFace Team
73
+
74
+ ```text
75
+ Licensed under the Apache License, Version 2.0 (the "License");
76
+ you may not use this file except in compliance with the License.
77
+ You may obtain a copy of the License at
78
+
79
+ http://www.apache.org/licenses/LICENSE-2.0
80
+
81
+ Unless required by applicable law or agreed to in writing, software
82
+ distributed under the License is distributed on an "AS IS" BASIS,
83
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
84
+ See the License for the specific language governing permissions and
85
+ limitations under the License.
86
+ ```
87
+
88
+ ## PEFT (Parameter-Efficient Fine-Tuning)
89
+
90
+ **Project:** https://github.com/huggingface/peft
91
+ **License:** Apache License 2.0
92
+ **Copyright:** Copyright 2024 The HuggingFace Team
93
+
94
+ ```text
95
+ Licensed under the Apache License, Version 2.0 (the "License");
96
+ you may not use this file except in compliance with the License.
97
+ You may obtain a copy of the License at
98
+
99
+ http://www.apache.org/licenses/LICENSE-2.0
100
+
101
+ Unless required by applicable law or agreed to in writing, software
102
+ distributed under the License is distributed on an "AS IS" BASIS,
103
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
104
+ See the License for the specific language governing permissions and
105
+ limitations under the License.
106
+ ```
107
+
108
+ ## Other Dependencies
109
+
110
+ The following dependencies are used under their respective licenses:
111
+
112
+ - **PyTorch & TorchVision**: BSD-3-Clause License
113
+ - **Einops**: MIT License
114
+ - **OmegaConf**: BSD-3-Clause License
115
+ - **Trimesh**: MIT License
116
+ - **Gradio**: Apache License 2.0
117
+ - **OpenCV**: BSD-3-Clause License
118
+ - **NumPy**: BSD-3-Clause License
119
+ - **ImageIO**: BSD-2-Clause License
120
+
121
+ ## Individual Contributors
122
+
123
+ **Qi Xin** - Condition Transformer implementation in `utils/controlnet_union.py`
124
+ - Copyright by Qi Xin (2024/07/06)
125
+ - Condition Transformer component for fusing single/multi conditions with input image
126
+
127
+ For the complete list of dependencies and their licenses, please refer to the respective package repositories.
app.py CHANGED
@@ -1,231 +1,192 @@
1
- import numpy as np
2
- import torch
3
- from einops import rearrange
4
- from PIL import Image
5
- from utils.image_generation import generate_image_condition
6
  from utils.mesh_utils import Mesh
7
  from utils.render_utils import render_views
8
- from utils.texture_generation import generate_texture
9
-
10
- import gradio as gr
11
- from gradio_litmodel3d import LitModel3D
12
 
13
  EXAMPLES = [
14
  ["examples/birdhouse.glb", True, False, False, False, 42, "First View", "SDXL", False, "A rustic birdhouse featuring a snow-covered roof, wood textures, and two decorative cardinal birds. It has a circular entryway and conveys a winter-themed aesthetic."],
15
- ["examples/mario.glb", False, False, False, True, 6666, "Third View", "FLUX", True, "Mario, a cartoon character wearing a red cap and blue overalls, with brown hair and a mustache, and white gloves, in a fighting pose. The clothes he wears are not in a reflection mode."],
 
16
  ]
 
 
17
 
18
- def tensor_to_pil(tensor, mask=None, normalize: bool = True):
19
- """
20
- Convert tensor to PIL Image.
21
- :param tensor: torch.Tensor, shape can be (Nv, H, W, C), (Nv, C, H, W), (H, W, C), (C, H, W)
22
- :param mask: torch.Tensor, shape same as tensor, effective when C=3
23
- :return: PIL.Image
24
- """
25
- # Move to cpu
26
- tensor = tensor.detach()
27
- if tensor.is_cuda:
28
- tensor = tensor.cpu()
29
- if mask is not None and mask.is_cuda:
30
- mask = mask.cpu()
31
-
32
- # Convert to float32
33
- tensor = tensor.float()
34
- if mask is not None:
35
- mask = mask.float()
36
-
37
- if normalize:
38
- tensor = (tensor + 1.0) / 2.0
39
- tensor = torch.clamp(tensor, 0.0, 1.0)
40
- if mask is not None:
41
- if mask.shape[-1] not in [1, 3]:
42
- mask = mask.unsqueeze(-1)
43
- tensor = torch.cat([tensor, mask], dim=-1)
44
-
45
- shape = tensor.shape
46
- # 4D: (Nv, H, W, C) or (Nv, C, H, W)
47
- if len(shape) == 4:
48
- Nv = shape[0]
49
- if shape[-1] in [3, 4]: # (Nv, H, W, C)
50
- tensor = rearrange(tensor, 'nv h w c -> h (nv w) c')
51
- else: # (Nv, C, H, W)
52
- tensor = rearrange(tensor, 'nv c h w -> h (nv w) c')
53
- # 3D: (H, W, C) or (C, H, W)
54
- elif len(shape) == 3:
55
- if shape[-1] in [3, 4]: # (H, W, C)
56
- tensor = rearrange(tensor, 'h w c -> h w c')
57
- else: # (C, H, W)
58
- tensor = rearrange(tensor, 'c h w -> h w c')
59
- else:
60
- raise ValueError(f"Unsupported tensor shape: {shape}")
61
-
62
- # Convert to numpy
63
- np_img = (tensor.numpy() * 255).round().astype(np.uint8)
64
-
65
- # Create PIL Image
66
- if np_img.shape[2] == 3:
67
- return Image.fromarray(np_img, mode="RGB")
68
- elif np_img.shape[2] == 4:
69
- return Image.fromarray(np_img, mode="RGBA")
70
- else:
71
- raise ValueError("Only support 3 or 4 channel images.")
72
 
73
- if __name__ == '__main__':
74
- with gr.Blocks() as demo:
75
- gr.Markdown("# 🎨 SeqTex: Generate Mesh Textures in Video Sequence")
76
-
77
- gr.Markdown("""
78
- ## 🚀 Welcome to SeqTex!
79
- **SeqTex** is a cutting-edge AI system that generates high-quality textures for 3D meshes using image prompts (here we use image generator to get them from textual prompts).
80
-
81
- Choose to either **try our example models** below or **upload your own 3D mesh** to create stunning textures.
82
- """)
83
-
84
- gr.Markdown("---")
85
-
86
- gr.Markdown("## 🔧 Step 1: Upload & Process 3D Mesh")
87
- gr.Markdown("""
88
- **📋 How to prepare your 3D mesh:**
89
- - Upload your 3D mesh in **.obj** or **.glb** format
90
- - **💡 Pro Tip**:
91
- - For optimal results, ensure your mesh includes only one part with <span style="color:#e74c3c; font-weight:bold;">UV parameterization</span>
92
- - Otherwise, we'll combine all parts and generate UV parameterization using *xAtlas* (may take longer for high-poly meshes; may also fail for certain meshes)
93
- - **⚠️ Important**: We recommend adjusting your model using *Mesh Orientation Adjustments* to be **Z-UP oriented** for best results
94
- """)
95
- position_map_tensor, normal_map_tensor, position_images_tensor, normal_images_tensor, mask_images_tensor, w2cs, mesh, mvp_matrix = gr.State(), gr.State(), gr.State(), gr.State(), gr.State(), gr.State(), gr.State(), gr.State()
96
-
97
- # fixed_texture_map = Image.open("image.webp").convert("RGB")
98
- # Step 1
99
- with gr.Row():
100
- with gr.Column():
101
- mesh_upload = gr.File(label="📁 Upload 3D Mesh", file_types=[".obj", ".glb"])
102
- # uv_tool = gr.Radio(["xAtlas", "UVAtlas"], label="UV parameterizer", value="xAtlas")
103
-
104
- gr.Markdown("**🔄 Mesh Orientation Adjustments** (if needed):")
105
- y2z = gr.Checkbox(label="Y → Z Transform", value=False, info="Rotate: Y becomes Z, -Z becomes Y")
106
- y2x = gr.Checkbox(label="Y → X Transform", value=False, info="Rotate: Y becomes X, -X becomes Y")
107
- z2x = gr.Checkbox(label="Z → X Transform", value=False, info="Rotate: Z becomes X, -X becomes Z")
108
- upside_down = gr.Checkbox(label="🔃 Flip Vertically", value=False, info="Fix upside-down mesh orientation")
109
 
110
- with gr.Column():
111
- step1_button = gr.Button("🔄 Process Mesh & Generate Views", variant="primary")
112
- step1_progress = gr.Textbox(label="📊 Processing Status", interactive=False)
113
- model_input = gr.Model3D(label="📐 Processed 3D Model", height=500)
114
-
115
- with gr.Row(equal_height=True):
116
- rgb_views = gr.Image(label="📷 Generated Views (Front, Back, Left, Right)", type="pil", scale=3)
117
- position_map = gr.Image(label="🗺️ Position Map", type="pil", scale=1)
118
- normal_map = gr.Image(label="🧭 Normal Map", type="pil", scale=1)
119
-
120
- step1_button.click(
121
- Mesh.process,
122
- inputs=[mesh_upload, gr.State("xAtlas"), y2z, y2x, z2x, upside_down],
123
- outputs=[position_map_tensor, normal_map_tensor, position_images_tensor, normal_images_tensor, mask_images_tensor, w2cs, mesh, mvp_matrix, step1_progress]
124
- ).then(
125
- tensor_to_pil,
126
- inputs=[normal_images_tensor, mask_images_tensor],
127
- outputs=[rgb_views]
128
- ).then(
129
- tensor_to_pil,
130
- inputs=[position_map_tensor],
131
- outputs=[position_map]
132
- ).then(
133
- tensor_to_pil,
134
- inputs=[normal_map_tensor],
135
- outputs=[normal_map]
136
- ).then(
137
- Mesh.export,
138
- inputs=[mesh],
139
- outputs=[model_input]
140
- )
141
-
142
- # Step 2
143
- gr.Markdown("---")
144
- gr.Markdown("## 👁️ Step 2: Select View & Generate Image Condition")
145
- gr.Markdown("""
146
- **📋 How to generate image condition:**
147
- - Your mesh will be rendered from **four viewpoints** (front, back, left, right)
148
- - Choose **one view** as your image condition
149
- - Enter a **descriptive text prompt** for the desired texture
150
- - Select your preferred AI model:
151
- - <span style="color:#27ae60; font-weight:bold;">🎯 SDXL</span>: Fast generation with depth + normal control, better details
152
- - <span style="color:#3498db; font-weight:bold;">⚡ FLUX</span>: High-quality generation with depth control (slower due to CPU offloading). Better work with **Edge Refinement**
153
- """)
154
- with gr.Row():
155
- with gr.Column():
156
- img_condition_seed = gr.Number(label="🎲 Random Seed", minimum=0, maximum=9999, step=1, value=42, info="Change for different results")
157
- selected_view = gr.Radio(["First View", "Second View", "Third View", "Fourth View"], label="📐 Camera View", value="First View", info="Choose which viewpoint to use as reference")
158
- with gr.Row():
159
- model_choice = gr.Radio(["SDXL", "FLUX"], label="🤖 AI Model", value="SDXL", info="SDXL: Fast, depth+normal control | FLUX: High-quality, slower processing")
160
- edge_refinement = gr.Checkbox(label="✨ Edge Refinement", value=True, info="Smooth boundary artifacts (recommended for cleaner results)")
161
- text_prompt = gr.Textbox(label="💬 Texture Description", placeholder="Describe the desired texture appearance (e.g., 'rustic wooden surface with weathered paint')", lines=2)
162
- step2_button = gr.Button("🎯 Generate Image Condition", variant="primary")
163
- step2_progress = gr.Textbox(label="📊 Generation Status", interactive=False)
164
-
165
- with gr.Column():
166
- condition_image = gr.Image(label="🖼️ Generated Image Condition", type="pil") # , interactive=False
167
-
168
- step2_button.click(
169
- generate_image_condition,
170
- inputs=[position_images_tensor, normal_images_tensor, mask_images_tensor, w2cs, text_prompt, selected_view, img_condition_seed, model_choice, edge_refinement],
171
- outputs=[condition_image, step2_progress],
172
- concurrency_id="gpu_intensive"
173
- )
174
-
175
- # Step 3
176
- gr.Markdown("---")
177
- gr.Markdown("## 🎨 Step 3: Generate Final Texture")
178
- gr.Markdown("""
179
- **📋 How to generate final texture:**
180
- - The **SeqTex pipeline** will create a complete texture map for your model
181
- - View the results from multiple angles and download your textured 3D model (the viewport is a little bit dark)
182
- """)
183
- texture_map_tensor, mv_out_tensor = gr.State(), gr.State()
184
- with gr.Row():
185
- with gr.Column(scale=1):
186
- step3_button = gr.Button("🎨 Generate Final Texture", variant="primary")
187
- step3_progress = gr.Textbox(label="📊 Texture Generation Status", interactive=False)
188
- texture_map = gr.Image(label="🏆 Generated Texture Map", interactive=False)
189
- with gr.Column(scale=2):
190
- rendered_imgs = gr.Image(label="🖼️ Final Rendered Views")
191
- mv_branch_imgs = gr.Image(label="🖼️ SeqTex Direct Output")
192
- with gr.Column(scale=1.5):
193
- # model_display = gr.Model3D(label="🏆 Final Textured Model", height=500)
194
- model_display = LitModel3D(label="Model with Texture",
195
- exposure=30.0,
196
- height=500)
197
-
198
- step3_button.click(
199
- generate_texture,
200
- inputs=[position_map_tensor, normal_map_tensor, position_images_tensor, normal_images_tensor, condition_image, text_prompt, selected_view],
201
- outputs=[texture_map_tensor, mv_out_tensor, step3_progress],
202
- concurrency_id="gpu_intensive"
203
- ).then(
204
- tensor_to_pil,
205
- inputs=[texture_map_tensor, gr.State(None), gr.State(False)],
206
- outputs=[texture_map]
207
- ).then(
208
- tensor_to_pil,
209
- inputs=[mv_out_tensor, gr.State(None), gr.State(False)],
210
- outputs=[mv_branch_imgs]
211
- ).then(
212
- render_views,
213
- inputs=[mesh, texture_map_tensor, mvp_matrix],
214
- outputs=[rendered_imgs]
215
- ).then(
216
- Mesh.export,
217
- inputs=[mesh, gr.State(None), texture_map],
218
- outputs=[model_display]
219
- )
220
-
221
- # Add example inputs for user convenience
222
- gr.Markdown("---")
223
- gr.Markdown("## 🚀 Try Our Examples")
224
- gr.Markdown("**Quick Start**: Click on any example below to see SeqTex in action with pre-configured settings!")
225
- gr.Examples(
226
- examples=EXAMPLES,
227
- inputs=[mesh_upload, y2z, y2x, z2x, upside_down, img_condition_seed, selected_view, model_choice, edge_refinement, text_prompt],
228
- cache_examples=False
229
- )
230
-
231
- demo.launch(server_name="0.0.0.0", server_port=52424)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gradio as gr
3
+
4
+ from utils import tensor_to_pil
5
+ from utils.image_generation import generate_image_condition, get_flux_pipe, get_sdxl_pipe
6
  from utils.mesh_utils import Mesh
7
  from utils.render_utils import render_views
8
+ from utils.texture_generation import generate_texture, get_seqtex_pipe
 
 
 
9
 
10
  EXAMPLES = [
11
  ["examples/birdhouse.glb", True, False, False, False, 42, "First View", "SDXL", False, "A rustic birdhouse featuring a snow-covered roof, wood textures, and two decorative cardinal birds. It has a circular entryway and conveys a winter-themed aesthetic."],
12
+ ["examples/shoe.glb", True, False, False, False, 42, "Second View", "SDXL", False, "Modern sneaker exhibiting a mesh upper and wavy rubber outsole. Features include lacing for adjustability and padded components for comfort. Normal maps emphasize geometric detail."],
13
+ # ["examples/mario.glb", False, False, False, True, 6666, "Third View", "FLUX", True, "Mario, a cartoon character wearing a red cap and blue overalls, with brown hair and a mustache, and white gloves, in a fighting pose. The clothes he wears are not in a reflection mode."],
14
  ]
15
+ LOAD_FIRST = True
16
+
17
 
18
+ with gr.Blocks(delete_cache=(600, 600)) as demo:
19
+ gr.Markdown("# 🎨 SeqTex: Generate Mesh Textures in Video Sequence")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
+ gr.Markdown("""
22
+ ## 🚀 Welcome to SeqTex!
23
+ **SeqTex** is a cutting-edge AI system that generates high-quality textures for 3D meshes using image prompts (here we use image generator to get them from textual prompts).
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
+ Choose to either **try our example models** below or **upload your own 3D mesh** to create stunning textures.
26
+ """)
27
+
28
+ gr.Markdown("---")
29
+
30
+ gr.Markdown("## 🔧 Step 1: Upload & Process 3D Mesh")
31
+ gr.Markdown("""
32
+ **📋 How to prepare your 3D mesh:**
33
+ - Upload your 3D mesh in **.obj** or **.glb** format
34
+ - **💡 Pro Tip**:
35
+ - For optimal results, ensure your mesh includes only one part with <span style="color:#e74c3c; font-weight:bold;">UV parameterization</span>
36
+ - Otherwise, we'll combine all parts and generate UV parameterization using *xAtlas* (may take longer for high-poly meshes; may also fail for certain meshes)
37
+ - **⚠️ Important**: We recommend adjusting your model using *Mesh Orientation Adjustments* to be **Z-UP oriented** for best results
38
+ """)
39
+ position_map_tensor_path = gr.State()
40
+ normal_map_tensor_path = gr.State()
41
+ position_images_tensor_path = gr.State()
42
+ normal_images_tensor_path = gr.State()
43
+ mask_images_tensor_path = gr.State()
44
+ w2c_tensor_path = gr.State()
45
+ mesh = gr.State()
46
+ mvp_matrix_tensor_path = gr.State()
47
+
48
+ # fixed_texture_map = Image.open("image.webp").convert("RGB")
49
+ # Step 1
50
+ with gr.Row():
51
+ with gr.Column():
52
+ mesh_upload = gr.File(label="📁 Upload 3D Mesh", file_types=[".obj", ".glb"])
53
+ # uv_tool = gr.Radio(["xAtlas", "UVAtlas"], label="UV parameterizer", value="xAtlas")
54
+
55
+ gr.Markdown("**🔄 Mesh Orientation Adjustments** (if needed):")
56
+ y2z = gr.Checkbox(label="Y → Z Transform", value=False, info="Rotate: Y becomes Z, -Z becomes Y")
57
+ y2x = gr.Checkbox(label="Y → X Transform", value=False, info="Rotate: Y becomes X, -X becomes Y")
58
+ z2x = gr.Checkbox(label="Z → X Transform", value=False, info="Rotate: Z becomes X, -X becomes Z")
59
+ upside_down = gr.Checkbox(label="🔃 Flip Vertically", value=False, info="Fix upside-down mesh orientation")
60
+ step1_button = gr.Button("🔄 Process Mesh & Generate Views", variant="primary")
61
+ step1_progress = gr.Textbox(label="📊 Processing Status", interactive=False)
62
+
63
+ with gr.Column():
64
+ model_input = gr.Model3D(label="📐 Processed 3D Model", height=500)
65
+
66
+ with gr.Row(equal_height=True):
67
+ rgb_views = gr.Image(label="📷 Generated Views", type="pil", scale=3)
68
+ position_map = gr.Image(label="🗺️ Position Map", type="pil", scale=1)
69
+ normal_map = gr.Image(label="🧭 Normal Map", type="pil", scale=1)
70
+
71
+ step1_button.click(
72
+ Mesh.process,
73
+ inputs=[mesh_upload, gr.State("xAtlas"), y2z, y2x, z2x, upside_down],
74
+ outputs=[position_map_tensor_path, normal_map_tensor_path, position_images_tensor_path, normal_images_tensor_path, mask_images_tensor_path, w2c_tensor_path, mesh, mvp_matrix_tensor_path, step1_progress]
75
+ ).success(
76
+ tensor_to_pil,
77
+ inputs=[normal_images_tensor_path, mask_images_tensor_path],
78
+ outputs=[rgb_views]
79
+ ).success(
80
+ tensor_to_pil,
81
+ inputs=[position_map_tensor_path],
82
+ outputs=[position_map]
83
+ ).success(
84
+ tensor_to_pil,
85
+ inputs=[normal_map_tensor_path],
86
+ outputs=[normal_map]
87
+ ).success(
88
+ Mesh.export,
89
+ inputs=[mesh, gr.State(None), gr.State(None)],
90
+ outputs=[model_input]
91
+ )
92
+
93
+ # Step 2
94
+ gr.Markdown("---")
95
+ gr.Markdown("## 👁️ Step 2: Select View & Generate Image Condition")
96
+ gr.Markdown("""
97
+ **📋 How to generate image condition:**
98
+ - Your mesh will be rendered from **four viewpoints** (front, back, left, right)
99
+ - Choose **one view** as your image condition
100
+ - Enter a **descriptive text prompt** for the desired texture
101
+ - Select your preferred AI model:
102
+ - <span style="color:#27ae60; font-weight:bold;">🎯 SDXL</span>: Fast generation with depth + normal control, better details (often suffer from wrong highlights)
103
+ - <span style="color:#3498db; font-weight:bold;">⚡ FLUX</span>: ~~High-quality generation with depth control (slower due to CPU offloading). Better work with **Edge Refinement**~~ (Not supported due to the memory limit of HF Space. You can try it locally)
104
+ """)
105
+ with gr.Row():
106
+ with gr.Column():
107
+ img_condition_seed = gr.Number(label="🎲 Random Seed", minimum=0, maximum=9999, step=1, value=42, info="Change for different results")
108
+ selected_view = gr.Radio(["First View", "Second View", "Third View", "Fourth View"], label="📐 Camera View", value="First View", info="Choose which viewpoint to use as reference")
109
+ with gr.Row():
110
+ # model_choice = gr.Radio(["SDXL", "FLUX"], label="🤖 AI Model", value="SDXL", info="SDXL: Fast, depth+normal control | FLUX: High-quality, slower processing")
111
+ model_choice = gr.Radio(["SDXL"], label="🤖 AI Model", value="SDXL", info="SDXL: Fast, depth+normal control | FLUX: High-quality, slower processing (Not supported due to the memory limit of HF Space)")
112
+ edge_refinement = gr.Checkbox(label="✨ Edge Refinement", value=True, info="Smooth boundary artifacts (recommended for delightning highlights in the boundary)")
113
+ text_prompt = gr.Textbox(label="💬 Texture Description", placeholder="Describe the desired texture appearance (e.g., 'rustic wooden surface with weathered paint')", lines=2)
114
+ step2_button = gr.Button("🎯 Generate Image Condition", variant="primary")
115
+ step2_progress = gr.Textbox(label="📊 Generation Status", interactive=False)
116
+
117
+ with gr.Column():
118
+ condition_image = gr.Image(label="🖼️ Generated Image Condition", type="pil") # , interactive=False
119
+
120
+ step2_button.click(
121
+ generate_image_condition,
122
+ inputs=[position_images_tensor_path, normal_images_tensor_path, mask_images_tensor_path, w2c_tensor_path, text_prompt, selected_view, img_condition_seed, model_choice, edge_refinement],
123
+ outputs=[condition_image, step2_progress],
124
+ )
125
+
126
+ # Step 3
127
+ gr.Markdown("---")
128
+ gr.Markdown("## 🎨 Step 3: Generate Final Texture")
129
+ gr.Markdown("""
130
+ **📋 How to generate final texture:**
131
+ - The **SeqTex pipeline** will create a complete texture map for your model
132
+ - View the results from multiple angles and download your textured 3D model (the viewport is a little bit dark)
133
+ """)
134
+ texture_map_tensor_path = gr.State()
135
+ with gr.Row():
136
+ with gr.Column(scale=1):
137
+ step3_button = gr.Button("🎨 Generate Final Texture", variant="primary")
138
+ step3_progress = gr.Textbox(label="📊 Texture Generation Status", interactive=False)
139
+ texture_map = gr.Image(label="🏆 Generated Texture Map", interactive=False)
140
+ with gr.Column(scale=2):
141
+ rendered_imgs = gr.Image(label="🖼️ Final Rendered Views")
142
+ mv_branch_imgs = gr.Image(label="🖼️ SeqTex Direct Output")
143
+ with gr.Column(scale=1.5):
144
+ model_display = gr.Model3D(label="🏆 Final Textured Model", height=500)
145
+ # model_display = LitModel3D(label="Model with Texture",
146
+ # exposure=30.0,
147
+ # height=500)
148
+
149
+ step3_button.click(
150
+ generate_texture,
151
+ inputs=[position_map_tensor_path, normal_map_tensor_path, position_images_tensor_path, normal_images_tensor_path, condition_image, text_prompt, selected_view],
152
+ outputs=[texture_map_tensor_path, texture_map, mv_branch_imgs, step3_progress],
153
+ ).success(
154
+ render_views,
155
+ inputs=[mesh, texture_map_tensor_path, mvp_matrix_tensor_path],
156
+ outputs=[rendered_imgs]
157
+ ).success(
158
+ Mesh.export,
159
+ inputs=[mesh, gr.State(None), texture_map],
160
+ outputs=[model_display]
161
+ )
162
+
163
+ # Add example inputs for user convenience
164
+ gr.Markdown("---")
165
+ gr.Markdown("## 🚀 Try Our Examples")
166
+ gr.Markdown("**Quick Start**: Click on any example below to see SeqTex in action with pre-configured settings!")
167
+ gr.Examples(
168
+ examples=EXAMPLES,
169
+ inputs=[mesh_upload, y2z, y2x, z2x, upside_down, img_condition_seed, selected_view, model_choice, edge_refinement, text_prompt],
170
+ cache_examples=False
171
+ )
172
+
173
+ # Acknowledgments
174
+ gr.Markdown("---")
175
+ gr.Markdown("## 🙏 Acknowledgments")
176
+ gr.Markdown("""
177
+ **Special thanks to [Toshihiro Hayashi](mailto:[email protected])** for his valuable support and assistance in fixing bugs for this demo.
178
+ """)
179
+
180
+ if LOAD_FIRST is True:
181
+ import gc
182
+ get_seqtex_pipe()
183
+ print("SeqTex pipeline loaded successfully.")
184
+ get_sdxl_pipe()
185
+ print("SDXL pipeline loaded successfully.")
186
+ # get_flux_pipe()
187
+ # Note: FLUX pipeline is available in code but not loaded due to GPU memory constraints on HF Space
188
+ print("Note: FLUX and other models are available for local deployment.")
189
+ gc.collect()
190
+
191
+ assert os.environ["OPENCV_IO_ENABLE_OPENEXR"] == "1", "OpenEXR support is required for this demo."
192
+ demo.launch(server_name="0.0.0.0")
examples/shoe.glb ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3945c77a6f98eb18aae9c97f253e4c6b06daf83194c21f91ea4b955756bace7e
3
+ size 8842904
packages.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ ninja-build
requirements.txt ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torchvision==0.21.0
2
+
3
+ einops
4
+ omegaconf==2.3.0
5
+ jaxtyping
6
+ typeguard
7
+ imageio
8
+ trimesh==4.6.4
9
+ peft==0.14.0
10
+ diffusers==0.33.1
11
+ bitsandbytes
12
+ transformers==4.52.4
13
+ ftfy
14
+ accelerate
15
+ sentencepiece
16
+ ipdb
17
+ clean-fid
18
+ apex==0.9.10.dev0
19
+ xatlas
20
+ gradio_litmodel3d
21
+ spaces
22
+ numpy
23
+ opencv-python
24
+ https://huggingface.co/spaces/VAST-AI/MV-Adapter-Img2Texture/resolve/main/wheels/nvdiffrast-0.3.3-cp310-cp310-linux_x86_64.whl?download=true
utils/__init__.py CHANGED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from einops import rearrange
4
+ from PIL import Image
5
+
6
+
7
+ def tensor_to_pil(tensor, mask=None, normalize: bool = True):
8
+ """
9
+ Convert tensor to PIL Image.
10
+ :param tensor: torch.Tensor or str (file path to tensor), shape can be (Nv, H, W, C), (Nv, C, H, W), (H, W, C), (C, H, W)
11
+ :param mask: torch.Tensor or str (file path to tensor), shape same as tensor, effective when C=3
12
+ :return: PIL.Image
13
+ """
14
+ # If input is a file path, load the tensor
15
+ if isinstance(tensor, str):
16
+ from utils.file_utils import load_tensor_from_file
17
+ tensor = load_tensor_from_file(tensor, map_location="cpu")
18
+ if mask is not None and isinstance(mask, str):
19
+ from utils.file_utils import load_tensor_from_file
20
+ mask = load_tensor_from_file(mask, map_location="cpu")
21
+ # Move to cpu
22
+ tensor = tensor.detach()
23
+ if tensor.is_cuda:
24
+ tensor = tensor.cpu()
25
+ if mask is not None and mask.is_cuda:
26
+ mask = mask.cpu()
27
+
28
+ # Convert to float32
29
+ tensor = tensor.float()
30
+ if mask is not None:
31
+ mask = mask.float()
32
+
33
+ if normalize:
34
+ tensor = (tensor + 1.0) / 2.0
35
+ tensor = torch.clamp(tensor, 0.0, 1.0)
36
+ if mask is not None:
37
+ if mask.shape[-1] not in [1, 3]:
38
+ mask = mask.unsqueeze(-1)
39
+ tensor = torch.cat([tensor, mask], dim=-1)
40
+
41
+ shape = tensor.shape
42
+ # 4D: (Nv, H, W, C) or (Nv, C, H, W)
43
+ if len(shape) == 4:
44
+ Nv = shape[0]
45
+ if shape[-1] in [3, 4]: # (Nv, H, W, C)
46
+ tensor = rearrange(tensor, 'nv h w c -> h (nv w) c')
47
+ else: # (Nv, C, H, W)
48
+ tensor = rearrange(tensor, 'nv c h w -> h (nv w) c')
49
+ # 3D: (H, W, C) or (C, H, W)
50
+ elif len(shape) == 3:
51
+ if shape[-1] in [3, 4]: # (H, W, C)
52
+ tensor = rearrange(tensor, 'h w c -> h w c')
53
+ else: # (C, H, W)
54
+ tensor = rearrange(tensor, 'c h w -> h w c')
55
+ else:
56
+ raise ValueError(f"Unsupported tensor shape: {shape}")
57
+
58
+ # Convert to numpy
59
+ np_img = (tensor.numpy() * 255).round().astype(np.uint8)
60
+
61
+ # Create PIL Image
62
+ if np_img.shape[2] == 3:
63
+ return Image.fromarray(np_img, mode="RGB")
64
+ elif np_img.shape[2] == 4:
65
+ return Image.fromarray(np_img, mode="RGBA")
66
+ else:
67
+ raise ValueError("Only support 3 or 4 channel images.")
utils/file_utils.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import uuid
3
+ import torch
4
+ from gradio.utils import get_upload_folder
5
+
6
+ def save_tensor_to_file(tensor, prefix="tensor"):
7
+ upload_dir = get_upload_folder()
8
+ os.makedirs(upload_dir, exist_ok=True)
9
+ path = os.path.join(upload_dir, f"{prefix}_{uuid.uuid4().hex}.pt")
10
+ torch.save(tensor, path)
11
+ return path
12
+
13
+ def load_tensor_from_file(path, map_location=None):
14
+ # Use weights_only=True for security and to suppress FutureWarning (only tensors are loaded in this app)
15
+ return torch.load(path, map_location=map_location, weights_only=True)
utils/image_generation.py CHANGED
@@ -1,6 +1,8 @@
 
1
  import threading
2
 
3
  import cv2
 
4
  import numpy as np
5
  import spaces
6
  import torch
@@ -12,12 +14,11 @@ from einops import rearrange
12
  from PIL import Image
13
  from torchvision.transforms import ToPILImage
14
 
15
- import gradio as gr
16
-
17
  from .controlnet_union import ControlNetModel_Union
18
  from .pipeline_controlnet_union_sd_xl import \
19
  StableDiffusionXLControlNetUnionPipeline
20
  from .render_utils import get_silhouette_image
 
21
 
22
  IMG_PIPE = None
23
  IMG_PIPE_LOCK = threading.Lock()
@@ -26,8 +27,9 @@ FLUX_PIPE = None
26
  FLUX_PIPE_LOCK = threading.Lock()
27
  FLUX_SUFFIX = None
28
  FLUX_NEGATIVE = None
 
29
 
30
- def lazy_get_flux_pipe():
31
  """
32
  Lazy load the FLUX pipeline with ControlNet for image generation.
33
  """
@@ -44,16 +46,21 @@ def lazy_get_flux_pipe():
44
  controlnet_model_union = 'Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro-2.0'
45
 
46
  controlnet = FluxControlNetModel.from_pretrained(controlnet_model_union, torch_dtype=torch.bfloat16)
 
47
  FLUX_PIPE = FluxControlNetPipeline.from_pretrained(
48
  base_model,
49
  controlnet=controlnet,
50
- torch_dtype=torch.bfloat16
 
51
  )
52
  # Use model CPU offload for better GPU utilization during inference
53
- FLUX_PIPE.enable_model_cpu_offload()
 
 
 
54
  return FLUX_PIPE
55
 
56
- def lazy_get_sdxl_pipe():
57
  """
58
  Lazy load the SDXL pipeline with ControlNet for image generation.
59
  """
@@ -74,8 +81,11 @@ def lazy_get_sdxl_pipe():
74
  torch_dtype=torch.float16,
75
  scheduler=eulera_scheduler,
76
  )
77
- # Move pipeline to CUDA device
78
- IMG_PIPE = IMG_PIPE.to("cuda")
 
 
 
79
  return IMG_PIPE
80
 
81
 
@@ -94,11 +104,11 @@ def generate_sdxl_condition(depth_img, normal_img, text_prompt, mask, seed=42, e
94
  :return: Generated image condition (e.g., PIL Image).
95
  """
96
  progress(0.1, desc="Loading SDXL pipeline...")
97
- pipeline = lazy_get_sdxl_pipe()
98
  progress(0.3, desc="SDXL pipeline loaded successfully.")
99
 
100
- positive_prompt = text_prompt + ", photo-realistic style, high quality, 8K, highly detailed texture, soft lightning, uniform color, foreground"
101
- negative_prompt = 'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality'
102
 
103
  img_generation_resolution = 1024 # SDXL performs better at 1024x1024
104
  image = pipeline(prompt=[positive_prompt]*1,
@@ -154,7 +164,7 @@ def generate_flux_condition(depth_img, text_prompt, mask, seed=42, edge_refineme
154
  :return: Generated image condition (PIL Image).
155
  """
156
  progress(0.1, desc="Loading FLUX pipeline...")
157
- pipeline = lazy_get_flux_pipe()
158
  progress(0.3, desc="FLUX pipeline loaded successfully.")
159
 
160
  # Enhanced prompt for better results
@@ -262,7 +272,7 @@ def refine_image_edges(rgb_tensor, mask_tensor):
262
 
263
  return refined_rgb_tensor
264
 
265
- @spaces.GPU(duration=120)
266
  def generate_image_condition(position_imgs, normal_imgs, mask_imgs, w2c, text_prompt, selected_view="First View", seed=42, model="SDXL", edge_refinement=True, progress=gr.Progress()):
267
  """
268
  Generate the image condition based on the selected view's silhouette and text prompt.
@@ -278,6 +288,20 @@ def generate_image_condition(position_imgs, normal_imgs, mask_imgs, w2c, text_pr
278
  :param edge_refinement: Whether to apply edge refinement to smooth mask boundaries (default: True).
279
  :return: Generated condition image and status message.
280
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
281
 
282
  progress(0, desc="Handling geometry information...")
283
  silhouette = get_silhouette_image(position_imgs, normal_imgs, mask_imgs=mask_imgs, w2c=w2c, selected_view=selected_view)
@@ -291,6 +315,7 @@ def generate_image_condition(position_imgs, normal_imgs, mask_imgs, w2c, text_pr
291
  return condition, "SDXL condition generated successfully."
292
  elif model == "FLUX":
293
  # FLUX only supports depth control, not normal
 
294
  condition = generate_flux_condition(depth_img, text_prompt, mask, seed, edge_refinement=edge_refinement, progress=progress)
295
  return condition, "FLUX condition generated successfully (depth-only control)."
296
  else:
 
1
+ import os
2
  import threading
3
 
4
  import cv2
5
+ import gradio as gr
6
  import numpy as np
7
  import spaces
8
  import torch
 
14
  from PIL import Image
15
  from torchvision.transforms import ToPILImage
16
 
 
 
17
  from .controlnet_union import ControlNetModel_Union
18
  from .pipeline_controlnet_union_sd_xl import \
19
  StableDiffusionXLControlNetUnionPipeline
20
  from .render_utils import get_silhouette_image
21
+ from utils.file_utils import load_tensor_from_file
22
 
23
  IMG_PIPE = None
24
  IMG_PIPE_LOCK = threading.Lock()
 
27
  FLUX_PIPE_LOCK = threading.Lock()
28
  FLUX_SUFFIX = None
29
  FLUX_NEGATIVE = None
30
+ CPU_OFFLOAD = False
31
 
32
+ def get_flux_pipe():
33
  """
34
  Lazy load the FLUX pipeline with ControlNet for image generation.
35
  """
 
46
  controlnet_model_union = 'Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro-2.0'
47
 
48
  controlnet = FluxControlNetModel.from_pretrained(controlnet_model_union, torch_dtype=torch.bfloat16)
49
+ assert os.environ["SEQTEX_SPACE_TOKEN"] != "", "Please set the SEQTEX_SPACE_TOKEN environment variable with your Hugging Face token, which has access to black-forest-labs/FLUX.1-dev."
50
  FLUX_PIPE = FluxControlNetPipeline.from_pretrained(
51
  base_model,
52
  controlnet=controlnet,
53
+ torch_dtype=torch.bfloat16,
54
+ token=os.environ["SEQTEX_SPACE_TOKEN"]
55
  )
56
  # Use model CPU offload for better GPU utilization during inference
57
+ if CPU_OFFLOAD:
58
+ FLUX_PIPE.enable_model_cpu_offload()
59
+ else:
60
+ FLUX_PIPE.to("cuda")
61
  return FLUX_PIPE
62
 
63
+ def get_sdxl_pipe():
64
  """
65
  Lazy load the SDXL pipeline with ControlNet for image generation.
66
  """
 
81
  torch_dtype=torch.float16,
82
  scheduler=eulera_scheduler,
83
  )
84
+ # Use model CPU offload for better GPU utilization during inference
85
+ if CPU_OFFLOAD:
86
+ IMG_PIPE.enable_model_cpu_offload()
87
+ else:
88
+ IMG_PIPE.to("cuda")
89
  return IMG_PIPE
90
 
91
 
 
104
  :return: Generated image condition (e.g., PIL Image).
105
  """
106
  progress(0.1, desc="Loading SDXL pipeline...")
107
+ pipeline = get_sdxl_pipe()
108
  progress(0.3, desc="SDXL pipeline loaded successfully.")
109
 
110
+ positive_prompt = text_prompt + ", photo-realistic style, high quality, 8K, highly detailed texture, soft diffuse lighting, uniform lighting, flat lighting, even illumination, matte surface, low contrast, uniform color, foreground"
111
+ negative_prompt = 'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, harsh lighting, high contrast, bright highlights, specular reflections, shiny surface, glossy, reflective, strong shadows, dramatic lighting, spotlight, direct sunlight, glare, bloom, lens flare'
112
 
113
  img_generation_resolution = 1024 # SDXL performs better at 1024x1024
114
  image = pipeline(prompt=[positive_prompt]*1,
 
164
  :return: Generated image condition (PIL Image).
165
  """
166
  progress(0.1, desc="Loading FLUX pipeline...")
167
+ pipeline = get_flux_pipe()
168
  progress(0.3, desc="FLUX pipeline loaded successfully.")
169
 
170
  # Enhanced prompt for better results
 
272
 
273
  return refined_rgb_tensor
274
 
275
+ @spaces.GPU()
276
  def generate_image_condition(position_imgs, normal_imgs, mask_imgs, w2c, text_prompt, selected_view="First View", seed=42, model="SDXL", edge_refinement=True, progress=gr.Progress()):
277
  """
278
  Generate the image condition based on the selected view's silhouette and text prompt.
 
288
  :param edge_refinement: Whether to apply edge refinement to smooth mask boundaries (default: True).
289
  :return: Generated condition image and status message.
290
  """
291
+ # If any input is a file path, load the tensor from file
292
+ if isinstance(position_imgs, str):
293
+ position_imgs = load_tensor_from_file(position_imgs, map_location="cuda")
294
+ if isinstance(normal_imgs, str):
295
+ normal_imgs = load_tensor_from_file(normal_imgs, map_location="cuda")
296
+ if isinstance(mask_imgs, str):
297
+ mask_imgs = load_tensor_from_file(mask_imgs, map_location="cuda")
298
+ if isinstance(w2c, str):
299
+ w2c = load_tensor_from_file(w2c, map_location="cuda")
300
+
301
+ position_imgs = position_imgs.to("cuda")
302
+ normal_imgs = normal_imgs.to("cuda")
303
+ mask_imgs = mask_imgs.to("cuda")
304
+ w2c = w2c.to("cuda")
305
 
306
  progress(0, desc="Handling geometry information...")
307
  silhouette = get_silhouette_image(position_imgs, normal_imgs, mask_imgs=mask_imgs, w2c=w2c, selected_view=selected_view)
 
315
  return condition, "SDXL condition generated successfully."
316
  elif model == "FLUX":
317
  # FLUX only supports depth control, not normal
318
+ raise NotImplementedError("FLUX model not supported in HF space, please delete it and use it locally")
319
  condition = generate_flux_condition(depth_img, text_prompt, mask, seed, edge_refinement=edge_refinement, progress=progress)
320
  return condition, "FLUX condition generated successfully (depth-only control)."
321
  else:
utils/mesh_utils.py CHANGED
@@ -1,16 +1,16 @@
1
- import os
2
  import tempfile
3
 
 
4
  import numpy as np
 
5
  import torch
6
  import trimesh
7
  import xatlas
8
  from PIL import Image
9
 
10
- import gradio as gr
11
-
12
  from .render_utils import (get_mvp_matrix, get_pure_texture, render_geo_map,
13
  render_geo_views_tensor, render_views, setup_lights)
 
14
 
15
 
16
  class Mesh:
@@ -19,7 +19,7 @@ class Mesh:
19
  Initialize the Mesh object with a mesh file path.
20
  :param mesh_path: Path to the mesh file (e.g., .obj or .glb).
21
  """
22
- self.device = device
23
  if mesh_path is not None:
24
  # Initialize _parts dictionary to store all parts
25
  self._parts = {}
@@ -91,11 +91,16 @@ class Mesh:
91
  progress(0.4, f"The model has SINGLE UV parameterization, no need to reparameterize.")
92
  self._vmapping = None # No vmapping needed when not reparameterizing
93
 
 
 
 
 
94
  def to(self, device):
95
  """
96
  Move the mesh data to the specified device.
97
  :param device: The target device (e.g., 'cuda' or 'cpu').
98
  """
 
99
  self._v_pos = self._v_pos.to(device)
100
  self._t_pos_idx = self._t_pos_idx.to(device)
101
  if self._v_tex is not None:
@@ -104,6 +109,7 @@ class Mesh:
104
  if hasattr(self, '_vmapping') and self._vmapping is not None:
105
  self._vmapping = self._vmapping.to(device)
106
  self._v_normal = self._v_normal.to(device)
 
107
 
108
  @property
109
  def has_multi_parts(self):
@@ -404,7 +410,12 @@ class Mesh:
404
  texture_map = Image.fromarray(texture_map)
405
  assert type(texture_map) is Image.Image, "texture_map should be a PIL.Image"
406
  texture_map = texture_map.transpose(Image.FLIP_TOP_BOTTOM).convert("RGB")
407
- material = trimesh.visual.texture.SimpleMaterial(image=texture_map)
 
 
 
 
 
408
  else:
409
  default_texture = Image.new("RGB", (1024, 1024), (200, 200, 200))
410
  material = trimesh.visual.texture.SimpleMaterial(image=default_texture)
@@ -445,6 +456,7 @@ class Mesh:
445
  return save_path
446
 
447
  @classmethod
 
448
  def process(cls, mesh_file, uv_tool="xAtlas", y2z=True, y2x=False, z2x=False, upside_down=False, img_size=(512, 512), uv_size=(1024, 1024), device='cuda', progress=gr.Progress()):
449
  """
450
  Handle the mesh processing, which includes normalization, parts merging, and UV mapping.
@@ -486,7 +498,15 @@ class Mesh:
486
  position_map, normal_map = render_geo_map(mesh)
487
 
488
  progress(1, f"Mesh processing completed.")
489
- return position_map, normal_map, position_images, normal_images, mask_images.squeeze(-1), w2c, mesh, mvp_matrix, "Mesh processing completed."
 
 
 
 
 
 
 
 
490
 
491
 
492
  if __name__ == '__main__':
 
 
1
  import tempfile
2
 
3
+ import gradio as gr
4
  import numpy as np
5
+ import spaces
6
  import torch
7
  import trimesh
8
  import xatlas
9
  from PIL import Image
10
 
 
 
11
  from .render_utils import (get_mvp_matrix, get_pure_texture, render_geo_map,
12
  render_geo_views_tensor, render_views, setup_lights)
13
+ from utils.file_utils import save_tensor_to_file
14
 
15
 
16
  class Mesh:
 
19
  Initialize the Mesh object with a mesh file path.
20
  :param mesh_path: Path to the mesh file (e.g., .obj or .glb).
21
  """
22
+ self._device = device
23
  if mesh_path is not None:
24
  # Initialize _parts dictionary to store all parts
25
  self._parts = {}
 
91
  progress(0.4, f"The model has SINGLE UV parameterization, no need to reparameterize.")
92
  self._vmapping = None # No vmapping needed when not reparameterizing
93
 
94
+ @property
95
+ def device(self):
96
+ return self._device
97
+
98
  def to(self, device):
99
  """
100
  Move the mesh data to the specified device.
101
  :param device: The target device (e.g., 'cuda' or 'cpu').
102
  """
103
+ self._device = device
104
  self._v_pos = self._v_pos.to(device)
105
  self._t_pos_idx = self._t_pos_idx.to(device)
106
  if self._v_tex is not None:
 
109
  if hasattr(self, '_vmapping') and self._vmapping is not None:
110
  self._vmapping = self._vmapping.to(device)
111
  self._v_normal = self._v_normal.to(device)
112
+ return self
113
 
114
  @property
115
  def has_multi_parts(self):
 
410
  texture_map = Image.fromarray(texture_map)
411
  assert type(texture_map) is Image.Image, "texture_map should be a PIL.Image"
412
  texture_map = texture_map.transpose(Image.FLIP_TOP_BOTTOM).convert("RGB")
413
+ material = trimesh.visual.material.PBRMaterial(
414
+ baseColorTexture=texture_map,
415
+ baseColorFactor=[255, 255, 255, 255], # 设置为白色以避免颜色混合
416
+ metallicFactor=0.0,
417
+ roughnessFactor=1.0
418
+ )
419
  else:
420
  default_texture = Image.new("RGB", (1024, 1024), (200, 200, 200))
421
  material = trimesh.visual.texture.SimpleMaterial(image=default_texture)
 
456
  return save_path
457
 
458
  @classmethod
459
+ @spaces.GPU(duration=30)
460
  def process(cls, mesh_file, uv_tool="xAtlas", y2z=True, y2x=False, z2x=False, upside_down=False, img_size=(512, 512), uv_size=(1024, 1024), device='cuda', progress=gr.Progress()):
461
  """
462
  Handle the mesh processing, which includes normalization, parts merging, and UV mapping.
 
498
  position_map, normal_map = render_geo_map(mesh)
499
 
500
  progress(1, f"Mesh processing completed.")
501
+ position_map_path = save_tensor_to_file(position_map, prefix="position_map")
502
+ normal_map_path = save_tensor_to_file(normal_map, prefix="normal_map")
503
+ position_images_path = save_tensor_to_file(position_images, prefix="position_images")
504
+ normal_images_path = save_tensor_to_file(normal_images, prefix="normal_images")
505
+ mask_images_path = save_tensor_to_file(mask_images.squeeze(-1), prefix="mask_images")
506
+ w2c_path = save_tensor_to_file(w2c, prefix="w2c")
507
+ mvp_matrix_path = save_tensor_to_file(mvp_matrix, prefix="mvp_matrix")
508
+ # Return mesh instance as is
509
+ return position_map_path, normal_map_path, position_images_path, normal_images_path, mask_images_path, w2c_path, mesh.to("cpu"), mvp_matrix_path, "Mesh processing completed."
510
 
511
 
512
  if __name__ == '__main__':
utils/rasterize.py CHANGED
@@ -1,17 +1,25 @@
 
 
 
 
 
 
 
 
 
1
  import nvdiffrast.torch as dr
2
  import torch
3
-
4
- from torch import Tensor
5
  from jaxtyping import Float, Integer
6
- from typing import Union, Tuple
 
7
 
8
  class NVDiffRasterizerContext:
9
- def __init__(self, context_type: str, device: torch.device) -> None:
10
  self.device = device
11
  self.ctx = self.initialize_context(context_type, device)
12
 
13
  def initialize_context(
14
- self, context_type: str, device: torch.device
15
  ) -> Union[dr.RasterizeGLContext, dr.RasterizeCudaContext]:
16
  if context_type == "gl":
17
  return dr.RasterizeGLContext(device=device)
 
1
+ # This file uses nvdiffrast library, which is licensed under the NVIDIA Source Code License (1-Way Commercial).
2
+ # nvdiffrast is available for non-commercial use (research or evaluation purposes only).
3
+ # For commercial use, please contact NVIDIA for licensing: https://www.nvidia.com/en-us/research/inquiries/
4
+ #
5
+ # nvdiffrast copyright: Copyright (c) 2020, NVIDIA Corporation. All rights reserved.
6
+ # Full license: https://github.com/NVlabs/nvdiffrast/blob/main/LICENSE.txt
7
+
8
+ from typing import Tuple, Union
9
+
10
  import nvdiffrast.torch as dr
11
  import torch
 
 
12
  from jaxtyping import Float, Integer
13
+ from torch import Tensor
14
+
15
 
16
  class NVDiffRasterizerContext:
17
+ def __init__(self, context_type: str, device) -> None:
18
  self.device = device
19
  self.ctx = self.initialize_context(context_type, device)
20
 
21
  def initialize_context(
22
+ self, context_type: str, device
23
  ) -> Union[dr.RasterizeGLContext, dr.RasterizeCudaContext]:
24
  if context_type == "gl":
25
  return dr.RasterizeGLContext(device=device)
utils/render_utils.py CHANGED
@@ -3,6 +3,7 @@ from functools import cache
3
  from typing import Dict, Union
4
 
5
  import numpy as np
 
6
  import torch
7
  import torch.nn.functional as F
8
  from einops import rearrange
@@ -15,8 +16,22 @@ from .rasterize import (NVDiffRasterizerContext,
15
  rasterize_position_and_normal_maps,
16
  render_geo_from_mesh,
17
  render_rgb_from_texture_mesh_with_mask)
 
18
 
19
- CTX = NVDiffRasterizerContext('cuda', 'cuda')
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
  def setup_lights():
22
  """
@@ -24,6 +39,7 @@ def setup_lights():
24
  """
25
  raise NotImplementedError("setup_lights function is not implemented yet.")
26
 
 
27
  def render_views(mesh, texture, mvp_matrix, lights=None, img_size=(512, 512)) -> Image.Image:
28
  """
29
  Render the RGB color images of the mesh. The background will be transparent.
@@ -34,11 +50,22 @@ def render_views(mesh, texture, mvp_matrix, lights=None, img_size=(512, 512)) ->
34
  :param img_size: The size of the output image, a tuple (height, width).
35
  :return: A concatenated PIL Image.
36
  """
 
 
 
 
 
 
 
 
 
 
 
37
  if texture.shape[-1] != 3:
38
  texture = texture.permute(1, 2, 0)
39
  image_height, image_width = img_size
40
  rgb_cond, mask = render_rgb_from_texture_mesh_with_mask(
41
- CTX, mesh, texture, mvp_matrix, image_height, image_width, torch.tensor([0.0, 0.0, 0.0], device='cuda'))
42
 
43
  if mvp_matrix.shape[0] == 0:
44
  return None
@@ -65,20 +92,24 @@ def render_views(mesh, texture, mvp_matrix, lights=None, img_size=(512, 512)) ->
65
 
66
  return concatenated_image
67
 
 
68
  def render_geo_views_tensor(mesh, mvp_matrix, img_size=(512, 512)) -> tuple[torch.Tensor, torch.Tensor]:
69
  """
70
  render the geometry information including position and normal from views that mvp matrix implies.
71
  """
 
72
  image_height, image_width = img_size
73
- position_images, normal_images, mask_images = render_geo_from_mesh(CTX, mesh, mvp_matrix, image_height, image_width)
74
  return position_images, normal_images, mask_images
75
 
 
76
  def render_geo_map(mesh, map_size=(1024, 1024)) -> tuple[torch.Tensor, torch.Tensor]:
77
  """
78
  Render the geometry information including position and normal from UV parameterization.
79
  """
 
80
  map_height, map_width = map_size
81
- position_images, normal_images, mask = rasterize_position_and_normal_maps(CTX, mesh, map_height, map_width)
82
  # out_imgs = []
83
  # if mask.ndim == 4:
84
  # mask = mask[0]
@@ -318,6 +349,7 @@ def _get_depth_noraml_map_with_mask(xyz_map, normal_map, mask, w2c, device="cuda
318
 
319
  return depth_map, normal_map, mask
320
 
 
321
  def get_silhouette_image(position_imgs, normal_imgs, mask_imgs, w2c, selected_view="First View") -> tuple[Image.Image, Image.Image]:
322
  """
323
  Get the silhouette image based on geometry image.
 
3
  from typing import Dict, Union
4
 
5
  import numpy as np
6
+ import spaces
7
  import torch
8
  import torch.nn.functional as F
9
  from einops import rearrange
 
16
  rasterize_position_and_normal_maps,
17
  render_geo_from_mesh,
18
  render_rgb_from_texture_mesh_with_mask)
19
+ from utils.file_utils import load_tensor_from_file
20
 
21
+ # Global variable to store the singleton context
22
+ _CTX_INSTANCE = None
23
+
24
+ @spaces.GPU
25
+ def get_rasterizer_context():
26
+ """
27
+ Get the NVDiffRasterizer context using singleton pattern.
28
+ This ensures only one context is created and reused across the application.
29
+ """
30
+ global _CTX_INSTANCE
31
+ if _CTX_INSTANCE is None:
32
+ # Use string 'cuda' instead of torch.device to avoid early CUDA initialization
33
+ _CTX_INSTANCE = NVDiffRasterizerContext('cuda', 'cuda')
34
+ return _CTX_INSTANCE
35
 
36
  def setup_lights():
37
  """
 
39
  """
40
  raise NotImplementedError("setup_lights function is not implemented yet.")
41
 
42
+ @spaces.GPU
43
  def render_views(mesh, texture, mvp_matrix, lights=None, img_size=(512, 512)) -> Image.Image:
44
  """
45
  Render the RGB color images of the mesh. The background will be transparent.
 
50
  :param img_size: The size of the output image, a tuple (height, width).
51
  :return: A concatenated PIL Image.
52
  """
53
+ # If texture or mvp_matrix is a file path, load the tensor from file
54
+ if isinstance(texture, str):
55
+ texture = load_tensor_from_file(texture, map_location="cuda")
56
+ if isinstance(mvp_matrix, str):
57
+ mvp_matrix = load_tensor_from_file(mvp_matrix, map_location="cuda")
58
+ mesh = mesh.to("cuda")
59
+ texture = texture.to("cuda")
60
+ mvp_matrix = mvp_matrix.to("cuda")
61
+
62
+ print("Trying to render views...")
63
+ ctx = get_rasterizer_context()
64
  if texture.shape[-1] != 3:
65
  texture = texture.permute(1, 2, 0)
66
  image_height, image_width = img_size
67
  rgb_cond, mask = render_rgb_from_texture_mesh_with_mask(
68
+ ctx, mesh, texture, mvp_matrix, image_height, image_width, torch.tensor([0.0, 0.0, 0.0], device=texture.device))
69
 
70
  if mvp_matrix.shape[0] == 0:
71
  return None
 
92
 
93
  return concatenated_image
94
 
95
+ @spaces.GPU
96
  def render_geo_views_tensor(mesh, mvp_matrix, img_size=(512, 512)) -> tuple[torch.Tensor, torch.Tensor]:
97
  """
98
  render the geometry information including position and normal from views that mvp matrix implies.
99
  """
100
+ ctx = get_rasterizer_context()
101
  image_height, image_width = img_size
102
+ position_images, normal_images, mask_images = render_geo_from_mesh(ctx, mesh, mvp_matrix, image_height, image_width)
103
  return position_images, normal_images, mask_images
104
 
105
+ @spaces.GPU
106
  def render_geo_map(mesh, map_size=(1024, 1024)) -> tuple[torch.Tensor, torch.Tensor]:
107
  """
108
  Render the geometry information including position and normal from UV parameterization.
109
  """
110
+ ctx = get_rasterizer_context()
111
  map_height, map_width = map_size
112
+ position_images, normal_images, mask = rasterize_position_and_normal_maps(ctx, mesh, map_height, map_width)
113
  # out_imgs = []
114
  # if mask.ndim == 4:
115
  # mask = mask[0]
 
349
 
350
  return depth_map, normal_map, mask
351
 
352
+ @spaces.GPU
353
  def get_silhouette_image(position_imgs, normal_imgs, mask_imgs, w2c, selected_view="First View") -> tuple[Image.Image, Image.Image]:
354
  """
355
  Get the silhouette image based on geometry image.
utils/texture_generation.py CHANGED
@@ -11,13 +11,15 @@ from diffusers.models import AutoencoderKLWan
11
  from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
12
  from einops import rearrange
13
  from jaxtyping import Float
14
- from peft import LoraConfig
15
  from PIL import Image
16
  from torch import Tensor
17
 
18
  from wan.pipeline_wan_t2tex_extra import WanT2TexPipeline
19
  from wan.wan_t2tex_transformer_3d_extra import WanT2TexTransformer3DModel
20
 
 
 
 
21
  TEX_PIPE = None
22
  VAE = None
23
  LATENTS_MEAN, LATENTS_STD = None, None
@@ -26,14 +28,8 @@ TEX_PIPE_LOCK = threading.Lock()
26
  @dataclass
27
  class Config:
28
  video_base_name: str = "Wan-AI/Wan2.1-T2V-1.3B-Diffusers"
29
- seqtex_path: str = "https://huggingface.co/VAST-AI/SeqTex/resolve/main/.gitattributes/edm2_ema_12176_clean.pth"
30
- min_noise_level_index: int = 15 # which is same as paper [WorldMem](https://arxiv.org/pdf/2504.12369v1)
31
-
32
- use_causal_mask: bool = False
33
- addtional_qk_geometry: bool = False
34
- use_normal: bool = True
35
- use_position: bool = True
36
- randomly_init: bool = True # we load the weights from a corresponding ckpt
37
 
38
  num_views: int = 4
39
  uv_num_views: int = 1
@@ -47,45 +43,12 @@ class Config:
47
  eval_num_inference_steps: int = 30
48
  eval_seed: int = 42
49
 
50
- lora_rank: int = 128
51
- lora_alpha: int = 64
52
-
53
  cfg = Config()
54
 
55
- def load_model_weights(model_path: str, map_location="cpu"):
56
- """
57
- Load model weights from either a URL or local file path.
58
-
59
- Args:
60
- model_path (str): Path to model weights, can be URL or local file path
61
- map_location (str): Device to map the model to
62
-
63
- Returns:
64
- Dict: Loaded state dictionary
65
- """
66
- # Check if the path is a URL
67
- parsed_url = urlparse(model_path)
68
- if parsed_url.scheme in ('http', 'https'):
69
- # Load from URL using torch.hub
70
- try:
71
- state_dict = torch.hub.load_state_dict_from_url(
72
- model_path,
73
- map_location=map_location,
74
- progress=True
75
- )
76
- return state_dict
77
- except Exception as e:
78
- gr.Warning(f"Failed to load from URL: {e}")
79
- raise e
80
- else:
81
- # Load from local file path
82
- if not os.path.exists(model_path):
83
- raise FileNotFoundError(f"Local model file not found: {model_path}")
84
- return torch.load(model_path, map_location=map_location)
85
-
86
- def lazy_get_seqtex_pipe():
87
  """
88
  Lazy load the SeqTex pipeline for texture generation.
 
89
  """
90
  global TEX_PIPE, VAE, LATENTS_MEAN, LATENTS_STD
91
  if TEX_PIPE is not None:
@@ -95,42 +58,31 @@ def lazy_get_seqtex_pipe():
95
  if TEX_PIPE is not None:
96
  return TEX_PIPE
97
 
98
- # Pipeline
99
- TEX_PIPE = WanT2TexPipeline.from_pretrained(cfg.video_base_name)
100
-
101
- # Models
102
- transformer = WanT2TexTransformer3DModel(
103
- TEX_PIPE.transformer,
104
- use_causal_mask=cfg.use_causal_mask,
105
- addtional_qk_geo=cfg.addtional_qk_geometry,
106
- use_normal=cfg.use_normal,
107
- use_position=cfg.use_position,
108
- randomly_init=cfg.randomly_init,
109
  )
110
- transformer.add_adapter(
111
- LoraConfig(
112
- r=cfg.lora_rank,
113
- lora_alpha=cfg.lora_alpha,
114
- init_lora_weights=True,
115
- target_modules=["attn1.to_q", "attn1.to_k", "attn1.to_v", "attn1.to_out.0", "attn1.to_out.2",
116
- "ffn.net.0.proj", "ffn.net.2"],
117
- )
118
  )
119
- # load transformer
120
- state_dict = load_model_weights(cfg.seqtex_path, map_location="cpu")
121
- transformer.load_state_dict(state_dict, strict=True)
122
- TEX_PIPE.transformer = transformer
123
 
124
- VAE = AutoencoderKLWan.from_pretrained(cfg.video_base_name, subfolder="vae", torch_dtype=torch.float32).to("cuda").requires_grad_(False)
125
  TEX_PIPE.vae = VAE
126
 
127
- # Some useful parameters
128
  LATENTS_MEAN = torch.tensor(VAE.config.latents_mean).view(
129
  1, VAE.config.z_dim, 1, 1, 1
130
- ).to("cuda", dtype=torch.float32)
131
  LATENTS_STD = 1.0 / torch.tensor(VAE.config.latents_std).view(
132
  1, VAE.config.z_dim, 1, 1, 1
133
- ).to("cuda", dtype=torch.float32)
134
 
135
  scheduler: FlowMatchEulerDiscreteScheduler = (
136
  FlowMatchEulerDiscreteScheduler.from_config(
@@ -141,10 +93,8 @@ def lazy_get_seqtex_pipe():
141
  setattr(TEX_PIPE, "min_noise_level_index", min_noise_level_index)
142
  min_noise_level_timestep = scheduler.timesteps[min_noise_level_index]
143
  setattr(TEX_PIPE, "min_noise_level_timestep", min_noise_level_timestep)
144
- setattr(TEX_PIPE, "min_noise_level_sigma", min_noise_level_timestep / 1000.)
145
-
146
- TEX_PIPE = TEX_PIPE.to("cuda", dtype=torch.float32) # use float32 for inference
147
- return TEX_PIPE
148
 
149
  @torch.amp.autocast('cuda', dtype=torch.float32)
150
  def encode_images(
@@ -157,6 +107,11 @@ def encode_images(
157
  :param encode_as_first: Whether to encode all frames as the first frame.
158
  :return: Encoded latents with shape [B, C', F, H/8, W/8].
159
  """
 
 
 
 
 
160
  if images.min() < - 0.1:
161
  # images are in [-1, 1] range
162
  images = (images + 1.0) / 2.0 # Normalize to [0, 1] range
@@ -171,19 +126,6 @@ def encode_images(
171
 
172
  return latents
173
 
174
- # @torch.no_grad()
175
- # @torch.amp.autocast('cuda', dtype=torch.float32)
176
- # def decode_images(self, latents: Float[Tensor, "B C F H W"], decode_as_first: bool = False):
177
- # if decode_as_first:
178
- # F = latents.shape[2]
179
- # latents = latents.to(self.vae.dtype)
180
- # latents = latents / self.latents_std + self.latents_mean
181
- # latents = rearrange(latents, "B C F H W -> (B F) C 1 H W")
182
- # images = self.vae.decode(latents, return_dict=False)[0]
183
- # images = rearrange(images, "(B F) C Nv H W -> B C (F Nv) H W", F=F, Nv=1)
184
- # else:
185
- # raise NotImplementedError("Currently only support decode as first frame.")
186
- # return images
187
  @torch.amp.autocast('cuda', dtype=torch.float32)
188
  def decode_images(latents: Float[Tensor, "B C F H W"], decode_as_first: bool = False):
189
  """
@@ -192,6 +134,11 @@ def decode_images(latents: Float[Tensor, "B C F H W"], decode_as_first: bool = F
192
  :param decode_as_first: Whether to decode all frames as the first frame.
193
  :return: Decoded images with shape [B, C, F*Nv, H*8, W*8].
194
  """
 
 
 
 
 
195
  if decode_as_first:
196
  F = latents.shape[2]
197
  latents = latents.to(VAE.dtype)
@@ -207,6 +154,7 @@ def convert_img_to_tensor(image: Image.Image, device="cuda") -> Float[Tensor, "H
207
  """
208
  Convert a PIL Image to a tensor. If Image is RGBA, mask it with black background using a-channel mask.
209
  :param image: PIL Image to convert. [0, 255]
 
210
  :return: Tensor representation of the image. [0.0, 1.0], still [H, W, C]
211
  """
212
  # Convert to RGBA to ensure alpha channel exists
@@ -217,25 +165,33 @@ def convert_img_to_tensor(image: Image.Image, device="cuda") -> Float[Tensor, "H
217
  # Blend with black background using alpha mask
218
  rgb = rgb * alpha
219
  rgb = rgb.astype(np.float32) / 255.0 # Normalize to [0, 1]
220
- tensor = torch.from_numpy(rgb).to(device)
 
 
221
  return tensor
222
 
223
- @spaces.GPU(duration=120)
224
- @torch.cuda.amp.autocast(dtype=torch.float32)
225
- @torch.inference_mode
226
  @torch.no_grad
227
- def generate_texture(position_map, normal_map, position_images, normal_images, condition_image, text_prompt, selected_view, negative_prompt=None, device="cuda", progress=gr.Progress()):
 
228
  """
229
  Use SeqTex to generate texture for the mesh based on the image condition.
230
- :param position_images: List of position images from different views.
231
- :param normal_images: List of normal images from different views.
232
  :param condition_image: Image condition generated from the selected view.
233
  :param text_prompt: Text prompt for texture generation.
234
  :param selected_view: The view selected for generating the image condition.
235
- :return: Generated texture map, and multi-view frames in tensor.
236
  """
 
 
 
 
 
237
  progress(0, desc="Loading SeqTex pipeline...")
238
- tex_pipe = lazy_get_seqtex_pipe()
 
 
239
  progress(0.2, desc="SeqTex pipeline loaded successfully.")
240
  view_id_map = {
241
  "First View": 0,
@@ -306,4 +262,5 @@ def generate_texture(position_map, normal_map, position_images, normal_images, c
306
  uv_map_pred = torch.clamp(uv_map_pred, 0.0, 1.0)
307
 
308
  progress(1, desc="Texture generated successfully.")
309
- return uv_map_pred.float(), mv_out.float(), "Step 3: Texture generated successfully."
 
 
11
  from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
12
  from einops import rearrange
13
  from jaxtyping import Float
 
14
  from PIL import Image
15
  from torch import Tensor
16
 
17
  from wan.pipeline_wan_t2tex_extra import WanT2TexPipeline
18
  from wan.wan_t2tex_transformer_3d_extra import WanT2TexTransformer3DModel
19
 
20
+ from . import tensor_to_pil
21
+ from utils.file_utils import save_tensor_to_file, load_tensor_from_file
22
+
23
  TEX_PIPE = None
24
  VAE = None
25
  LATENTS_MEAN, LATENTS_STD = None, None
 
28
  @dataclass
29
  class Config:
30
  video_base_name: str = "Wan-AI/Wan2.1-T2V-1.3B-Diffusers"
31
+ seqtex_transformer_path: str = "VAST-AI/SeqTex-Transformer"
32
+ min_noise_level_index: int = 15 # refer to paper [WorldMem](https://arxiv.org/pdf/2504.12369v1)
 
 
 
 
 
 
33
 
34
  num_views: int = 4
35
  uv_num_views: int = 1
 
43
  eval_num_inference_steps: int = 30
44
  eval_seed: int = 42
45
 
 
 
 
46
  cfg = Config()
47
 
48
+ def get_seqtex_pipe():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  """
50
  Lazy load the SeqTex pipeline for texture generation.
51
+ Must be called within @spaces.GPU context.
52
  """
53
  global TEX_PIPE, VAE, LATENTS_MEAN, LATENTS_STD
54
  if TEX_PIPE is not None:
 
58
  if TEX_PIPE is not None:
59
  return TEX_PIPE
60
 
61
+ # Load transformer with auto-configured LoRA adapter first
62
+ transformer = WanT2TexTransformer3DModel.from_pretrained(
63
+ cfg.seqtex_transformer_path,
64
+ token=os.environ["SEQTEX_SPACE_TOKEN"]
 
 
 
 
 
 
 
65
  )
66
+
67
+ assert os.environ["SEQTEX_SPACE_TOKEN"] != "", "Please set the SEQTEX_SPACE_TOKEN environment variable with your Hugging Face token, which has access to VAST-AI/SeqTex-Transformer."
68
+ # Pipeline - pass the pre-loaded transformer to avoid re-loading
69
+ TEX_PIPE = WanT2TexPipeline.from_pretrained(
70
+ cfg.video_base_name,
71
+ transformer=transformer,
72
+ torch_dtype=torch.bfloat16
 
73
  )
74
+ del(transformer)
 
 
 
75
 
76
+ VAE = AutoencoderKLWan.from_pretrained(cfg.video_base_name, subfolder="vae", torch_dtype=torch.float32)
77
  TEX_PIPE.vae = VAE
78
 
79
+ # Some useful parameters - delay CUDA initialization until GPU context
80
  LATENTS_MEAN = torch.tensor(VAE.config.latents_mean).view(
81
  1, VAE.config.z_dim, 1, 1, 1
82
+ ).to(torch.float32)
83
  LATENTS_STD = 1.0 / torch.tensor(VAE.config.latents_std).view(
84
  1, VAE.config.z_dim, 1, 1, 1
85
+ ).to(torch.float32)
86
 
87
  scheduler: FlowMatchEulerDiscreteScheduler = (
88
  FlowMatchEulerDiscreteScheduler.from_config(
 
93
  setattr(TEX_PIPE, "min_noise_level_index", min_noise_level_index)
94
  min_noise_level_timestep = scheduler.timesteps[min_noise_level_index]
95
  setattr(TEX_PIPE, "min_noise_level_timestep", min_noise_level_timestep)
96
+ setattr(TEX_PIPE, "min_noise_level_sigma", min_noise_level_timestep / 1000.)
97
+ return TEX_PIPE.to("cuda")
 
 
98
 
99
  @torch.amp.autocast('cuda', dtype=torch.float32)
100
  def encode_images(
 
107
  :param encode_as_first: Whether to encode all frames as the first frame.
108
  :return: Encoded latents with shape [B, C', F, H/8, W/8].
109
  """
110
+ global VAE, LATENTS_MEAN, LATENTS_STD
111
+ VAE = VAE.to("cuda").requires_grad_(False)
112
+ LATENTS_MEAN = LATENTS_MEAN.to("cuda")
113
+ LATENTS_STD = LATENTS_STD.to("cuda")
114
+
115
  if images.min() < - 0.1:
116
  # images are in [-1, 1] range
117
  images = (images + 1.0) / 2.0 # Normalize to [0, 1] range
 
126
 
127
  return latents
128
 
 
 
 
 
 
 
 
 
 
 
 
 
 
129
  @torch.amp.autocast('cuda', dtype=torch.float32)
130
  def decode_images(latents: Float[Tensor, "B C F H W"], decode_as_first: bool = False):
131
  """
 
134
  :param decode_as_first: Whether to decode all frames as the first frame.
135
  :return: Decoded images with shape [B, C, F*Nv, H*8, W*8].
136
  """
137
+ global VAE, LATENTS_MEAN, LATENTS_STD
138
+ VAE = VAE.to("cuda").requires_grad_(False)
139
+ LATENTS_MEAN = LATENTS_MEAN.to("cuda")
140
+ LATENTS_STD = LATENTS_STD.to("cuda")
141
+
142
  if decode_as_first:
143
  F = latents.shape[2]
144
  latents = latents.to(VAE.dtype)
 
154
  """
155
  Convert a PIL Image to a tensor. If Image is RGBA, mask it with black background using a-channel mask.
156
  :param image: PIL Image to convert. [0, 255]
157
+ :param device: Target device for the tensor.
158
  :return: Tensor representation of the image. [0.0, 1.0], still [H, W, C]
159
  """
160
  # Convert to RGBA to ensure alpha channel exists
 
165
  # Blend with black background using alpha mask
166
  rgb = rgb * alpha
167
  rgb = rgb.astype(np.float32) / 255.0 # Normalize to [0, 1]
168
+ tensor = torch.from_numpy(rgb)
169
+ if device != "cpu":
170
+ tensor = tensor.to(device)
171
  return tensor
172
 
173
+ @spaces.GPU(duration=90)
 
 
174
  @torch.no_grad
175
+ @torch.inference_mode
176
+ def generate_texture(position_map_path, normal_map_path, position_images_path, normal_images_path, condition_image, text_prompt, selected_view, negative_prompt=None, device="cuda", progress=gr.Progress()):
177
  """
178
  Use SeqTex to generate texture for the mesh based on the image condition.
179
+ :param position_images_path: File path to position images tensor
180
+ :param normal_images_path: File path to normal images tensor
181
  :param condition_image: Image condition generated from the selected view.
182
  :param text_prompt: Text prompt for texture generation.
183
  :param selected_view: The view selected for generating the image condition.
184
+ :return: File paths of generated texture map and multi-view frames, and PIL images
185
  """
186
+ position_map = load_tensor_from_file(position_map_path, map_location=device)
187
+ normal_map = load_tensor_from_file(normal_map_path, map_location=device)
188
+ position_images = load_tensor_from_file(position_images_path, map_location=device)
189
+ normal_images = load_tensor_from_file(normal_images_path, map_location=device)
190
+
191
  progress(0, desc="Loading SeqTex pipeline...")
192
+ tex_pipe = get_seqtex_pipe()
193
+ # assert tex_pipe is in gpu
194
+ assert tex_pipe.device.type == "cuda", "SeqTex pipeline must be loaded in GPU context."
195
  progress(0.2, desc="SeqTex pipeline loaded successfully.")
196
  view_id_map = {
197
  "First View": 0,
 
262
  uv_map_pred = torch.clamp(uv_map_pred, 0.0, 1.0)
263
 
264
  progress(1, desc="Texture generated successfully.")
265
+ uv_map_pred_path = save_tensor_to_file(uv_map_pred, prefix="uv_map_pred")
266
+ return uv_map_pred_path, tensor_to_pil(uv_map_pred, normalize=False), tensor_to_pil(mv_out, normalize=False), "Step 3: Texture generated successfully."
wan/pipeline_wan_t2tex_extra.py CHANGED
@@ -1,19 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import copy
2
- from typing import Any, Callable, Dict, List, Optional, Union, Tuple
3
 
4
- from einops import rearrange
5
  import regex as re
6
  import torch
7
- from diffusers.pipelines.wan.pipeline_wan import WanPipeline
8
  from diffusers.pipelines.wan.pipeline_output import WanPipelineOutput
9
- from diffusers.callbacks import PipelineCallback, MultiPipelineCallbacks
10
  from diffusers.utils.torch_utils import randn_tensor
 
 
11
  from torch import Tensor
12
  from transformers import AutoTokenizer, UMT5EncoderModel
13
- from jaxtyping import Float
14
- import gradio as gr
15
 
16
  def get_sigmas(scheduler, timesteps, dtype=torch.float32, device="cuda"):
 
 
 
 
 
17
  sigmas = scheduler.sigmas.to(device=device, dtype=dtype)
18
  schedule_timesteps = scheduler.timesteps.to(device)
19
  timesteps = timesteps.to(device)
@@ -83,6 +103,8 @@ class WanT2TexPipeline(WanPipeline):
83
  negative_prompt_embeds: Optional[torch.Tensor] = None,
84
  output_type: Optional[str] = "np",
85
  return_dict: bool = True,
 
 
86
  attention_kwargs: Optional[Dict[str, Any]] = None,
87
  callback_on_step_end: Optional[
88
  Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
@@ -191,7 +213,8 @@ class WanT2TexPipeline(WanPipeline):
191
  self._current_timestep = None
192
  self._interrupt = False
193
 
194
- device = self._execution_device
 
195
 
196
  # 2. Define call parameters
197
  if prompt is not None and isinstance(prompt, str):
 
1
+ # Copyright 2025 The Wan Team and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
  import copy
16
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
17
 
18
+ import gradio as gr
19
  import regex as re
20
  import torch
21
+ from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
22
  from diffusers.pipelines.wan.pipeline_output import WanPipelineOutput
23
+ from diffusers.pipelines.wan.pipeline_wan import WanPipeline
24
  from diffusers.utils.torch_utils import randn_tensor
25
+ from einops import rearrange
26
+ from jaxtyping import Float
27
  from torch import Tensor
28
  from transformers import AutoTokenizer, UMT5EncoderModel
29
+
 
30
 
31
  def get_sigmas(scheduler, timesteps, dtype=torch.float32, device="cuda"):
32
+ # Ensure device is available before using it
33
+ if isinstance(device, str) and device.startswith("cuda"):
34
+ if not torch.cuda.is_available():
35
+ device = "cpu"
36
+
37
  sigmas = scheduler.sigmas.to(device=device, dtype=dtype)
38
  schedule_timesteps = scheduler.timesteps.to(device)
39
  timesteps = timesteps.to(device)
 
103
  negative_prompt_embeds: Optional[torch.Tensor] = None,
104
  output_type: Optional[str] = "np",
105
  return_dict: bool = True,
106
+ device: Optional[str] = "cuda",
107
+
108
  attention_kwargs: Optional[Dict[str, Any]] = None,
109
  callback_on_step_end: Optional[
110
  Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
 
213
  self._current_timestep = None
214
  self._interrupt = False
215
 
216
+ device = torch.device(device) if isinstance(device, str) else device
217
+ self.to(device)
218
 
219
  # 2. Define call parameters
220
  if prompt is not None and isinstance(prompt, str):
wan/wan_t2tex_transformer_3d_extra.py CHANGED
@@ -13,30 +13,23 @@
13
  # limitations under the License.
14
 
15
  import copy
16
- import math
17
- from typing import Any, Dict, Optional, Tuple, Union
18
  from functools import cache
 
19
 
20
- from einops import rearrange, repeat
21
  import torch
22
  import torch.nn as nn
23
  import torch.nn.functional as F
24
- from diffusers.configuration_utils import ConfigMixin, register_to_config
25
- from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin
26
  from diffusers.models import WanTransformer3DModel
27
  from diffusers.models.attention import FeedForward
28
  from diffusers.models.attention_processor import Attention
29
- from diffusers.models.cache_utils import CacheMixin
30
- from diffusers.models.embeddings import (PixArtAlphaTextProjection,
31
- TimestepEmbedding, Timesteps,
32
- get_1d_rotary_pos_embed)
33
- from diffusers.models.modeling_outputs import Transformer2DModelOutput
34
- from diffusers.models.modeling_utils import ModelMixin
35
  from diffusers.models.normalization import FP32LayerNorm
36
  from diffusers.models.transformers.transformer_wan import \
37
  WanTimeTextImageEmbedding
38
  from diffusers.utils import (USE_PEFT_BACKEND, logging, scale_lora_layers,
39
  unscale_lora_layers)
 
 
40
 
41
 
42
  class WanT2TexAttnProcessor2_0:
@@ -228,43 +221,6 @@ class WanRotaryPosEmbed(nn.Module):
228
  uv_freqs = torch.cat([uv_freqs_f, uv_freqs_h, uv_freqs_w], dim=-1).reshape(1, 1, uppf * upph * uppw, -1)
229
  return torch.cat([freqs, uv_freqs], dim=-2)
230
 
231
- # def pseudo_code(freqs, mv_tokens_shape, uv_tokens_shape, dimmension):
232
- # """
233
- # Input:
234
- # freqs: [S, D/2], S is the number of tokens, D is the dimension of tokens, 2 indicates Cos and Sin in original RoPE.
235
- # mv_tokens_shape: (mv_num_frames, mv_height, mv_width)
236
- # uv_tokens_shape: (uv_num_frames, uv_height, uv_width)
237
- # dimension: the dimension of tokens
238
- # Output:
239
- # """
240
- # mpf, mph, mpw = mv_tokens_shape # mv_num_frames, mv_height, mv_width
241
- # upf, uph, upw = uv_tokens_shape # uv_num_frames, uv_height, uv_width
242
-
243
- # # 1. To evenly split the freqs into 3 parts
244
- # freqs = freqs.split_with_sizes(
245
- # [
246
- # dimmension // 2 - 2 * (dimmension // 6),
247
- # dimmension // 6,
248
- # dimmension // 6,
249
- # ],
250
- # dim=1,
251
- # )
252
-
253
- # # 2. In time dimension, the freqs for UV are subsequent to the freqs for MV
254
- # freqs_f = freqs[0][:mpf].view(mpf, 1, 1, -1).expand(mpf, mph, mpw, -1)
255
- # uv_freqs_f = freqs[0][mpf:mpf+upf].view(upf, 1, 1, -1).expand(upf, uph, upw, -1)
256
-
257
- # # 3. The freqs in height and width dimension are the same for mv and uv
258
- # freqs_h = freqs[1][:mph].view(1, mph, 1, -1).expand(mpf, mph, mpw, -1)
259
- # uv_freqs_h = freqs[1][:uph].view(1, uph, 1, -1).expand(upf, uph, upw, -1)
260
- # freqs_w = freqs[2][:mpw].view(1, 1, mpw, -1).expand(mpf, mph, mpw, -1)
261
- # uv_freqs_w = freqs[2][:upw].view(1, 1, upw, -1).expand(upf, uph, upw, -1)
262
-
263
- # # 4. rearrange three 1D RoPEs into 3D RoPE in channel dimension
264
- # mv_rope = torch.cat([freqs_f, freqs_h, freqs_w], dim=-1).reshape(mpf * mph * mpw, -1)
265
- # uv_rope = torch.cat([uv_freqs_f, uv_freqs_h, uv_freqs_w], dim=-1).reshape(upf * uph * upw, -1)
266
- # return torch.cat([mv_rope, uv_rope], dim=-2)
267
-
268
  class WanT2TexTransformerBlock(nn.Module):
269
  def __init__(
270
  self,
@@ -400,71 +356,104 @@ class WanT2TexTransformer3DModel(WanTransformer3DModel):
400
  """
401
  3D Transformer model for T2Tex.
402
  """
403
- def __init__(self, original_model, use_causal_mask=False, addtional_qk_geo=False, randomly_init=False, **kwargs):
404
- super(WanT2TexTransformer3DModel, self).__init__(**original_model.config)
405
- if not randomly_init:
406
- self.load_state_dict(original_model.state_dict(), strict=True)
407
- self.addtional_qk_geo = addtional_qk_geo
408
- if addtional_qk_geo:
409
- raise ValueError("addtional_qk_geo did not work")
410
- warn("addtional_qk_geo is set to True, this will drastically increase the memory usage and slow down the training, without significant performance gain.")
411
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
412
  # 1. Patch & position embedding
413
- self.rope = WanRotaryPosEmbed(self.rope.attention_head_dim, self.rope.patch_size, self.rope.max_seq_len, addtional_qk_geo=addtional_qk_geo)
414
- self.use_normal, self.use_position = kwargs.get("use_normal", True), kwargs.get("use_position", True)
415
- if self.use_normal:
416
- self.norm_patch_embedding = copy.deepcopy(self.patch_embedding)
417
- # torch.nn.init.zeros_(self.norm_patch_embedding.weight.data)
418
- # torch.nn.init.zeros_(self.norm_patch_embedding.bias.data)
419
- if self.use_position:
420
- self.pos_patch_embedding = copy.deepcopy(self.patch_embedding)
421
- # torch.nn.init.zeros_(self.pos_patch_embedding.weight.data)
422
- # torch.nn.init.zeros_(self.pos_patch_embedding.bias.data)
423
 
424
  # 2. Condition embeddings
425
- inner_dim = original_model.config.num_attention_heads * original_model.config.attention_head_dim
426
  self.condition_embedder = WanTimeTaskTextImageEmbedding(
427
  original_model=self.condition_embedder,
428
  dim=inner_dim,
429
- time_freq_dim=original_model.config.freq_dim,
430
  time_proj_dim=inner_dim * 6,
431
- text_embed_dim=original_model.config.text_dim,
432
- image_embed_dim=original_model.config.image_dim,
433
- randomly_init=randomly_init,
434
  )
435
 
436
  # 3. Transformer blocks
437
- self.use_causal_mask = use_causal_mask
438
- self.num_attention_heads = original_model.config.num_attention_heads
439
 
440
  block = WanT2TexTransformerBlock(
441
  inner_dim,
442
- original_model.config.ffn_dim,
443
- original_model.config.num_attention_heads,
444
- original_model.config.qk_norm,
445
- original_model.config.cross_attn_norm,
446
- original_model.config.eps,
447
- original_model.config.added_kv_proj_dim,
448
  )
449
  self.blocks = None
450
  self.blocks = nn.ModuleList(
451
  [
452
  copy.deepcopy(block)
453
- for _ in range(original_model.config.num_layers)
454
  ]
455
  )
456
  self.scale_shift_table_uv = nn.Parameter(torch.randn(1, 2, inner_dim) / inner_dim**0.5)
457
- if not randomly_init:
458
- self.scale_shift_table_uv.data.copy_(self.scale_shift_table.data)
459
- self.blocks.load_state_dict(original_model.blocks.state_dict(), strict=False)
460
- for block in self.blocks:
461
- block.attnuv.load_state_dict(block.attn1.state_dict())
462
- block.scale_shift_table_uv.data.copy_(block.scale_shift_table.data)
463
- block.normuv2.load_state_dict(block.norm2.state_dict())
464
- block.ffnuv.load_state_dict(block.ffn.state_dict())
465
 
466
- # 4. Output norm & projection
467
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
468
 
469
  @cache
470
  def get_attention_bias(self, mv_length, uv_length):
@@ -521,23 +510,10 @@ class WanT2TexTransformer3DModel(WanTransformer3DModel):
521
  rotary_emb = self.rope(mv_hidden_states, uv_hidden_states)
522
 
523
  # Patchify
524
- if self.use_normal and self.use_position:
525
- mv_rgb_hidden_states, mv_pos_hidden_states, mv_norm_hidden_states = torch.chunk(mv_hidden_states, 3, dim=1)
526
- uv_rgb_hidden_states, uv_pos_hidden_states, uv_norm_hidden_states = torch.chunk(uv_hidden_states, 3, dim=1)
527
- mv_geometry_embedding = self.pos_patch_embedding(mv_pos_hidden_states) + self.norm_patch_embedding(mv_norm_hidden_states)
528
- uv_geometry_embedding = self.pos_patch_embedding(uv_pos_hidden_states) + self.norm_patch_embedding(uv_norm_hidden_states)
529
- elif self.use_normal:
530
- mv_rgb_hidden_states, mv_norm_hidden_states = torch.chunk(mv_hidden_states, 2, dim=1)
531
- uv_rgb_hidden_states, uv_norm_hidden_states = torch.chunk(uv_hidden_states, 2, dim=1)
532
- mv_geometry_embedding = self.norm_patch_embedding(mv_norm_hidden_states)
533
- uv_geometry_embedding = self.norm_patch_embedding(uv_norm_hidden_states)
534
- elif self.use_position:
535
- mv_rgb_hidden_states, mv_pos_hidden_states = torch.chunk(mv_hidden_states, 2, dim=1)
536
- uv_rgb_hidden_states, uv_pos_hidden_states = torch.chunk(uv_hidden_states, 2, dim=1)
537
- mv_geometry_embedding = self.pos_patch_embedding(mv_pos_hidden_states)
538
- uv_geometry_embedding = self.pos_patch_embedding(uv_pos_hidden_states)
539
- else:
540
- raise ValueError("use_normal and use_position are both False, please set at least one of them to True.")
541
 
542
  mv_hidden_states = self.patch_embedding(mv_rgb_hidden_states)
543
  uv_hidden_states = self.patch_embedding(uv_rgb_hidden_states)
@@ -564,12 +540,6 @@ class WanT2TexTransformer3DModel(WanTransformer3DModel):
564
  if encoder_hidden_states_image is not None:
565
  encoder_hidden_states = torch.concat([encoder_hidden_states_image, encoder_hidden_states], dim=1)
566
 
567
- # # Get attention bias
568
- # if self.use_causal_mask:
569
- # # This may be gainless, because the patch embedding is not causal, which will leak information to MV
570
- # attn_bias = self.get_attention_bias(post_patch_num_frames * post_patch_height * post_patch_width,
571
- # post_uv_num_frames * post_uv_height * post_uv_width)
572
- # else:
573
  attn_bias = None
574
 
575
  # 4. Transformer blocks
 
13
  # limitations under the License.
14
 
15
  import copy
 
 
16
  from functools import cache
17
+ from typing import Any, Dict, Optional, Tuple, Union
18
 
 
19
  import torch
20
  import torch.nn as nn
21
  import torch.nn.functional as F
 
 
22
  from diffusers.models import WanTransformer3DModel
23
  from diffusers.models.attention import FeedForward
24
  from diffusers.models.attention_processor import Attention
25
+ from diffusers.models.embeddings import get_1d_rotary_pos_embed
 
 
 
 
 
26
  from diffusers.models.normalization import FP32LayerNorm
27
  from diffusers.models.transformers.transformer_wan import \
28
  WanTimeTextImageEmbedding
29
  from diffusers.utils import (USE_PEFT_BACKEND, logging, scale_lora_layers,
30
  unscale_lora_layers)
31
+ from einops import rearrange, repeat
32
+ from peft import LoraConfig
33
 
34
 
35
  class WanT2TexAttnProcessor2_0:
 
221
  uv_freqs = torch.cat([uv_freqs_f, uv_freqs_h, uv_freqs_w], dim=-1).reshape(1, 1, uppf * upph * uppw, -1)
222
  return torch.cat([freqs, uv_freqs], dim=-2)
223
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
224
  class WanT2TexTransformerBlock(nn.Module):
225
  def __init__(
226
  self,
 
356
  """
357
  3D Transformer model for T2Tex.
358
  """
359
+ def __init__(self,
360
+ patch_size: Tuple[int] = (1, 2, 2),
361
+ num_attention_heads: int = 40,
362
+ attention_head_dim: int = 128,
363
+ in_channels: int = 16,
364
+ out_channels: int = 16,
365
+ text_dim: int = 4096,
366
+ freq_dim: int = 256,
367
+ ffn_dim: int = 13824,
368
+ num_layers: int = 40,
369
+ cross_attn_norm: bool = True,
370
+ qk_norm: Optional[str] = "rms_norm_across_heads",
371
+ eps: float = 1e-6,
372
+ image_dim: Optional[int] = None,
373
+ added_kv_proj_dim: Optional[int] = None,
374
+ rope_max_seq_len: int = 1024,
375
+ **kwargs
376
+ ):
377
+ super(WanT2TexTransformer3DModel, self).__init__(
378
+ patch_size=patch_size,
379
+ num_attention_heads=num_attention_heads,
380
+ attention_head_dim=attention_head_dim,
381
+ in_channels=in_channels,
382
+ out_channels=out_channels,
383
+ text_dim=text_dim,
384
+ freq_dim=freq_dim,
385
+ ffn_dim=ffn_dim,
386
+ num_layers=num_layers,
387
+ cross_attn_norm=cross_attn_norm,
388
+ qk_norm=qk_norm,
389
+ eps=eps,
390
+ image_dim=image_dim,
391
+ added_kv_proj_dim=added_kv_proj_dim,
392
+ rope_max_seq_len=rope_max_seq_len
393
+ )
394
  # 1. Patch & position embedding
395
+ self.rope = WanRotaryPosEmbed(self.rope.attention_head_dim, self.rope.patch_size, self.rope.max_seq_len)
396
+ self.norm_patch_embedding = copy.deepcopy(self.patch_embedding)
397
+ self.pos_patch_embedding = copy.deepcopy(self.patch_embedding)
 
 
 
 
 
 
 
398
 
399
  # 2. Condition embeddings
400
+ inner_dim = num_attention_heads * attention_head_dim
401
  self.condition_embedder = WanTimeTaskTextImageEmbedding(
402
  original_model=self.condition_embedder,
403
  dim=inner_dim,
404
+ time_freq_dim=freq_dim,
405
  time_proj_dim=inner_dim * 6,
406
+ text_embed_dim=text_dim,
407
+ image_embed_dim=image_dim,
 
408
  )
409
 
410
  # 3. Transformer blocks
411
+ self.num_attention_heads = num_attention_heads
 
412
 
413
  block = WanT2TexTransformerBlock(
414
  inner_dim,
415
+ ffn_dim,
416
+ num_attention_heads,
417
+ qk_norm,
418
+ cross_attn_norm,
419
+ eps,
420
+ added_kv_proj_dim,
421
  )
422
  self.blocks = None
423
  self.blocks = nn.ModuleList(
424
  [
425
  copy.deepcopy(block)
426
+ for _ in range(num_layers)
427
  ]
428
  )
429
  self.scale_shift_table_uv = nn.Parameter(torch.randn(1, 2, inner_dim) / inner_dim**0.5)
430
+
431
+ # 4. Auto-configure LoRA adapter for SeqTex
432
+ self.configure_lora_adapter()
 
 
 
 
 
433
 
434
+ def configure_lora_adapter(self, lora_rank: int = 128, lora_alpha: int = 64):
435
+ """
436
+ Configure LoRA adapter with custom settings or auto-configuration.
437
+
438
+ Args:
439
+ lora_rank (int, optional): LoRA rank parameter, default (128)
440
+ lora_alpha (int, optional): LoRA alpha parameter, default (64)
441
+ """
442
+ # Get parameters from args, environment variables, or defaults
443
+ target_modules = [
444
+ "attn1.to_q", "attn1.to_k", "attn1.to_v",
445
+ "attn1.to_out.0", "attn1.to_out.2",
446
+ "ffn.net.0.proj", "ffn.net.2"
447
+ ]
448
+
449
+ lora_config = LoraConfig(
450
+ r=lora_rank,
451
+ lora_alpha=lora_alpha,
452
+ init_lora_weights=True,
453
+ target_modules=target_modules,
454
+ )
455
+
456
+ self.add_adapter(lora_config)
457
 
458
  @cache
459
  def get_attention_bias(self, mv_length, uv_length):
 
510
  rotary_emb = self.rope(mv_hidden_states, uv_hidden_states)
511
 
512
  # Patchify
513
+ mv_rgb_hidden_states, mv_pos_hidden_states, mv_norm_hidden_states = torch.chunk(mv_hidden_states, 3, dim=1)
514
+ uv_rgb_hidden_states, uv_pos_hidden_states, uv_norm_hidden_states = torch.chunk(uv_hidden_states, 3, dim=1)
515
+ mv_geometry_embedding = self.pos_patch_embedding(mv_pos_hidden_states) + self.norm_patch_embedding(mv_norm_hidden_states)
516
+ uv_geometry_embedding = self.pos_patch_embedding(uv_pos_hidden_states) + self.norm_patch_embedding(uv_norm_hidden_states)
 
 
 
 
 
 
 
 
 
 
 
 
 
517
 
518
  mv_hidden_states = self.patch_embedding(mv_rgb_hidden_states)
519
  uv_hidden_states = self.patch_embedding(uv_rgb_hidden_states)
 
540
  if encoder_hidden_states_image is not None:
541
  encoder_hidden_states = torch.concat([encoder_hidden_states_image, encoder_hidden_states], dim=1)
542
 
 
 
 
 
 
 
543
  attn_bias = None
544
 
545
  # 4. Transformer blocks