File size: 7,510 Bytes
f876753
9c76928
 
e0305d4
 
 
 
9b177de
9c76928
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f876753
e9a4e66
9b177de
f876753
 
e9a4e66
f876753
 
e9a4e66
f876753
 
9c76928
f876753
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bab9f24
9b177de
f876753
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4a69847
f876753
 
 
 
 
 
4a69847
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f876753
 
4a69847
 
f876753
4a69847
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f876753
 
 
4a69847
 
 
 
 
 
 
f876753
 
4a69847
 
 
 
 
 
 
 
 
 
 
 
f876753
 
 
4a69847
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
import os
import subprocess
import sys
try:
    import spaces
except:
    pass
os.environ["PYDANTIC_STRICT_TYPE_CHECKING"] = "0"

# Check if setup has been run
setup_marker = ".setup_complete"
if not os.path.exists(setup_marker):
    print("First run detected, installing dependencies...")
    try:
        subprocess.check_call(["bash", "setup.sh"])
        # Create marker file to indicate setup is complete
        with open(setup_marker, "w") as f:
            f.write("Setup completed")
        print("Setup completed successfully!")
    except subprocess.CalledProcessError as e:
        print(f"Setup failed with error: {e}")
        sys.exit(1)

import torch
import gradio as gr
from typing import Tuple, List, Dict, Any, Optional
from collections import deque
from diffusers import StableDiffusionPipeline

from triplaneturbo_executable import TriplaneTurboTextTo3DPipeline
from triplaneturbo_executable.utils.mesh_exporter import export_obj

# Initialize global variables
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
ADAPTER_PATH = "pretrained/triplane_turbo_sd_v1.pth" #"/home/user/app/pretrained/triplane_turbo_sd_v1.pth"
PIPELINE = None  # Will hold our pipeline instance
OBJ_FILE_QUEUE = deque(maxlen=100)  # Queue to store OBJ file paths

def download_model():
    """Download the pretrained model if not exists"""
    if not os.path.exists(ADAPTER_PATH):
        print("Downloading pretrained models from huggingface")
        os.system(
            f"huggingface-cli download --resume-download ZhiyuanthePony/TriplaneTurbo \
            --include \"triplane_turbo_sd_v1.pth\" \
            --local-dir ./pretrained \
            --local-dir-use-symlinks False"
        )

def initialize_pipeline():
    """Initialize the pipeline once and keep it in memory"""
    global PIPELINE
    if PIPELINE is None:
        print("Initializing pipeline...")
        PIPELINE = TriplaneTurboTextTo3DPipeline.from_pretrained(ADAPTER_PATH)
        PIPELINE.to(DEVICE)
        print("Pipeline initialized!")
    return PIPELINE

