noumanjavaid's picture
Update app.py
df60a80 verified
import os
import gradio as gr
import plotly.graph_objects as go
import sys
import torch
from huggingface_hub import hf_hub_download
import numpy as np
import random
import tempfile # For creating temporary files for download
import traceback # For detailed error logging
# --- Environment Setup ---
# Suppress TensorFlow oneDNN optimization messages if TensorFlow is inadvertently imported by a dependency
os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0'
# Clone the repository only if the directory doesn't exist
if not os.path.exists("diffusion-point-cloud"):
print("Cloning diffusion-point-cloud repository...")
os.system("git clone https://github.com/luost26/diffusion-point-cloud")
else:
print("diffusion-point-cloud repository already exists.")
sys.path.append("diffusion-point-cloud")
# --- Model Imports ---
try:
from models.vae_gaussian import GaussianVAE
from models.vae_flow import FlowVAE
except ImportError as e:
print(f"CRITICAL Error importing models: {e}")
print("Please ensure 'diffusion-point-cloud' directory is in sys.path and contains the model definitions.")
sys.exit(1)
# --- Model Checkpoint Paths and Loading ---
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {DEVICE.upper()}")
MODEL_CONFIGS = {
"Airplane": {
"path_function": lambda: hf_hub_download("SerdarHelli/diffusion-point-cloud", filename="GEN_airplane.pt", revision="main"),
"expected_model_type": "gaussian",
"default_args": {
'model': "gaussian", # Should match expected_model_type
'latent_dim': 128,
'hyper': None,
'residual': True,
'num_points': 2048, # For sampling
# 'flexibility' will be taken from UI
}
},
"Chair": {
"path_function": lambda: "./GEN_chair.pt",
"expected_model_type": "gaussian", # Assuming Gaussian for chair as well
"default_args": {
'model': "gaussian",
'latent_dim': 128,
'hyper': None,
'residual': True,
'num_points': 2048,
}
}
# To add more models:
# "YourModelName": {
# "path_function": lambda: "path/to/your/model.pt",
# "expected_model_type": "gaussian", # or "flow"
# "default_args": { ... } # Model-specific defaults
# }
}
# Load checkpoints
LOADED_CHECKPOINTS = {}
for model_name, config in MODEL_CONFIGS.items():
model_path = "" # Initialize for error message
try:
model_path = config["path_function"]()
if model_name == "Chair" and not os.path.exists(model_path): # Specific check for local file
print(f"WARNING: Checkpoint for {model_name} not found at '{model_path}'. This model will not be available.")
LOADED_CHECKPOINTS[model_name] = None
continue
print(f"Loading checkpoint for {model_name} from '{model_path}'...")
LOADED_CHECKPOINTS[model_name] = torch.load(model_path, map_location=torch.device(DEVICE), weights_only=False)
print(f"Successfully loaded {model_name}.")
except Exception as e:
print(f"ERROR loading checkpoint for {model_name} from '{model_path}': {e}")
LOADED_CHECKPOINTS[model_name] = None
# --- Helper Functions ---
def seed_all(seed):
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
def normalize_point_clouds(pcs, mode):
if mode is None:
return pcs
for i in range(pcs.size(0)):
pc = pcs[i]
if mode == 'shape_unit':
shift = pc.mean(dim=0).reshape(1, 3)
scale = pc.flatten().std().reshape(1, 1)
elif mode == 'shape_bbox':
pc_max, _ = pc.max(dim=0, keepdim=True)
pc_min, _ = pc.min(dim=0, keepdim=True)
shift = ((pc_min + pc_max) / 2).view(1, 3)
scale = (pc_max - pc_min).max().reshape(1, 1) / 2
else: # Fallback
shift = torch.zeros_like(pc.mean(dim=0).reshape(1, 3))
scale = torch.ones_like(pc.flatten().std().reshape(1, 1))
if scale.abs().item() < 1e-8: # Prevent division by zero or very small scale
scale = torch.tensor(1.0, device=pc.device, dtype=pc.dtype).reshape(1, 1)
pcs[i] = (pc - shift) / scale
return pcs
# --- Core Prediction Logic ---
def predict(seed_val, selected_model_name, flexibility_val):
seed_all(int(seed_val))
ckpt = LOADED_CHECKPOINTS.get(selected_model_name)
if ckpt is None:
raise ValueError(f"Checkpoint for model '{selected_model_name}' not loaded or unavailable.")
model_specific_defaults = MODEL_CONFIGS[selected_model_name].get("default_args", {})
# --- Argument Handling for Model Instantiation and Sampling ---
actual_args = None
# Prioritize args from checkpoint if available and seems valid
if 'args' in ckpt and hasattr(ckpt['args'], 'model'):
actual_args = ckpt['args']
print(f"Using 'args' found in checkpoint for {selected_model_name}.")
# Augment with model-specific defaults if attributes are missing from ckpt['args']
for key, default_value in model_specific_defaults.items():
if not hasattr(actual_args, key):
print(f"Checkpoint 'args' missing '{key}'. Setting default: {default_value}")
setattr(actual_args, key, default_value)
else:
print(f"Warning: 'args' not found or 'args.model' missing in checkpoint for {selected_model_name}. Constructing mock_args from defaults.")
# Fallback: construct args using model_specific_defaults, trying to get values from top-level of ckpt
actual_args_dict = {}
for key, default_value in model_specific_defaults.items():
# Try to get from ckpt top-level first, then use the model-specific default
actual_args_dict[key] = ckpt.get(key, default_value)
actual_args = type('Args', (), actual_args_dict)()
# Ensure essential attributes for model construction and sampling are present on actual_args
# These might have been set by defaults above, but good to double check or enforce
if not hasattr(actual_args, 'model'): # Critical
raise ValueError("Resolved 'actual_args' is missing the 'model' attribute.")
if not hasattr(actual_args, 'latent_dim'): setattr(actual_args, 'latent_dim', 128) # A common default
if actual_args.model == 'gaussian':
if not hasattr(actual_args, 'residual'):
print("Setting default 'residual=True' for GaussianVAE.")
setattr(actual_args, 'residual', True)
elif actual_args.model == 'flow': # Parameters for FlowVAE
if not hasattr(actual_args, 'flow_depth'): setattr(actual_args, 'flow_depth', 10)
if not hasattr(actual_args, 'flow_hidden_dim'): setattr(actual_args, 'flow_hidden_dim', 256)
# Sampling parameters
if not hasattr(actual_args, 'num_points'):
print("Setting default 'num_points=2048' for sampling.")
setattr(actual_args, 'num_points', 2048)
# Use flexibility from UI slider, this overrides any 'flexibility' in args
setattr(actual_args, 'flexibility', flexibility_val)
print(f"Using flexibility: {actual_args.flexibility} for sampling.")
# --- Model Instantiation ---
model = None
if actual_args.model == 'gaussian':
model = GaussianVAE(actual_args).to(DEVICE)
elif actual_args.model == 'flow':
model = FlowVAE(actual_args).to(DEVICE)
else:
raise ValueError(f"Unknown model type in args: '{actual_args.model}'. Expected 'gaussian' or 'flow'.")
model.load_state_dict(ckpt['state_dict'])
model.eval()
# --- Point Cloud Generation ---
gen_pcs = []
with torch.no_grad():
z = torch.randn([1, actual_args.latent_dim], device=DEVICE)
x = model.sample(z, int(actual_args.num_points), flexibility=actual_args.flexibility)
gen_pcs.append(x.detach().cpu())
gen_pcs_tensor = torch.cat(gen_pcs, dim=0)[:1]
gen_pcs_normalized = normalize_point_clouds(gen_pcs_tensor.clone(), mode="shape_bbox")
return gen_pcs_normalized[0]
# --- Gradio Interface Function ---
def generate_gradio(seed, model_choice, flexibility, point_color_hex, marker_size):
error_message = ""
figure_plot = None
download_file_path = None
try:
if seed is None:
seed = random.randint(0, 2**16 - 1)
seed = int(seed)
if not model_choice:
error_message = "Please choose a model type."
# Return empty plot and no file if model not chosen
return go.Figure(), None, error_message
print(f"Generating {model_choice} with Seed: {seed}, Flex: {flexibility}, Color: {point_color_hex}, Size: {marker_size}")
points = predict(seed, model_choice, flexibility)
# Create Plotly figure
figure_plot = go.Figure(
data=[
go.Scatter3d(
x=points[:, 0], y=points[:, 1], z=points[:, 2],
mode='markers',
marker=dict(size=marker_size, color=point_color_hex) # Use hex color directly
)
],
layout=dict(
title=f"Generated {model_choice} (Seed: {seed}, Flex: {flexibility:.2f})",
scene=dict(
xaxis=dict(visible=True, title='X', backgroundcolor="rgb(230,230,230)", gridcolor="white", zerolinecolor="white"),
yaxis=dict(visible=True, title='Y', backgroundcolor="rgb(230,230,230)", gridcolor="white", zerolinecolor="white"),
zaxis=dict(visible=True, title='Z', backgroundcolor="rgb(230,230,230)", gridcolor="white", zerolinecolor="white"),
aspectmode='data'
),
margin=dict(l=0, r=0, b=0, t=40)
)
)
# Prepare file for download
with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".xyz", encoding='utf-8') as tmp_file:
for point in points:
tmp_file.write(f"{point[0]:.6f} {point[1]:.6f} {point[2]:.6f}\n")
download_file_path = tmp_file.name
print(f"Point cloud saved for download at: {download_file_path}")
except ValueError as ve:
error_message = f"Configuration Error: {str(ve)}"
print(error_message)
except AttributeError as ae:
error_message = f"Model Configuration Issue: {str(ae)}. The checkpoint might be missing expected parameters or they are incompatible."
print(error_message)
except Exception as e:
error_message = f"An unexpected error occurred: {str(e)}"
print(f"{error_message}\nFull Traceback:\n{traceback.format_exc()}")
# Ensure we always return three values, even on error
if figure_plot is None: figure_plot = go.Figure() # Empty plot on error
return figure_plot, download_file_path, error_message
# --- Gradio UI Definition ---
available_models = [name for name, ckpt in LOADED_CHECKPOINTS.items() if ckpt is not None]
if not available_models:
print("CRITICAL: No models were loaded successfully. The application may not function as expected.")
markdown_description = f'''
# Diffusion Probabilistic Models for 3D Point Cloud Generation
[CVPR 2021 Paper: "Diffusion Probabilistic Models for 3D Point Cloud Generation"](https://arxiv.org/abs/2103.01458) | [Official GitHub](https://github.com/luost26/diffusion-point-cloud)
This demo allows you to generate 3D point clouds using pre-trained models.
- Adjust the **Seed** for different random initializations.
- Choose a **Model Type** (e.g., Airplane, Chair).
- Control **Sampling Flexibility**: Lower values tend towards the mean shape, higher values increase diversity.
- Customize **Point Color** and **Marker Size**.
Running on: **{DEVICE.upper()}**
'''
if "Chair" in MODEL_CONFIGS and "Chair" not in available_models: # Check if Chair was intended but failed to load
markdown_description += "\n\n**Warning:** The 'Chair' model checkpoint (`GEN_chair.pt`) was not found or failed to load. Please ensure it's in the root directory if you intend to use it."
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown(markdown_description)
with gr.Row():
with gr.Column(scale=1): # Controls Column
model_dropdown = gr.Dropdown(choices=available_models, label="Choose Model Type", value=available_models[0] if available_models else None)
seed_slider = gr.Slider(minimum=0, maximum=2**16 - 1, step=1, label='Seed', value=777, randomize=True)
flexibility_slider = gr.Slider(minimum=0.0, maximum=2.0, step=0.05, label='Sampling Flexibility', value=0.0)
with gr.Row():
color_picker = gr.ColorPicker(label="Point Color", value="#EE4B2B") # Default orange
marker_size_slider = gr.Slider(minimum=1, maximum=10, step=1, label="Marker Size", value=2)
generate_btn = gr.Button(value="Generate Point Cloud", variant="primary")
with gr.Column(scale=2): # Output Column
plot_output = gr.Plot(label="Generated Point Cloud")
file_download_output = gr.File(label="Download Point Cloud (.xyz)")
error_display = gr.Markdown("") # For displaying error messages
generate_btn.click(
fn=generate_gradio,
inputs=[seed_slider, model_dropdown, flexibility_slider, color_picker, marker_size_slider],
outputs=[plot_output, file_download_output, error_display]
)
if available_models:
example_list = [
[777, available_models[0], 0.0, "#EE4B2B", 2],
[1234, available_models[0], 0.5, "#1E90FF", 3], # DodgerBlue
]
if len(available_models) > 1: # If Chair (or another model) is available
example_list.append([100, available_models[1], 0.2, "#32CD32", 2.5]) # LimeGreen
gr.Examples(
examples=example_list,
inputs=[seed_slider, model_dropdown, flexibility_slider, color_picker, marker_size_slider],
outputs=[plot_output, file_download_output, error_display],
fn=generate_gradio,
cache_examples=False, # Generation is fast enough, no need to cache potentially large plots
)
# --- Application Launch ---
if __name__ == "__main__":
if not available_models:
print("No models available to run the Gradio demo. You might want to check checkpoint paths and errors above.")
# Optionally, you could still launch a limited UI that just shows an error.
# For now, we'll just print and let it potentially launch an empty UI if Gradio is set up.
print("Launching Gradio demo...")
demo.launch() # Add share=True if you want a public link when running locally