ZhiyuanthePony commited on
Commit
7655b4c
·
1 Parent(s): 0d33a1d
Files changed (1) hide show
  1. example.py +85 -60
example.py CHANGED
@@ -1,7 +1,10 @@
1
  try:
2
  import spaces
3
- except:
4
- pass
 
 
 
5
 
6
  import os
7
  import torch
@@ -39,66 +42,88 @@ if not os.path.exists(adapter_name_or_path):
39
  triplane_turbo_pipeline = TriplaneTurboTextTo3DPipeline.from_pretrained(adapter_name_or_path)
40
  triplane_turbo_pipeline.to(device)
41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  # Run the pipeline
43
- output = triplane_turbo_pipeline(
44
  prompt=prompt,
45
  num_results_per_prompt=num_results_per_prompt,
46
- generator=torch.Generator(device=device).manual_seed(seed),
47
- device=device,
48
  )
49
 
50
- # Initialize a deque with maximum length of 100 to store obj file paths
51
- obj_file_queue = deque(maxlen=max_obj_files)
52
-
53
- # Save mesh
54
- os.makedirs(output_dir, exist_ok=True)
55
- for i, mesh in enumerate(output["mesh"]):
56
- vertices = mesh.v_pos
57
-
58
- # 1. First rotate -90 degrees around X-axis to make the model face up
59
- vertices = torch.stack([
60
- vertices[:, 0], # x remains unchanged
61
- vertices[:, 2], # y = z
62
- -vertices[:, 1] # z = -y
63
- ], dim=1)
64
-
65
- # 2. Then rotate 90 degrees around Y-axis to make the model face the observer
66
- vertices = torch.stack([
67
- -vertices[:, 2], # x = -z
68
- vertices[:, 1], # y remains unchanged
69
- vertices[:, 0] # z = x
70
- ], dim=1)
71
-
72
- mesh.v_pos = vertices
73
-
74
- # If mesh has normals, they need to be rotated in the same way
75
- if mesh.v_nrm is not None:
76
- normals = mesh.v_nrm
77
- # 1. Rotate -90 degrees around X-axis
78
- normals = torch.stack([
79
- normals[:, 0],
80
- normals[:, 2],
81
- -normals[:, 1]
82
- ], dim=1)
83
- # 2. Rotate 90 degrees around Y-axis
84
- normals = torch.stack([
85
- -normals[:, 2],
86
- normals[:, 1],
87
- normals[:, 0]
88
- ], dim=1)
89
- mesh._v_nrm = normals
90
-
91
- # Save obj file and add its path to the queue
92
- name = f"{prompt.replace(' ', '_')}_{seed}_{i}"
93
- save_paths = export_obj(mesh, f"{output_dir}/{name}.obj")
94
- obj_file_queue.append(save_paths[0])
95
-
96
- # If an old file needs to be removed (queue is at max length)
97
- # and the file exists, delete it
98
- if len(obj_file_queue) == max_obj_files and os.path.exists(obj_file_queue[0]):
99
- old_file = obj_file_queue[0]
100
- try:
101
- os.remove(old_file)
102
- except OSError as e:
103
- print(f"Error deleting file {old_file}: {e}")
104
-
 
1
  try:
2
  import spaces
3
+ except ImportError:
4
+ # Define a dummy decorator if spaces is not available
5
+ def GPU(func):
6
+ return func
7
+ spaces = type('spaces', (), {'GPU': GPU})
8
 
9
  import os
10
  import torch
 
42
  triplane_turbo_pipeline = TriplaneTurboTextTo3DPipeline.from_pretrained(adapter_name_or_path)
43
  triplane_turbo_pipeline.to(device)
44
 
45
+ @spaces.GPU
46
+ def generate_3d_model(prompt, num_results_per_prompt=1, seed=42, device="cuda"):
47
+ """
48
+ Generate 3D models using TriplaneTurbo pipeline.
49
+
50
+ Args:
51
+ prompt (str): Text prompt for the 3D model
52
+ num_results_per_prompt (int): Number of results to generate
53
+ seed (int): Random seed for generation
54
+ device (str): Device to use for computation
55
+
56
+ Returns:
57
+ dict: Output from the pipeline
58
+ """
59
+ output = triplane_turbo_pipeline(
60
+ prompt=prompt,
61
+ num_results_per_prompt=num_results_per_prompt,
62
+ generator=torch.Generator(device=device).manual_seed(seed),
63
+ device=device,
64
+ )
65
+ # Initialize a deque with maximum length of 100 to store obj file paths
66
+ obj_file_queue = deque(maxlen=max_obj_files)
67
+
68
+ # Save mesh
69
+ os.makedirs(output_dir, exist_ok=True)
70
+ for i, mesh in enumerate(output["mesh"]):
71
+ vertices = mesh.v_pos
72
+
73
+ # 1. First rotate -90 degrees around X-axis to make the model face up
74
+ vertices = torch.stack([
75
+ vertices[:, 0], # x remains unchanged
76
+ vertices[:, 2], # y = z
77
+ -vertices[:, 1] # z = -y
78
+ ], dim=1)
79
+
80
+ # 2. Then rotate 90 degrees around Y-axis to make the model face the observer
81
+ vertices = torch.stack([
82
+ -vertices[:, 2], # x = -z
83
+ vertices[:, 1], # y remains unchanged
84
+ vertices[:, 0] # z = x
85
+ ], dim=1)
86
+
87
+ mesh.v_pos = vertices
88
+
89
+ # If mesh has normals, they need to be rotated in the same way
90
+ if mesh.v_nrm is not None:
91
+ normals = mesh.v_nrm
92
+ # 1. Rotate -90 degrees around X-axis
93
+ normals = torch.stack([
94
+ normals[:, 0],
95
+ normals[:, 2],
96
+ -normals[:, 1]
97
+ ], dim=1)
98
+ # 2. Rotate 90 degrees around Y-axis
99
+ normals = torch.stack([
100
+ -normals[:, 2],
101
+ normals[:, 1],
102
+ normals[:, 0]
103
+ ], dim=1)
104
+ mesh._v_nrm = normals
105
+
106
+ # Save obj file and add its path to the queue
107
+ name = f"{prompt.replace(' ', '_')}_{seed}_{i}"
108
+ save_paths = export_obj(mesh, f"{output_dir}/{name}.obj")
109
+ obj_file_queue.append(save_paths[0])
110
+
111
+ # If an old file needs to be removed (queue is at max length)
112
+ # and the file exists, delete it
113
+ if len(obj_file_queue) == max_obj_files and os.path.exists(obj_file_queue[0]):
114
+ old_file = obj_file_queue[0]
115
+ try:
116
+ os.remove(old_file)
117
+ except OSError as e:
118
+ print(f"Error deleting file {old_file}: {e}")
119
+
120
+
121
+
122
  # Run the pipeline
123
+ output = generate_3d_model(
124
  prompt=prompt,
125
  num_results_per_prompt=num_results_per_prompt,
126
+ seed=seed,
127
+ device=device
128
  )
129