@spaces.GPU
def generate_3d_mesh(prompt: str) -> Tuple[Optional[str], Optional[str]]:
    """Generate 3D mesh from text prompt"""
    global PIPELINE, OBJ_FILE_QUEUE
    
    # Use the global pipeline instance
    pipeline = initialize_pipeline()
    
    # Use fixed seed value
    seed = 42
    
    # Generate mesh
    output = pipeline(
        prompt=prompt,
        num_results_per_prompt=1,
        generator=torch.Generator(device=DEVICE).manual_seed(seed),
    )
    
    # Save mesh
    output_dir = "outputs"
    os.makedirs(output_dir, exist_ok=True)
    
    mesh_path = None
    for i, mesh in enumerate(output["mesh"]):
        vertices = mesh.v_pos
        
        # 1. First rotate -90 degrees around X-axis to make the model face up
        vertices = torch.stack([
            vertices[:, 0],           # x remains unchanged
            vertices[:, 2],           # y = z
            -vertices[:, 1]           # z = -y
        ], dim=1)
        
        # 2. Then rotate 90 degrees around Y-axis to make the model face the observer
        vertices = torch.stack([
            -vertices[:, 2],          # x = -z
            vertices[:, 1],           # y remains unchanged
            vertices[:, 0]            # z = x
        ], dim=1)
        
        mesh.v_pos = vertices
        
        # If mesh has normals, they need to be rotated in the same way
        if mesh.v_nrm is not None:
            normals = mesh.v_nrm
            # 1. Rotate -90 degrees around X-axis
            normals = torch.stack([
                normals[:, 0],
                normals[:, 2],
                -normals[:, 1]
            ], dim=1)
            # 2. Rotate 90 degrees around Y-axis
            normals = torch.stack([
                -normals[:, 2],
                normals[:, 1],
                normals[:, 0]
            ], dim=1)
            mesh._v_nrm = normals
        
        name = f"{prompt.replace(' ', '_')}"
        save_paths = export_obj(mesh, f"{output_dir}/{name}.obj")
        mesh_path = save_paths[0]
        
        # Add new file path to queue
        OBJ_FILE_QUEUE.append(mesh_path)
        
        # If queue is at max length, remove oldest file
        if len(OBJ_FILE_QUEUE) == OBJ_FILE_QUEUE.maxlen:
            old_file = OBJ_FILE_QUEUE[0]  # Get oldest file (will be automatically removed from queue)
            if os.path.exists(old_file):
                try:
                    os.remove(old_file)
                except OSError as e:
                    print(f"Error deleting file {old_file}: {e}")
        
    return mesh_path, mesh_path  # Return the path twice - once for 3D preview, once for download

with gr.Blocks(css=".output-image, .input-image, .image-preview {height: 512px !important}") as demo:
    # Download model if needed
    download_model()
    
    # Initialize pipeline at startup
    initialize_pipeline()
    
    gr.Markdown(
        """
        # 🌟 Text to 3D Mesh Generation with TriplaneTurbo
        
        Demo of the paper "Progressive Rendering Distillation: Adapting Stable Diffusion for Instant Text-to-Mesh Generation beyond 3D Training Data" [CVPR 2025]
        
        [GitHub Repository](https://github.com/theEricMa/TriplaneTurbo)
        
        ## Instructions
        1. Enter a text prompt describing what 3D object you want to generate
        2. Click "Generate" and wait for the model to create your 3D mesh
        3. View the result in the 3D viewer or download the OBJ file
        """
    )
    
    with gr.Row():
        with gr.Column(scale=1):
            prompt = gr.Textbox(
                label="Text Prompt",
                placeholder="Enter your text description...",
                value="Armor dress style of outsiderzone fantasy helmet",
                lines=2
            )
            
            generate_btn = gr.Button("Generate", variant="primary")
            
            examples = gr.Examples(
                examples=[
                    ["Armor dress style of outsiderzone fantasy helmet"],
                    ["Gandalf the grey riding a camel in a rock concert, victorian newspaper article, hyperrealistic"],
                    ["A DSLR photo of a bald eagle"],
                    ["A goblin riding a lawnmower in a hospital, victorian newspaper article, 4k hd"],
                    ["An imperial stormtrooper, highly detailed"],
                ],
                inputs=[prompt],
                label="Example Prompts"
            )
            
        with gr.Column(scale=1):
            output_model = gr.Model3D(
                label="Generated 3D Mesh",
                camera_position=(90, 90, 3),
                clear_color=(0.5, 0.5, 0.5, 1),
            )
            output_file = gr.File(label="Download OBJ file")
    
    generate_btn.click(
        fn=generate_3d_mesh,
        inputs=[prompt],
        outputs=[output_model, output_file]
    )
    
    gr.Markdown(
        """
        ## About
        
        This demo uses TriplaneTurbo, which adapts Stable Diffusion for instant text-to-mesh generation.
        The model can generate high-quality 3D meshes from text descriptions without requiring 3D training data.
        
        ### Limitations
        - Generation is deterministic with a fixed seed
        - Complex prompts may produce unpredictable results
        - Generated meshes may require clean-up for professional use
        """
    )

if __name__ == "__main__":
    demo.launch()