|
import os |
|
import gradio as gr |
|
import plotly.graph_objects as go |
|
import sys |
|
import torch |
|
from huggingface_hub import hf_hub_download |
|
import numpy as np |
|
import random |
|
import tempfile |
|
import traceback |
|
|
|
|
|
|
|
os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0' |
|
|
|
if not os.path.exists("diffusion-point-cloud"): |
|
print("Cloning diffusion-point-cloud repository...") |
|
os.system("git clone https://github.com/luost26/diffusion-point-cloud") |
|
else: |
|
print("diffusion-point-cloud repository already exists.") |
|
sys.path.append("diffusion-point-cloud") |
|
|
|
|
|
try: |
|
from models.vae_gaussian import GaussianVAE |
|
from models.vae_flow import FlowVAE |
|
except ImportError as e: |
|
print(f"CRITICAL Error importing models: {e}") |
|
print("Please ensure 'diffusion-point-cloud' directory is in sys.path and contains the model definitions.") |
|
sys.exit(1) |
|
|
|
|
|
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
print(f"Using device: {DEVICE.upper()}") |
|
|
|
MODEL_CONFIGS = { |
|
"Airplane": { |
|
"path_function": lambda: hf_hub_download("SerdarHelli/diffusion-point-cloud", filename="GEN_airplane.pt", revision="main"), |
|
"expected_model_type": "gaussian", |
|
"default_args": { |
|
'model': "gaussian", |
|
'latent_dim': 128, |
|
'hyper': None, |
|
'residual': True, |
|
'num_points': 2048, |
|
|
|
} |
|
}, |
|
"Chair": { |
|
"path_function": lambda: "./GEN_chair.pt", |
|
"expected_model_type": "gaussian", |
|
"default_args": { |
|
'model': "gaussian", |
|
'latent_dim': 128, |
|
'hyper': None, |
|
'residual': True, |
|
'num_points': 2048, |
|
} |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
LOADED_CHECKPOINTS = {} |
|
for model_name, config in MODEL_CONFIGS.items(): |
|
model_path = "" |
|
try: |
|
model_path = config["path_function"]() |
|
if model_name == "Chair" and not os.path.exists(model_path): |
|
print(f"WARNING: Checkpoint for {model_name} not found at '{model_path}'. This model will not be available.") |
|
LOADED_CHECKPOINTS[model_name] = None |
|
continue |
|
print(f"Loading checkpoint for {model_name} from '{model_path}'...") |
|
LOADED_CHECKPOINTS[model_name] = torch.load(model_path, map_location=torch.device(DEVICE), weights_only=False) |
|
print(f"Successfully loaded {model_name}.") |
|
except Exception as e: |
|
print(f"ERROR loading checkpoint for {model_name} from '{model_path}': {e}") |
|
LOADED_CHECKPOINTS[model_name] = None |
|
|
|
|
|
def seed_all(seed): |
|
torch.manual_seed(seed) |
|
np.random.seed(seed) |
|
random.seed(seed) |
|
|
|
def normalize_point_clouds(pcs, mode): |
|
if mode is None: |
|
return pcs |
|
for i in range(pcs.size(0)): |
|
pc = pcs[i] |
|
if mode == 'shape_unit': |
|
shift = pc.mean(dim=0).reshape(1, 3) |
|
scale = pc.flatten().std().reshape(1, 1) |
|
elif mode == 'shape_bbox': |
|
pc_max, _ = pc.max(dim=0, keepdim=True) |
|
pc_min, _ = pc.min(dim=0, keepdim=True) |
|
shift = ((pc_min + pc_max) / 2).view(1, 3) |
|
scale = (pc_max - pc_min).max().reshape(1, 1) / 2 |
|
else: |
|
shift = torch.zeros_like(pc.mean(dim=0).reshape(1, 3)) |
|
scale = torch.ones_like(pc.flatten().std().reshape(1, 1)) |
|
|
|
if scale.abs().item() < 1e-8: |
|
scale = torch.tensor(1.0, device=pc.device, dtype=pc.dtype).reshape(1, 1) |
|
|
|
pcs[i] = (pc - shift) / scale |
|
return pcs |
|
|
|
|
|
def predict(seed_val, selected_model_name, flexibility_val): |
|
seed_all(int(seed_val)) |
|
|
|
ckpt = LOADED_CHECKPOINTS.get(selected_model_name) |
|
if ckpt is None: |
|
raise ValueError(f"Checkpoint for model '{selected_model_name}' not loaded or unavailable.") |
|
|
|
model_specific_defaults = MODEL_CONFIGS[selected_model_name].get("default_args", {}) |
|
|
|
|
|
actual_args = None |
|
|
|
if 'args' in ckpt and hasattr(ckpt['args'], 'model'): |
|
actual_args = ckpt['args'] |
|
print(f"Using 'args' found in checkpoint for {selected_model_name}.") |
|
|
|
for key, default_value in model_specific_defaults.items(): |
|
if not hasattr(actual_args, key): |
|
print(f"Checkpoint 'args' missing '{key}'. Setting default: {default_value}") |
|
setattr(actual_args, key, default_value) |
|
else: |
|
print(f"Warning: 'args' not found or 'args.model' missing in checkpoint for {selected_model_name}. Constructing mock_args from defaults.") |
|
|
|
actual_args_dict = {} |
|
for key, default_value in model_specific_defaults.items(): |
|
|
|
actual_args_dict[key] = ckpt.get(key, default_value) |
|
actual_args = type('Args', (), actual_args_dict)() |
|
|
|
|
|
|
|
if not hasattr(actual_args, 'model'): |
|
raise ValueError("Resolved 'actual_args' is missing the 'model' attribute.") |
|
if not hasattr(actual_args, 'latent_dim'): setattr(actual_args, 'latent_dim', 128) |
|
|
|
if actual_args.model == 'gaussian': |
|
if not hasattr(actual_args, 'residual'): |
|
print("Setting default 'residual=True' for GaussianVAE.") |
|
setattr(actual_args, 'residual', True) |
|
elif actual_args.model == 'flow': |
|
if not hasattr(actual_args, 'flow_depth'): setattr(actual_args, 'flow_depth', 10) |
|
if not hasattr(actual_args, 'flow_hidden_dim'): setattr(actual_args, 'flow_hidden_dim', 256) |
|
|
|
|
|
if not hasattr(actual_args, 'num_points'): |
|
print("Setting default 'num_points=2048' for sampling.") |
|
setattr(actual_args, 'num_points', 2048) |
|
|
|
|
|
setattr(actual_args, 'flexibility', flexibility_val) |
|
print(f"Using flexibility: {actual_args.flexibility} for sampling.") |
|
|
|
|
|
|
|
model = None |
|
if actual_args.model == 'gaussian': |
|
model = GaussianVAE(actual_args).to(DEVICE) |
|
elif actual_args.model == 'flow': |
|
model = FlowVAE(actual_args).to(DEVICE) |
|
else: |
|
raise ValueError(f"Unknown model type in args: '{actual_args.model}'. Expected 'gaussian' or 'flow'.") |
|
|
|
model.load_state_dict(ckpt['state_dict']) |
|
model.eval() |
|
|
|
|
|
gen_pcs = [] |
|
with torch.no_grad(): |
|
z = torch.randn([1, actual_args.latent_dim], device=DEVICE) |
|
x = model.sample(z, int(actual_args.num_points), flexibility=actual_args.flexibility) |
|
gen_pcs.append(x.detach().cpu()) |
|
|
|
gen_pcs_tensor = torch.cat(gen_pcs, dim=0)[:1] |
|
gen_pcs_normalized = normalize_point_clouds(gen_pcs_tensor.clone(), mode="shape_bbox") |
|
|
|
return gen_pcs_normalized[0] |
|
|
|
|
|
|
|
def generate_gradio(seed, model_choice, flexibility, point_color_hex, marker_size): |
|
error_message = "" |
|
figure_plot = None |
|
download_file_path = None |
|
|
|
try: |
|
if seed is None: |
|
seed = random.randint(0, 2**16 - 1) |
|
seed = int(seed) |
|
|
|
if not model_choice: |
|
error_message = "Please choose a model type." |
|
|
|
return go.Figure(), None, error_message |
|
|
|
print(f"Generating {model_choice} with Seed: {seed}, Flex: {flexibility}, Color: {point_color_hex}, Size: {marker_size}") |
|
|
|
points = predict(seed, model_choice, flexibility) |
|
|
|
|
|
figure_plot = go.Figure( |
|
data=[ |
|
go.Scatter3d( |
|
x=points[:, 0], y=points[:, 1], z=points[:, 2], |
|
mode='markers', |
|
marker=dict(size=marker_size, color=point_color_hex) |
|
) |
|
], |
|
layout=dict( |
|
title=f"Generated {model_choice} (Seed: {seed}, Flex: {flexibility:.2f})", |
|
scene=dict( |
|
xaxis=dict(visible=True, title='X', backgroundcolor="rgb(230,230,230)", gridcolor="white", zerolinecolor="white"), |
|
yaxis=dict(visible=True, title='Y', backgroundcolor="rgb(230,230,230)", gridcolor="white", zerolinecolor="white"), |
|
zaxis=dict(visible=True, title='Z', backgroundcolor="rgb(230,230,230)", gridcolor="white", zerolinecolor="white"), |
|
aspectmode='data' |
|
), |
|
margin=dict(l=0, r=0, b=0, t=40) |
|
) |
|
) |
|
|
|
|
|
with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".xyz", encoding='utf-8') as tmp_file: |
|
for point in points: |
|
tmp_file.write(f"{point[0]:.6f} {point[1]:.6f} {point[2]:.6f}\n") |
|
download_file_path = tmp_file.name |
|
print(f"Point cloud saved for download at: {download_file_path}") |
|
|
|
except ValueError as ve: |
|
error_message = f"Configuration Error: {str(ve)}" |
|
print(error_message) |
|
except AttributeError as ae: |
|
error_message = f"Model Configuration Issue: {str(ae)}. The checkpoint might be missing expected parameters or they are incompatible." |
|
print(error_message) |
|
except Exception as e: |
|
error_message = f"An unexpected error occurred: {str(e)}" |
|
print(f"{error_message}\nFull Traceback:\n{traceback.format_exc()}") |
|
|
|
|
|
if figure_plot is None: figure_plot = go.Figure() |
|
return figure_plot, download_file_path, error_message |
|
|
|
|
|
available_models = [name for name, ckpt in LOADED_CHECKPOINTS.items() if ckpt is not None] |
|
if not available_models: |
|
print("CRITICAL: No models were loaded successfully. The application may not function as expected.") |
|
|
|
markdown_description = f''' |
|
# Diffusion Probabilistic Models for 3D Point Cloud Generation |
|
|
|
[CVPR 2021 Paper: "Diffusion Probabilistic Models for 3D Point Cloud Generation"](https://arxiv.org/abs/2103.01458) | [Official GitHub](https://github.com/luost26/diffusion-point-cloud) |
|
|
|
This demo allows you to generate 3D point clouds using pre-trained models. |
|
- Adjust the **Seed** for different random initializations. |
|
- Choose a **Model Type** (e.g., Airplane, Chair). |
|
- Control **Sampling Flexibility**: Lower values tend towards the mean shape, higher values increase diversity. |
|
- Customize **Point Color** and **Marker Size**. |
|
|
|
Running on: **{DEVICE.upper()}** |
|
''' |
|
if "Chair" in MODEL_CONFIGS and "Chair" not in available_models: |
|
markdown_description += "\n\n**Warning:** The 'Chair' model checkpoint (`GEN_chair.pt`) was not found or failed to load. Please ensure it's in the root directory if you intend to use it." |
|
|
|
|
|
with gr.Blocks(theme=gr.themes.Soft()) as demo: |
|
gr.Markdown(markdown_description) |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
model_dropdown = gr.Dropdown(choices=available_models, label="Choose Model Type", value=available_models[0] if available_models else None) |
|
seed_slider = gr.Slider(minimum=0, maximum=2**16 - 1, step=1, label='Seed', value=777, randomize=True) |
|
flexibility_slider = gr.Slider(minimum=0.0, maximum=2.0, step=0.05, label='Sampling Flexibility', value=0.0) |
|
|
|
with gr.Row(): |
|
color_picker = gr.ColorPicker(label="Point Color", value="#EE4B2B") |
|
marker_size_slider = gr.Slider(minimum=1, maximum=10, step=1, label="Marker Size", value=2) |
|
|
|
generate_btn = gr.Button(value="Generate Point Cloud", variant="primary") |
|
|
|
with gr.Column(scale=2): |
|
plot_output = gr.Plot(label="Generated Point Cloud") |
|
file_download_output = gr.File(label="Download Point Cloud (.xyz)") |
|
error_display = gr.Markdown("") |
|
|
|
generate_btn.click( |
|
fn=generate_gradio, |
|
inputs=[seed_slider, model_dropdown, flexibility_slider, color_picker, marker_size_slider], |
|
outputs=[plot_output, file_download_output, error_display] |
|
) |
|
|
|
if available_models: |
|
example_list = [ |
|
[777, available_models[0], 0.0, "#EE4B2B", 2], |
|
[1234, available_models[0], 0.5, "#1E90FF", 3], |
|
] |
|
if len(available_models) > 1: |
|
example_list.append([100, available_models[1], 0.2, "#32CD32", 2.5]) |
|
|
|
gr.Examples( |
|
examples=example_list, |
|
inputs=[seed_slider, model_dropdown, flexibility_slider, color_picker, marker_size_slider], |
|
outputs=[plot_output, file_download_output, error_display], |
|
fn=generate_gradio, |
|
cache_examples=False, |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
if not available_models: |
|
print("No models available to run the Gradio demo. You might want to check checkpoint paths and errors above.") |
|
|
|
|
|
|
|
print("Launching Gradio demo...") |
|
demo.launch() |