diff --git a/.gitattributes b/.gitattributes index a6344aac8c09253b3b630fb776ae94478aa0275b..df3e3cc04fceaa5632ea1452f4b986a2466c426b 100644 --- a/.gitattributes +++ b/.gitattributes @@ -33,3 +33,38 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text +assets/example_multi_image/chair_1.png filter=lfs diff=lfs merge=lfs -text +assets/example_multi_image/chair_2.png filter=lfs diff=lfs merge=lfs -text +assets/example_multi_image/chair_3.png filter=lfs diff=lfs merge=lfs -text +assets/example_multi_image/flower_1.png filter=lfs diff=lfs merge=lfs -text +assets/example_multi_image/flower_2.png filter=lfs diff=lfs merge=lfs -text +assets/example_multi_image/flower_3.png filter=lfs diff=lfs merge=lfs -text +assets/example_multi_image/flower_4.png filter=lfs diff=lfs merge=lfs -text +assets/example_multi_image/flower_5.png filter=lfs diff=lfs merge=lfs -text +assets/example_multi_image/flower_6.png filter=lfs diff=lfs merge=lfs -text +assets/example_multi_image/flower_7.png filter=lfs diff=lfs merge=lfs -text +assets/example_multi_image/flower_8.png filter=lfs diff=lfs merge=lfs -text +assets/example_multi_image/monkey_1.png filter=lfs diff=lfs merge=lfs -text +assets/example_multi_image/monkey_2.png filter=lfs diff=lfs merge=lfs -text +assets/example_multi_image/monkey_3.png filter=lfs diff=lfs merge=lfs -text +assets/example_multi_image/monkey_4.png filter=lfs diff=lfs merge=lfs -text +assets/example_multi_image/paopao_1.png filter=lfs diff=lfs merge=lfs -text +assets/example_multi_image/paopao_2.png filter=lfs diff=lfs merge=lfs -text +assets/example_multi_image/paopao_3.png filter=lfs diff=lfs merge=lfs -text +assets/example_multi_image/paopao_4.png filter=lfs diff=lfs merge=lfs -text +assets/example_multi_image/paopao_5.png filter=lfs diff=lfs merge=lfs -text +assets/example_multi_image/paopao_6.png filter=lfs diff=lfs merge=lfs -text +assets/example_multi_image/paopao_7.png filter=lfs diff=lfs merge=lfs -text +assets/example_multi_image/paopao_8.png filter=lfs diff=lfs merge=lfs -text +assets/example_multi_image/puppet_1.png filter=lfs diff=lfs merge=lfs -text +assets/example_multi_image/puppet_2.png filter=lfs diff=lfs merge=lfs -text +assets/example_multi_image/puppet_3.png filter=lfs diff=lfs merge=lfs -text +assets/example_multi_image/robot_1.png filter=lfs diff=lfs merge=lfs -text +assets/example_multi_image/robot_2.png filter=lfs diff=lfs merge=lfs -text +assets/example_multi_image/SpongeBob_1.png filter=lfs diff=lfs merge=lfs -text +assets/example_multi_image/SpongeBob_2.png filter=lfs diff=lfs merge=lfs -text +assets/example_multi_image/SpongeBob_3.png filter=lfs diff=lfs merge=lfs -text +assets/example_multi_image/SpongeBob_4.png filter=lfs diff=lfs merge=lfs -text +assets/example_multi_image/toolcar_1.png filter=lfs diff=lfs merge=lfs -text +assets/example_multi_image/toolcar_2.png filter=lfs diff=lfs merge=lfs -text +assets/example_multi_image/toolcar_3.png filter=lfs diff=lfs merge=lfs -text diff --git a/README.md b/README.md index fb1c6a7c37564ac8e83a0fb29545ee5567957507..54d013082a917e21118c23e1f4ae370463b26ed1 100644 --- a/README.md +++ b/README.md @@ -1,13 +1,16 @@ --- title: ReconViaGen -emoji: 💻 -colorFrom: green -colorTo: purple +emoji: 🖥️ +colorFrom: indigo +colorTo: blue sdk: gradio -sdk_version: 5.44.1 +sdk_version: 5.34.2 app_file: app.py pinned: false -license: apache-2.0 +license: mit +short_description: High-fidelity 3D Geometry Generation from single view image --- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference + +Project Page: https://jiahao620.github.io/reconviagen/ \ No newline at end of file diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..d5fa0aec5caf0bcd8e573d326d6fe5ebdd09cc9f --- /dev/null +++ b/app.py @@ -0,0 +1,392 @@ +import gradio as gr +import spaces +from gradio_litmodel3d import LitModel3D + +import os +import shutil +os.environ['SPCONV_ALGO'] = 'native' +from typing import * +import torch +import numpy as np +import imageio +from easydict import EasyDict as edict +from PIL import Image +from trellis.pipelines import TrellisVGGTTo3DPipeline +from trellis.representations import Gaussian, MeshExtractResult +from trellis.utils import render_utils, postprocessing_utils + + + +MAX_SEED = np.iinfo(np.int32).max +# TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp') +TMP_DIR = "tmp/Trellis-demo" +os.environ['GRADIO_TEMP_DIR'] = 'tmp' +os.makedirs(TMP_DIR, exist_ok=True) + +def start_session(req: gr.Request): + user_dir = os.path.join(TMP_DIR, str(req.session_hash)) + os.makedirs(user_dir, exist_ok=True) + + +def end_session(req: gr.Request): + user_dir = os.path.join(TMP_DIR, str(req.session_hash)) + shutil.rmtree(user_dir) +def preprocess_image(image: Image.Image) -> Image.Image: + """ + Preprocess the input image for 3D generation. + + This function is called when a user uploads an image or selects an example. + It applies background removal and other preprocessing steps necessary for + optimal 3D model generation. + + Args: + image (Image.Image): The input image from the user + + Returns: + Image.Image: The preprocessed image ready for 3D generation + """ + processed_image = pipeline.preprocess_image(image) + return processed_image + + +def preprocess_images(images: List[Tuple[Image.Image, str]]) -> List[Image.Image]: + """ + Preprocess a list of input images for multi-image 3D generation. + + This function is called when users upload multiple images in the gallery. + It processes each image to prepare them for the multi-image 3D generation pipeline. + + Args: + images (List[Tuple[Image.Image, str]]): The input images from the gallery + + Returns: + List[Image.Image]: The preprocessed images ready for 3D generation + """ + images = [image[0] for image in images] + processed_images = [pipeline.preprocess_image(image) for image in images] + return processed_images + + +def pack_state(gs: Gaussian, mesh: MeshExtractResult) -> dict: + return { + 'gaussian': { + **gs.init_params, + '_xyz': gs._xyz.cpu().numpy(), + '_features_dc': gs._features_dc.cpu().numpy(), + '_scaling': gs._scaling.cpu().numpy(), + '_rotation': gs._rotation.cpu().numpy(), + '_opacity': gs._opacity.cpu().numpy(), + }, + 'mesh': { + 'vertices': mesh.vertices.cpu().numpy(), + 'faces': mesh.faces.cpu().numpy(), + }, + } + + +def unpack_state(state: dict) -> Tuple[Gaussian, edict, str]: + gs = Gaussian( + aabb=state['gaussian']['aabb'], + sh_degree=state['gaussian']['sh_degree'], + mininum_kernel_size=state['gaussian']['mininum_kernel_size'], + scaling_bias=state['gaussian']['scaling_bias'], + opacity_bias=state['gaussian']['opacity_bias'], + scaling_activation=state['gaussian']['scaling_activation'], + ) + gs._xyz = torch.tensor(state['gaussian']['_xyz'], device='cuda') + gs._features_dc = torch.tensor(state['gaussian']['_features_dc'], device='cuda') + gs._scaling = torch.tensor(state['gaussian']['_scaling'], device='cuda') + gs._rotation = torch.tensor(state['gaussian']['_rotation'], device='cuda') + gs._opacity = torch.tensor(state['gaussian']['_opacity'], device='cuda') + + mesh = edict( + vertices=torch.tensor(state['mesh']['vertices'], device='cuda'), + faces=torch.tensor(state['mesh']['faces'], device='cuda'), + ) + + return gs, mesh + + +def get_seed(randomize_seed: bool, seed: int) -> int: + """ + Get the random seed for generation. + + This function is called by the generate button to determine whether to use + a random seed or the user-specified seed value. + + Args: + randomize_seed (bool): Whether to generate a random seed + seed (int): The user-specified seed value + + Returns: + int: The seed to use for generation + """ + return np.random.randint(0, MAX_SEED) if randomize_seed else seed + + +@spaces.GPU(duration=120) +def generate_and_extract_glb( + multiimages: List[Tuple[Image.Image, str]], + seed: int, + ss_guidance_strength: float, + ss_sampling_steps: int, + slat_guidance_strength: float, + slat_sampling_steps: int, + multiimage_algo: Literal["multidiffusion", "stochastic"], + mesh_simplify: float, + texture_size: int, + req: gr.Request, +) -> Tuple[dict, str, str, str]: + """ + Convert an image to a 3D model and extract GLB file. + + Args: + image (Image.Image): The input image. + multiimages (List[Tuple[Image.Image, str]]): The input images in multi-image mode. + is_multiimage (bool): Whether is in multi-image mode. + seed (int): The random seed. + ss_guidance_strength (float): The guidance strength for sparse structure generation. + ss_sampling_steps (int): The number of sampling steps for sparse structure generation. + slat_guidance_strength (float): The guidance strength for structured latent generation. + slat_sampling_steps (int): The number of sampling steps for structured latent generation. + multiimage_algo (Literal["multidiffusion", "stochastic"]): The algorithm for multi-image generation. + mesh_simplify (float): The mesh simplification factor. + texture_size (int): The texture resolution. + + Returns: + dict: The information of the generated 3D model. + str: The path to the video of the 3D model. + str: The path to the extracted GLB file. + str: The path to the extracted GLB file (for download). + """ + user_dir = os.path.join(TMP_DIR, str(req.session_hash)) + image_files = [image[0] for image in multiimages] + + # Generate 3D model + outputs = pipeline.run( + image=image_files, + seed=seed, + formats=["gaussian", "mesh"], + preprocess_image=False, + sparse_structure_sampler_params={ + "steps": ss_sampling_steps, + "cfg_strength": ss_guidance_strength, + }, + slat_sampler_params={ + "steps": slat_sampling_steps, + "cfg_strength": slat_guidance_strength, + }, + mode=multiimage_algo, + ) + + # Render video + video = render_utils.render_video(outputs['gaussian'][0], num_frames=120)['color'] + video_geo = render_utils.render_video(outputs['mesh'][0], num_frames=120)['normal'] + video = [np.concatenate([video[i], video_geo[i]], axis=1) for i in range(len(video))] + video_path = os.path.join(user_dir, 'sample.mp4') + imageio.mimsave(video_path, video, fps=15) + + # Extract GLB + gs = outputs['gaussian'][0] + mesh = outputs['mesh'][0] + glb = postprocessing_utils.to_glb(gs, mesh, simplify=mesh_simplify, texture_size=texture_size, verbose=False) + glb_path = os.path.join(user_dir, 'sample.glb') + glb.export(glb_path) + + # Pack state for optional Gaussian extraction + state = pack_state(gs, mesh) + + torch.cuda.empty_cache() + return state, video_path, glb_path, glb_path + + +@spaces.GPU +def extract_gaussian(state: dict, req: gr.Request) -> Tuple[str, str]: + """ + Extract a Gaussian splatting file from the generated 3D model. + + This function is called when the user clicks "Extract Gaussian" button. + It converts the 3D model state into a .ply file format containing + Gaussian splatting data for advanced 3D applications. + + Args: + state (dict): The state of the generated 3D model containing Gaussian data + req (gr.Request): Gradio request object for session management + + Returns: + Tuple[str, str]: Paths to the extracted Gaussian file (for display and download) + """ + user_dir = os.path.join(TMP_DIR, str(req.session_hash)) + gs, _ = unpack_state(state) + gaussian_path = os.path.join(user_dir, 'sample.ply') + gs.save_ply(gaussian_path) + torch.cuda.empty_cache() + return gaussian_path, gaussian_path + + +def prepare_multi_example() -> List[Image.Image]: + multi_case = list(set([i.split('_')[0] for i in os.listdir("assets/example_multi_image")])) + images = [] + for case in multi_case: + _images = [] + for i in range(1, 9): + if os.path.exists(f'assets/example_multi_image/{case}_{i}.png'): + img = Image.open(f'assets/example_multi_image/{case}_{i}.png') + W, H = img.size + img = img.resize((int(W / H * 512), 512)) + _images.append(np.array(img)) + if len(_images) > 0: + images.append(Image.fromarray(np.concatenate(_images, axis=1))) + return images + + +def split_image(image: Image.Image) -> List[Image.Image]: + """ + Split a multi-view image into separate view images. + + This function is called when users select multi-image examples that contain + multiple views in a single concatenated image. It automatically splits them + based on alpha channel boundaries and preprocesses each view. + + Args: + image (Image.Image): A concatenated image containing multiple views + + Returns: + List[Image.Image]: List of individual preprocessed view images + """ + image = np.array(image) + alpha = image[..., 3] + alpha = np.any(alpha>0, axis=0) + start_pos = np.where(~alpha[:-1] & alpha[1:])[0].tolist() + end_pos = np.where(alpha[:-1] & ~alpha[1:])[0].tolist() + images = [] + for s, e in zip(start_pos, end_pos): + images.append(Image.fromarray(image[:, s:e+1])) + return [preprocess_image(image) for image in images] + + +with gr.Blocks(delete_cache=(600, 600)) as demo: + gr.Markdown(""" + ## Multi-view images to 3D Asset with [ReconViaGen](https://jiahao620.github.io/reconviagen/) + * Upload an image and click "Generate & Extract GLB" to create a 3D asset and automatically extract the GLB file. + * If you want the Gaussian file as well, click "Extract Gaussian" after generation. + * If the image has alpha channel, it will be used as the mask. Otherwise, we use `rembg` to remove the background. + + ✨This demo is partial. We will release the whole model later. Stay tuned!✨ + """) + + with gr.Row(): + with gr.Column(): + with gr.Tabs() as input_tabs: + with gr.Tab(label="Multiple Images", id=0) as multiimage_input_tab: + image_prompt = gr.Image(label="Image Prompt", format="png", visible=False, image_mode="RGBA", type="pil", height=300) + multiimage_prompt = gr.Gallery(label="Image Prompt", format="png", type="pil", height=300, columns=3) + gr.Markdown(""" + Input different views of the object in separate images. + + *NOTE: this is an experimental algorithm without training a specialized model. It may not produce the best results for all images, especially those having different poses or inconsistent details.* + """) + + with gr.Accordion(label="Generation Settings", open=False): + seed = gr.Slider(0, MAX_SEED, label="Seed", value=0, step=1) + randomize_seed = gr.Checkbox(label="Randomize Seed", value=False) + gr.Markdown("Stage 1: Sparse Structure Generation") + with gr.Row(): + ss_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=7.5, step=0.1) + ss_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=30, step=1) + gr.Markdown("Stage 2: Structured Latent Generation") + with gr.Row(): + slat_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=3.0, step=0.1) + slat_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1) + multiimage_algo = gr.Radio(["stochastic", "multidiffusion"], label="Multi-image Algorithm", value="multidiffusion") + + with gr.Accordion(label="GLB Extraction Settings", open=False): + mesh_simplify = gr.Slider(0.9, 0.98, label="Simplify", value=0.95, step=0.01) + texture_size = gr.Slider(512, 2048, label="Texture Size", value=1024, step=512) + + generate_btn = gr.Button("Generate & Extract GLB", variant="primary") + extract_gs_btn = gr.Button("Extract Gaussian", interactive=False) + gr.Markdown(""" + *NOTE: Gaussian file can be very large (~50MB), it will take a while to display and download.* + """) + + with gr.Column(): + video_output = gr.Video(label="Generated 3D Asset", autoplay=True, loop=True, height=300) + model_output = LitModel3D(label="Extracted GLB/Gaussian", exposure=10.0, height=300) + + with gr.Row(): + download_glb = gr.DownloadButton(label="Download GLB", interactive=False) + download_gs = gr.DownloadButton(label="Download Gaussian", interactive=False) + + output_buf = gr.State() + + # Example images at the bottom of the page + with gr.Row() as multiimage_example: + examples_multi = gr.Examples( + examples=prepare_multi_example(), + inputs=[image_prompt], + fn=split_image, + outputs=[multiimage_prompt], + run_on_click=True, + examples_per_page=8, + ) + + # Handlers + demo.load(start_session) + demo.unload(end_session) + + + multiimage_prompt.upload( + preprocess_images, + inputs=[multiimage_prompt], + outputs=[multiimage_prompt], + ) + + generate_btn.click( + get_seed, + inputs=[randomize_seed, seed], + outputs=[seed], + ).then( + lambda: [None, None, None, None], # 先清空 video_output + inputs=[], + outputs=[video_output, model_output, download_glb, download_gs], + ).then( + generate_and_extract_glb, + inputs=[multiimage_prompt, seed, ss_guidance_strength, ss_sampling_steps, slat_guidance_strength, slat_sampling_steps, multiimage_algo, mesh_simplify, texture_size], + outputs=[output_buf, video_output, model_output, download_glb], + ).then( + lambda: tuple([gr.Button(interactive=True), gr.Button(interactive=True)]), + outputs=[extract_gs_btn, download_glb], + ) + + video_output.clear( + lambda: tuple([gr.Button(interactive=False), gr.Button(interactive=False), gr.Button(interactive=False)]), + outputs=[extract_gs_btn, download_glb, download_gs], + ) + + extract_gs_btn.click( + extract_gaussian, + inputs=[output_buf], + outputs=[model_output, download_gs], + ).then( + lambda: gr.Button(interactive=True), + outputs=[download_gs], + ) + + model_output.clear( + lambda: tuple([gr.Button(interactive=False), gr.Button(interactive=False)]), + outputs=[download_glb, download_gs], + ) + + +# Launch the Gradio app +if __name__ == "__main__": + pipeline = TrellisVGGTTo3DPipeline.from_pretrained("weights/trellis-vggt-v0-1") + # pipeline = TrellisVGGTTo3DPipeline.from_pretrained("Stable-X/trellis-vggt-v0-1") + pipeline.cuda() + pipeline.VGGT_model.cuda() + try: + pipeline.preprocess_image(Image.fromarray(np.zeros((512, 512, 3), dtype=np.uint8))) # Preload rembg + except: + pass + demo.launch() \ No newline at end of file diff --git a/assets/example_multi_image/SpongeBob_1.png b/assets/example_multi_image/SpongeBob_1.png new file mode 100644 index 0000000000000000000000000000000000000000..8ed1ce6de4d21413e83b55f6e54901020299d7cf --- /dev/null +++ b/assets/example_multi_image/SpongeBob_1.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a022951d5eb80145eb7a523786c217e680b40b91e2aae1b8369aff35d849da55 +size 274212 diff --git a/assets/example_multi_image/SpongeBob_2.png b/assets/example_multi_image/SpongeBob_2.png new file mode 100644 index 0000000000000000000000000000000000000000..20dd4b4b8cca0e46dc6faab9a27491c63ae048d5 --- /dev/null +++ b/assets/example_multi_image/SpongeBob_2.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2ea5243c1decd64fca9db076de0857a90eb173c4deb6ff5feb7af0412f99d0c4 +size 238527 diff --git a/assets/example_multi_image/SpongeBob_3.png b/assets/example_multi_image/SpongeBob_3.png new file mode 100644 index 0000000000000000000000000000000000000000..27ef26203f222703218a4b5851ab57ac877b5a51 --- /dev/null +++ b/assets/example_multi_image/SpongeBob_3.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2bdb9dc4a6215ba93b3a489ac05c4bd0c713578690e31f9c6c7bf9a801e35160 +size 149425 diff --git a/assets/example_multi_image/SpongeBob_4.png b/assets/example_multi_image/SpongeBob_4.png new file mode 100644 index 0000000000000000000000000000000000000000..e9c809ed3fa1a7fd60842f3cd902e384e41b35b0 --- /dev/null +++ b/assets/example_multi_image/SpongeBob_4.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9274122325c00f390cf7ed91fee5051774791a91718c932b87e8bcf4518d262a +size 183839 diff --git a/assets/example_multi_image/chair_1.png b/assets/example_multi_image/chair_1.png new file mode 100644 index 0000000000000000000000000000000000000000..df43157c1321c911b9ec6440743c45a8ba620075 --- /dev/null +++ b/assets/example_multi_image/chair_1.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e60f01e62be25418ce96581b4ed2268a011b32f7c6d5409697a3f297f95fea4c +size 170588 diff --git a/assets/example_multi_image/chair_2.png b/assets/example_multi_image/chair_2.png new file mode 100644 index 0000000000000000000000000000000000000000..4249d4ea34703d4bd8c73461e2eb980f1cbf23f9 --- /dev/null +++ b/assets/example_multi_image/chair_2.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0ac39f31bb7f0173fc545796f280bd449f6a67a017966d701067b8faf26060aa +size 154897 diff --git a/assets/example_multi_image/chair_3.png b/assets/example_multi_image/chair_3.png new file mode 100644 index 0000000000000000000000000000000000000000..a5e8653b57aaa63cdff2cb517118701e25a1b1fa --- /dev/null +++ b/assets/example_multi_image/chair_3.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1b3cc4debfba605ba1eaf540d9e7b3d77248e42e84c4301da6685ce9248db1ee +size 147707 diff --git a/assets/example_multi_image/flower_1.png b/assets/example_multi_image/flower_1.png new file mode 100644 index 0000000000000000000000000000000000000000..856bd9a6910829654401fd273ec6ae2c32c8fb12 --- /dev/null +++ b/assets/example_multi_image/flower_1.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3888a244c5e11b7d31db48ecef13835436b55f7f89a4a335bd3c92d411e19dd7 +size 159334 diff --git a/assets/example_multi_image/flower_2.png b/assets/example_multi_image/flower_2.png new file mode 100644 index 0000000000000000000000000000000000000000..952cfa41df43b487c0d4547bd659d73642c924df --- /dev/null +++ b/assets/example_multi_image/flower_2.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b0e20b8cba9027f725b6d7d8893a9d8b6bb2ee1dc6d7613992cafb4642b3fe34 +size 160592 diff --git a/assets/example_multi_image/flower_3.png b/assets/example_multi_image/flower_3.png new file mode 100644 index 0000000000000000000000000000000000000000..bc1b89402e583127cb9c32c2f5624e0b2718ca3e --- /dev/null +++ b/assets/example_multi_image/flower_3.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b93c7ff6093cfc5a41a37b9956211c50406652dd754178f99fd3b76b9eb2b5f0 +size 156555 diff --git a/assets/example_multi_image/flower_4.png b/assets/example_multi_image/flower_4.png new file mode 100644 index 0000000000000000000000000000000000000000..823ecede6dbb126e28589c19f2d782777e4e5d77 --- /dev/null +++ b/assets/example_multi_image/flower_4.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4f8230cabcdc9d189307d22b8b999ba3258fecf32f2406ec0f2d1fe22d527071 +size 157859 diff --git a/assets/example_multi_image/flower_5.png b/assets/example_multi_image/flower_5.png new file mode 100644 index 0000000000000000000000000000000000000000..79d6c67a9460280b0f4d9fd12d09351356c2b086 --- /dev/null +++ b/assets/example_multi_image/flower_5.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:288268846343ba9b14fa5645c8f27d0f6a481645437355e26ae05f58a3f86826 +size 156809 diff --git a/assets/example_multi_image/flower_6.png b/assets/example_multi_image/flower_6.png new file mode 100644 index 0000000000000000000000000000000000000000..cc137b74cd0f079de9abd061494bfe1b78d7f01d --- /dev/null +++ b/assets/example_multi_image/flower_6.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:fca1cd7c68af1b35cb70ac775f37fd8dd65883e32a5f1cf4b2d3dbbe1b68fcb4 +size 155963 diff --git a/assets/example_multi_image/flower_7.png b/assets/example_multi_image/flower_7.png new file mode 100644 index 0000000000000000000000000000000000000000..f6832226e9eef72e473fea68c11ffaf50104fcd6 --- /dev/null +++ b/assets/example_multi_image/flower_7.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:52e025b07dc2ee6e55c154d3c1cee598d19fc47d66d6f1954230f0622813d1d1 +size 158364 diff --git a/assets/example_multi_image/flower_8.png b/assets/example_multi_image/flower_8.png new file mode 100644 index 0000000000000000000000000000000000000000..e0a6c9e9249383e9a37ab19a9480fc8a9147a4fe --- /dev/null +++ b/assets/example_multi_image/flower_8.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b3c2c05502cc77cb59704cbeaf63d30da5492f08ce24bbc9b1a527cb057a9841 +size 158819 diff --git a/assets/example_multi_image/monkey_1.png b/assets/example_multi_image/monkey_1.png new file mode 100644 index 0000000000000000000000000000000000000000..949bbb3d53892780385a2166a1992237b3929fc4 --- /dev/null +++ b/assets/example_multi_image/monkey_1.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4b3e610685ddaa1375af40b96f97dbfa4f4f53be6756c9529e91dab5ae7d7292 +size 123008 diff --git a/assets/example_multi_image/monkey_2.png b/assets/example_multi_image/monkey_2.png new file mode 100644 index 0000000000000000000000000000000000000000..d448ac18d76a4a944057ee2c91dc9e7a19618a20 --- /dev/null +++ b/assets/example_multi_image/monkey_2.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c0386797982ec0da142abc85aa66a3f6b65f779540a07b27d25e7ffc64ebb9c7 +size 128975 diff --git a/assets/example_multi_image/monkey_3.png b/assets/example_multi_image/monkey_3.png new file mode 100644 index 0000000000000000000000000000000000000000..6b6fdc13b2200db42fc9018be0b66ff7486a3aec --- /dev/null +++ b/assets/example_multi_image/monkey_3.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:95aa11f4e5616c2e0d6eca7331e0ddc6c6d9a1fa9c18ea186abb73e66b78d8a8 +size 136403 diff --git a/assets/example_multi_image/monkey_4.png b/assets/example_multi_image/monkey_4.png new file mode 100644 index 0000000000000000000000000000000000000000..2d81654a572eb16b50b3903d0e62b9fc7ee7848e --- /dev/null +++ b/assets/example_multi_image/monkey_4.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a13a9537461baf5ff842bac59f5b3a27a8d77340b2be2f527dc8302c502fedcb +size 114632 diff --git a/assets/example_multi_image/paopao_1.png b/assets/example_multi_image/paopao_1.png new file mode 100644 index 0000000000000000000000000000000000000000..13f9c324e59f7a5b4d41d63ea6dcf520092ac7e0 --- /dev/null +++ b/assets/example_multi_image/paopao_1.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6765b458154ebd0e84af5e3b281b07418afc617b64b02fa6f9d041a4f1810630 +size 127093 diff --git a/assets/example_multi_image/paopao_2.png b/assets/example_multi_image/paopao_2.png new file mode 100644 index 0000000000000000000000000000000000000000..0bb69be1e099c852adfa9535ed8677b435f47c68 --- /dev/null +++ b/assets/example_multi_image/paopao_2.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c1076a73d7f83225e24fccb4f33805012f1e4229afb4cece06f5eff8a88d8986 +size 119538 diff --git a/assets/example_multi_image/paopao_3.png b/assets/example_multi_image/paopao_3.png new file mode 100644 index 0000000000000000000000000000000000000000..14bef0412f58a3f40160a3ddf0d524bd690a3cd0 --- /dev/null +++ b/assets/example_multi_image/paopao_3.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8a2b7ee735c94f1cc6102db02eb08bf6516ba5d7e30cd4b0f0a5437a247f6343 +size 127161 diff --git a/assets/example_multi_image/paopao_4.png b/assets/example_multi_image/paopao_4.png new file mode 100644 index 0000000000000000000000000000000000000000..25e92cee83d84b6800278d860c8828413d307ae4 --- /dev/null +++ b/assets/example_multi_image/paopao_4.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:359d7b28b03f309b2b602c2f2003f57b71e239a90fbc8971c99b52c93f5ea05e +size 121718 diff --git a/assets/example_multi_image/paopao_5.png b/assets/example_multi_image/paopao_5.png new file mode 100644 index 0000000000000000000000000000000000000000..e200082edbb6f3297eee1d4ecabb9c79ba8a898b --- /dev/null +++ b/assets/example_multi_image/paopao_5.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:28d360a528c1e9661245eea3c14c2c3111dcb26b8dc13f7efaed8ecedc451ab3 +size 118950 diff --git a/assets/example_multi_image/paopao_6.png b/assets/example_multi_image/paopao_6.png new file mode 100644 index 0000000000000000000000000000000000000000..6241bf4e62fec629f3c96bccea5b5af7a0100b5b --- /dev/null +++ b/assets/example_multi_image/paopao_6.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:65d356f07f6e57b44604768e6ffae771af98ea7e2b553839500f37d8119f2655 +size 123623 diff --git a/assets/example_multi_image/paopao_7.png b/assets/example_multi_image/paopao_7.png new file mode 100644 index 0000000000000000000000000000000000000000..7b1b219f5791792b84285b621efa3d7c7a418129 --- /dev/null +++ b/assets/example_multi_image/paopao_7.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6859331313770c2174b361291d8a79523a11a9acb9c456617855a8b1f2ef036b +size 119381 diff --git a/assets/example_multi_image/paopao_8.png b/assets/example_multi_image/paopao_8.png new file mode 100644 index 0000000000000000000000000000000000000000..a84325d1383df1f0e3a4fb45e829214d08994722 --- /dev/null +++ b/assets/example_multi_image/paopao_8.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:df1a323c8f879df98594d3c47a2eb5991c2f0415f73c4d083b80261d090f4d8d +size 123132 diff --git a/assets/example_multi_image/puppet_1.png b/assets/example_multi_image/puppet_1.png new file mode 100644 index 0000000000000000000000000000000000000000..40d93f60fe61967f05f8e27500aae5c6015a77c0 --- /dev/null +++ b/assets/example_multi_image/puppet_1.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4f8c870d6abb33eb6de29e4adba0afb06415d487ea17ee92d761733cb4dddeed +size 181740 diff --git a/assets/example_multi_image/puppet_2.png b/assets/example_multi_image/puppet_2.png new file mode 100644 index 0000000000000000000000000000000000000000..ab5b6d59f2cd19c526329068e2c9dbc1331ec302 --- /dev/null +++ b/assets/example_multi_image/puppet_2.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:29ab76e25445690226148f6b71f1710d217dbfe7cdb5b11d60b4193dba45158d +size 188178 diff --git a/assets/example_multi_image/puppet_3.png b/assets/example_multi_image/puppet_3.png new file mode 100644 index 0000000000000000000000000000000000000000..b3505a275e45298dfcd092378e9f07a56ef943ab --- /dev/null +++ b/assets/example_multi_image/puppet_3.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5b0aa58ac7dadcd4fb15d3ddc80d8c07ab0572018e6af12ffb417b308c868319 +size 239733 diff --git a/assets/example_multi_image/robot_1.png b/assets/example_multi_image/robot_1.png new file mode 100644 index 0000000000000000000000000000000000000000..c2f09936271b03ac693b03eeb64ac6f6537d6f0a --- /dev/null +++ b/assets/example_multi_image/robot_1.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0da58dacc1fa327e2a084661bcb00328dba5febc4c8a5c8592d63a7c59f1925c +size 169163 diff --git a/assets/example_multi_image/robot_2.png b/assets/example_multi_image/robot_2.png new file mode 100644 index 0000000000000000000000000000000000000000..9a9f79c8f46f257ad750b4689344054dc35d37c7 --- /dev/null +++ b/assets/example_multi_image/robot_2.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:98becf95105550816b2bd95ada56219655895bf2aa60b373068f3df164fe6a15 +size 225490 diff --git a/assets/example_multi_image/toolcar_1.png b/assets/example_multi_image/toolcar_1.png new file mode 100644 index 0000000000000000000000000000000000000000..6a6889fbae69b6a3b34fa66479e5648fdfb72d04 --- /dev/null +++ b/assets/example_multi_image/toolcar_1.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b48779b49b026b33742e1a84c51c0ef1c299ed0dd3bb215736fa5d4380109cda +size 167157 diff --git a/assets/example_multi_image/toolcar_2.png b/assets/example_multi_image/toolcar_2.png new file mode 100644 index 0000000000000000000000000000000000000000..57aae25e27bd76eedee4e7ea84afc6222fd79c1a --- /dev/null +++ b/assets/example_multi_image/toolcar_2.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2e04e2715577bf790b131491c3b70b9d658ae10ba3980419c6cee356996cd5db +size 150647 diff --git a/assets/example_multi_image/toolcar_3.png b/assets/example_multi_image/toolcar_3.png new file mode 100644 index 0000000000000000000000000000000000000000..7fa61302c70922ca1f10e6a7ab45848f48fdf62d --- /dev/null +++ b/assets/example_multi_image/toolcar_3.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8eb50ffe75dca646a1578d00bfc3003e7d5bdf506fcb4f7840f2436c4e08d6d1 +size 146798 diff --git a/extensions/nvdiffrast/LICENSE.txt b/extensions/nvdiffrast/LICENSE.txt new file mode 100644 index 0000000000000000000000000000000000000000..26a070a431ce5bb4e926e1289f508f003a4ec730 --- /dev/null +++ b/extensions/nvdiffrast/LICENSE.txt @@ -0,0 +1,97 @@ +Copyright (c) 2020, NVIDIA Corporation. All rights reserved. + + +Nvidia Source Code License (1-Way Commercial) + +======================================================================= + +1. Definitions + +"Licensor" means any person or entity that distributes its Work. + +"Software" means the original work of authorship made available under +this License. + +"Work" means the Software and any additions to or derivative works of +the Software that are made available under this License. + +The terms "reproduce," "reproduction," "derivative works," and +"distribution" have the meaning as provided under U.S. copyright law; +provided, however, that for the purposes of this License, derivative +works shall not include works that remain separable from, or merely +link (or bind by name) to the interfaces of, the Work. + +Works, including the Software, are "made available" under this License +by including in or with the Work either (a) a copyright notice +referencing the applicability of this License to the Work, or (b) a +copy of this License. + +2. License Grants + + 2.1 Copyright Grant. Subject to the terms and conditions of this + License, each Licensor grants to you a perpetual, worldwide, + non-exclusive, royalty-free, copyright license to reproduce, + prepare derivative works of, publicly display, publicly perform, + sublicense and distribute its Work and any resulting derivative + works in any form. + +3. Limitations + + 3.1 Redistribution. You may reproduce or distribute the Work only + if (a) you do so under this License, (b) you include a complete + copy of this License with your distribution, and (c) you retain + without modification any copyright, patent, trademark, or + attribution notices that are present in the Work. + + 3.2 Derivative Works. You may specify that additional or different + terms apply to the use, reproduction, and distribution of your + derivative works of the Work ("Your Terms") only if (a) Your Terms + provide that the use limitation in Section 3.3 applies to your + derivative works, and (b) you identify the specific derivative + works that are subject to Your Terms. Notwithstanding Your Terms, + this License (including the redistribution requirements in Section + 3.1) will continue to apply to the Work itself. + + 3.3 Use Limitation. The Work and any derivative works thereof only + may be used or intended for use non-commercially. The Work or + derivative works thereof may be used or intended for use by Nvidia + or its affiliates commercially or non-commercially. As used herein, + "non-commercially" means for research or evaluation purposes only + and not for any direct or indirect monetary gain. + + 3.4 Patent Claims. If you bring or threaten to bring a patent claim + against any Licensor (including any claim, cross-claim or + counterclaim in a lawsuit) to enforce any patents that you allege + are infringed by any Work, then your rights under this License from + such Licensor (including the grant in Section 2.1) will terminate + immediately. + + 3.5 Trademarks. This License does not grant any rights to use any + Licensor's or its affiliates' names, logos, or trademarks, except + as necessary to reproduce the notices described in this License. + + 3.6 Termination. If you violate any term of this License, then your + rights under this License (including the grant in Section 2.1) will + terminate immediately. + +4. Disclaimer of Warranty. + +THE WORK IS PROVIDED "AS IS" WITHOUT WARRANTIES OR CONDITIONS OF ANY +KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WARRANTIES OR CONDITIONS OF +MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR +NON-INFRINGEMENT. YOU BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES UNDER +THIS LICENSE. + +5. Limitation of Liability. + +EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL +THEORY, WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT, OR OTHERWISE +SHALL ANY LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT, +INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF +OR RELATED TO THIS LICENSE, THE USE OR INABILITY TO USE THE WORK +(INCLUDING BUT NOT LIMITED TO LOSS OF GOODWILL, BUSINESS INTERRUPTION, +LOST PROFITS OR DATA, COMPUTER FAILURE OR MALFUNCTION, OR ANY OTHER +COMMERCIAL DAMAGES OR LOSSES), EVEN IF THE LICENSOR HAS BEEN ADVISED OF +THE POSSIBILITY OF SUCH DAMAGES. + +======================================================================= diff --git a/extensions/nvdiffrast/README.md b/extensions/nvdiffrast/README.md new file mode 100644 index 0000000000000000000000000000000000000000..3eeb4115c839a7703c5cac22fe6e89828ad29f2c --- /dev/null +++ b/extensions/nvdiffrast/README.md @@ -0,0 +1,42 @@ +## Nvdiffrast – Modular Primitives for High-Performance Differentiable Rendering + +![Teaser image](./docs/img/teaser.png) + +**Modular Primitives for High-Performance Differentiable Rendering**
+Samuli Laine, Janne Hellsten, Tero Karras, Yeongho Seol, Jaakko Lehtinen, Timo Aila
+[http://arxiv.org/abs/2011.03277](http://arxiv.org/abs/2011.03277) + +Nvdiffrast is a PyTorch/TensorFlow library that provides high-performance primitive operations for rasterization-based differentiable rendering. +Please refer to ☞☞ [nvdiffrast documentation](https://nvlabs.github.io/nvdiffrast) ☜☜ for more information. + +## Licenses + +Copyright © 2020–2024, NVIDIA Corporation. All rights reserved. + +This work is made available under the [Nvidia Source Code License](https://github.com/NVlabs/nvdiffrast/blob/main/LICENSE.txt). + +For business inquiries, please visit our website and submit the form: [NVIDIA Research Licensing](https://www.nvidia.com/en-us/research/inquiries/) + +We do not currently accept outside code contributions in the form of pull requests. + +Environment map stored as part of `samples/data/envphong.npz` is derived from a Wave Engine +[sample material](https://github.com/WaveEngine/Samples-2.5/tree/master/Materials/EnvironmentMap/Content/Assets/CubeMap.cubemap) +originally shared under +[MIT License](https://github.com/WaveEngine/Samples-2.5/blob/master/LICENSE.md). +Mesh and texture stored as part of `samples/data/earth.npz` are derived from +[3D Earth Photorealistic 2K](https://www.turbosquid.com/3d-models/3d-realistic-earth-photorealistic-2k-1279125) +model originally made available under +[TurboSquid 3D Model License](https://blog.turbosquid.com/turbosquid-3d-model-license/#3d-model-license). + +## Citation + +``` +@article{Laine2020diffrast, + title = {Modular Primitives for High-Performance Differentiable Rendering}, + author = {Samuli Laine and Janne Hellsten and Tero Karras and Yeongho Seol and Jaakko Lehtinen and Timo Aila}, + journal = {ACM Transactions on Graphics}, + year = {2020}, + volume = {39}, + number = {6} +} +``` diff --git a/extensions/nvdiffrast/nvdiffrast/__init__.py b/extensions/nvdiffrast/nvdiffrast/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..fd28a0879ef844ef791dca19abdc8416c2468e58 --- /dev/null +++ b/extensions/nvdiffrast/nvdiffrast/__init__.py @@ -0,0 +1,9 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +__version__ = '0.3.3' diff --git a/extensions/nvdiffrast/nvdiffrast/common/antialias.cu b/extensions/nvdiffrast/nvdiffrast/common/antialias.cu new file mode 100644 index 0000000000000000000000000000000000000000..95cc3bab582661a7deb6064daa616adf7121ea36 --- /dev/null +++ b/extensions/nvdiffrast/nvdiffrast/common/antialias.cu @@ -0,0 +1,558 @@ +// Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +// +// NVIDIA CORPORATION and its licensors retain all intellectual property +// and proprietary rights in and to this software, related documentation +// and any modifications thereto. Any use, reproduction, disclosure or +// distribution of this software and related documentation without an express +// license agreement from NVIDIA CORPORATION is strictly prohibited. + +#include "antialias.h" + +//------------------------------------------------------------------------ +// Helpers. + +#define F32_MAX (3.402823466e+38f) +static __forceinline__ __device__ bool same_sign(float a, float b) { return (__float_as_int(a) ^ __float_as_int(b)) >= 0; } +static __forceinline__ __device__ bool rational_gt(float n0, float n1, float d0, float d1) { return (n0*d1 > n1*d0) == same_sign(d0, d1); } +static __forceinline__ __device__ int max_idx3(float n0, float n1, float n2, float d0, float d1, float d2) +{ + bool g10 = rational_gt(n1, n0, d1, d0); + bool g20 = rational_gt(n2, n0, d2, d0); + bool g21 = rational_gt(n2, n1, d2, d1); + if (g20 && g21) return 2; + if (g10) return 1; + return 0; +} + +//------------------------------------------------------------------------ +// Format of antialiasing work items stored in work buffer. Usually accessed directly as int4. + +struct AAWorkItem +{ + enum + { + EDGE_MASK = 3, // Edge index in lowest bits. + FLAG_DOWN_BIT = 2, // Down instead of right. + FLAG_TRI1_BIT = 3, // Edge is from other pixel's triangle. + }; + + int px, py; // Pixel x, y. + unsigned int pz_flags; // High 16 bits = pixel z, low 16 bits = edge index and flags. + float alpha; // Antialiasing alpha value. Zero if no AA. +}; + +//------------------------------------------------------------------------ +// Hash functions. Adapted from public-domain code at http://www.burtleburtle.net/bob/hash/doobs.html + +#define JENKINS_MAGIC (0x9e3779b9u) +static __device__ __forceinline__ void jenkins_mix(unsigned int& a, unsigned int& b, unsigned int& c) +{ + a -= b; a -= c; a ^= (c>>13); + b -= c; b -= a; b ^= (a<<8); + c -= a; c -= b; c ^= (b>>13); + a -= b; a -= c; a ^= (c>>12); + b -= c; b -= a; b ^= (a<<16); + c -= a; c -= b; c ^= (b>>5); + a -= b; a -= c; a ^= (c>>3); + b -= c; b -= a; b ^= (a<<10); + c -= a; c -= b; c ^= (b>>15); +} + +// Helper class for hash index iteration. Implements simple odd-skip linear probing with a key-dependent skip. +class HashIndex +{ +public: + __device__ __forceinline__ HashIndex(const AntialiasKernelParams& p, uint64_t key) + { + m_mask = (p.allocTriangles << AA_LOG_HASH_ELEMENTS_PER_TRIANGLE(p.allocTriangles)) - 1; // This should work until triangle count exceeds 1073741824. + m_idx = (uint32_t)(key & 0xffffffffu); + m_skip = (uint32_t)(key >> 32); + uint32_t dummy = JENKINS_MAGIC; + jenkins_mix(m_idx, m_skip, dummy); + m_idx &= m_mask; + m_skip &= m_mask; + m_skip |= 1; + } + __device__ __forceinline__ int get(void) const { return m_idx; } + __device__ __forceinline__ void next(void) { m_idx = (m_idx + m_skip) & m_mask; } +private: + uint32_t m_idx, m_skip, m_mask; +}; + +static __device__ __forceinline__ void hash_insert(const AntialiasKernelParams& p, uint64_t key, int v) +{ + HashIndex idx(p, key); + while(1) + { + uint64_t prev = atomicCAS((unsigned long long*)&p.evHash[idx.get()], 0, (unsigned long long)key); + if (prev == 0 || prev == key) + break; + idx.next(); + } + int* q = (int*)&p.evHash[idx.get()]; + int a = atomicCAS(q+2, 0, v); + if (a != 0 && a != v) + atomicCAS(q+3, 0, v); +} + +static __device__ __forceinline__ int2 hash_find(const AntialiasKernelParams& p, uint64_t key) +{ + HashIndex idx(p, key); + while(1) + { + uint4 entry = p.evHash[idx.get()]; + uint64_t k = ((uint64_t)entry.x) | (((uint64_t)entry.y) << 32); + if (k == key || k == 0) + return make_int2((int)entry.z, (int)entry.w); + idx.next(); + } +} + +static __device__ __forceinline__ void evhash_insert_vertex(const AntialiasKernelParams& p, int va, int vb, int vn) +{ + if (va == vb) + return; + + uint64_t v0 = (uint32_t)min(va, vb) + 1; // canonical vertex order + uint64_t v1 = (uint32_t)max(va, vb) + 1; + uint64_t vk = v0 | (v1 << 32); // hash key + hash_insert(p, vk, vn + 1); +} + +static __forceinline__ __device__ int evhash_find_vertex(const AntialiasKernelParams& p, int va, int vb, int vr) +{ + if (va == vb) + return -1; + + uint64_t v0 = (uint32_t)min(va, vb) + 1; // canonical vertex order + uint64_t v1 = (uint32_t)max(va, vb) + 1; + uint64_t vk = v0 | (v1 << 32); // hash key + int2 vn = hash_find(p, vk) - 1; + if (vn.x == vr) return vn.y; + if (vn.y == vr) return vn.x; + return -1; +} + +//------------------------------------------------------------------------ +// Mesh analysis kernel. + +__global__ void AntialiasFwdMeshKernel(const AntialiasKernelParams p) +{ + int idx = threadIdx.x + blockIdx.x * blockDim.x; + if (idx >= p.numTriangles) + return; + + int v0 = p.tri[idx * 3 + 0]; + int v1 = p.tri[idx * 3 + 1]; + int v2 = p.tri[idx * 3 + 2]; + + if (v0 < 0 || v0 >= p.numVertices || + v1 < 0 || v1 >= p.numVertices || + v2 < 0 || v2 >= p.numVertices) + return; + + if (v0 == v1 || v1 == v2 || v2 == v0) + return; + + evhash_insert_vertex(p, v1, v2, v0); + evhash_insert_vertex(p, v2, v0, v1); + evhash_insert_vertex(p, v0, v1, v2); +} + +//------------------------------------------------------------------------ +// Discontinuity finder kernel. + +__global__ void AntialiasFwdDiscontinuityKernel(const AntialiasKernelParams p) +{ + // Calculate pixel position. + int px = blockIdx.x * AA_DISCONTINUITY_KERNEL_BLOCK_WIDTH + threadIdx.x; + int py = blockIdx.y * AA_DISCONTINUITY_KERNEL_BLOCK_HEIGHT + threadIdx.y; + int pz = blockIdx.z; + if (px >= p.width || py >= p.height || pz >= p.n) + return; + + // Pointer to our TriIdx and fetch. + int pidx0 = ((px + p.width * (py + p.height * pz)) << 2) + 3; + float tri0 = p.rasterOut[pidx0]; // These can stay as float, as we only compare them against each other. + + // Look right, clamp at edge. + int pidx1 = pidx0; + if (px < p.width - 1) + pidx1 += 4; + float tri1 = p.rasterOut[pidx1]; + + // Look down, clamp at edge. + int pidx2 = pidx0; + if (py < p.height - 1) + pidx2 += p.width << 2; + float tri2 = p.rasterOut[pidx2]; + + // Determine amount of work. + int count = 0; + if (tri1 != tri0) count = 1; + if (tri2 != tri0) count += 1; + if (!count) + return; // Exit warp. + + // Coalesce work counter update to once per CTA. + __shared__ int s_temp; + s_temp = 0; + __syncthreads(); + int idx = atomicAdd(&s_temp, count); + __syncthreads(); + if (idx == 0) + { + int base = atomicAdd(&p.workBuffer[0].x, s_temp); + s_temp = base + 1; // don't clobber the counters in first slot. + } + __syncthreads(); + idx += s_temp; + + // Write to memory. + if (tri1 != tri0) p.workBuffer[idx++] = make_int4(px, py, (pz << 16), 0); + if (tri2 != tri0) p.workBuffer[idx] = make_int4(px, py, (pz << 16) + (1 << AAWorkItem::FLAG_DOWN_BIT), 0); +} + +//------------------------------------------------------------------------ +// Forward analysis kernel. + +__global__ void AntialiasFwdAnalysisKernel(const AntialiasKernelParams p) +{ + __shared__ int s_base; + int workCount = p.workBuffer[0].x; + for(;;) + { + // Persistent threads work fetcher. + __syncthreads(); + if (threadIdx.x == 0) + s_base = atomicAdd(&p.workBuffer[0].y, AA_ANALYSIS_KERNEL_THREADS_PER_BLOCK); + __syncthreads(); + int thread_idx = s_base + threadIdx.x; + if (thread_idx >= workCount) + return; + + int4* pItem = p.workBuffer + thread_idx + 1; + int4 item = *pItem; + int px = item.x; + int py = item.y; + int pz = (int)(((unsigned int)item.z) >> 16); + int d = (item.z >> AAWorkItem::FLAG_DOWN_BIT) & 1; + + int pixel0 = px + p.width * (py + p.height * pz); + int pixel1 = pixel0 + (d ? p.width : 1); + float2 zt0 = ((float2*)p.rasterOut)[(pixel0 << 1) + 1]; + float2 zt1 = ((float2*)p.rasterOut)[(pixel1 << 1) + 1]; + int tri0 = float_to_triidx(zt0.y) - 1; + int tri1 = float_to_triidx(zt1.y) - 1; + + // Select triangle based on background / depth. + int tri = (tri0 >= 0) ? tri0 : tri1; + if (tri0 >= 0 && tri1 >= 0) + tri = (zt0.x < zt1.x) ? tri0 : tri1; + if (tri == tri1) + { + // Calculate with respect to neighbor pixel if chose that triangle. + px += 1 - d; + py += d; + } + + // Bail out if triangle index is corrupt. + if (tri < 0 || tri >= p.numTriangles) + continue; + + // Fetch vertex indices. + int vi0 = p.tri[tri * 3 + 0]; + int vi1 = p.tri[tri * 3 + 1]; + int vi2 = p.tri[tri * 3 + 2]; + + // Bail out if vertex indices are corrupt. + if (vi0 < 0 || vi0 >= p.numVertices || + vi1 < 0 || vi1 >= p.numVertices || + vi2 < 0 || vi2 >= p.numVertices) + continue; + + // Fetch opposite vertex indices. Use vertex itself (always silhouette) if no opposite vertex exists. + int op0 = evhash_find_vertex(p, vi2, vi1, vi0); + int op1 = evhash_find_vertex(p, vi0, vi2, vi1); + int op2 = evhash_find_vertex(p, vi1, vi0, vi2); + + // Instance mode: Adjust vertex indices based on minibatch index. + if (p.instance_mode) + { + int vbase = pz * p.numVertices; + vi0 += vbase; + vi1 += vbase; + vi2 += vbase; + if (op0 >= 0) op0 += vbase; + if (op1 >= 0) op1 += vbase; + if (op2 >= 0) op2 += vbase; + } + + // Fetch vertex positions. + float4 p0 = ((float4*)p.pos)[vi0]; + float4 p1 = ((float4*)p.pos)[vi1]; + float4 p2 = ((float4*)p.pos)[vi2]; + float4 o0 = (op0 < 0) ? p0 : ((float4*)p.pos)[op0]; + float4 o1 = (op1 < 0) ? p1 : ((float4*)p.pos)[op1]; + float4 o2 = (op2 < 0) ? p2 : ((float4*)p.pos)[op2]; + + // Project vertices to pixel space. + float w0 = 1.f / p0.w; + float w1 = 1.f / p1.w; + float w2 = 1.f / p2.w; + float ow0 = 1.f / o0.w; + float ow1 = 1.f / o1.w; + float ow2 = 1.f / o2.w; + float fx = (float)px + .5f - p.xh; + float fy = (float)py + .5f - p.yh; + float x0 = p0.x * w0 * p.xh - fx; + float y0 = p0.y * w0 * p.yh - fy; + float x1 = p1.x * w1 * p.xh - fx; + float y1 = p1.y * w1 * p.yh - fy; + float x2 = p2.x * w2 * p.xh - fx; + float y2 = p2.y * w2 * p.yh - fy; + float ox0 = o0.x * ow0 * p.xh - fx; + float oy0 = o0.y * ow0 * p.yh - fy; + float ox1 = o1.x * ow1 * p.xh - fx; + float oy1 = o1.y * ow1 * p.yh - fy; + float ox2 = o2.x * ow2 * p.xh - fx; + float oy2 = o2.y * ow2 * p.yh - fy; + + // Signs to kill non-silhouette edges. + float bb = (x1-x0)*(y2-y0) - (x2-x0)*(y1-y0); // Triangle itself. + float a0 = (x1-ox0)*(y2-oy0) - (x2-ox0)*(y1-oy0); // Wings. + float a1 = (x2-ox1)*(y0-oy1) - (x0-ox1)*(y2-oy1); + float a2 = (x0-ox2)*(y1-oy2) - (x1-ox2)*(y0-oy2); + + // If no matching signs anywhere, skip the rest. + if (same_sign(a0, bb) || same_sign(a1, bb) || same_sign(a2, bb)) + { + // XY flip for horizontal edges. + if (d) + { + swap(x0, y0); + swap(x1, y1); + swap(x2, y2); + } + + float dx0 = x2 - x1; + float dx1 = x0 - x2; + float dx2 = x1 - x0; + float dy0 = y2 - y1; + float dy1 = y0 - y2; + float dy2 = y1 - y0; + + // Check if an edge crosses between us and the neighbor pixel. + float dc = -F32_MAX; + float ds = (tri == tri0) ? 1.f : -1.f; + float d0 = ds * (x1*dy0 - y1*dx0); + float d1 = ds * (x2*dy1 - y2*dx1); + float d2 = ds * (x0*dy2 - y0*dx2); + + if (same_sign(y1, y2)) d0 = -F32_MAX, dy0 = 1.f; + if (same_sign(y2, y0)) d1 = -F32_MAX, dy1 = 1.f; + if (same_sign(y0, y1)) d2 = -F32_MAX, dy2 = 1.f; + + int di = max_idx3(d0, d1, d2, dy0, dy1, dy2); + if (di == 0 && same_sign(a0, bb) && fabsf(dy0) >= fabsf(dx0)) dc = d0 / dy0; + if (di == 1 && same_sign(a1, bb) && fabsf(dy1) >= fabsf(dx1)) dc = d1 / dy1; + if (di == 2 && same_sign(a2, bb) && fabsf(dy2) >= fabsf(dx2)) dc = d2 / dy2; + float eps = .0625f; // Expect no more than 1/16 pixel inaccuracy. + + // Adjust output image if a suitable edge was found. + if (dc > -eps && dc < 1.f + eps) + { + dc = fminf(fmaxf(dc, 0.f), 1.f); + float alpha = ds * (.5f - dc); + const float* pColor0 = p.color + pixel0 * p.channels; + const float* pColor1 = p.color + pixel1 * p.channels; + float* pOutput = p.output + (alpha > 0.f ? pixel0 : pixel1) * p.channels; + for (int i=0; i < p.channels; i++) + atomicAdd(&pOutput[i], alpha * (pColor1[i] - pColor0[i])); + + // Rewrite the work item's flags and alpha. Keep original px, py. + unsigned int flags = pz << 16; + flags |= di; + flags |= d << AAWorkItem::FLAG_DOWN_BIT; + flags |= (__float_as_uint(ds) >> 31) << AAWorkItem::FLAG_TRI1_BIT; + ((int2*)pItem)[1] = make_int2(flags, __float_as_int(alpha)); + } + } + } +} + +//------------------------------------------------------------------------ +// Gradient kernel. + +__global__ void AntialiasGradKernel(const AntialiasKernelParams p) +{ + // Temporary space for coalesced atomics. + CA_DECLARE_TEMP(AA_GRAD_KERNEL_THREADS_PER_BLOCK); + __shared__ int s_base; // Work counter communication across entire CTA. + + int workCount = p.workBuffer[0].x; + + for(;;) + { + // Persistent threads work fetcher. + __syncthreads(); + if (threadIdx.x == 0) + s_base = atomicAdd(&p.workBuffer[0].y, AA_GRAD_KERNEL_THREADS_PER_BLOCK); + __syncthreads(); + int thread_idx = s_base + threadIdx.x; + if (thread_idx >= workCount) + return; + + // Read work item filled out by forward kernel. + int4 item = p.workBuffer[thread_idx + 1]; + unsigned int amask = __ballot_sync(0xffffffffu, item.w); + if (item.w == 0) + continue; // No effect. + + // Unpack work item and replicate setup from forward analysis kernel. + int px = item.x; + int py = item.y; + int pz = (int)(((unsigned int)item.z) >> 16); + int d = (item.z >> AAWorkItem::FLAG_DOWN_BIT) & 1; + float alpha = __int_as_float(item.w); + int tri1 = (item.z >> AAWorkItem::FLAG_TRI1_BIT) & 1; + int di = item.z & AAWorkItem::EDGE_MASK; + float ds = __int_as_float(__float_as_int(1.0) | (tri1 << 31)); + int pixel0 = px + p.width * (py + p.height * pz); + int pixel1 = pixel0 + (d ? p.width : 1); + int tri = float_to_triidx(p.rasterOut[((tri1 ? pixel1 : pixel0) << 2) + 3]) - 1; + if (tri1) + { + px += 1 - d; + py += d; + } + + // Bail out if triangle index is corrupt. + bool triFail = (tri < 0 || tri >= p.numTriangles); + amask = __ballot_sync(amask, !triFail); + if (triFail) + continue; + + // Outgoing color gradients. + float* pGrad0 = p.gradColor + pixel0 * p.channels; + float* pGrad1 = p.gradColor + pixel1 * p.channels; + + // Incoming color gradients. + const float* pDy = p.dy + (alpha > 0.f ? pixel0 : pixel1) * p.channels; + + // Position gradient weight based on colors and incoming gradients. + float dd = 0.f; + const float* pColor0 = p.color + pixel0 * p.channels; + const float* pColor1 = p.color + pixel1 * p.channels; + + // Loop over channels and accumulate. + for (int i=0; i < p.channels; i++) + { + float dy = pDy[i]; + if (dy != 0.f) + { + // Update position gradient weight. + dd += dy * (pColor1[i] - pColor0[i]); + + // Update color gradients. No coalescing because all have different targets. + float v = alpha * dy; + atomicAdd(&pGrad0[i], -v); + atomicAdd(&pGrad1[i], v); + } + } + + // If position weight is zero, skip the rest. + bool noGrad = (dd == 0.f); + amask = __ballot_sync(amask, !noGrad); + if (noGrad) + continue; + + // Fetch vertex indices of the active edge and their positions. + int i1 = (di < 2) ? (di + 1) : 0; + int i2 = (i1 < 2) ? (i1 + 1) : 0; + int vi1 = p.tri[3 * tri + i1]; + int vi2 = p.tri[3 * tri + i2]; + + // Bail out if vertex indices are corrupt. + bool vtxFail = (vi1 < 0 || vi1 >= p.numVertices || vi2 < 0 || vi2 >= p.numVertices); + amask = __ballot_sync(amask, !vtxFail); + if (vtxFail) + continue; + + // Instance mode: Adjust vertex indices based on minibatch index. + if (p.instance_mode) + { + vi1 += pz * p.numVertices; + vi2 += pz * p.numVertices; + } + + // Fetch vertex positions. + float4 p1 = ((float4*)p.pos)[vi1]; + float4 p2 = ((float4*)p.pos)[vi2]; + + // Project vertices to pixel space. + float pxh = p.xh; + float pyh = p.yh; + float fx = (float)px + .5f - pxh; + float fy = (float)py + .5f - pyh; + + // XY flip for horizontal edges. + if (d) + { + swap(p1.x, p1.y); + swap(p2.x, p2.y); + swap(pxh, pyh); + swap(fx, fy); + } + + // Gradient calculation setup. + float w1 = 1.f / p1.w; + float w2 = 1.f / p2.w; + float x1 = p1.x * w1 * pxh - fx; + float y1 = p1.y * w1 * pyh - fy; + float x2 = p2.x * w2 * pxh - fx; + float y2 = p2.y * w2 * pyh - fy; + float dx = x2 - x1; + float dy = y2 - y1; + float db = x1*dy - y1*dx; + + // Compute inverse delta-y with epsilon. + float ep = copysignf(1e-3f, dy); // ~1/1000 pixel. + float iy = 1.f / (dy + ep); + + // Compute position gradients. + float dby = db * iy; + float iw1 = -w1 * iy * dd; + float iw2 = w2 * iy * dd; + float gp1x = iw1 * pxh * y2; + float gp2x = iw2 * pxh * y1; + float gp1y = iw1 * pyh * (dby - x2); + float gp2y = iw2 * pyh * (dby - x1); + float gp1w = -(p1.x * gp1x + p1.y * gp1y) * w1; + float gp2w = -(p2.x * gp2x + p2.y * gp2y) * w2; + + // XY flip the gradients. + if (d) + { + swap(gp1x, gp1y); + swap(gp2x, gp2y); + } + + // Kill position gradients if alpha was saturated. + if (fabsf(alpha) >= 0.5f) + { + gp1x = gp1y = gp1w = 0.f; + gp2x = gp2y = gp2w = 0.f; + } + + // Initialize coalesced atomics. Match both triangle ID and edge index. + // Also note that some threads may be inactive. + CA_SET_GROUP_MASK(tri ^ (di << 30), amask); + + // Accumulate gradients. + caAtomicAdd3_xyw(p.gradPos + 4 * vi1, gp1x, gp1y, gp1w); + caAtomicAdd3_xyw(p.gradPos + 4 * vi2, gp2x, gp2y, gp2w); + } +} + +//------------------------------------------------------------------------ diff --git a/extensions/nvdiffrast/nvdiffrast/common/antialias.h b/extensions/nvdiffrast/nvdiffrast/common/antialias.h new file mode 100644 index 0000000000000000000000000000000000000000..a324f2f2efc9e45ff6cb9dc125ce6a56dda47698 --- /dev/null +++ b/extensions/nvdiffrast/nvdiffrast/common/antialias.h @@ -0,0 +1,50 @@ +// Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +// +// NVIDIA CORPORATION and its licensors retain all intellectual property +// and proprietary rights in and to this software, related documentation +// and any modifications thereto. Any use, reproduction, disclosure or +// distribution of this software and related documentation without an express +// license agreement from NVIDIA CORPORATION is strictly prohibited. + +#pragma once +#include "common.h" + +//------------------------------------------------------------------------ +// Constants and helpers. + +#define AA_DISCONTINUITY_KERNEL_BLOCK_WIDTH 32 +#define AA_DISCONTINUITY_KERNEL_BLOCK_HEIGHT 8 +#define AA_ANALYSIS_KERNEL_THREADS_PER_BLOCK 256 +#define AA_MESH_KERNEL_THREADS_PER_BLOCK 256 +#define AA_HASH_ELEMENTS_PER_TRIANGLE(alloc) ((alloc) >= (2 << 25) ? 4 : 8) // With more than 16777216 triangles (alloc >= 33554432) use smallest possible value of 4 to conserve memory, otherwise use 8 for fewer collisions. +#define AA_LOG_HASH_ELEMENTS_PER_TRIANGLE(alloc) ((alloc) >= (2 << 25) ? 2 : 3) +#define AA_GRAD_KERNEL_THREADS_PER_BLOCK 256 + +//------------------------------------------------------------------------ +// CUDA kernel params. + +struct AntialiasKernelParams +{ + const float* color; // Incoming color buffer. + const float* rasterOut; // Incoming rasterizer output buffer. + const int* tri; // Incoming triangle buffer. + const float* pos; // Incoming position buffer. + float* output; // Output buffer of forward kernel. + const float* dy; // Incoming gradients. + float* gradColor; // Output buffer, color gradient. + float* gradPos; // Output buffer, position gradient. + int4* workBuffer; // Buffer for storing intermediate work items. First item reserved for counters. + uint4* evHash; // Edge-vertex hash. + int allocTriangles; // Number of triangles accommodated by evHash. Always power of two. + int numTriangles; // Number of triangles. + int numVertices; // Number of vertices. + int width; // Input width. + int height; // Input height. + int n; // Minibatch size. + int channels; // Channel count in color input. + float xh, yh; // Transfer to pixel space. + int instance_mode; // 0=normal, 1=instance mode. + int tri_const; // 1 if triangle array is known to be constant. +}; + +//------------------------------------------------------------------------ diff --git a/extensions/nvdiffrast/nvdiffrast/common/common.cpp b/extensions/nvdiffrast/nvdiffrast/common/common.cpp new file mode 100644 index 0000000000000000000000000000000000000000..e566c035bdef66e9b75265a58fb8602b0fa530ca --- /dev/null +++ b/extensions/nvdiffrast/nvdiffrast/common/common.cpp @@ -0,0 +1,60 @@ +// Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +// +// NVIDIA CORPORATION and its licensors retain all intellectual property +// and proprietary rights in and to this software, related documentation +// and any modifications thereto. Any use, reproduction, disclosure or +// distribution of this software and related documentation without an express +// license agreement from NVIDIA CORPORATION is strictly prohibited. + +#include + +//------------------------------------------------------------------------ +// Block and grid size calculators for kernel launches. + +dim3 getLaunchBlockSize(int maxWidth, int maxHeight, int width, int height) +{ + int maxThreads = maxWidth * maxHeight; + if (maxThreads <= 1 || (width * height) <= 1) + return dim3(1, 1, 1); // Degenerate. + + // Start from max size. + int bw = maxWidth; + int bh = maxHeight; + + // Optimizations for weirdly sized buffers. + if (width < bw) + { + // Decrease block width to smallest power of two that covers the buffer width. + while ((bw >> 1) >= width) + bw >>= 1; + + // Maximize height. + bh = maxThreads / bw; + if (bh > height) + bh = height; + } + else if (height < bh) + { + // Halve height and double width until fits completely inside buffer vertically. + while (bh > height) + { + bh >>= 1; + if (bw < width) + bw <<= 1; + } + } + + // Done. + return dim3(bw, bh, 1); +} + +dim3 getLaunchGridSize(dim3 blockSize, int width, int height, int depth) +{ + dim3 gridSize; + gridSize.x = (width - 1) / blockSize.x + 1; + gridSize.y = (height - 1) / blockSize.y + 1; + gridSize.z = (depth - 1) / blockSize.z + 1; + return gridSize; +} + +//------------------------------------------------------------------------ diff --git a/extensions/nvdiffrast/nvdiffrast/common/common.h b/extensions/nvdiffrast/nvdiffrast/common/common.h new file mode 100644 index 0000000000000000000000000000000000000000..01ecf9fc009081eaaa86c32c7959b599e360cfc7 --- /dev/null +++ b/extensions/nvdiffrast/nvdiffrast/common/common.h @@ -0,0 +1,263 @@ +// Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +// +// NVIDIA CORPORATION and its licensors retain all intellectual property +// and proprietary rights in and to this software, related documentation +// and any modifications thereto. Any use, reproduction, disclosure or +// distribution of this software and related documentation without an express +// license agreement from NVIDIA CORPORATION is strictly prohibited. + +#pragma once +#include +#include + +//------------------------------------------------------------------------ +// C++ helper function prototypes. + +dim3 getLaunchBlockSize(int maxWidth, int maxHeight, int width, int height); +dim3 getLaunchGridSize(dim3 blockSize, int width, int height, int depth); + +//------------------------------------------------------------------------ +// The rest is CUDA device code specific stuff. + +#ifdef __CUDACC__ + +//------------------------------------------------------------------------ +// Helpers for CUDA vector types. + +static __device__ __forceinline__ float2& operator*= (float2& a, const float2& b) { a.x *= b.x; a.y *= b.y; return a; } +static __device__ __forceinline__ float2& operator+= (float2& a, const float2& b) { a.x += b.x; a.y += b.y; return a; } +static __device__ __forceinline__ float2& operator-= (float2& a, const float2& b) { a.x -= b.x; a.y -= b.y; return a; } +static __device__ __forceinline__ float2& operator*= (float2& a, float b) { a.x *= b; a.y *= b; return a; } +static __device__ __forceinline__ float2& operator+= (float2& a, float b) { a.x += b; a.y += b; return a; } +static __device__ __forceinline__ float2& operator-= (float2& a, float b) { a.x -= b; a.y -= b; return a; } +static __device__ __forceinline__ float2 operator* (const float2& a, const float2& b) { return make_float2(a.x * b.x, a.y * b.y); } +static __device__ __forceinline__ float2 operator+ (const float2& a, const float2& b) { return make_float2(a.x + b.x, a.y + b.y); } +static __device__ __forceinline__ float2 operator- (const float2& a, const float2& b) { return make_float2(a.x - b.x, a.y - b.y); } +static __device__ __forceinline__ float2 operator* (const float2& a, float b) { return make_float2(a.x * b, a.y * b); } +static __device__ __forceinline__ float2 operator+ (const float2& a, float b) { return make_float2(a.x + b, a.y + b); } +static __device__ __forceinline__ float2 operator- (const float2& a, float b) { return make_float2(a.x - b, a.y - b); } +static __device__ __forceinline__ float2 operator* (float a, const float2& b) { return make_float2(a * b.x, a * b.y); } +static __device__ __forceinline__ float2 operator+ (float a, const float2& b) { return make_float2(a + b.x, a + b.y); } +static __device__ __forceinline__ float2 operator- (float a, const float2& b) { return make_float2(a - b.x, a - b.y); } +static __device__ __forceinline__ float2 operator- (const float2& a) { return make_float2(-a.x, -a.y); } +static __device__ __forceinline__ float3& operator*= (float3& a, const float3& b) { a.x *= b.x; a.y *= b.y; a.z *= b.z; return a; } +static __device__ __forceinline__ float3& operator+= (float3& a, const float3& b) { a.x += b.x; a.y += b.y; a.z += b.z; return a; } +static __device__ __forceinline__ float3& operator-= (float3& a, const float3& b) { a.x -= b.x; a.y -= b.y; a.z -= b.z; return a; } +static __device__ __forceinline__ float3& operator*= (float3& a, float b) { a.x *= b; a.y *= b; a.z *= b; return a; } +static __device__ __forceinline__ float3& operator+= (float3& a, float b) { a.x += b; a.y += b; a.z += b; return a; } +static __device__ __forceinline__ float3& operator-= (float3& a, float b) { a.x -= b; a.y -= b; a.z -= b; return a; } +static __device__ __forceinline__ float3 operator* (const float3& a, const float3& b) { return make_float3(a.x * b.x, a.y * b.y, a.z * b.z); } +static __device__ __forceinline__ float3 operator+ (const float3& a, const float3& b) { return make_float3(a.x + b.x, a.y + b.y, a.z + b.z); } +static __device__ __forceinline__ float3 operator- (const float3& a, const float3& b) { return make_float3(a.x - b.x, a.y - b.y, a.z - b.z); } +static __device__ __forceinline__ float3 operator* (const float3& a, float b) { return make_float3(a.x * b, a.y * b, a.z * b); } +static __device__ __forceinline__ float3 operator+ (const float3& a, float b) { return make_float3(a.x + b, a.y + b, a.z + b); } +static __device__ __forceinline__ float3 operator- (const float3& a, float b) { return make_float3(a.x - b, a.y - b, a.z - b); } +static __device__ __forceinline__ float3 operator* (float a, const float3& b) { return make_float3(a * b.x, a * b.y, a * b.z); } +static __device__ __forceinline__ float3 operator+ (float a, const float3& b) { return make_float3(a + b.x, a + b.y, a + b.z); } +static __device__ __forceinline__ float3 operator- (float a, const float3& b) { return make_float3(a - b.x, a - b.y, a - b.z); } +static __device__ __forceinline__ float3 operator- (const float3& a) { return make_float3(-a.x, -a.y, -a.z); } +static __device__ __forceinline__ float4& operator*= (float4& a, const float4& b) { a.x *= b.x; a.y *= b.y; a.z *= b.z; a.w *= b.w; return a; } +static __device__ __forceinline__ float4& operator+= (float4& a, const float4& b) { a.x += b.x; a.y += b.y; a.z += b.z; a.w += b.w; return a; } +static __device__ __forceinline__ float4& operator-= (float4& a, const float4& b) { a.x -= b.x; a.y -= b.y; a.z -= b.z; a.w -= b.w; return a; } +static __device__ __forceinline__ float4& operator*= (float4& a, float b) { a.x *= b; a.y *= b; a.z *= b; a.w *= b; return a; } +static __device__ __forceinline__ float4& operator+= (float4& a, float b) { a.x += b; a.y += b; a.z += b; a.w += b; return a; } +static __device__ __forceinline__ float4& operator-= (float4& a, float b) { a.x -= b; a.y -= b; a.z -= b; a.w -= b; return a; } +static __device__ __forceinline__ float4 operator* (const float4& a, const float4& b) { return make_float4(a.x * b.x, a.y * b.y, a.z * b.z, a.w * b.w); } +static __device__ __forceinline__ float4 operator+ (const float4& a, const float4& b) { return make_float4(a.x + b.x, a.y + b.y, a.z + b.z, a.w + b.w); } +static __device__ __forceinline__ float4 operator- (const float4& a, const float4& b) { return make_float4(a.x - b.x, a.y - b.y, a.z - b.z, a.w - b.w); } +static __device__ __forceinline__ float4 operator* (const float4& a, float b) { return make_float4(a.x * b, a.y * b, a.z * b, a.w * b); } +static __device__ __forceinline__ float4 operator+ (const float4& a, float b) { return make_float4(a.x + b, a.y + b, a.z + b, a.w + b); } +static __device__ __forceinline__ float4 operator- (const float4& a, float b) { return make_float4(a.x - b, a.y - b, a.z - b, a.w - b); } +static __device__ __forceinline__ float4 operator* (float a, const float4& b) { return make_float4(a * b.x, a * b.y, a * b.z, a * b.w); } +static __device__ __forceinline__ float4 operator+ (float a, const float4& b) { return make_float4(a + b.x, a + b.y, a + b.z, a + b.w); } +static __device__ __forceinline__ float4 operator- (float a, const float4& b) { return make_float4(a - b.x, a - b.y, a - b.z, a - b.w); } +static __device__ __forceinline__ float4 operator- (const float4& a) { return make_float4(-a.x, -a.y, -a.z, -a.w); } +static __device__ __forceinline__ int2& operator*= (int2& a, const int2& b) { a.x *= b.x; a.y *= b.y; return a; } +static __device__ __forceinline__ int2& operator+= (int2& a, const int2& b) { a.x += b.x; a.y += b.y; return a; } +static __device__ __forceinline__ int2& operator-= (int2& a, const int2& b) { a.x -= b.x; a.y -= b.y; return a; } +static __device__ __forceinline__ int2& operator*= (int2& a, int b) { a.x *= b; a.y *= b; return a; } +static __device__ __forceinline__ int2& operator+= (int2& a, int b) { a.x += b; a.y += b; return a; } +static __device__ __forceinline__ int2& operator-= (int2& a, int b) { a.x -= b; a.y -= b; return a; } +static __device__ __forceinline__ int2 operator* (const int2& a, const int2& b) { return make_int2(a.x * b.x, a.y * b.y); } +static __device__ __forceinline__ int2 operator+ (const int2& a, const int2& b) { return make_int2(a.x + b.x, a.y + b.y); } +static __device__ __forceinline__ int2 operator- (const int2& a, const int2& b) { return make_int2(a.x - b.x, a.y - b.y); } +static __device__ __forceinline__ int2 operator* (const int2& a, int b) { return make_int2(a.x * b, a.y * b); } +static __device__ __forceinline__ int2 operator+ (const int2& a, int b) { return make_int2(a.x + b, a.y + b); } +static __device__ __forceinline__ int2 operator- (const int2& a, int b) { return make_int2(a.x - b, a.y - b); } +static __device__ __forceinline__ int2 operator* (int a, const int2& b) { return make_int2(a * b.x, a * b.y); } +static __device__ __forceinline__ int2 operator+ (int a, const int2& b) { return make_int2(a + b.x, a + b.y); } +static __device__ __forceinline__ int2 operator- (int a, const int2& b) { return make_int2(a - b.x, a - b.y); } +static __device__ __forceinline__ int2 operator- (const int2& a) { return make_int2(-a.x, -a.y); } +static __device__ __forceinline__ int3& operator*= (int3& a, const int3& b) { a.x *= b.x; a.y *= b.y; a.z *= b.z; return a; } +static __device__ __forceinline__ int3& operator+= (int3& a, const int3& b) { a.x += b.x; a.y += b.y; a.z += b.z; return a; } +static __device__ __forceinline__ int3& operator-= (int3& a, const int3& b) { a.x -= b.x; a.y -= b.y; a.z -= b.z; return a; } +static __device__ __forceinline__ int3& operator*= (int3& a, int b) { a.x *= b; a.y *= b; a.z *= b; return a; } +static __device__ __forceinline__ int3& operator+= (int3& a, int b) { a.x += b; a.y += b; a.z += b; return a; } +static __device__ __forceinline__ int3& operator-= (int3& a, int b) { a.x -= b; a.y -= b; a.z -= b; return a; } +static __device__ __forceinline__ int3 operator* (const int3& a, const int3& b) { return make_int3(a.x * b.x, a.y * b.y, a.z * b.z); } +static __device__ __forceinline__ int3 operator+ (const int3& a, const int3& b) { return make_int3(a.x + b.x, a.y + b.y, a.z + b.z); } +static __device__ __forceinline__ int3 operator- (const int3& a, const int3& b) { return make_int3(a.x - b.x, a.y - b.y, a.z - b.z); } +static __device__ __forceinline__ int3 operator* (const int3& a, int b) { return make_int3(a.x * b, a.y * b, a.z * b); } +static __device__ __forceinline__ int3 operator+ (const int3& a, int b) { return make_int3(a.x + b, a.y + b, a.z + b); } +static __device__ __forceinline__ int3 operator- (const int3& a, int b) { return make_int3(a.x - b, a.y - b, a.z - b); } +static __device__ __forceinline__ int3 operator* (int a, const int3& b) { return make_int3(a * b.x, a * b.y, a * b.z); } +static __device__ __forceinline__ int3 operator+ (int a, const int3& b) { return make_int3(a + b.x, a + b.y, a + b.z); } +static __device__ __forceinline__ int3 operator- (int a, const int3& b) { return make_int3(a - b.x, a - b.y, a - b.z); } +static __device__ __forceinline__ int3 operator- (const int3& a) { return make_int3(-a.x, -a.y, -a.z); } +static __device__ __forceinline__ int4& operator*= (int4& a, const int4& b) { a.x *= b.x; a.y *= b.y; a.z *= b.z; a.w *= b.w; return a; } +static __device__ __forceinline__ int4& operator+= (int4& a, const int4& b) { a.x += b.x; a.y += b.y; a.z += b.z; a.w += b.w; return a; } +static __device__ __forceinline__ int4& operator-= (int4& a, const int4& b) { a.x -= b.x; a.y -= b.y; a.z -= b.z; a.w -= b.w; return a; } +static __device__ __forceinline__ int4& operator*= (int4& a, int b) { a.x *= b; a.y *= b; a.z *= b; a.w *= b; return a; } +static __device__ __forceinline__ int4& operator+= (int4& a, int b) { a.x += b; a.y += b; a.z += b; a.w += b; return a; } +static __device__ __forceinline__ int4& operator-= (int4& a, int b) { a.x -= b; a.y -= b; a.z -= b; a.w -= b; return a; } +static __device__ __forceinline__ int4 operator* (const int4& a, const int4& b) { return make_int4(a.x * b.x, a.y * b.y, a.z * b.z, a.w * b.w); } +static __device__ __forceinline__ int4 operator+ (const int4& a, const int4& b) { return make_int4(a.x + b.x, a.y + b.y, a.z + b.z, a.w + b.w); } +static __device__ __forceinline__ int4 operator- (const int4& a, const int4& b) { return make_int4(a.x - b.x, a.y - b.y, a.z - b.z, a.w - b.w); } +static __device__ __forceinline__ int4 operator* (const int4& a, int b) { return make_int4(a.x * b, a.y * b, a.z * b, a.w * b); } +static __device__ __forceinline__ int4 operator+ (const int4& a, int b) { return make_int4(a.x + b, a.y + b, a.z + b, a.w + b); } +static __device__ __forceinline__ int4 operator- (const int4& a, int b) { return make_int4(a.x - b, a.y - b, a.z - b, a.w - b); } +static __device__ __forceinline__ int4 operator* (int a, const int4& b) { return make_int4(a * b.x, a * b.y, a * b.z, a * b.w); } +static __device__ __forceinline__ int4 operator+ (int a, const int4& b) { return make_int4(a + b.x, a + b.y, a + b.z, a + b.w); } +static __device__ __forceinline__ int4 operator- (int a, const int4& b) { return make_int4(a - b.x, a - b.y, a - b.z, a - b.w); } +static __device__ __forceinline__ int4 operator- (const int4& a) { return make_int4(-a.x, -a.y, -a.z, -a.w); } +static __device__ __forceinline__ uint2& operator*= (uint2& a, const uint2& b) { a.x *= b.x; a.y *= b.y; return a; } +static __device__ __forceinline__ uint2& operator+= (uint2& a, const uint2& b) { a.x += b.x; a.y += b.y; return a; } +static __device__ __forceinline__ uint2& operator-= (uint2& a, const uint2& b) { a.x -= b.x; a.y -= b.y; return a; } +static __device__ __forceinline__ uint2& operator*= (uint2& a, unsigned int b) { a.x *= b; a.y *= b; return a; } +static __device__ __forceinline__ uint2& operator+= (uint2& a, unsigned int b) { a.x += b; a.y += b; return a; } +static __device__ __forceinline__ uint2& operator-= (uint2& a, unsigned int b) { a.x -= b; a.y -= b; return a; } +static __device__ __forceinline__ uint2 operator* (const uint2& a, const uint2& b) { return make_uint2(a.x * b.x, a.y * b.y); } +static __device__ __forceinline__ uint2 operator+ (const uint2& a, const uint2& b) { return make_uint2(a.x + b.x, a.y + b.y); } +static __device__ __forceinline__ uint2 operator- (const uint2& a, const uint2& b) { return make_uint2(a.x - b.x, a.y - b.y); } +static __device__ __forceinline__ uint2 operator* (const uint2& a, unsigned int b) { return make_uint2(a.x * b, a.y * b); } +static __device__ __forceinline__ uint2 operator+ (const uint2& a, unsigned int b) { return make_uint2(a.x + b, a.y + b); } +static __device__ __forceinline__ uint2 operator- (const uint2& a, unsigned int b) { return make_uint2(a.x - b, a.y - b); } +static __device__ __forceinline__ uint2 operator* (unsigned int a, const uint2& b) { return make_uint2(a * b.x, a * b.y); } +static __device__ __forceinline__ uint2 operator+ (unsigned int a, const uint2& b) { return make_uint2(a + b.x, a + b.y); } +static __device__ __forceinline__ uint2 operator- (unsigned int a, const uint2& b) { return make_uint2(a - b.x, a - b.y); } +static __device__ __forceinline__ uint3& operator*= (uint3& a, const uint3& b) { a.x *= b.x; a.y *= b.y; a.z *= b.z; return a; } +static __device__ __forceinline__ uint3& operator+= (uint3& a, const uint3& b) { a.x += b.x; a.y += b.y; a.z += b.z; return a; } +static __device__ __forceinline__ uint3& operator-= (uint3& a, const uint3& b) { a.x -= b.x; a.y -= b.y; a.z -= b.z; return a; } +static __device__ __forceinline__ uint3& operator*= (uint3& a, unsigned int b) { a.x *= b; a.y *= b; a.z *= b; return a; } +static __device__ __forceinline__ uint3& operator+= (uint3& a, unsigned int b) { a.x += b; a.y += b; a.z += b; return a; } +static __device__ __forceinline__ uint3& operator-= (uint3& a, unsigned int b) { a.x -= b; a.y -= b; a.z -= b; return a; } +static __device__ __forceinline__ uint3 operator* (const uint3& a, const uint3& b) { return make_uint3(a.x * b.x, a.y * b.y, a.z * b.z); } +static __device__ __forceinline__ uint3 operator+ (const uint3& a, const uint3& b) { return make_uint3(a.x + b.x, a.y + b.y, a.z + b.z); } +static __device__ __forceinline__ uint3 operator- (const uint3& a, const uint3& b) { return make_uint3(a.x - b.x, a.y - b.y, a.z - b.z); } +static __device__ __forceinline__ uint3 operator* (const uint3& a, unsigned int b) { return make_uint3(a.x * b, a.y * b, a.z * b); } +static __device__ __forceinline__ uint3 operator+ (const uint3& a, unsigned int b) { return make_uint3(a.x + b, a.y + b, a.z + b); } +static __device__ __forceinline__ uint3 operator- (const uint3& a, unsigned int b) { return make_uint3(a.x - b, a.y - b, a.z - b); } +static __device__ __forceinline__ uint3 operator* (unsigned int a, const uint3& b) { return make_uint3(a * b.x, a * b.y, a * b.z); } +static __device__ __forceinline__ uint3 operator+ (unsigned int a, const uint3& b) { return make_uint3(a + b.x, a + b.y, a + b.z); } +static __device__ __forceinline__ uint3 operator- (unsigned int a, const uint3& b) { return make_uint3(a - b.x, a - b.y, a - b.z); } +static __device__ __forceinline__ uint4& operator*= (uint4& a, const uint4& b) { a.x *= b.x; a.y *= b.y; a.z *= b.z; a.w *= b.w; return a; } +static __device__ __forceinline__ uint4& operator+= (uint4& a, const uint4& b) { a.x += b.x; a.y += b.y; a.z += b.z; a.w += b.w; return a; } +static __device__ __forceinline__ uint4& operator-= (uint4& a, const uint4& b) { a.x -= b.x; a.y -= b.y; a.z -= b.z; a.w -= b.w; return a; } +static __device__ __forceinline__ uint4& operator*= (uint4& a, unsigned int b) { a.x *= b; a.y *= b; a.z *= b; a.w *= b; return a; } +static __device__ __forceinline__ uint4& operator+= (uint4& a, unsigned int b) { a.x += b; a.y += b; a.z += b; a.w += b; return a; } +static __device__ __forceinline__ uint4& operator-= (uint4& a, unsigned int b) { a.x -= b; a.y -= b; a.z -= b; a.w -= b; return a; } +static __device__ __forceinline__ uint4 operator* (const uint4& a, const uint4& b) { return make_uint4(a.x * b.x, a.y * b.y, a.z * b.z, a.w * b.w); } +static __device__ __forceinline__ uint4 operator+ (const uint4& a, const uint4& b) { return make_uint4(a.x + b.x, a.y + b.y, a.z + b.z, a.w + b.w); } +static __device__ __forceinline__ uint4 operator- (const uint4& a, const uint4& b) { return make_uint4(a.x - b.x, a.y - b.y, a.z - b.z, a.w - b.w); } +static __device__ __forceinline__ uint4 operator* (const uint4& a, unsigned int b) { return make_uint4(a.x * b, a.y * b, a.z * b, a.w * b); } +static __device__ __forceinline__ uint4 operator+ (const uint4& a, unsigned int b) { return make_uint4(a.x + b, a.y + b, a.z + b, a.w + b); } +static __device__ __forceinline__ uint4 operator- (const uint4& a, unsigned int b) { return make_uint4(a.x - b, a.y - b, a.z - b, a.w - b); } +static __device__ __forceinline__ uint4 operator* (unsigned int a, const uint4& b) { return make_uint4(a * b.x, a * b.y, a * b.z, a * b.w); } +static __device__ __forceinline__ uint4 operator+ (unsigned int a, const uint4& b) { return make_uint4(a + b.x, a + b.y, a + b.z, a + b.w); } +static __device__ __forceinline__ uint4 operator- (unsigned int a, const uint4& b) { return make_uint4(a - b.x, a - b.y, a - b.z, a - b.w); } + +template static __device__ __forceinline__ T zero_value(void); +template<> __device__ __forceinline__ float zero_value (void) { return 0.f; } +template<> __device__ __forceinline__ float2 zero_value(void) { return make_float2(0.f, 0.f); } +template<> __device__ __forceinline__ float4 zero_value(void) { return make_float4(0.f, 0.f, 0.f, 0.f); } +static __device__ __forceinline__ float3 make_float3(const float2& a, float b) { return make_float3(a.x, a.y, b); } +static __device__ __forceinline__ float4 make_float4(const float3& a, float b) { return make_float4(a.x, a.y, a.z, b); } +static __device__ __forceinline__ float4 make_float4(const float2& a, const float2& b) { return make_float4(a.x, a.y, b.x, b.y); } +static __device__ __forceinline__ int3 make_int3(const int2& a, int b) { return make_int3(a.x, a.y, b); } +static __device__ __forceinline__ int4 make_int4(const int3& a, int b) { return make_int4(a.x, a.y, a.z, b); } +static __device__ __forceinline__ int4 make_int4(const int2& a, const int2& b) { return make_int4(a.x, a.y, b.x, b.y); } +static __device__ __forceinline__ uint3 make_uint3(const uint2& a, unsigned int b) { return make_uint3(a.x, a.y, b); } +static __device__ __forceinline__ uint4 make_uint4(const uint3& a, unsigned int b) { return make_uint4(a.x, a.y, a.z, b); } +static __device__ __forceinline__ uint4 make_uint4(const uint2& a, const uint2& b) { return make_uint4(a.x, a.y, b.x, b.y); } + +template static __device__ __forceinline__ void swap(T& a, T& b) { T temp = a; a = b; b = temp; } + +//------------------------------------------------------------------------ +// Triangle ID <-> float32 conversion functions to support very large triangle IDs. +// +// Values up to and including 16777216 (also, negative values) are converted trivially and retain +// compatibility with previous versions. Larger values are mapped to unique float32 that are not equal to +// the ID. The largest value that converts to float32 and back without generating inf or nan is 889192447. + +static __device__ __forceinline__ int float_to_triidx(float x) { if (x <= 16777216.f) return (int)x; return __float_as_int(x) - 0x4a800000; } +static __device__ __forceinline__ float triidx_to_float(int x) { if (x <= 0x01000000) return (float)x; return __int_as_float(0x4a800000 + x); } + +//------------------------------------------------------------------------ +// Coalesced atomics. These are all done via macros. + +#if __CUDA_ARCH__ >= 700 // Warp match instruction __match_any_sync() is only available on compute capability 7.x and higher + +#define CA_TEMP _ca_temp +#define CA_TEMP_PARAM float* CA_TEMP +#define CA_DECLARE_TEMP(threads_per_block) \ + __shared__ float CA_TEMP[(threads_per_block)] + +#define CA_SET_GROUP_MASK(group, thread_mask) \ + bool _ca_leader; \ + float* _ca_ptr; \ + do { \ + int tidx = threadIdx.x + blockDim.x * threadIdx.y; \ + int lane = tidx & 31; \ + int warp = tidx >> 5; \ + int tmask = __match_any_sync((thread_mask), (group)); \ + int leader = __ffs(tmask) - 1; \ + _ca_leader = (leader == lane); \ + _ca_ptr = &_ca_temp[((warp << 5) + leader)]; \ + } while(0) + +#define CA_SET_GROUP(group) \ + CA_SET_GROUP_MASK((group), 0xffffffffu) + +#define caAtomicAdd(ptr, value) \ + do { \ + if (_ca_leader) \ + *_ca_ptr = 0.f; \ + atomicAdd(_ca_ptr, (value)); \ + if (_ca_leader) \ + atomicAdd((ptr), *_ca_ptr); \ + } while(0) + +#define caAtomicAdd3_xyw(ptr, x, y, w) \ + do { \ + caAtomicAdd((ptr), (x)); \ + caAtomicAdd((ptr)+1, (y)); \ + caAtomicAdd((ptr)+3, (w)); \ + } while(0) + +#define caAtomicAddTexture(ptr, level, idx, value) \ + do { \ + CA_SET_GROUP((idx) ^ ((level) << 27)); \ + caAtomicAdd((ptr)+(idx), (value)); \ + } while(0) + +//------------------------------------------------------------------------ +// Disable atomic coalescing for compute capability lower than 7.x + +#else // __CUDA_ARCH__ >= 700 +#define CA_TEMP _ca_temp +#define CA_TEMP_PARAM float CA_TEMP +#define CA_DECLARE_TEMP(threads_per_block) CA_TEMP_PARAM +#define CA_SET_GROUP_MASK(group, thread_mask) +#define CA_SET_GROUP(group) +#define caAtomicAdd(ptr, value) atomicAdd((ptr), (value)) +#define caAtomicAdd3_xyw(ptr, x, y, w) \ + do { \ + atomicAdd((ptr), (x)); \ + atomicAdd((ptr)+1, (y)); \ + atomicAdd((ptr)+3, (w)); \ + } while(0) +#define caAtomicAddTexture(ptr, level, idx, value) atomicAdd((ptr)+(idx), (value)) +#endif // __CUDA_ARCH__ >= 700 + +//------------------------------------------------------------------------ +#endif // __CUDACC__ diff --git a/extensions/nvdiffrast/nvdiffrast/common/cudaraster/CudaRaster.hpp b/extensions/nvdiffrast/nvdiffrast/common/cudaraster/CudaRaster.hpp new file mode 100644 index 0000000000000000000000000000000000000000..3c1c3a7fd137618d6d20217b5ee4d9b964d3f9b8 --- /dev/null +++ b/extensions/nvdiffrast/nvdiffrast/common/cudaraster/CudaRaster.hpp @@ -0,0 +1,63 @@ +// Copyright (c) 2009-2022, NVIDIA CORPORATION. All rights reserved. +// +// NVIDIA CORPORATION and its licensors retain all intellectual property +// and proprietary rights in and to this software, related documentation +// and any modifications thereto. Any use, reproduction, disclosure or +// distribution of this software and related documentation without an express +// license agreement from NVIDIA CORPORATION is strictly prohibited. + +#pragma once + +//------------------------------------------------------------------------ +// This is a slimmed-down and modernized version of the original +// CudaRaster codebase that accompanied the HPG 2011 paper +// "High-Performance Software Rasterization on GPUs" by Laine and Karras. +// Modifications have been made to accommodate post-Volta execution model +// with warp divergence. Support for shading, blending, quad rendering, +// and supersampling have been removed as unnecessary for nvdiffrast. +//------------------------------------------------------------------------ + +namespace CR +{ + +class RasterImpl; + +//------------------------------------------------------------------------ +// Interface class to isolate user from implementation details. +//------------------------------------------------------------------------ + +class CudaRaster +{ +public: + enum + { + RenderModeFlag_EnableBackfaceCulling = 1 << 0, // Enable backface culling. + RenderModeFlag_EnableDepthPeeling = 1 << 1, // Enable depth peeling. Must have a peel buffer set. + }; + +public: + CudaRaster (void); + ~CudaRaster (void); + + void setBufferSize (int width, int height, int numImages); // Width and height are internally rounded up to multiples of tile size (8x8) for buffer sizes. + void setViewport (int width, int height, int offsetX, int offsetY); // Tiled rendering viewport setup. + void setRenderModeFlags (unsigned int renderModeFlags); // Affects all subsequent calls to drawTriangles(). Defaults to zero. + void deferredClear (unsigned int clearColor); // Clears color and depth buffers during next call to drawTriangles(). + void setVertexBuffer (void* vertices, int numVertices); // GPU pointer managed by caller. Vertex positions in clip space as float4 (x, y, z, w). + void setIndexBuffer (void* indices, int numTriangles); // GPU pointer managed by caller. Triangle index+color quadruplets as uint4 (idx0, idx1, idx2, color). + bool drawTriangles (const int* ranges, bool peel, cudaStream_t stream); // Ranges (offsets and counts) as #triangles entries, not as bytes. If NULL, draw all triangles. Returns false in case of internal overflow. + void* getColorBuffer (void); // GPU pointer managed by CudaRaster. + void* getDepthBuffer (void); // GPU pointer managed by CudaRaster. + void swapDepthAndPeel (void); // Swap depth and peeling buffers. + +private: + CudaRaster (const CudaRaster&); // forbidden + CudaRaster& operator= (const CudaRaster&); // forbidden + +private: + RasterImpl* m_impl; // Opaque pointer to implementation. +}; + +//------------------------------------------------------------------------ +} // namespace CR + diff --git a/extensions/nvdiffrast/nvdiffrast/common/cudaraster/impl/BinRaster.inl b/extensions/nvdiffrast/nvdiffrast/common/cudaraster/impl/BinRaster.inl new file mode 100644 index 0000000000000000000000000000000000000000..deae9d2c16d780f6cb223fa6a44aa8082003b5ee --- /dev/null +++ b/extensions/nvdiffrast/nvdiffrast/common/cudaraster/impl/BinRaster.inl @@ -0,0 +1,423 @@ +// Copyright (c) 2009-2022, NVIDIA CORPORATION. All rights reserved. +// +// NVIDIA CORPORATION and its licensors retain all intellectual property +// and proprietary rights in and to this software, related documentation +// and any modifications thereto. Any use, reproduction, disclosure or +// distribution of this software and related documentation without an express +// license agreement from NVIDIA CORPORATION is strictly prohibited. + +//------------------------------------------------------------------------ + +__device__ __inline__ void binRasterImpl(const CRParams p) +{ + __shared__ volatile U32 s_broadcast [CR_BIN_WARPS + 16]; + __shared__ volatile S32 s_outOfs [CR_MAXBINS_SQR]; + __shared__ volatile S32 s_outTotal [CR_MAXBINS_SQR]; + __shared__ volatile S32 s_overIndex [CR_MAXBINS_SQR]; + __shared__ volatile S32 s_outMask [CR_BIN_WARPS][CR_MAXBINS_SQR + 1]; // +1 to avoid bank collisions + __shared__ volatile S32 s_outCount [CR_BIN_WARPS][CR_MAXBINS_SQR + 1]; // +1 to avoid bank collisions + __shared__ volatile S32 s_triBuf [CR_BIN_WARPS*32*4]; // triangle ring buffer + __shared__ volatile U32 s_batchPos; + __shared__ volatile U32 s_bufCount; + __shared__ volatile U32 s_overTotal; + __shared__ volatile U32 s_allocBase; + + const CRImageParams& ip = getImageParams(p, blockIdx.z); + CRAtomics& atomics = p.atomics[blockIdx.z]; + const U8* triSubtris = (const U8*)p.triSubtris + p.maxSubtris * blockIdx.z; + const CRTriangleHeader* triHeader = (const CRTriangleHeader*)p.triHeader + p.maxSubtris * blockIdx.z; + + S32* binFirstSeg = (S32*)p.binFirstSeg + CR_MAXBINS_SQR * CR_BIN_STREAMS_SIZE * blockIdx.z; + S32* binTotal = (S32*)p.binTotal + CR_MAXBINS_SQR * CR_BIN_STREAMS_SIZE * blockIdx.z; + S32* binSegData = (S32*)p.binSegData + p.maxBinSegs * CR_BIN_SEG_SIZE * blockIdx.z; + S32* binSegNext = (S32*)p.binSegNext + p.maxBinSegs * blockIdx.z; + S32* binSegCount = (S32*)p.binSegCount + p.maxBinSegs * blockIdx.z; + + if (atomics.numSubtris > p.maxSubtris) + return; + + // per-thread state + int thrInBlock = threadIdx.x + threadIdx.y * 32; + int batchPos = 0; + + // first 16 elements of s_broadcast are always zero + if (thrInBlock < 16) + s_broadcast[thrInBlock] = 0; + + // initialize output linked lists and offsets + if (thrInBlock < p.numBins) + { + binFirstSeg[(thrInBlock << CR_BIN_STREAMS_LOG2) + blockIdx.x] = -1; + s_outOfs[thrInBlock] = -CR_BIN_SEG_SIZE; + s_outTotal[thrInBlock] = 0; + } + + // repeat until done + for(;;) + { + // get batch + if (thrInBlock == 0) + s_batchPos = atomicAdd(&atomics.binCounter, ip.binBatchSize); + __syncthreads(); + batchPos = s_batchPos; + + // all batches done? + if (batchPos >= ip.triCount) + break; + + // per-thread state + int bufIndex = 0; + int bufCount = 0; + int batchEnd = min(batchPos + ip.binBatchSize, ip.triCount); + + // loop over batch as long as we have triangles in it + do + { + // read more triangles + while (bufCount < CR_BIN_WARPS*32 && batchPos < batchEnd) + { + // get subtriangle count + + int triIdx = batchPos + thrInBlock; + int num = 0; + if (triIdx < batchEnd) + num = triSubtris[triIdx]; + + // cumulative sum of subtriangles within each warp + U32 myIdx = __popc(__ballot_sync(~0u, num & 1) & getLaneMaskLt()); + if (__any_sync(~0u, num > 1)) + { + myIdx += __popc(__ballot_sync(~0u, num & 2) & getLaneMaskLt()) * 2; + myIdx += __popc(__ballot_sync(~0u, num & 4) & getLaneMaskLt()) * 4; + } + if (threadIdx.x == 31) // Do not assume that last thread in warp wins the write. + s_broadcast[threadIdx.y + 16] = myIdx + num; + __syncthreads(); + + // cumulative sum of per-warp subtriangle counts + // Note: cannot have more than 32 warps or this needs to sync between each step. + bool act = (thrInBlock < CR_BIN_WARPS); + U32 actMask = __ballot_sync(~0u, act); + if (threadIdx.y == 0 && act) + { + volatile U32* ptr = &s_broadcast[thrInBlock + 16]; + U32 val = *ptr; + #if (CR_BIN_WARPS > 1) + val += ptr[-1]; __syncwarp(actMask); + *ptr = val; __syncwarp(actMask); + #endif + #if (CR_BIN_WARPS > 2) + val += ptr[-2]; __syncwarp(actMask); + *ptr = val; __syncwarp(actMask); + #endif + #if (CR_BIN_WARPS > 4) + val += ptr[-4]; __syncwarp(actMask); + *ptr = val; __syncwarp(actMask); + #endif + #if (CR_BIN_WARPS > 8) + val += ptr[-8]; __syncwarp(actMask); + *ptr = val; __syncwarp(actMask); + #endif + #if (CR_BIN_WARPS > 16) + val += ptr[-16]; __syncwarp(actMask); + *ptr = val; __syncwarp(actMask); + #endif + + // initially assume that we consume everything + // only last active thread does the writes + if (threadIdx.x == CR_BIN_WARPS - 1) + { + s_batchPos = batchPos + CR_BIN_WARPS * 32; + s_bufCount = bufCount + val; + } + } + __syncthreads(); + + // skip if no subtriangles + if (num) + { + // calculate write position for first subtriangle + U32 pos = bufCount + myIdx + s_broadcast[threadIdx.y + 16 - 1]; + + // only write if entire triangle fits + if (pos + num <= CR_ARRAY_SIZE(s_triBuf)) + { + pos += bufIndex; // adjust for current start position + pos &= CR_ARRAY_SIZE(s_triBuf)-1; + if (num == 1) + s_triBuf[pos] = triIdx * 8 + 7; // single triangle + else + { + for (int i=0; i < num; i++) + { + s_triBuf[pos] = triIdx * 8 + i; + pos++; + pos &= CR_ARRAY_SIZE(s_triBuf)-1; + } + } + } else if (pos <= CR_ARRAY_SIZE(s_triBuf)) + { + // this triangle is the first that failed, overwrite total count and triangle count + s_batchPos = batchPos + thrInBlock; + s_bufCount = pos; + } + } + + // update triangle counts + __syncthreads(); + batchPos = s_batchPos; + bufCount = s_bufCount; + } + + // make every warp clear its output buffers + for (int i=threadIdx.x; i < p.numBins; i += 32) + s_outMask[threadIdx.y][i] = 0; + __syncwarp(); + + // choose our triangle + uint4 triData = make_uint4(0, 0, 0, 0); + if (thrInBlock < bufCount) + { + U32 triPos = bufIndex + thrInBlock; + triPos &= CR_ARRAY_SIZE(s_triBuf)-1; + + // find triangle + int triIdx = s_triBuf[triPos]; + int dataIdx = triIdx >> 3; + int subtriIdx = triIdx & 7; + if (subtriIdx != 7) + dataIdx = triHeader[dataIdx].misc + subtriIdx; + + // read triangle + + triData = *(((const uint4*)triHeader) + dataIdx); + } + + // setup bounding box and edge functions, and rasterize + S32 lox, loy, hix, hiy; + bool hasTri = (thrInBlock < bufCount); + U32 hasTriMask = __ballot_sync(~0u, hasTri); + if (hasTri) + { + S32 v0x = add_s16lo_s16lo(triData.x, p.widthPixelsVp * (CR_SUBPIXEL_SIZE >> 1)); + S32 v0y = add_s16hi_s16lo(triData.x, p.heightPixelsVp * (CR_SUBPIXEL_SIZE >> 1)); + S32 d01x = sub_s16lo_s16lo(triData.y, triData.x); + S32 d01y = sub_s16hi_s16hi(triData.y, triData.x); + S32 d02x = sub_s16lo_s16lo(triData.z, triData.x); + S32 d02y = sub_s16hi_s16hi(triData.z, triData.x); + int binLog = CR_BIN_LOG2 + CR_TILE_LOG2 + CR_SUBPIXEL_LOG2; + lox = add_clamp_0_x((v0x + min_min(d01x, 0, d02x)) >> binLog, 0, p.widthBins - 1); + loy = add_clamp_0_x((v0y + min_min(d01y, 0, d02y)) >> binLog, 0, p.heightBins - 1); + hix = add_clamp_0_x((v0x + max_max(d01x, 0, d02x)) >> binLog, 0, p.widthBins - 1); + hiy = add_clamp_0_x((v0y + max_max(d01y, 0, d02y)) >> binLog, 0, p.heightBins - 1); + + U32 bit = 1 << threadIdx.x; +#if __CUDA_ARCH__ >= 700 + bool multi = (hix != lox || hiy != loy); + if (!__any_sync(hasTriMask, multi)) + { + int binIdx = lox + p.widthBins * loy; + U32 mask = __match_any_sync(hasTriMask, binIdx); + s_outMask[threadIdx.y][binIdx] = mask; + __syncwarp(hasTriMask); + } else +#endif + { + bool complex = (hix > lox+1 || hiy > loy+1); + if (!__any_sync(hasTriMask, complex)) + { + int binIdx = lox + p.widthBins * loy; + atomicOr((U32*)&s_outMask[threadIdx.y][binIdx], bit); + if (hix > lox) atomicOr((U32*)&s_outMask[threadIdx.y][binIdx + 1], bit); + if (hiy > loy) atomicOr((U32*)&s_outMask[threadIdx.y][binIdx + p.widthBins], bit); + if (hix > lox && hiy > loy) atomicOr((U32*)&s_outMask[threadIdx.y][binIdx + p.widthBins + 1], bit); + } else + { + S32 d12x = d02x - d01x, d12y = d02y - d01y; + v0x -= lox << binLog, v0y -= loy << binLog; + + S32 t01 = v0x * d01y - v0y * d01x; + S32 t02 = v0y * d02x - v0x * d02y; + S32 t12 = d01x * d12y - d01y * d12x - t01 - t02; + S32 b01 = add_sub(t01 >> binLog, max(d01x, 0), min(d01y, 0)); + S32 b02 = add_sub(t02 >> binLog, max(d02y, 0), min(d02x, 0)); + S32 b12 = add_sub(t12 >> binLog, max(d12x, 0), min(d12y, 0)); + + int width = hix - lox + 1; + d01x += width * d01y; + d02x += width * d02y; + d12x += width * d12y; + + U8* currPtr = (U8*)&s_outMask[threadIdx.y][lox + loy * p.widthBins]; + U8* skipPtr = (U8*)&s_outMask[threadIdx.y][(hix + 1) + loy * p.widthBins]; + U8* endPtr = (U8*)&s_outMask[threadIdx.y][lox + (hiy + 1) * p.widthBins]; + int stride = p.widthBins * 4; + int ptrYInc = stride - width * 4; + + do + { + if (b01 >= 0 && b02 >= 0 && b12 >= 0) + atomicOr((U32*)currPtr, bit); + currPtr += 4, b01 -= d01y, b02 += d02y, b12 -= d12y; + if (currPtr == skipPtr) + currPtr += ptrYInc, b01 += d01x, b02 -= d02x, b12 += d12x, skipPtr += stride; + } + while (currPtr != endPtr); + } + } + } + + // count per-bin contributions + if (thrInBlock == 0) + s_overTotal = 0; // overflow counter + + // ensure that out masks are done + __syncthreads(); + + int overIndex = -1; + bool act = (thrInBlock < p.numBins); + U32 actMask = __ballot_sync(~0u, act); + if (act) + { + U8* srcPtr = (U8*)&s_outMask[0][thrInBlock]; + U8* dstPtr = (U8*)&s_outCount[0][thrInBlock]; + int total = 0; + for (int i = 0; i < CR_BIN_WARPS; i++) + { + total += __popc(*(U32*)srcPtr); + *(U32*)dstPtr = total; + srcPtr += (CR_MAXBINS_SQR + 1) * 4; + dstPtr += (CR_MAXBINS_SQR + 1) * 4; + } + + // overflow => request a new segment + int ofs = s_outOfs[thrInBlock]; + bool ovr = (((ofs - 1) >> CR_BIN_SEG_LOG2) != (((ofs - 1) + total) >> CR_BIN_SEG_LOG2)); + U32 ovrMask = __ballot_sync(actMask, ovr); + if (ovr) + { + overIndex = __popc(ovrMask & getLaneMaskLt()); + if (overIndex == 0) + s_broadcast[threadIdx.y + 16] = atomicAdd((U32*)&s_overTotal, __popc(ovrMask)); + __syncwarp(ovrMask); + overIndex += s_broadcast[threadIdx.y + 16]; + s_overIndex[thrInBlock] = overIndex; + } + } + + // sync after overTotal is ready + __syncthreads(); + + // at least one segment overflowed => allocate segments + U32 overTotal = s_overTotal; + U32 allocBase = 0; + if (overTotal > 0) + { + // allocate memory + if (thrInBlock == 0) + { + U32 allocBase = atomicAdd(&atomics.numBinSegs, overTotal); + s_allocBase = (allocBase + overTotal <= p.maxBinSegs) ? allocBase : 0; + } + __syncthreads(); + allocBase = s_allocBase; + + // did my bin overflow? + if (overIndex != -1) + { + // calculate new segment index + int segIdx = allocBase + overIndex; + + // add to linked list + if (s_outOfs[thrInBlock] < 0) + binFirstSeg[(thrInBlock << CR_BIN_STREAMS_LOG2) + blockIdx.x] = segIdx; + else + binSegNext[(s_outOfs[thrInBlock] - 1) >> CR_BIN_SEG_LOG2] = segIdx; + + // defaults + binSegNext [segIdx] = -1; + binSegCount[segIdx] = CR_BIN_SEG_SIZE; + } + } + + // concurrent emission -- each warp handles its own triangle + if (thrInBlock < bufCount) + { + int triPos = (bufIndex + thrInBlock) & (CR_ARRAY_SIZE(s_triBuf) - 1); + int currBin = lox + loy * p.widthBins; + int skipBin = (hix + 1) + loy * p.widthBins; + int endBin = lox + (hiy + 1) * p.widthBins; + int binYInc = p.widthBins - (hix - lox + 1); + + // loop over triangle's bins + do + { + U32 outMask = s_outMask[threadIdx.y][currBin]; + if (outMask & (1< 0) + idx += s_outCount[threadIdx.y-1][currBin]; + + int base = s_outOfs[currBin]; + int free = (-base) & (CR_BIN_SEG_SIZE - 1); + if (idx >= free) + idx += ((allocBase + s_overIndex[currBin]) << CR_BIN_SEG_LOG2) - free; + else + idx += base; + + binSegData[idx] = s_triBuf[triPos]; + } + + currBin++; + if (currBin == skipBin) + currBin += binYInc, skipBin += p.widthBins; + } + while (currBin != endBin); + } + + // wait all triangles to finish, then replace overflown segment offsets + __syncthreads(); + if (thrInBlock < p.numBins) + { + U32 total = s_outCount[CR_BIN_WARPS - 1][thrInBlock]; + U32 oldOfs = s_outOfs[thrInBlock]; + if (overIndex == -1) + s_outOfs[thrInBlock] = oldOfs + total; + else + { + int addr = oldOfs + total; + addr = ((addr - 1) & (CR_BIN_SEG_SIZE - 1)) + 1; + addr += (allocBase + overIndex) << CR_BIN_SEG_LOG2; + s_outOfs[thrInBlock] = addr; + } + s_outTotal[thrInBlock] += total; + } + + // these triangles are now done + int count = ::min(bufCount, CR_BIN_WARPS * 32); + bufCount -= count; + bufIndex += count; + bufIndex &= CR_ARRAY_SIZE(s_triBuf)-1; + } + while (bufCount > 0 || batchPos < batchEnd); + + // flush all bins + if (thrInBlock < p.numBins) + { + int ofs = s_outOfs[thrInBlock]; + if (ofs & (CR_BIN_SEG_SIZE-1)) + { + int seg = ofs >> CR_BIN_SEG_LOG2; + binSegCount[seg] = ofs & (CR_BIN_SEG_SIZE-1); + s_outOfs[thrInBlock] = (ofs + CR_BIN_SEG_SIZE - 1) & -CR_BIN_SEG_SIZE; + } + } + } + + // output totals + if (thrInBlock < p.numBins) + binTotal[(thrInBlock << CR_BIN_STREAMS_LOG2) + blockIdx.x] = s_outTotal[thrInBlock]; +} + +//------------------------------------------------------------------------ diff --git a/extensions/nvdiffrast/nvdiffrast/common/cudaraster/impl/Buffer.cpp b/extensions/nvdiffrast/nvdiffrast/common/cudaraster/impl/Buffer.cpp new file mode 100644 index 0000000000000000000000000000000000000000..b2cd7b92ba90964d4d8f66b6a3554d75b1737885 --- /dev/null +++ b/extensions/nvdiffrast/nvdiffrast/common/cudaraster/impl/Buffer.cpp @@ -0,0 +1,94 @@ +// Copyright (c) 2009-2022, NVIDIA CORPORATION. All rights reserved. +// +// NVIDIA CORPORATION and its licensors retain all intellectual property +// and proprietary rights in and to this software, related documentation +// and any modifications thereto. Any use, reproduction, disclosure or +// distribution of this software and related documentation without an express +// license agreement from NVIDIA CORPORATION is strictly prohibited. + +#include "../../framework.h" +#include "Buffer.hpp" + +using namespace CR; + +//------------------------------------------------------------------------ +// GPU buffer. +//------------------------------------------------------------------------ + +Buffer::Buffer(void) +: m_gpuPtr(NULL), + m_bytes (0) +{ + // empty +} + +Buffer::~Buffer(void) +{ + if (m_gpuPtr) + cudaFree(m_gpuPtr); // Don't throw an exception. +} + +void Buffer::reset(size_t bytes) +{ + if (bytes == m_bytes) + return; + + if (m_gpuPtr) + { + NVDR_CHECK_CUDA_ERROR(cudaFree(m_gpuPtr)); + m_gpuPtr = NULL; + } + + if (bytes > 0) + NVDR_CHECK_CUDA_ERROR(cudaMalloc(&m_gpuPtr, bytes)); + + m_bytes = bytes; +} + +void Buffer::grow(size_t bytes) +{ + if (bytes > m_bytes) + reset(bytes); +} + +//------------------------------------------------------------------------ +// Host buffer with page-locked memory. +//------------------------------------------------------------------------ + +HostBuffer::HostBuffer(void) +: m_hostPtr(NULL), + m_bytes (0) +{ + // empty +} + +HostBuffer::~HostBuffer(void) +{ + if (m_hostPtr) + cudaFreeHost(m_hostPtr); // Don't throw an exception. +} + +void HostBuffer::reset(size_t bytes) +{ + if (bytes == m_bytes) + return; + + if (m_hostPtr) + { + NVDR_CHECK_CUDA_ERROR(cudaFreeHost(m_hostPtr)); + m_hostPtr = NULL; + } + + if (bytes > 0) + NVDR_CHECK_CUDA_ERROR(cudaMallocHost(&m_hostPtr, bytes)); + + m_bytes = bytes; +} + +void HostBuffer::grow(size_t bytes) +{ + if (bytes > m_bytes) + reset(bytes); +} + +//------------------------------------------------------------------------ diff --git a/extensions/nvdiffrast/nvdiffrast/common/cudaraster/impl/Buffer.hpp b/extensions/nvdiffrast/nvdiffrast/common/cudaraster/impl/Buffer.hpp new file mode 100644 index 0000000000000000000000000000000000000000..8a4b38fdbedf668366c94c0263a61815e62a6a3a --- /dev/null +++ b/extensions/nvdiffrast/nvdiffrast/common/cudaraster/impl/Buffer.hpp @@ -0,0 +1,55 @@ +// Copyright (c) 2009-2022, NVIDIA CORPORATION. All rights reserved. +// +// NVIDIA CORPORATION and its licensors retain all intellectual property +// and proprietary rights in and to this software, related documentation +// and any modifications thereto. Any use, reproduction, disclosure or +// distribution of this software and related documentation without an express +// license agreement from NVIDIA CORPORATION is strictly prohibited. + +#pragma once +#include "Defs.hpp" + +namespace CR +{ +//------------------------------------------------------------------------ + +class Buffer +{ +public: + Buffer (void); + ~Buffer (void); + + void reset (size_t bytes); + void grow (size_t bytes); + void* getPtr (size_t offset = 0) { return (void*)(((uintptr_t)m_gpuPtr) + offset); } + size_t getSize (void) const { return m_bytes; } + + void setPtr (void* ptr) { m_gpuPtr = ptr; } + +private: + void* m_gpuPtr; + size_t m_bytes; +}; + +//------------------------------------------------------------------------ + +class HostBuffer +{ +public: + HostBuffer (void); + ~HostBuffer (void); + + void reset (size_t bytes); + void grow (size_t bytes); + void* getPtr (void) { return m_hostPtr; } + size_t getSize (void) const { return m_bytes; } + + void setPtr (void* ptr) { m_hostPtr = ptr; } + +private: + void* m_hostPtr; + size_t m_bytes; +}; + +//------------------------------------------------------------------------ +} diff --git a/extensions/nvdiffrast/nvdiffrast/common/cudaraster/impl/CoarseRaster.inl b/extensions/nvdiffrast/nvdiffrast/common/cudaraster/impl/CoarseRaster.inl new file mode 100644 index 0000000000000000000000000000000000000000..a7081c7e3dee992bbb0223e9008b17a3c69e6387 --- /dev/null +++ b/extensions/nvdiffrast/nvdiffrast/common/cudaraster/impl/CoarseRaster.inl @@ -0,0 +1,730 @@ +// Copyright (c) 2009-2022, NVIDIA CORPORATION. All rights reserved. +// +// NVIDIA CORPORATION and its licensors retain all intellectual property +// and proprietary rights in and to this software, related documentation +// and any modifications thereto. Any use, reproduction, disclosure or +// distribution of this software and related documentation without an express +// license agreement from NVIDIA CORPORATION is strictly prohibited. + +//------------------------------------------------------------------------ + +__device__ __inline__ int globalTileIdx(int tileInBin, int widthTiles) +{ + int tileX = tileInBin & (CR_BIN_SIZE - 1); + int tileY = tileInBin >> CR_BIN_LOG2; + return tileX + tileY * widthTiles; +} + +//------------------------------------------------------------------------ + +__device__ __inline__ void coarseRasterImpl(const CRParams p) +{ + // Common. + + __shared__ volatile U32 s_workCounter; + __shared__ volatile U32 s_scanTemp [CR_COARSE_WARPS][48]; // 3KB + + // Input. + + __shared__ volatile U32 s_binOrder [CR_MAXBINS_SQR]; // 1KB + __shared__ volatile S32 s_binStreamCurrSeg [CR_BIN_STREAMS_SIZE]; // 0KB + __shared__ volatile S32 s_binStreamFirstTri [CR_BIN_STREAMS_SIZE]; // 0KB + __shared__ volatile S32 s_triQueue [CR_COARSE_QUEUE_SIZE]; // 4KB + __shared__ volatile S32 s_triQueueWritePos; + __shared__ volatile U32 s_binStreamSelectedOfs; + __shared__ volatile U32 s_binStreamSelectedSize; + + // Output. + + __shared__ volatile U32 s_warpEmitMask [CR_COARSE_WARPS][CR_BIN_SQR + 1]; // 16KB, +1 to avoid bank collisions + __shared__ volatile U32 s_warpEmitPrefixSum [CR_COARSE_WARPS][CR_BIN_SQR + 1]; // 16KB, +1 to avoid bank collisions + __shared__ volatile U32 s_tileEmitPrefixSum [CR_BIN_SQR + 1]; // 1KB, zero at the beginning + __shared__ volatile U32 s_tileAllocPrefixSum[CR_BIN_SQR + 1]; // 1KB, zero at the beginning + __shared__ volatile S32 s_tileStreamCurrOfs [CR_BIN_SQR]; // 1KB + __shared__ volatile U32 s_firstAllocSeg; + __shared__ volatile U32 s_firstActiveIdx; + + // Pointers and constants. + + CRAtomics& atomics = p.atomics[blockIdx.z]; + const CRTriangleHeader* triHeader = (const CRTriangleHeader*)p.triHeader + p.maxSubtris * blockIdx.z; + const S32* binFirstSeg = (const S32*)p.binFirstSeg + CR_MAXBINS_SQR * CR_BIN_STREAMS_SIZE * blockIdx.z; + const S32* binTotal = (const S32*)p.binTotal + CR_MAXBINS_SQR * CR_BIN_STREAMS_SIZE * blockIdx.z; + const S32* binSegData = (const S32*)p.binSegData + p.maxBinSegs * CR_BIN_SEG_SIZE * blockIdx.z; + const S32* binSegNext = (const S32*)p.binSegNext + p.maxBinSegs * blockIdx.z; + const S32* binSegCount = (const S32*)p.binSegCount + p.maxBinSegs * blockIdx.z; + S32* activeTiles = (S32*)p.activeTiles + CR_MAXTILES_SQR * blockIdx.z; + S32* tileFirstSeg = (S32*)p.tileFirstSeg + CR_MAXTILES_SQR * blockIdx.z; + S32* tileSegData = (S32*)p.tileSegData + p.maxTileSegs * CR_TILE_SEG_SIZE * blockIdx.z; + S32* tileSegNext = (S32*)p.tileSegNext + p.maxTileSegs * blockIdx.z; + S32* tileSegCount = (S32*)p.tileSegCount + p.maxTileSegs * blockIdx.z; + + int tileLog = CR_TILE_LOG2 + CR_SUBPIXEL_LOG2; + int thrInBlock = threadIdx.x + threadIdx.y * 32; + int emitShift = CR_BIN_LOG2 * 2 + 5; // We scan ((numEmits << emitShift) | numAllocs) over tiles. + + if (atomics.numSubtris > p.maxSubtris || atomics.numBinSegs > p.maxBinSegs) + return; + + // Initialize sharedmem arrays. + + if (thrInBlock == 0) + { + s_tileEmitPrefixSum[0] = 0; + s_tileAllocPrefixSum[0] = 0; + } + s_scanTemp[threadIdx.y][threadIdx.x] = 0; + + // Sort bins in descending order of triangle count. + + for (int binIdx = thrInBlock; binIdx < p.numBins; binIdx += CR_COARSE_WARPS * 32) + { + int count = 0; + for (int i = 0; i < CR_BIN_STREAMS_SIZE; i++) + count += binTotal[(binIdx << CR_BIN_STREAMS_LOG2) + i]; + s_binOrder[binIdx] = (~count << (CR_MAXBINS_LOG2 * 2)) | binIdx; + } + + __syncthreads(); + sortShared(s_binOrder, p.numBins); + + // Process each bin by one block. + + for (;;) + { + // Pick a bin for the block. + + if (thrInBlock == 0) + s_workCounter = atomicAdd(&atomics.coarseCounter, 1); + __syncthreads(); + + int workCounter = s_workCounter; + if (workCounter >= p.numBins) + break; + + U32 binOrder = s_binOrder[workCounter]; + bool binEmpty = ((~binOrder >> (CR_MAXBINS_LOG2 * 2)) == 0); + if (binEmpty && !p.deferredClear) + break; + + int binIdx = binOrder & (CR_MAXBINS_SQR - 1); + + // Initialize input/output streams. + + int triQueueWritePos = 0; + int triQueueReadPos = 0; + + if (thrInBlock < CR_BIN_STREAMS_SIZE) + { + int segIdx = binFirstSeg[(binIdx << CR_BIN_STREAMS_LOG2) + thrInBlock]; + s_binStreamCurrSeg[thrInBlock] = segIdx; + s_binStreamFirstTri[thrInBlock] = (segIdx == -1) ? ~0u : binSegData[segIdx << CR_BIN_SEG_LOG2]; + } + + for (int tileInBin = CR_COARSE_WARPS * 32 - 1 - thrInBlock; tileInBin < CR_BIN_SQR; tileInBin += CR_COARSE_WARPS * 32) + s_tileStreamCurrOfs[tileInBin] = -CR_TILE_SEG_SIZE; + + // Initialize per-bin state. + + int binY = idiv_fast(binIdx, p.widthBins); + int binX = binIdx - binY * p.widthBins; + int originX = (binX << (CR_BIN_LOG2 + tileLog)) - (p.widthPixelsVp << (CR_SUBPIXEL_LOG2 - 1)); + int originY = (binY << (CR_BIN_LOG2 + tileLog)) - (p.heightPixelsVp << (CR_SUBPIXEL_LOG2 - 1)); + int maxTileXInBin = ::min(p.widthTiles - (binX << CR_BIN_LOG2), CR_BIN_SIZE) - 1; + int maxTileYInBin = ::min(p.heightTiles - (binY << CR_BIN_LOG2), CR_BIN_SIZE) - 1; + int binTileIdx = (binX + binY * p.widthTiles) << CR_BIN_LOG2; + + // Entire block: Merge input streams and process triangles. + + if (!binEmpty) + do + { + //------------------------------------------------------------------------ + // Merge. + //------------------------------------------------------------------------ + + // Entire block: Not enough triangles => merge and queue segments. + // NOTE: The bin exit criterion assumes that we queue more triangles than we actually need. + + while (triQueueWritePos - triQueueReadPos <= CR_COARSE_WARPS * 32) + { + // First warp: Choose the segment with the lowest initial triangle index. + + bool hasStream = (thrInBlock < CR_BIN_STREAMS_SIZE); + U32 hasStreamMask = __ballot_sync(~0u, hasStream); + if (hasStream) + { + // Find the stream with the lowest triangle index. + + U32 firstTri = s_binStreamFirstTri[thrInBlock]; + U32 t = firstTri; + volatile U32* v = &s_scanTemp[0][thrInBlock + 16]; + + #if (CR_BIN_STREAMS_SIZE > 1) + v[0] = t; __syncwarp(hasStreamMask); t = ::min(t, v[-1]); __syncwarp(hasStreamMask); + #endif + #if (CR_BIN_STREAMS_SIZE > 2) + v[0] = t; __syncwarp(hasStreamMask); t = ::min(t, v[-2]); __syncwarp(hasStreamMask); + #endif + #if (CR_BIN_STREAMS_SIZE > 4) + v[0] = t; __syncwarp(hasStreamMask); t = ::min(t, v[-4]); __syncwarp(hasStreamMask); + #endif + #if (CR_BIN_STREAMS_SIZE > 8) + v[0] = t; __syncwarp(hasStreamMask); t = ::min(t, v[-8]); __syncwarp(hasStreamMask); + #endif + #if (CR_BIN_STREAMS_SIZE > 16) + v[0] = t; __syncwarp(hasStreamMask); t = ::min(t, v[-16]); __syncwarp(hasStreamMask); + #endif + v[0] = t; __syncwarp(hasStreamMask); + + // Consume and broadcast. + + bool first = (s_scanTemp[0][CR_BIN_STREAMS_SIZE - 1 + 16] == firstTri); + U32 firstMask = __ballot_sync(hasStreamMask, first); + if (first && (firstMask >> threadIdx.x) == 1u) + { + int segIdx = s_binStreamCurrSeg[thrInBlock]; + s_binStreamSelectedOfs = segIdx << CR_BIN_SEG_LOG2; + if (segIdx != -1) + { + int segSize = binSegCount[segIdx]; + int segNext = binSegNext[segIdx]; + s_binStreamSelectedSize = segSize; + s_triQueueWritePos = triQueueWritePos + segSize; + s_binStreamCurrSeg[thrInBlock] = segNext; + s_binStreamFirstTri[thrInBlock] = (segNext == -1) ? ~0u : binSegData[segNext << CR_BIN_SEG_LOG2]; + } + } + } + + // No more segments => break. + + __syncthreads(); + triQueueWritePos = s_triQueueWritePos; + int segOfs = s_binStreamSelectedOfs; + if (segOfs < 0) + break; + + int segSize = s_binStreamSelectedSize; + __syncthreads(); + + // Fetch triangles into the queue. + + for (int idxInSeg = CR_COARSE_WARPS * 32 - 1 - thrInBlock; idxInSeg < segSize; idxInSeg += CR_COARSE_WARPS * 32) + { + S32 triIdx = binSegData[segOfs + idxInSeg]; + s_triQueue[(triQueueWritePos - segSize + idxInSeg) & (CR_COARSE_QUEUE_SIZE - 1)] = triIdx; + } + } + + // All threads: Clear emit masks. + + for (int maskIdx = thrInBlock; maskIdx < CR_COARSE_WARPS * CR_BIN_SQR; maskIdx += CR_COARSE_WARPS * 32) + s_warpEmitMask[maskIdx >> (CR_BIN_LOG2 * 2)][maskIdx & (CR_BIN_SQR - 1)] = 0; + + __syncthreads(); + + //------------------------------------------------------------------------ + // Raster. + //------------------------------------------------------------------------ + + // Triangle per thread: Read from the queue. + + int triIdx = -1; + if (triQueueReadPos + thrInBlock < triQueueWritePos) + triIdx = s_triQueue[(triQueueReadPos + thrInBlock) & (CR_COARSE_QUEUE_SIZE - 1)]; + + uint4 triData = make_uint4(0, 0, 0, 0); + if (triIdx != -1) + { + int dataIdx = triIdx >> 3; + int subtriIdx = triIdx & 7; + if (subtriIdx != 7) + dataIdx = triHeader[dataIdx].misc + subtriIdx; + triData = *((uint4*)triHeader + dataIdx); + } + + // 32 triangles per warp: Record emits (= tile intersections). + + if (__any_sync(~0u, triIdx != -1)) + { + S32 v0x = sub_s16lo_s16lo(triData.x, originX); + S32 v0y = sub_s16hi_s16lo(triData.x, originY); + S32 d01x = sub_s16lo_s16lo(triData.y, triData.x); + S32 d01y = sub_s16hi_s16hi(triData.y, triData.x); + S32 d02x = sub_s16lo_s16lo(triData.z, triData.x); + S32 d02y = sub_s16hi_s16hi(triData.z, triData.x); + + // Compute tile-based AABB. + + int lox = add_clamp_0_x((v0x + min_min(d01x, 0, d02x)) >> tileLog, 0, maxTileXInBin); + int loy = add_clamp_0_x((v0y + min_min(d01y, 0, d02y)) >> tileLog, 0, maxTileYInBin); + int hix = add_clamp_0_x((v0x + max_max(d01x, 0, d02x)) >> tileLog, 0, maxTileXInBin); + int hiy = add_clamp_0_x((v0y + max_max(d01y, 0, d02y)) >> tileLog, 0, maxTileYInBin); + int sizex = add_sub(hix, 1, lox); + int sizey = add_sub(hiy, 1, loy); + int area = sizex * sizey; + + // Miscellaneous init. + + U8* currPtr = (U8*)&s_warpEmitMask[threadIdx.y][lox + (loy << CR_BIN_LOG2)]; + int ptrYInc = CR_BIN_SIZE * 4 - (sizex << 2); + U32 maskBit = 1 << threadIdx.x; + + // Case A: All AABBs are small => record the full AABB using atomics. + + if (__all_sync(~0u, sizex <= 2 && sizey <= 2)) + { + if (triIdx != -1) + { + atomicOr((U32*)currPtr, maskBit); + if (sizex == 2) atomicOr((U32*)(currPtr + 4), maskBit); + if (sizey == 2) atomicOr((U32*)(currPtr + CR_BIN_SIZE * 4), maskBit); + if (sizex == 2 && sizey == 2) atomicOr((U32*)(currPtr + 4 + CR_BIN_SIZE * 4), maskBit); + } + } + else + { + // Compute warp-AABB (scan-32). + + U32 aabbMask = add_sub(2 << hix, 0x20000 << hiy, 1 << lox) - (0x10000 << loy); + if (triIdx == -1) + aabbMask = 0; + + volatile U32* v = &s_scanTemp[threadIdx.y][threadIdx.x + 16]; + v[0] = aabbMask; __syncwarp(); aabbMask |= v[-1]; __syncwarp(); + v[0] = aabbMask; __syncwarp(); aabbMask |= v[-2]; __syncwarp(); + v[0] = aabbMask; __syncwarp(); aabbMask |= v[-4]; __syncwarp(); + v[0] = aabbMask; __syncwarp(); aabbMask |= v[-8]; __syncwarp(); + v[0] = aabbMask; __syncwarp(); aabbMask |= v[-16]; __syncwarp(); + v[0] = aabbMask; __syncwarp(); aabbMask = s_scanTemp[threadIdx.y][47]; + + U32 maskX = aabbMask & 0xFFFF; + U32 maskY = aabbMask >> 16; + int wlox = findLeadingOne(maskX ^ (maskX - 1)); + int wloy = findLeadingOne(maskY ^ (maskY - 1)); + int whix = findLeadingOne(maskX); + int whiy = findLeadingOne(maskY); + int warea = (add_sub(whix, 1, wlox)) * (add_sub(whiy, 1, wloy)); + + // Initialize edge functions. + + S32 d12x = d02x - d01x; + S32 d12y = d02y - d01y; + v0x -= lox << tileLog; + v0y -= loy << tileLog; + + S32 t01 = v0x * d01y - v0y * d01x; + S32 t02 = v0y * d02x - v0x * d02y; + S32 t12 = d01x * d12y - d01y * d12x - t01 - t02; + S32 b01 = add_sub(t01 >> tileLog, ::max(d01x, 0), ::min(d01y, 0)); + S32 b02 = add_sub(t02 >> tileLog, ::max(d02y, 0), ::min(d02x, 0)); + S32 b12 = add_sub(t12 >> tileLog, ::max(d12x, 0), ::min(d12y, 0)); + + d01x += sizex * d01y; + d02x += sizex * d02y; + d12x += sizex * d12y; + + // Case B: Warp-AABB is not much larger than largest AABB => Check tiles in warp-AABB, record using ballots. + if (__any_sync(~0u, warea * 4 <= area * 8)) + { + // Not sure if this is any faster than Case C after all the post-Volta ballot mask tracking. + bool act = (triIdx != -1); + U32 actMask = __ballot_sync(~0u, act); + if (act) + { + for (int y = wloy; y <= whiy; y++) + { + bool yIn = (y >= loy && y <= hiy); + U32 yMask = __ballot_sync(actMask, yIn); + if (yIn) + { + for (int x = wlox; x <= whix; x++) + { + bool xyIn = (x >= lox && x <= hix); + U32 xyMask = __ballot_sync(yMask, xyIn); + if (xyIn) + { + U32 res = __ballot_sync(xyMask, b01 >= 0 && b02 >= 0 && b12 >= 0); + if (threadIdx.x == 31 - __clz(xyMask)) + *(U32*)currPtr = res; + currPtr += 4, b01 -= d01y, b02 += d02y, b12 -= d12y; + } + } + currPtr += ptrYInc, b01 += d01x, b02 -= d02x, b12 += d12x; + } + } + } + } + + // Case C: General case => Check tiles in AABB, record using atomics. + + else + { + if (triIdx != -1) + { + U8* skipPtr = currPtr + (sizex << 2); + U8* endPtr = currPtr + (sizey << (CR_BIN_LOG2 + 2)); + do + { + if (b01 >= 0 && b02 >= 0 && b12 >= 0) + atomicOr((U32*)currPtr, maskBit); + currPtr += 4, b01 -= d01y, b02 += d02y, b12 -= d12y; + if (currPtr == skipPtr) + currPtr += ptrYInc, b01 += d01x, b02 -= d02x, b12 += d12x, skipPtr += CR_BIN_SIZE * 4; + } + while (currPtr != endPtr); + } + } + } + } + + __syncthreads(); + + //------------------------------------------------------------------------ + // Count. + //------------------------------------------------------------------------ + + // Tile per thread: Initialize prefix sums. + + for (int tileInBin_base = 0; tileInBin_base < CR_BIN_SQR; tileInBin_base += CR_COARSE_WARPS * 32) + { + int tileInBin = tileInBin_base + thrInBlock; + bool act = (tileInBin < CR_BIN_SQR); + U32 actMask = __ballot_sync(~0u, act); + if (act) + { + // Compute prefix sum of emits over warps. + + U8* srcPtr = (U8*)&s_warpEmitMask[0][tileInBin]; + U8* dstPtr = (U8*)&s_warpEmitPrefixSum[0][tileInBin]; + int tileEmits = 0; + for (int i = 0; i < CR_COARSE_WARPS; i++) + { + tileEmits += __popc(*(U32*)srcPtr); + *(U32*)dstPtr = tileEmits; + srcPtr += (CR_BIN_SQR + 1) * 4; + dstPtr += (CR_BIN_SQR + 1) * 4; + } + + // Determine the number of segments to allocate. + + int spaceLeft = -s_tileStreamCurrOfs[tileInBin] & (CR_TILE_SEG_SIZE - 1); + int tileAllocs = (tileEmits - spaceLeft + CR_TILE_SEG_SIZE - 1) >> CR_TILE_SEG_LOG2; + volatile U32* v = &s_tileEmitPrefixSum[tileInBin + 1]; + + // All counters within the warp are small => compute prefix sum using ballot. + + if (!__any_sync(actMask, tileEmits >= 2)) + { + U32 m = getLaneMaskLe(); + *v = (__popc(__ballot_sync(actMask, tileEmits & 1) & m) << emitShift) | __popc(__ballot_sync(actMask, tileAllocs & 1) & m); + } + + // Otherwise => scan-32 within the warp. + + else + { + U32 sum = (tileEmits << emitShift) | tileAllocs; + *v = sum; __syncwarp(actMask); if (threadIdx.x >= 1) sum += v[-1]; __syncwarp(actMask); + *v = sum; __syncwarp(actMask); if (threadIdx.x >= 2) sum += v[-2]; __syncwarp(actMask); + *v = sum; __syncwarp(actMask); if (threadIdx.x >= 4) sum += v[-4]; __syncwarp(actMask); + *v = sum; __syncwarp(actMask); if (threadIdx.x >= 8) sum += v[-8]; __syncwarp(actMask); + *v = sum; __syncwarp(actMask); if (threadIdx.x >= 16) sum += v[-16]; __syncwarp(actMask); + *v = sum; __syncwarp(actMask); + } + } + } + + // First warp: Scan-8. + + __syncthreads(); + + bool scan8 = (thrInBlock < CR_BIN_SQR / 32); + U32 scan8Mask = __ballot_sync(~0u, scan8); + if (scan8) + { + int sum = s_tileEmitPrefixSum[(thrInBlock << 5) + 32]; + volatile U32* v = &s_scanTemp[0][thrInBlock + 16]; + v[0] = sum; __syncwarp(scan8Mask); + #if (CR_BIN_SQR > 1 * 32) + sum += v[-1]; __syncwarp(scan8Mask); v[0] = sum; __syncwarp(scan8Mask); + #endif + #if (CR_BIN_SQR > 2 * 32) + sum += v[-2]; __syncwarp(scan8Mask); v[0] = sum; __syncwarp(scan8Mask); + #endif + #if (CR_BIN_SQR > 4 * 32) + sum += v[-4]; __syncwarp(scan8Mask); v[0] = sum; __syncwarp(scan8Mask); + #endif + } + + __syncthreads(); + + // Tile per thread: Finalize prefix sums. + // Single thread: Allocate segments. + + for (int tileInBin = thrInBlock; tileInBin < CR_BIN_SQR; tileInBin += CR_COARSE_WARPS * 32) + { + int sum = s_tileEmitPrefixSum[tileInBin + 1] + s_scanTemp[0][(tileInBin >> 5) + 15]; + int numEmits = sum >> emitShift; + int numAllocs = sum & ((1 << emitShift) - 1); + s_tileEmitPrefixSum[tileInBin + 1] = numEmits; + s_tileAllocPrefixSum[tileInBin + 1] = numAllocs; + + if (tileInBin == CR_BIN_SQR - 1 && numAllocs != 0) + { + int t = atomicAdd(&atomics.numTileSegs, numAllocs); + s_firstAllocSeg = (t + numAllocs <= p.maxTileSegs) ? t : 0; + } + } + + __syncthreads(); + int firstAllocSeg = s_firstAllocSeg; + int totalEmits = s_tileEmitPrefixSum[CR_BIN_SQR]; + int totalAllocs = s_tileAllocPrefixSum[CR_BIN_SQR]; + + //------------------------------------------------------------------------ + // Emit. + //------------------------------------------------------------------------ + + // Emit per thread: Write triangle index to globalmem. + + for (int emitInBin = thrInBlock; emitInBin < totalEmits; emitInBin += CR_COARSE_WARPS * 32) + { + // Find tile in bin. + + U8* tileBase = (U8*)&s_tileEmitPrefixSum[0]; + U8* tilePtr = tileBase; + U8* ptr; + + #if (CR_BIN_SQR > 128) + ptr = tilePtr + 0x80 * 4; if (emitInBin >= *(U32*)ptr) tilePtr = ptr; + #endif + #if (CR_BIN_SQR > 64) + ptr = tilePtr + 0x40 * 4; if (emitInBin >= *(U32*)ptr) tilePtr = ptr; + #endif + #if (CR_BIN_SQR > 32) + ptr = tilePtr + 0x20 * 4; if (emitInBin >= *(U32*)ptr) tilePtr = ptr; + #endif + #if (CR_BIN_SQR > 16) + ptr = tilePtr + 0x10 * 4; if (emitInBin >= *(U32*)ptr) tilePtr = ptr; + #endif + #if (CR_BIN_SQR > 8) + ptr = tilePtr + 0x08 * 4; if (emitInBin >= *(U32*)ptr) tilePtr = ptr; + #endif + #if (CR_BIN_SQR > 4) + ptr = tilePtr + 0x04 * 4; if (emitInBin >= *(U32*)ptr) tilePtr = ptr; + #endif + #if (CR_BIN_SQR > 2) + ptr = tilePtr + 0x02 * 4; if (emitInBin >= *(U32*)ptr) tilePtr = ptr; + #endif + #if (CR_BIN_SQR > 1) + ptr = tilePtr + 0x01 * 4; if (emitInBin >= *(U32*)ptr) tilePtr = ptr; + #endif + + int tileInBin = (tilePtr - tileBase) >> 2; + int emitInTile = emitInBin - *(U32*)tilePtr; + + // Find warp in tile. + + int warpStep = (CR_BIN_SQR + 1) * 4; + U8* warpBase = (U8*)&s_warpEmitPrefixSum[0][tileInBin] - warpStep; + U8* warpPtr = warpBase; + + #if (CR_COARSE_WARPS > 8) + ptr = warpPtr + 0x08 * warpStep; if (emitInTile >= *(U32*)ptr) warpPtr = ptr; + #endif + #if (CR_COARSE_WARPS > 4) + ptr = warpPtr + 0x04 * warpStep; if (emitInTile >= *(U32*)ptr) warpPtr = ptr; + #endif + #if (CR_COARSE_WARPS > 2) + ptr = warpPtr + 0x02 * warpStep; if (emitInTile >= *(U32*)ptr) warpPtr = ptr; + #endif + #if (CR_COARSE_WARPS > 1) + ptr = warpPtr + 0x01 * warpStep; if (emitInTile >= *(U32*)ptr) warpPtr = ptr; + #endif + + int warpInTile = (warpPtr - warpBase) >> (CR_BIN_LOG2 * 2 + 2); + U32 emitMask = *(U32*)(warpPtr + warpStep + ((U8*)s_warpEmitMask - (U8*)s_warpEmitPrefixSum)); + int emitInWarp = emitInTile - *(U32*)(warpPtr + warpStep) + __popc(emitMask); + + // Find thread in warp. + + int threadInWarp = 0; + int pop = __popc(emitMask & 0xFFFF); + bool pred = (emitInWarp >= pop); + if (pred) emitInWarp -= pop; + if (pred) emitMask >>= 0x10; + if (pred) threadInWarp += 0x10; + + pop = __popc(emitMask & 0xFF); + pred = (emitInWarp >= pop); + if (pred) emitInWarp -= pop; + if (pred) emitMask >>= 0x08; + if (pred) threadInWarp += 0x08; + + pop = __popc(emitMask & 0xF); + pred = (emitInWarp >= pop); + if (pred) emitInWarp -= pop; + if (pred) emitMask >>= 0x04; + if (pred) threadInWarp += 0x04; + + pop = __popc(emitMask & 0x3); + pred = (emitInWarp >= pop); + if (pred) emitInWarp -= pop; + if (pred) emitMask >>= 0x02; + if (pred) threadInWarp += 0x02; + + if (emitInWarp >= (emitMask & 1)) + threadInWarp++; + + // Figure out where to write. + + int currOfs = s_tileStreamCurrOfs[tileInBin]; + int spaceLeft = -currOfs & (CR_TILE_SEG_SIZE - 1); + int outOfs = emitInTile; + + if (outOfs < spaceLeft) + outOfs += currOfs; + else + { + int allocLo = firstAllocSeg + s_tileAllocPrefixSum[tileInBin]; + outOfs += (allocLo << CR_TILE_SEG_LOG2) - spaceLeft; + } + + // Write. + + int queueIdx = warpInTile * 32 + threadInWarp; + int triIdx = s_triQueue[(triQueueReadPos + queueIdx) & (CR_COARSE_QUEUE_SIZE - 1)]; + + tileSegData[outOfs] = triIdx; + } + + //------------------------------------------------------------------------ + // Patch. + //------------------------------------------------------------------------ + + // Allocated segment per thread: Initialize next-pointer and count. + + for (int i = CR_COARSE_WARPS * 32 - 1 - thrInBlock; i < totalAllocs; i += CR_COARSE_WARPS * 32) + { + int segIdx = firstAllocSeg + i; + tileSegNext[segIdx] = segIdx + 1; + tileSegCount[segIdx] = CR_TILE_SEG_SIZE; + } + + // Tile per thread: Fix previous segment's next-pointer and update s_tileStreamCurrOfs. + + __syncthreads(); + for (int tileInBin = CR_COARSE_WARPS * 32 - 1 - thrInBlock; tileInBin < CR_BIN_SQR; tileInBin += CR_COARSE_WARPS * 32) + { + int oldOfs = s_tileStreamCurrOfs[tileInBin]; + int newOfs = oldOfs + s_warpEmitPrefixSum[CR_COARSE_WARPS - 1][tileInBin]; + int allocLo = s_tileAllocPrefixSum[tileInBin]; + int allocHi = s_tileAllocPrefixSum[tileInBin + 1]; + + if (allocLo != allocHi) + { + S32* nextPtr = &tileSegNext[(oldOfs - 1) >> CR_TILE_SEG_LOG2]; + if (oldOfs < 0) + nextPtr = &tileFirstSeg[binTileIdx + globalTileIdx(tileInBin, p.widthTiles)]; + *nextPtr = firstAllocSeg + allocLo; + + newOfs--; + newOfs &= CR_TILE_SEG_SIZE - 1; + newOfs |= (firstAllocSeg + allocHi - 1) << CR_TILE_SEG_LOG2; + newOfs++; + } + s_tileStreamCurrOfs[tileInBin] = newOfs; + } + + // Advance queue read pointer. + // Queue became empty => bin done. + + triQueueReadPos += CR_COARSE_WARPS * 32; + } + while (triQueueReadPos < triQueueWritePos); + + // Tile per thread: Fix next-pointer and count of the last segment. + // 32 tiles per warp: Count active tiles. + + __syncthreads(); + + for (int tileInBin_base = 0; tileInBin_base < CR_BIN_SQR; tileInBin_base += CR_COARSE_WARPS * 32) + { + int tileInBin = tileInBin_base + thrInBlock; + bool act = (tileInBin < CR_BIN_SQR); + U32 actMask = __ballot_sync(~0u, act); + if (act) + { + int tileX = tileInBin & (CR_BIN_SIZE - 1); + int tileY = tileInBin >> CR_BIN_LOG2; + bool force = (p.deferredClear & tileX <= maxTileXInBin & tileY <= maxTileYInBin); + + int ofs = s_tileStreamCurrOfs[tileInBin]; + int segIdx = (ofs - 1) >> CR_TILE_SEG_LOG2; + int segCount = ofs & (CR_TILE_SEG_SIZE - 1); + + if (ofs >= 0) + tileSegNext[segIdx] = -1; + else if (force) + { + s_tileStreamCurrOfs[tileInBin] = 0; + tileFirstSeg[binTileIdx + tileX + tileY * p.widthTiles] = -1; + } + + if (segCount != 0) + tileSegCount[segIdx] = segCount; + + U32 res = __ballot_sync(actMask, ofs >= 0 | force); + if (threadIdx.x == 0) + s_scanTemp[0][(tileInBin >> 5) + 16] = __popc(res); + } + } + + // First warp: Scan-8. + // One thread: Allocate space for active tiles. + + __syncthreads(); + + bool scan8 = (thrInBlock < CR_BIN_SQR / 32); + U32 scan8Mask = __ballot_sync(~0u, scan8); + if (scan8) + { + volatile U32* v = &s_scanTemp[0][thrInBlock + 16]; + U32 sum = v[0]; + #if (CR_BIN_SQR > 1 * 32) + sum += v[-1]; __syncwarp(scan8Mask); v[0] = sum; __syncwarp(scan8Mask); + #endif + #if (CR_BIN_SQR > 2 * 32) + sum += v[-2]; __syncwarp(scan8Mask); v[0] = sum; __syncwarp(scan8Mask); + #endif + #if (CR_BIN_SQR > 4 * 32) + sum += v[-4]; __syncwarp(scan8Mask); v[0] = sum; __syncwarp(scan8Mask); + #endif + + if (thrInBlock == CR_BIN_SQR / 32 - 1) + s_firstActiveIdx = atomicAdd(&atomics.numActiveTiles, sum); + } + + // Tile per thread: Output active tiles. + + __syncthreads(); + + for (int tileInBin_base = 0; tileInBin_base < CR_BIN_SQR; tileInBin_base += CR_COARSE_WARPS * 32) + { + int tileInBin = tileInBin_base + thrInBlock; + bool act = (tileInBin < CR_BIN_SQR) && (s_tileStreamCurrOfs[tileInBin] >= 0); + U32 actMask = __ballot_sync(~0u, act); + if (act) + { + int activeIdx = s_firstActiveIdx; + activeIdx += s_scanTemp[0][(tileInBin >> 5) + 15]; + activeIdx += __popc(actMask & getLaneMaskLt()); + activeTiles[activeIdx] = binTileIdx + globalTileIdx(tileInBin, p.widthTiles); + } + } + } +} + +//------------------------------------------------------------------------ diff --git a/extensions/nvdiffrast/nvdiffrast/common/cudaraster/impl/Constants.hpp b/extensions/nvdiffrast/nvdiffrast/common/cudaraster/impl/Constants.hpp new file mode 100644 index 0000000000000000000000000000000000000000..916315cdec21948632ce8b3b383ee654225aad9c --- /dev/null +++ b/extensions/nvdiffrast/nvdiffrast/common/cudaraster/impl/Constants.hpp @@ -0,0 +1,73 @@ +// Copyright (c) 2009-2022, NVIDIA CORPORATION. All rights reserved. +// +// NVIDIA CORPORATION and its licensors retain all intellectual property +// and proprietary rights in and to this software, related documentation +// and any modifications thereto. Any use, reproduction, disclosure or +// distribution of this software and related documentation without an express +// license agreement from NVIDIA CORPORATION is strictly prohibited. + +#pragma once + +//------------------------------------------------------------------------ + +#define CR_MAXVIEWPORT_LOG2 11 // ViewportSize / PixelSize. +#define CR_SUBPIXEL_LOG2 4 // PixelSize / SubpixelSize. + +#define CR_MAXBINS_LOG2 4 // ViewportSize / BinSize. +#define CR_BIN_LOG2 4 // BinSize / TileSize. +#define CR_TILE_LOG2 3 // TileSize / PixelSize. + +#define CR_COVER8X8_LUT_SIZE 768 // 64-bit entries. +#define CR_FLIPBIT_FLIP_Y 2 +#define CR_FLIPBIT_FLIP_X 3 +#define CR_FLIPBIT_SWAP_XY 4 +#define CR_FLIPBIT_COMPL 5 + +#define CR_BIN_STREAMS_LOG2 4 +#define CR_BIN_SEG_LOG2 9 // 32-bit entries. +#define CR_TILE_SEG_LOG2 5 // 32-bit entries. + +#define CR_MAXSUBTRIS_LOG2 24 // Triangle structs. Dictated by CoarseRaster. +#define CR_COARSE_QUEUE_LOG2 10 // Triangles. + +#define CR_SETUP_WARPS 2 +#define CR_SETUP_OPT_BLOCKS 8 +#define CR_BIN_WARPS 16 +#define CR_COARSE_WARPS 16 // Must be a power of two. +#define CR_FINE_MAX_WARPS 20 + +#define CR_EMBED_IMAGE_PARAMS 32 // Number of per-image parameter structs embedded in kernel launch parameter block. + +//------------------------------------------------------------------------ + +#define CR_MAXVIEWPORT_SIZE (1 << CR_MAXVIEWPORT_LOG2) +#define CR_SUBPIXEL_SIZE (1 << CR_SUBPIXEL_LOG2) +#define CR_SUBPIXEL_SQR (1 << (CR_SUBPIXEL_LOG2 * 2)) + +#define CR_MAXBINS_SIZE (1 << CR_MAXBINS_LOG2) +#define CR_MAXBINS_SQR (1 << (CR_MAXBINS_LOG2 * 2)) +#define CR_BIN_SIZE (1 << CR_BIN_LOG2) +#define CR_BIN_SQR (1 << (CR_BIN_LOG2 * 2)) + +#define CR_MAXTILES_LOG2 (CR_MAXBINS_LOG2 + CR_BIN_LOG2) +#define CR_MAXTILES_SIZE (1 << CR_MAXTILES_LOG2) +#define CR_MAXTILES_SQR (1 << (CR_MAXTILES_LOG2 * 2)) +#define CR_TILE_SIZE (1 << CR_TILE_LOG2) +#define CR_TILE_SQR (1 << (CR_TILE_LOG2 * 2)) + +#define CR_BIN_STREAMS_SIZE (1 << CR_BIN_STREAMS_LOG2) +#define CR_BIN_SEG_SIZE (1 << CR_BIN_SEG_LOG2) +#define CR_TILE_SEG_SIZE (1 << CR_TILE_SEG_LOG2) + +#define CR_MAXSUBTRIS_SIZE (1 << CR_MAXSUBTRIS_LOG2) +#define CR_COARSE_QUEUE_SIZE (1 << CR_COARSE_QUEUE_LOG2) + +//------------------------------------------------------------------------ +// When evaluating interpolated Z pixel centers, we may introduce an error +// of (+-CR_LERP_ERROR) ULPs. + +#define CR_LERP_ERROR(SAMPLES_LOG2) (2200u << (SAMPLES_LOG2)) +#define CR_DEPTH_MIN CR_LERP_ERROR(3) +#define CR_DEPTH_MAX (CR_U32_MAX - CR_LERP_ERROR(3)) + +//------------------------------------------------------------------------ diff --git a/extensions/nvdiffrast/nvdiffrast/common/cudaraster/impl/CudaRaster.cpp b/extensions/nvdiffrast/nvdiffrast/common/cudaraster/impl/CudaRaster.cpp new file mode 100644 index 0000000000000000000000000000000000000000..db8bf31434bf2ac1ba420e9aa0fc3a14c05f5c73 --- /dev/null +++ b/extensions/nvdiffrast/nvdiffrast/common/cudaraster/impl/CudaRaster.cpp @@ -0,0 +1,79 @@ +// Copyright (c) 2009-2022, NVIDIA CORPORATION. All rights reserved. +// +// NVIDIA CORPORATION and its licensors retain all intellectual property +// and proprietary rights in and to this software, related documentation +// and any modifications thereto. Any use, reproduction, disclosure or +// distribution of this software and related documentation without an express +// license agreement from NVIDIA CORPORATION is strictly prohibited. + +#include "Defs.hpp" +#include "../CudaRaster.hpp" +#include "RasterImpl.hpp" + +using namespace CR; + +//------------------------------------------------------------------------ +// Stub interface implementation. +//------------------------------------------------------------------------ + +CudaRaster::CudaRaster() +{ + m_impl = new RasterImpl(); +} + +CudaRaster::~CudaRaster() +{ + delete m_impl; +} + +void CudaRaster::setBufferSize(int width, int height, int numImages) +{ + m_impl->setBufferSize(Vec3i(width, height, numImages)); +} + +void CudaRaster::setViewport(int width, int height, int offsetX, int offsetY) +{ + m_impl->setViewport(Vec2i(width, height), Vec2i(offsetX, offsetY)); +} + +void CudaRaster::setRenderModeFlags(U32 flags) +{ + m_impl->setRenderModeFlags(flags); +} + +void CudaRaster::deferredClear(U32 clearColor) +{ + m_impl->deferredClear(clearColor); +} + +void CudaRaster::setVertexBuffer(void* vertices, int numVertices) +{ + m_impl->setVertexBuffer(vertices, numVertices); +} + +void CudaRaster::setIndexBuffer(void* indices, int numTriangles) +{ + m_impl->setIndexBuffer(indices, numTriangles); +} + +bool CudaRaster::drawTriangles(const int* ranges, bool peel, cudaStream_t stream) +{ + return m_impl->drawTriangles((const Vec2i*)ranges, peel, stream); +} + +void* CudaRaster::getColorBuffer(void) +{ + return m_impl->getColorBuffer(); +} + +void* CudaRaster::getDepthBuffer(void) +{ + return m_impl->getDepthBuffer(); +} + +void CudaRaster::swapDepthAndPeel(void) +{ + m_impl->swapDepthAndPeel(); +} + +//------------------------------------------------------------------------ diff --git a/extensions/nvdiffrast/nvdiffrast/common/cudaraster/impl/Defs.hpp b/extensions/nvdiffrast/nvdiffrast/common/cudaraster/impl/Defs.hpp new file mode 100644 index 0000000000000000000000000000000000000000..7aa7774c652954dc975b48f1f6f839369d191e4c --- /dev/null +++ b/extensions/nvdiffrast/nvdiffrast/common/cudaraster/impl/Defs.hpp @@ -0,0 +1,90 @@ +// Copyright (c) 2009-2022, NVIDIA CORPORATION. All rights reserved. +// +// NVIDIA CORPORATION and its licensors retain all intellectual property +// and proprietary rights in and to this software, related documentation +// and any modifications thereto. Any use, reproduction, disclosure or +// distribution of this software and related documentation without an express +// license agreement from NVIDIA CORPORATION is strictly prohibited. + +#pragma once +#include +#include + +namespace CR +{ +//------------------------------------------------------------------------ + +#ifndef NULL +# define NULL 0 +#endif + +#ifdef __CUDACC__ +# define CR_CUDA 1 +#else +# define CR_CUDA 0 +#endif + +#if CR_CUDA +# define CR_CUDA_FUNC __device__ __inline__ +# define CR_CUDA_CONST __constant__ +#else +# define CR_CUDA_FUNC inline +# define CR_CUDA_CONST static const +#endif + +#define CR_UNREF(X) ((void)(X)) +#define CR_ARRAY_SIZE(X) ((int)(sizeof(X) / sizeof((X)[0]))) + +//------------------------------------------------------------------------ + +typedef uint8_t U8; +typedef uint16_t U16; +typedef uint32_t U32; +typedef uint64_t U64; +typedef int8_t S8; +typedef int16_t S16; +typedef int32_t S32; +typedef int64_t S64; +typedef float F32; +typedef double F64; +typedef void (*FuncPtr)(void); + +//------------------------------------------------------------------------ + +#define CR_U32_MAX (0xFFFFFFFFu) +#define CR_S32_MIN (~0x7FFFFFFF) +#define CR_S32_MAX (0x7FFFFFFF) +#define CR_U64_MAX ((U64)(S64)-1) +#define CR_S64_MIN ((S64)-1 << 63) +#define CR_S64_MAX (~((S64)-1 << 63)) +#define CR_F32_MIN (1.175494351e-38f) +#define CR_F32_MAX (3.402823466e+38f) +#define CR_F64_MIN (2.2250738585072014e-308) +#define CR_F64_MAX (1.7976931348623158e+308) + +//------------------------------------------------------------------------ +// Misc types. + +class Vec2i +{ +public: + Vec2i(int x_, int y_) : x(x_), y(y_) {} + int x, y; +}; + +class Vec3i +{ +public: + Vec3i(int x_, int y_, int z_) : x(x_), y(y_), z(z_) {} + int x, y, z; +}; + +//------------------------------------------------------------------------ +// CUDA utilities. + +#if CR_CUDA +# define globalThreadIdx (threadIdx.x + blockDim.x * (threadIdx.y + blockDim.y * (blockIdx.x + gridDim.x * blockIdx.y))) +#endif + +//------------------------------------------------------------------------ +} // namespace CR diff --git a/extensions/nvdiffrast/nvdiffrast/common/cudaraster/impl/FineRaster.inl b/extensions/nvdiffrast/nvdiffrast/common/cudaraster/impl/FineRaster.inl new file mode 100644 index 0000000000000000000000000000000000000000..720e9997cf04265a6e1a28f8f0cd2d7b34a25e28 --- /dev/null +++ b/extensions/nvdiffrast/nvdiffrast/common/cudaraster/impl/FineRaster.inl @@ -0,0 +1,385 @@ +// Copyright (c) 2009-2022, NVIDIA CORPORATION. All rights reserved. +// +// NVIDIA CORPORATION and its licensors retain all intellectual property +// and proprietary rights in and to this software, related documentation +// and any modifications thereto. Any use, reproduction, disclosure or +// distribution of this software and related documentation without an express +// license agreement from NVIDIA CORPORATION is strictly prohibited. + +//------------------------------------------------------------------------ +// Utility funcs. +//------------------------------------------------------------------------ + +__device__ __inline__ void initTileZMax(U32& tileZMax, bool& tileZUpd, volatile U32* tileDepth) +{ + tileZMax = CR_DEPTH_MAX; + tileZUpd = (::min(tileDepth[threadIdx.x], tileDepth[threadIdx.x + 32]) < tileZMax); +} + +__device__ __inline__ void updateTileZMax(U32& tileZMax, bool& tileZUpd, volatile U32* tileDepth, volatile U32* temp) +{ + // Entry is warp-coherent. + if (__any_sync(~0u, tileZUpd)) + { + U32 z = ::max(tileDepth[threadIdx.x], tileDepth[threadIdx.x + 32]); __syncwarp(); + temp[threadIdx.x + 16] = z; __syncwarp(); + z = ::max(z, temp[threadIdx.x + 16 - 1]); __syncwarp(); temp[threadIdx.x + 16] = z; __syncwarp(); + z = ::max(z, temp[threadIdx.x + 16 - 2]); __syncwarp(); temp[threadIdx.x + 16] = z; __syncwarp(); + z = ::max(z, temp[threadIdx.x + 16 - 4]); __syncwarp(); temp[threadIdx.x + 16] = z; __syncwarp(); + z = ::max(z, temp[threadIdx.x + 16 - 8]); __syncwarp(); temp[threadIdx.x + 16] = z; __syncwarp(); + z = ::max(z, temp[threadIdx.x + 16 - 16]); __syncwarp(); temp[threadIdx.x + 16] = z; __syncwarp(); + tileZMax = temp[47]; + tileZUpd = false; + } +} + +//------------------------------------------------------------------------ + +__device__ __inline__ void getTriangle(const CRParams& p, S32& triIdx, S32& dataIdx, uint4& triHeader, S32& segment) +{ + const CRTriangleHeader* triHeaderPtr = (const CRTriangleHeader*)p.triHeader + blockIdx.z * p.maxSubtris;; + const S32* tileSegData = (const S32*)p.tileSegData + p.maxTileSegs * CR_TILE_SEG_SIZE * blockIdx.z; + const S32* tileSegNext = (const S32*)p.tileSegNext + p.maxTileSegs * blockIdx.z; + const S32* tileSegCount = (const S32*)p.tileSegCount + p.maxTileSegs * blockIdx.z; + + if (threadIdx.x >= tileSegCount[segment]) + { + triIdx = -1; + dataIdx = -1; + } + else + { + int subtriIdx = tileSegData[segment * CR_TILE_SEG_SIZE + threadIdx.x]; + triIdx = subtriIdx >> 3; + dataIdx = triIdx; + subtriIdx &= 7; + if (subtriIdx != 7) + dataIdx = triHeaderPtr[triIdx].misc + subtriIdx; + triHeader = *((uint4*)triHeaderPtr + dataIdx); + } + + // advance to next segment + segment = tileSegNext[segment]; +} + +//------------------------------------------------------------------------ + +__device__ __inline__ bool earlyZCull(uint4 triHeader, U32 tileZMax) +{ + U32 zmin = triHeader.w & 0xFFFFF000; + return (zmin > tileZMax); +} + +//------------------------------------------------------------------------ + +__device__ __inline__ U64 trianglePixelCoverage(const CRParams& p, const uint4& triHeader, int tileX, int tileY, volatile U64* s_cover8x8_lut) +{ + int baseX = (tileX << (CR_TILE_LOG2 + CR_SUBPIXEL_LOG2)) - ((p.widthPixelsVp - 1) << (CR_SUBPIXEL_LOG2 - 1)); + int baseY = (tileY << (CR_TILE_LOG2 + CR_SUBPIXEL_LOG2)) - ((p.heightPixelsVp - 1) << (CR_SUBPIXEL_LOG2 - 1)); + + // extract S16 vertex positions while subtracting tile coordinates + S32 v0x = sub_s16lo_s16lo(triHeader.x, baseX); + S32 v0y = sub_s16hi_s16lo(triHeader.x, baseY); + S32 v01x = sub_s16lo_s16lo(triHeader.y, triHeader.x); + S32 v01y = sub_s16hi_s16hi(triHeader.y, triHeader.x); + S32 v20x = sub_s16lo_s16lo(triHeader.x, triHeader.z); + S32 v20y = sub_s16hi_s16hi(triHeader.x, triHeader.z); + + // extract flipbits + U32 f01 = (triHeader.w >> 6) & 0x3C; + U32 f12 = (triHeader.w >> 2) & 0x3C; + U32 f20 = (triHeader.w << 2) & 0x3C; + + // compute per-edge coverage masks + U64 c01, c12, c20; + c01 = cover8x8_exact_fast(v0x, v0y, v01x, v01y, f01, s_cover8x8_lut); + c12 = cover8x8_exact_fast(v0x + v01x, v0y + v01y, -v01x - v20x, -v01y - v20y, f12, s_cover8x8_lut); + c20 = cover8x8_exact_fast(v0x, v0y, v20x, v20y, f20, s_cover8x8_lut); + + // combine masks + return c01 & c12 & c20; +} + +//------------------------------------------------------------------------ + +__device__ __inline__ U32 scan32_value(U32 value, volatile U32* temp) +{ + __syncwarp(); + temp[threadIdx.x + 16] = value; __syncwarp(); + value += temp[threadIdx.x + 16 - 1]; __syncwarp(); temp[threadIdx.x + 16] = value; __syncwarp(); + value += temp[threadIdx.x + 16 - 2]; __syncwarp(); temp[threadIdx.x + 16] = value; __syncwarp(); + value += temp[threadIdx.x + 16 - 4]; __syncwarp(); temp[threadIdx.x + 16] = value; __syncwarp(); + value += temp[threadIdx.x + 16 - 8]; __syncwarp(); temp[threadIdx.x + 16] = value; __syncwarp(); + value += temp[threadIdx.x + 16 - 16]; __syncwarp(); temp[threadIdx.x + 16] = value; __syncwarp(); + return value; +} + +__device__ __inline__ volatile const U32& scan32_total(volatile U32* temp) +{ + return temp[47]; +} + +//------------------------------------------------------------------------ + +__device__ __inline__ S32 findBit(U64 mask, int idx) +{ + U32 x = getLo(mask); + int pop = __popc(x); + bool p = (pop <= idx); + if (p) x = getHi(mask); + if (p) idx -= pop; + int bit = p ? 32 : 0; + + pop = __popc(x & 0x0000ffffu); + p = (pop <= idx); + if (p) x >>= 16; + if (p) bit += 16; + if (p) idx -= pop; + + U32 tmp = x & 0x000000ffu; + pop = __popc(tmp); + p = (pop <= idx); + if (p) tmp = x & 0x0000ff00u; + if (p) idx -= pop; + + return findLeadingOne(tmp) + bit - idx; +} + +//------------------------------------------------------------------------ +// Single-sample implementation. +//------------------------------------------------------------------------ + +__device__ __inline__ void executeROP(U32 color, U32 depth, volatile U32* pColor, volatile U32* pDepth, U32 ropMask) +{ + atomicMin((U32*)pDepth, depth); + __syncwarp(ropMask); + bool act = (depth == *pDepth); + __syncwarp(ropMask); + U32 actMask = __ballot_sync(ropMask, act); + if (act) + { + *pDepth = 0; + __syncwarp(actMask); + atomicMax((U32*)pDepth, threadIdx.x); + __syncwarp(actMask); + if (*pDepth == threadIdx.x) + { + *pDepth = depth; + *pColor = color; + } + __syncwarp(actMask); + } +} + +//------------------------------------------------------------------------ + +__device__ __inline__ void fineRasterImpl(const CRParams p) +{ + // for 20 warps: + __shared__ volatile U64 s_cover8x8_lut[CR_COVER8X8_LUT_SIZE]; // 6KB + __shared__ volatile U32 s_tileColor [CR_FINE_MAX_WARPS][CR_TILE_SQR]; // 5KB + __shared__ volatile U32 s_tileDepth [CR_FINE_MAX_WARPS][CR_TILE_SQR]; // 5KB + __shared__ volatile U32 s_tilePeel [CR_FINE_MAX_WARPS][CR_TILE_SQR]; // 5KB + __shared__ volatile U32 s_triDataIdx [CR_FINE_MAX_WARPS][64]; // 5KB CRTriangleData index + __shared__ volatile U64 s_triangleCov [CR_FINE_MAX_WARPS][64]; // 10KB coverage mask + __shared__ volatile U32 s_triangleFrag[CR_FINE_MAX_WARPS][64]; // 5KB fragment index + __shared__ volatile U32 s_temp [CR_FINE_MAX_WARPS][80]; // 6.25KB + // = 47.25KB total + + CRAtomics& atomics = p.atomics[blockIdx.z]; + const CRTriangleData* triData = (const CRTriangleData*)p.triData + blockIdx.z * p.maxSubtris; + + const S32* activeTiles = (const S32*)p.activeTiles + CR_MAXTILES_SQR * blockIdx.z; + const S32* tileFirstSeg = (const S32*)p.tileFirstSeg + CR_MAXTILES_SQR * blockIdx.z; + + volatile U32* tileColor = s_tileColor[threadIdx.y]; + volatile U32* tileDepth = s_tileDepth[threadIdx.y]; + volatile U32* tilePeel = s_tilePeel[threadIdx.y]; + volatile U32* triDataIdx = s_triDataIdx[threadIdx.y]; + volatile U64* triangleCov = s_triangleCov[threadIdx.y]; + volatile U32* triangleFrag = s_triangleFrag[threadIdx.y]; + volatile U32* temp = s_temp[threadIdx.y]; + + if (atomics.numSubtris > p.maxSubtris || atomics.numBinSegs > p.maxBinSegs || atomics.numTileSegs > p.maxTileSegs) + return; + + temp[threadIdx.x] = 0; // first 16 elements of temp are always zero + cover8x8_setupLUT(s_cover8x8_lut); + __syncthreads(); + + // loop over tiles + for (;;) + { + // pick a tile + if (threadIdx.x == 0) + temp[16] = atomicAdd(&atomics.fineCounter, 1); + __syncwarp(); + int activeIdx = temp[16]; + if (activeIdx >= atomics.numActiveTiles) + break; + + int tileIdx = activeTiles[activeIdx]; + S32 segment = tileFirstSeg[tileIdx]; + int tileY = tileIdx / p.widthTiles; + int tileX = tileIdx - tileY * p.widthTiles; + int px = (tileX << CR_TILE_LOG2) + (threadIdx.x & (CR_TILE_SIZE - 1)); + int py = (tileY << CR_TILE_LOG2) + (threadIdx.x >> CR_TILE_LOG2); + + // initialize per-tile state + int triRead = 0, triWrite = 0; + int fragRead = 0, fragWrite = 0; + if (threadIdx.x == 0) + triangleFrag[63] = 0; // "previous triangle" + + // deferred clear => clear tile + if (p.deferredClear) + { + tileColor[threadIdx.x] = p.clearColor; + tileDepth[threadIdx.x] = p.clearDepth; + tileColor[threadIdx.x + 32] = p.clearColor; + tileDepth[threadIdx.x + 32] = p.clearDepth; + } + else // otherwise => read tile from framebuffer + { + U32* pColor = (U32*)p.colorBuffer + p.strideX * p.strideY * blockIdx.z; + U32* pDepth = (U32*)p.depthBuffer + p.strideX * p.strideY * blockIdx.z; + tileColor[threadIdx.x] = pColor[px + p.strideX * py]; + tileDepth[threadIdx.x] = pDepth[px + p.strideX * py]; + tileColor[threadIdx.x + 32] = pColor[px + p.strideX * (py + 4)]; + tileDepth[threadIdx.x + 32] = pDepth[px + p.strideX * (py + 4)]; + } + + // read peeling inputs if enabled + if (p.renderModeFlags & CudaRaster::RenderModeFlag_EnableDepthPeeling) + { + U32* pPeel = (U32*)p.peelBuffer + p.strideX * p.strideY * blockIdx.z; + tilePeel[threadIdx.x] = pPeel[px + p.strideX * py]; + tilePeel[threadIdx.x + 32] = pPeel[px + p.strideX * (py + 4)]; + } + + U32 tileZMax; + bool tileZUpd; + initTileZMax(tileZMax, tileZUpd, tileDepth); + + // process fragments + for(;;) + { + // need to queue more fragments? + if (fragWrite - fragRead < 32 && segment >= 0) + { + // update tile z - coherent over warp + updateTileZMax(tileZMax, tileZUpd, tileDepth, temp); + + // read triangles + do + { + // read triangle index and data, advance to next segment + S32 triIdx, dataIdx; + uint4 triHeader; + getTriangle(p, triIdx, dataIdx, triHeader, segment); + + // early z cull + if (triIdx >= 0 && earlyZCull(triHeader, tileZMax)) + triIdx = -1; + + // determine coverage + U64 coverage = trianglePixelCoverage(p, triHeader, tileX, tileY, s_cover8x8_lut); + S32 pop = (triIdx == -1) ? 0 : __popcll(coverage); + + // fragment count scan + U32 frag = scan32_value(pop, temp); + frag += fragWrite; // frag now holds cumulative fragment count + fragWrite += scan32_total(temp); + + // queue non-empty triangles + U32 goodMask = __ballot_sync(~0u, pop != 0); + if (pop != 0) + { + int idx = (triWrite + __popc(goodMask & getLaneMaskLt())) & 63; + triDataIdx [idx] = dataIdx; + triangleFrag[idx] = frag; + triangleCov [idx] = coverage; + } + triWrite += __popc(goodMask); + } + while (fragWrite - fragRead < 32 && segment >= 0); + } + __syncwarp(); + + // end of segment? + if (fragRead == fragWrite) + break; + + // clear triangle boundaries + temp[threadIdx.x + 16] = 0; + __syncwarp(); + + // tag triangle boundaries + if (triRead + threadIdx.x < triWrite) + { + int idx = triangleFrag[(triRead + threadIdx.x) & 63] - fragRead; + if (idx <= 32) + temp[idx + 16 - 1] = 1; + } + __syncwarp(); + + int ropLaneIdx = threadIdx.x; + U32 boundaryMask = __ballot_sync(~0u, temp[ropLaneIdx + 16]); + + // distribute fragments + bool hasFragment = (ropLaneIdx < fragWrite - fragRead); + U32 fragmentMask = __ballot_sync(~0u, hasFragment); + if (hasFragment) + { + int triBufIdx = (triRead + __popc(boundaryMask & getLaneMaskLt())) & 63; + int fragIdx = add_sub(fragRead, ropLaneIdx, triangleFrag[(triBufIdx - 1) & 63]); + U64 coverage = triangleCov[triBufIdx]; + int pixelInTile = findBit(coverage, fragIdx); + int dataIdx = triDataIdx[triBufIdx]; + + // determine pixel position + U32 pixelX = (tileX << CR_TILE_LOG2) + (pixelInTile & 7); + U32 pixelY = (tileY << CR_TILE_LOG2) + (pixelInTile >> 3); + + // depth test + U32 depth = 0; + uint4 td = *((uint4*)triData + dataIdx * (sizeof(CRTriangleData) >> 4)); + + depth = td.x * pixelX + td.y * pixelY + td.z; + bool zkill = (p.renderModeFlags & CudaRaster::RenderModeFlag_EnableDepthPeeling) && (depth <= tilePeel[pixelInTile]); + if (!zkill) + { + U32 oldDepth = tileDepth[pixelInTile]; + if (depth > oldDepth) + zkill = true; + else if (oldDepth == tileZMax) + tileZUpd = true; // we are replacing previous zmax => need to update + } + + U32 ropMask = __ballot_sync(fragmentMask, !zkill); + if (!zkill) + executeROP(td.w, depth, &tileColor[pixelInTile], &tileDepth[pixelInTile], ropMask); + } + // no need to sync, as next up is updateTileZMax that does internal warp sync + + // update counters + fragRead = ::min(fragRead + 32, fragWrite); + triRead += __popc(boundaryMask); + } + + // Write tile back to the framebuffer. + if (true) + { + int px = (tileX << CR_TILE_LOG2) + (threadIdx.x & (CR_TILE_SIZE - 1)); + int py = (tileY << CR_TILE_LOG2) + (threadIdx.x >> CR_TILE_LOG2); + U32* pColor = (U32*)p.colorBuffer + p.strideX * p.strideY * blockIdx.z; + U32* pDepth = (U32*)p.depthBuffer + p.strideX * p.strideY * blockIdx.z; + pColor[px + p.strideX * py] = tileColor[threadIdx.x]; + pDepth[px + p.strideX * py] = tileDepth[threadIdx.x]; + pColor[px + p.strideX * (py + 4)] = tileColor[threadIdx.x + 32]; + pDepth[px + p.strideX * (py + 4)] = tileDepth[threadIdx.x + 32]; + } + } +} + +//------------------------------------------------------------------------ diff --git a/extensions/nvdiffrast/nvdiffrast/common/cudaraster/impl/PrivateDefs.hpp b/extensions/nvdiffrast/nvdiffrast/common/cudaraster/impl/PrivateDefs.hpp new file mode 100644 index 0000000000000000000000000000000000000000..26133c97d0479c19a61d757c9eac19618dbc8729 --- /dev/null +++ b/extensions/nvdiffrast/nvdiffrast/common/cudaraster/impl/PrivateDefs.hpp @@ -0,0 +1,153 @@ +// Copyright (c) 2009-2022, NVIDIA CORPORATION. All rights reserved. +// +// NVIDIA CORPORATION and its licensors retain all intellectual property +// and proprietary rights in and to this software, related documentation +// and any modifications thereto. Any use, reproduction, disclosure or +// distribution of this software and related documentation without an express +// license agreement from NVIDIA CORPORATION is strictly prohibited. + +#pragma once +#include "Defs.hpp" +#include "Constants.hpp" + +namespace CR +{ +//------------------------------------------------------------------------ +// Projected triangle. +//------------------------------------------------------------------------ + +struct CRTriangleHeader +{ + S16 v0x; // Subpixels relative to viewport center. Valid if triSubtris = 1. + S16 v0y; + S16 v1x; + S16 v1y; + S16 v2x; + S16 v2y; + + U32 misc; // triSubtris=1: (zmin:20, f01:4, f12:4, f20:4), triSubtris>=2: (subtriBase) +}; + +//------------------------------------------------------------------------ + +struct CRTriangleData +{ + U32 zx; // zx * sampleX + zy * sampleY + zb = lerp(CR_DEPTH_MIN, CR_DEPTH_MAX, (clipZ / clipW + 1) / 2) + U32 zy; + U32 zb; + U32 id; // Triangle id. +}; + +//------------------------------------------------------------------------ +// Device-side structures. +//------------------------------------------------------------------------ + +struct CRAtomics +{ + // Setup. + S32 numSubtris; // = numTris + + // Bin. + S32 binCounter; // = 0 + S32 numBinSegs; // = 0 + + // Coarse. + S32 coarseCounter; // = 0 + S32 numTileSegs; // = 0 + S32 numActiveTiles; // = 0 + + // Fine. + S32 fineCounter; // = 0 +}; + +//------------------------------------------------------------------------ + +struct CRImageParams +{ + S32 triOffset; // First triangle index to draw. + S32 triCount; // Number of triangles to draw. + S32 binBatchSize; // Number of triangles per batch. +}; + +//------------------------------------------------------------------------ + +struct CRParams +{ + // Common. + + CRAtomics* atomics; // Work counters. Per-image. + S32 numImages; // Batch size. + S32 totalCount; // In range mode, total number of triangles to render. + S32 instanceMode; // 0 = range mode, 1 = instance mode. + + S32 numVertices; // Number of vertices in input buffer, not counting multiples in instance mode. + S32 numTriangles; // Number of triangles in input buffer. + void* vertexBuffer; // numVertices * float4(x, y, z, w) + void* indexBuffer; // numTriangles * int3(vi0, vi1, vi2) + + S32 widthPixels; // Render buffer size in pixels. Must be multiple of tile size (8x8). + S32 heightPixels; + S32 widthPixelsVp; // Viewport size in pixels. + S32 heightPixelsVp; + S32 widthBins; // widthPixels / CR_BIN_SIZE + S32 heightBins; // heightPixels / CR_BIN_SIZE + S32 numBins; // widthBins * heightBins + + F32 xs; // Vertex position adjustments for tiled rendering. + F32 ys; + F32 xo; + F32 yo; + + S32 widthTiles; // widthPixels / CR_TILE_SIZE + S32 heightTiles; // heightPixels / CR_TILE_SIZE + S32 numTiles; // widthTiles * heightTiles + + U32 renderModeFlags; + S32 deferredClear; // 1 = Clear framebuffer before rendering triangles. + U32 clearColor; + U32 clearDepth; + + // These are uniform across batch. + + S32 maxSubtris; + S32 maxBinSegs; + S32 maxTileSegs; + + // Setup output / bin input. + + void* triSubtris; // maxSubtris * U8 + void* triHeader; // maxSubtris * CRTriangleHeader + void* triData; // maxSubtris * CRTriangleData + + // Bin output / coarse input. + + void* binSegData; // maxBinSegs * CR_BIN_SEG_SIZE * S32 + void* binSegNext; // maxBinSegs * S32 + void* binSegCount; // maxBinSegs * S32 + void* binFirstSeg; // CR_MAXBINS_SQR * CR_BIN_STREAMS_SIZE * (S32 segIdx), -1 = none + void* binTotal; // CR_MAXBINS_SQR * CR_BIN_STREAMS_SIZE * (S32 numTris) + + // Coarse output / fine input. + + void* tileSegData; // maxTileSegs * CR_TILE_SEG_SIZE * S32 + void* tileSegNext; // maxTileSegs * S32 + void* tileSegCount; // maxTileSegs * S32 + void* activeTiles; // CR_MAXTILES_SQR * (S32 tileIdx) + void* tileFirstSeg; // CR_MAXTILES_SQR * (S32 segIdx), -1 = none + + // Surface buffers. Outer tile offset is baked into pointers. + + void* colorBuffer; // sizePixels.x * sizePixels.y * numImages * U32 + void* depthBuffer; // sizePixels.x * sizePixels.y * numImages * U32 + void* peelBuffer; // sizePixels.x * sizePixels.y * numImages * U32, only if peeling enabled. + S32 strideX; // horizontal size in pixels + S32 strideY; // vertical stride in pixels + + // Per-image parameters for first images are embedded here to avoid extra memcpy for small batches. + + CRImageParams imageParamsFirst[CR_EMBED_IMAGE_PARAMS]; + const CRImageParams* imageParamsExtra; // After CR_EMBED_IMAGE_PARAMS. +}; + +//------------------------------------------------------------------------ +} diff --git a/extensions/nvdiffrast/nvdiffrast/common/cudaraster/impl/RasterImpl.cpp b/extensions/nvdiffrast/nvdiffrast/common/cudaraster/impl/RasterImpl.cpp new file mode 100644 index 0000000000000000000000000000000000000000..f7f05d57f56ed033b34f0bbcef412297b01f5abc --- /dev/null +++ b/extensions/nvdiffrast/nvdiffrast/common/cudaraster/impl/RasterImpl.cpp @@ -0,0 +1,370 @@ +// Copyright (c) 2009-2022, NVIDIA CORPORATION. All rights reserved. +// +// NVIDIA CORPORATION and its licensors retain all intellectual property +// and proprietary rights in and to this software, related documentation +// and any modifications thereto. Any use, reproduction, disclosure or +// distribution of this software and related documentation without an express +// license agreement from NVIDIA CORPORATION is strictly prohibited. + +#include "../../framework.h" +#include "PrivateDefs.hpp" +#include "Constants.hpp" +#include "RasterImpl.hpp" +#include + +using namespace CR; +using std::min; +using std::max; + +//------------------------------------------------------------------------ +// Kernel prototypes and variables. + +void triangleSetupKernel (const CRParams p); +void binRasterKernel (const CRParams p); +void coarseRasterKernel (const CRParams p); +void fineRasterKernel (const CRParams p); + +//------------------------------------------------------------------------ + +RasterImpl::RasterImpl(void) +: m_renderModeFlags (0), + m_deferredClear (false), + m_clearColor (0), + m_vertexPtr (NULL), + m_indexPtr (NULL), + m_numVertices (0), + m_numTriangles (0), + m_bufferSizesReported (0), + + m_numImages (0), + m_bufferSizePixels (0, 0), + m_bufferSizeVp (0, 0), + m_sizePixels (0, 0), + m_sizeVp (0, 0), + m_offsetPixels (0, 0), + m_sizeBins (0, 0), + m_numBins (0), + m_sizeTiles (0, 0), + m_numTiles (0), + + m_numSMs (1), + m_numCoarseBlocksPerSM (1), + m_numFineBlocksPerSM (1), + m_numFineWarpsPerBlock (1), + + m_maxSubtris (1), + m_maxBinSegs (1), + m_maxTileSegs (1) +{ + // Query relevant device attributes. + + int currentDevice = 0; + NVDR_CHECK_CUDA_ERROR(cudaGetDevice(¤tDevice)); + NVDR_CHECK_CUDA_ERROR(cudaDeviceGetAttribute(&m_numSMs, cudaDevAttrMultiProcessorCount, currentDevice)); + cudaFuncAttributes attr; + NVDR_CHECK_CUDA_ERROR(cudaFuncGetAttributes(&attr, (void*)fineRasterKernel)); + m_numFineWarpsPerBlock = min(attr.maxThreadsPerBlock / 32, CR_FINE_MAX_WARPS); + NVDR_CHECK_CUDA_ERROR(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&m_numCoarseBlocksPerSM, (void*)coarseRasterKernel, 32 * CR_COARSE_WARPS, 0)); + NVDR_CHECK_CUDA_ERROR(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&m_numFineBlocksPerSM, (void*)fineRasterKernel, 32 * m_numFineWarpsPerBlock, 0)); + + // Setup functions. + + NVDR_CHECK_CUDA_ERROR(cudaFuncSetCacheConfig((void*)triangleSetupKernel, cudaFuncCachePreferShared)); + NVDR_CHECK_CUDA_ERROR(cudaFuncSetCacheConfig((void*)binRasterKernel, cudaFuncCachePreferShared)); + NVDR_CHECK_CUDA_ERROR(cudaFuncSetCacheConfig((void*)coarseRasterKernel, cudaFuncCachePreferShared)); + NVDR_CHECK_CUDA_ERROR(cudaFuncSetCacheConfig((void*)fineRasterKernel, cudaFuncCachePreferShared)); +} + +//------------------------------------------------------------------------ + +RasterImpl::~RasterImpl(void) +{ + // Empty. +} + +//------------------------------------------------------------------------ + +void RasterImpl::setBufferSize(Vec3i size) +{ + // Internal buffer width and height must be divisible by tile size. + int w = (size.x + CR_TILE_SIZE - 1) & (-CR_TILE_SIZE); + int h = (size.y + CR_TILE_SIZE - 1) & (-CR_TILE_SIZE); + + m_bufferSizePixels = Vec2i(w, h); + m_bufferSizeVp = Vec2i(size.x, size.y); + m_numImages = size.z; + + m_colorBuffer.reset(w * h * size.z * sizeof(U32)); + m_depthBuffer.reset(w * h * size.z * sizeof(U32)); +} + +//------------------------------------------------------------------------ + +void RasterImpl::setViewport(Vec2i size, Vec2i offset) +{ + // Offset must be divisible by tile size. + NVDR_CHECK((offset.x & (CR_TILE_SIZE - 1)) == 0 && (offset.y & (CR_TILE_SIZE - 1)) == 0, "invalid viewport offset"); + + // Round internal viewport size to multiples of tile size. + int w = (size.x + CR_TILE_SIZE - 1) & (-CR_TILE_SIZE); + int h = (size.y + CR_TILE_SIZE - 1) & (-CR_TILE_SIZE); + + m_sizePixels = Vec2i(w, h); + m_offsetPixels = offset; + m_sizeVp = Vec2i(size.x, size.y); + m_sizeTiles.x = m_sizePixels.x >> CR_TILE_LOG2; + m_sizeTiles.y = m_sizePixels.y >> CR_TILE_LOG2; + m_numTiles = m_sizeTiles.x * m_sizeTiles.y; + m_sizeBins.x = (m_sizeTiles.x + CR_BIN_SIZE - 1) >> CR_BIN_LOG2; + m_sizeBins.y = (m_sizeTiles.y + CR_BIN_SIZE - 1) >> CR_BIN_LOG2; + m_numBins = m_sizeBins.x * m_sizeBins.y; +} + +void RasterImpl::swapDepthAndPeel(void) +{ + m_peelBuffer.reset(m_depthBuffer.getSize()); // Ensure equal size and valid pointer. + + void* tmp = m_depthBuffer.getPtr(); + m_depthBuffer.setPtr(m_peelBuffer.getPtr()); + m_peelBuffer.setPtr(tmp); +} + +//------------------------------------------------------------------------ + +bool RasterImpl::drawTriangles(const Vec2i* ranges, bool peel, cudaStream_t stream) +{ + bool instanceMode = (!ranges); + + int maxSubtrisSlack = 4096; // x 81B = 324KB + int maxBinSegsSlack = 256; // x 2137B = 534KB + int maxTileSegsSlack = 4096; // x 136B = 544KB + + // Resize atomics as needed. + m_crAtomics .grow(m_numImages * sizeof(CRAtomics)); + m_crAtomicsHost.grow(m_numImages * sizeof(CRAtomics)); + + // Size of these buffers doesn't depend on input. + m_binFirstSeg .grow(m_numImages * CR_MAXBINS_SQR * CR_BIN_STREAMS_SIZE * sizeof(S32)); + m_binTotal .grow(m_numImages * CR_MAXBINS_SQR * CR_BIN_STREAMS_SIZE * sizeof(S32)); + m_activeTiles .grow(m_numImages * CR_MAXTILES_SQR * sizeof(S32)); + m_tileFirstSeg .grow(m_numImages * CR_MAXTILES_SQR * sizeof(S32)); + + // Construct per-image parameters and determine worst-case buffer sizes. + m_crImageParamsHost.grow(m_numImages * sizeof(CRImageParams)); + CRImageParams* imageParams = (CRImageParams*)m_crImageParamsHost.getPtr(); + for (int i=0; i < m_numImages; i++) + { + CRImageParams& ip = imageParams[i]; + + int roundSize = CR_BIN_WARPS * 32; + int minBatches = CR_BIN_STREAMS_SIZE * 2; + int maxRounds = 32; + + ip.triOffset = instanceMode ? 0 : ranges[i].x; + ip.triCount = instanceMode ? m_numTriangles : ranges[i].y; + ip.binBatchSize = min(max(ip.triCount / (roundSize * minBatches), 1), maxRounds) * roundSize; + + m_maxSubtris = max(m_maxSubtris, min(ip.triCount + maxSubtrisSlack, CR_MAXSUBTRIS_SIZE)); + m_maxBinSegs = max(m_maxBinSegs, max(m_numBins * CR_BIN_STREAMS_SIZE, (ip.triCount - 1) / CR_BIN_SEG_SIZE + 1) + maxBinSegsSlack); + m_maxTileSegs = max(m_maxTileSegs, max(m_numTiles, (ip.triCount - 1) / CR_TILE_SEG_SIZE + 1) + maxTileSegsSlack); + } + + // Retry until successful. + + for (;;) + { + // Allocate buffers. + m_triSubtris.reset(m_numImages * m_maxSubtris * sizeof(U8)); + m_triHeader .reset(m_numImages * m_maxSubtris * sizeof(CRTriangleHeader)); + m_triData .reset(m_numImages * m_maxSubtris * sizeof(CRTriangleData)); + + m_binSegData .reset(m_numImages * m_maxBinSegs * CR_BIN_SEG_SIZE * sizeof(S32)); + m_binSegNext .reset(m_numImages * m_maxBinSegs * sizeof(S32)); + m_binSegCount.reset(m_numImages * m_maxBinSegs * sizeof(S32)); + + m_tileSegData .reset(m_numImages * m_maxTileSegs * CR_TILE_SEG_SIZE * sizeof(S32)); + m_tileSegNext .reset(m_numImages * m_maxTileSegs * sizeof(S32)); + m_tileSegCount.reset(m_numImages * m_maxTileSegs * sizeof(S32)); + + // Report if buffers grow from last time. + size_t sizesTotal = getTotalBufferSizes(); + if (sizesTotal > m_bufferSizesReported) + { + size_t sizesMB = ((sizesTotal - 1) >> 20) + 1; // Round up. + sizesMB = ((sizesMB + 9) / 10) * 10; // 10MB granularity enough in this day and age. + LOG(INFO) << "Internal buffers grown to " << sizesMB << " MB"; + m_bufferSizesReported = sizesMB << 20; + } + + // Launch stages. Blocks until everything is done. + launchStages(instanceMode, peel, stream); + + // Peeling iteration cannot fail, so no point checking things further. + if (peel) + break; + + // Atomics after coarse stage are now available. + CRAtomics* atomics = (CRAtomics*)m_crAtomicsHost.getPtr(); + + // Success? + bool failed = false; + for (int i=0; i < m_numImages; i++) + { + const CRAtomics& a = atomics[i]; + failed = failed || (a.numSubtris > m_maxSubtris) || (a.numBinSegs > m_maxBinSegs) || (a.numTileSegs > m_maxTileSegs); + } + if (!failed) + break; // Success! + + // If we were already at maximum capacity, no can do. + if (m_maxSubtris == CR_MAXSUBTRIS_SIZE) + return false; + + // Enlarge buffers and try again. + for (int i=0; i < m_numImages; i++) + { + const CRAtomics& a = atomics[i]; + m_maxSubtris = max(m_maxSubtris, min(a.numSubtris + maxSubtrisSlack, CR_MAXSUBTRIS_SIZE)); + m_maxBinSegs = max(m_maxBinSegs, a.numBinSegs + maxBinSegsSlack); + m_maxTileSegs = max(m_maxTileSegs, a.numTileSegs + maxTileSegsSlack); + } + } + + m_deferredClear = false; + return true; // Success. +} + +//------------------------------------------------------------------------ + +size_t RasterImpl::getTotalBufferSizes(void) const +{ + return + m_colorBuffer.getSize() + m_depthBuffer.getSize() + // Don't include atomics and image params. + m_triSubtris.getSize() + m_triHeader.getSize() + m_triData.getSize() + + m_binFirstSeg.getSize() + m_binTotal.getSize() + m_binSegData.getSize() + m_binSegNext.getSize() + m_binSegCount.getSize() + + m_activeTiles.getSize() + m_tileFirstSeg.getSize() + m_tileSegData.getSize() + m_tileSegNext.getSize() + m_tileSegCount.getSize(); +} + +//------------------------------------------------------------------------ + +void RasterImpl::launchStages(bool instanceMode, bool peel, cudaStream_t stream) +{ + CRImageParams* imageParams = (CRImageParams*)m_crImageParamsHost.getPtr(); + + // Unless peeling, initialize atomics to mostly zero. + CRAtomics* atomics = (CRAtomics*)m_crAtomicsHost.getPtr(); + if (!peel) + { + memset(atomics, 0, m_numImages * sizeof(CRAtomics)); + for (int i=0; i < m_numImages; i++) + atomics[i].numSubtris = imageParams[i].triCount; + } + + // Copy to device. If peeling, this is the state after coarse raster launch on first iteration. + NVDR_CHECK_CUDA_ERROR(cudaMemcpyAsync(m_crAtomics.getPtr(), atomics, m_numImages * sizeof(CRAtomics), cudaMemcpyHostToDevice, stream)); + + // Copy per-image parameters if there are more than fits in launch parameter block and we haven't done it already. + if (!peel && m_numImages > CR_EMBED_IMAGE_PARAMS) + { + int numImageParamsExtra = m_numImages - CR_EMBED_IMAGE_PARAMS; + m_crImageParamsExtra.grow(numImageParamsExtra * sizeof(CRImageParams)); + NVDR_CHECK_CUDA_ERROR(cudaMemcpyAsync(m_crImageParamsExtra.getPtr(), imageParams + CR_EMBED_IMAGE_PARAMS, numImageParamsExtra * sizeof(CRImageParams), cudaMemcpyHostToDevice, stream)); + } + + // Set global parameters. + CRParams p; + { + p.atomics = (CRAtomics*)m_crAtomics.getPtr(); + p.numImages = m_numImages; + p.totalCount = 0; // Only relevant in range mode. + p.instanceMode = instanceMode ? 1 : 0; + + p.numVertices = m_numVertices; + p.numTriangles = m_numTriangles; + p.vertexBuffer = m_vertexPtr; + p.indexBuffer = m_indexPtr; + + p.widthPixels = m_sizePixels.x; + p.heightPixels = m_sizePixels.y; + p.widthPixelsVp = m_sizeVp.x; + p.heightPixelsVp = m_sizeVp.y; + p.widthBins = m_sizeBins.x; + p.heightBins = m_sizeBins.y; + p.numBins = m_numBins; + + p.xs = (float)m_bufferSizeVp.x / (float)m_sizeVp.x; + p.ys = (float)m_bufferSizeVp.y / (float)m_sizeVp.y; + p.xo = (float)(m_bufferSizeVp.x - m_sizeVp.x - 2 * m_offsetPixels.x) / (float)m_sizeVp.x; + p.yo = (float)(m_bufferSizeVp.y - m_sizeVp.y - 2 * m_offsetPixels.y) / (float)m_sizeVp.y; + + p.widthTiles = m_sizeTiles.x; + p.heightTiles = m_sizeTiles.y; + p.numTiles = m_numTiles; + + p.renderModeFlags = m_renderModeFlags; + p.deferredClear = m_deferredClear ? 1 : 0; + p.clearColor = m_clearColor; + p.clearDepth = CR_DEPTH_MAX; + + p.maxSubtris = m_maxSubtris; + p.maxBinSegs = m_maxBinSegs; + p.maxTileSegs = m_maxTileSegs; + + p.triSubtris = m_triSubtris.getPtr(); + p.triHeader = m_triHeader.getPtr(); + p.triData = m_triData.getPtr(); + p.binSegData = m_binSegData.getPtr(); + p.binSegNext = m_binSegNext.getPtr(); + p.binSegCount = m_binSegCount.getPtr(); + p.binFirstSeg = m_binFirstSeg.getPtr(); + p.binTotal = m_binTotal.getPtr(); + p.tileSegData = m_tileSegData.getPtr(); + p.tileSegNext = m_tileSegNext.getPtr(); + p.tileSegCount = m_tileSegCount.getPtr(); + p.activeTiles = m_activeTiles.getPtr(); + p.tileFirstSeg = m_tileFirstSeg.getPtr(); + + size_t byteOffset = ((size_t)m_offsetPixels.x + (size_t)m_offsetPixels.y * (size_t)p.strideX) * sizeof(U32); + p.colorBuffer = m_colorBuffer.getPtr(byteOffset); + p.depthBuffer = m_depthBuffer.getPtr(byteOffset); + p.peelBuffer = (m_renderModeFlags & CudaRaster::RenderModeFlag_EnableDepthPeeling) ? m_peelBuffer.getPtr(byteOffset) : 0; + p.strideX = m_bufferSizePixels.x; + p.strideY = m_bufferSizePixels.y; + + memcpy(&p.imageParamsFirst, imageParams, min(m_numImages, CR_EMBED_IMAGE_PARAMS) * sizeof(CRImageParams)); + p.imageParamsExtra = (CRImageParams*)m_crImageParamsExtra.getPtr(); + } + + // Setup block sizes. + + dim3 brBlock(32, CR_BIN_WARPS); + dim3 crBlock(32, CR_COARSE_WARPS); + dim3 frBlock(32, m_numFineWarpsPerBlock); + void* args[] = {&p}; + + // Launch stages from setup to coarse and copy atomics to host only if this is not a single-tile peeling iteration. + if (!peel) + { + if (instanceMode) + { + int setupBlocks = (m_numTriangles - 1) / (32 * CR_SETUP_WARPS) + 1; + NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((void*)triangleSetupKernel, dim3(setupBlocks, 1, m_numImages), dim3(32, CR_SETUP_WARPS), args, 0, stream)); + } + else + { + for (int i=0; i < m_numImages; i++) + p.totalCount += imageParams[i].triCount; + int setupBlocks = (p.totalCount - 1) / (32 * CR_SETUP_WARPS) + 1; + NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((void*)triangleSetupKernel, dim3(setupBlocks, 1, 1), dim3(32, CR_SETUP_WARPS), args, 0, stream)); + } + NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((void*)binRasterKernel, dim3(CR_BIN_STREAMS_SIZE, 1, m_numImages), brBlock, args, 0, stream)); + NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((void*)coarseRasterKernel, dim3(m_numSMs * m_numCoarseBlocksPerSM, 1, m_numImages), crBlock, args, 0, stream)); + NVDR_CHECK_CUDA_ERROR(cudaMemcpyAsync(m_crAtomicsHost.getPtr(), m_crAtomics.getPtr(), sizeof(CRAtomics) * m_numImages, cudaMemcpyDeviceToHost, stream)); + } + + // Fine rasterizer is launched always. + NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((void*)fineRasterKernel, dim3(m_numSMs * m_numFineBlocksPerSM, 1, m_numImages), frBlock, args, 0, stream)); + NVDR_CHECK_CUDA_ERROR(cudaStreamSynchronize(stream)); +} + +//------------------------------------------------------------------------ diff --git a/extensions/nvdiffrast/nvdiffrast/common/cudaraster/impl/RasterImpl.hpp b/extensions/nvdiffrast/nvdiffrast/common/cudaraster/impl/RasterImpl.hpp new file mode 100644 index 0000000000000000000000000000000000000000..d594acdfeb2a83133726a6dfd594b3ccad0d74cc --- /dev/null +++ b/extensions/nvdiffrast/nvdiffrast/common/cudaraster/impl/RasterImpl.hpp @@ -0,0 +1,102 @@ +// Copyright (c) 2009-2022, NVIDIA CORPORATION. All rights reserved. +// +// NVIDIA CORPORATION and its licensors retain all intellectual property +// and proprietary rights in and to this software, related documentation +// and any modifications thereto. Any use, reproduction, disclosure or +// distribution of this software and related documentation without an express +// license agreement from NVIDIA CORPORATION is strictly prohibited. + +#pragma once +#include "PrivateDefs.hpp" +#include "Buffer.hpp" +#include "../CudaRaster.hpp" + +namespace CR +{ +//------------------------------------------------------------------------ + +class RasterImpl +{ +public: + RasterImpl (void); + ~RasterImpl (void); + + void setBufferSize (Vec3i size); + void setViewport (Vec2i size, Vec2i offset); + void setRenderModeFlags (U32 flags) { m_renderModeFlags = flags; } + void deferredClear (U32 color) { m_deferredClear = true; m_clearColor = color; } + void setVertexBuffer (void* ptr, int numVertices) { m_vertexPtr = ptr; m_numVertices = numVertices; } // GPU pointer. + void setIndexBuffer (void* ptr, int numTriangles) { m_indexPtr = ptr; m_numTriangles = numTriangles; } // GPU pointer. + bool drawTriangles (const Vec2i* ranges, bool peel, cudaStream_t stream); + void* getColorBuffer (void) { return m_colorBuffer.getPtr(); } // GPU pointer. + void* getDepthBuffer (void) { return m_depthBuffer.getPtr(); } // GPU pointer. + void swapDepthAndPeel (void); + size_t getTotalBufferSizes (void) const; + +private: + void launchStages (bool instanceMode, bool peel, cudaStream_t stream); + + // State. + + unsigned int m_renderModeFlags; + bool m_deferredClear; + unsigned int m_clearColor; + void* m_vertexPtr; + void* m_indexPtr; + int m_numVertices; // Input buffer size. + int m_numTriangles; // Input buffer size. + size_t m_bufferSizesReported; // Previously reported buffer sizes. + + // Surfaces. + + Buffer m_colorBuffer; + Buffer m_depthBuffer; + Buffer m_peelBuffer; + int m_numImages; + Vec2i m_bufferSizePixels; // Internal buffer size. + Vec2i m_bufferSizeVp; // Total viewport size. + Vec2i m_sizePixels; // Internal size at which all computation is done, buffers reserved, etc. + Vec2i m_sizeVp; // Size to which output will be cropped outside, determines viewport size. + Vec2i m_offsetPixels; // Viewport offset for tiled rendering. + Vec2i m_sizeBins; + S32 m_numBins; + Vec2i m_sizeTiles; + S32 m_numTiles; + + // Launch sizes etc. + + S32 m_numSMs; + S32 m_numCoarseBlocksPerSM; + S32 m_numFineBlocksPerSM; + S32 m_numFineWarpsPerBlock; + + // Global intermediate buffers. Individual images have offsets to these. + + Buffer m_crAtomics; + HostBuffer m_crAtomicsHost; + HostBuffer m_crImageParamsHost; + Buffer m_crImageParamsExtra; + Buffer m_triSubtris; + Buffer m_triHeader; + Buffer m_triData; + Buffer m_binFirstSeg; + Buffer m_binTotal; + Buffer m_binSegData; + Buffer m_binSegNext; + Buffer m_binSegCount; + Buffer m_activeTiles; + Buffer m_tileFirstSeg; + Buffer m_tileSegData; + Buffer m_tileSegNext; + Buffer m_tileSegCount; + + // Actual buffer sizes. + + S32 m_maxSubtris; + S32 m_maxBinSegs; + S32 m_maxTileSegs; +}; + +//------------------------------------------------------------------------ +} // namespace CR + diff --git a/extensions/nvdiffrast/nvdiffrast/common/cudaraster/impl/RasterImpl_.cu b/extensions/nvdiffrast/nvdiffrast/common/cudaraster/impl/RasterImpl_.cu new file mode 100644 index 0000000000000000000000000000000000000000..43b1edf04a36d52d22aac8465b584e576ecb723b --- /dev/null +++ b/extensions/nvdiffrast/nvdiffrast/common/cudaraster/impl/RasterImpl_.cu @@ -0,0 +1,37 @@ +// Copyright (c) 2009-2022, NVIDIA CORPORATION. All rights reserved. +// +// NVIDIA CORPORATION and its licensors retain all intellectual property +// and proprietary rights in and to this software, related documentation +// and any modifications thereto. Any use, reproduction, disclosure or +// distribution of this software and related documentation without an express +// license agreement from NVIDIA CORPORATION is strictly prohibited. + +#include "../CudaRaster.hpp" +#include "PrivateDefs.hpp" +#include "Constants.hpp" +#include "Util.inl" + +namespace CR +{ + +//------------------------------------------------------------------------ +// Stage implementations. +//------------------------------------------------------------------------ + +#include "TriangleSetup.inl" +#include "BinRaster.inl" +#include "CoarseRaster.inl" +#include "FineRaster.inl" + +} + +//------------------------------------------------------------------------ +// Stage entry points. +//------------------------------------------------------------------------ + +__global__ void __launch_bounds__(CR_SETUP_WARPS * 32, CR_SETUP_OPT_BLOCKS) triangleSetupKernel (const CR::CRParams p) { CR::triangleSetupImpl(p); } +__global__ void __launch_bounds__(CR_BIN_WARPS * 32, 1) binRasterKernel (const CR::CRParams p) { CR::binRasterImpl(p); } +__global__ void __launch_bounds__(CR_COARSE_WARPS * 32, 1) coarseRasterKernel (const CR::CRParams p) { CR::coarseRasterImpl(p); } +__global__ void __launch_bounds__(CR_FINE_MAX_WARPS * 32, 1) fineRasterKernel (const CR::CRParams p) { CR::fineRasterImpl(p); } + +//------------------------------------------------------------------------ diff --git a/extensions/nvdiffrast/nvdiffrast/common/cudaraster/impl/TriangleSetup.inl b/extensions/nvdiffrast/nvdiffrast/common/cudaraster/impl/TriangleSetup.inl new file mode 100644 index 0000000000000000000000000000000000000000..276f0a40ee7ddd3010fed13aebc2cf4fd37011a9 --- /dev/null +++ b/extensions/nvdiffrast/nvdiffrast/common/cudaraster/impl/TriangleSetup.inl @@ -0,0 +1,402 @@ +// Copyright (c) 2009-2022, NVIDIA CORPORATION. All rights reserved. +// +// NVIDIA CORPORATION and its licensors retain all intellectual property +// and proprietary rights in and to this software, related documentation +// and any modifications thereto. Any use, reproduction, disclosure or +// distribution of this software and related documentation without an express +// license agreement from NVIDIA CORPORATION is strictly prohibited. + +//------------------------------------------------------------------------ + +__device__ __inline__ void snapTriangle( + const CRParams& p, + float4 v0, float4 v1, float4 v2, + int2& p0, int2& p1, int2& p2, float3& rcpW, int2& lo, int2& hi) +{ + F32 viewScaleX = (F32)(p.widthPixelsVp << (CR_SUBPIXEL_LOG2 - 1)); + F32 viewScaleY = (F32)(p.heightPixelsVp << (CR_SUBPIXEL_LOG2 - 1)); + rcpW = make_float3(1.0f / v0.w, 1.0f / v1.w, 1.0f / v2.w); + p0 = make_int2(f32_to_s32_sat(v0.x * rcpW.x * viewScaleX), f32_to_s32_sat(v0.y * rcpW.x * viewScaleY)); + p1 = make_int2(f32_to_s32_sat(v1.x * rcpW.y * viewScaleX), f32_to_s32_sat(v1.y * rcpW.y * viewScaleY)); + p2 = make_int2(f32_to_s32_sat(v2.x * rcpW.z * viewScaleX), f32_to_s32_sat(v2.y * rcpW.z * viewScaleY)); + lo = make_int2(min_min(p0.x, p1.x, p2.x), min_min(p0.y, p1.y, p2.y)); + hi = make_int2(max_max(p0.x, p1.x, p2.x), max_max(p0.y, p1.y, p2.y)); +} + +//------------------------------------------------------------------------ + +__device__ __inline__ U32 cover8x8_selectFlips(S32 dx, S32 dy) // 10 instr +{ + U32 flips = 0; + if (dy > 0 || (dy == 0 && dx <= 0)) + flips ^= (1 << CR_FLIPBIT_FLIP_X) ^ (1 << CR_FLIPBIT_FLIP_Y) ^ (1 << CR_FLIPBIT_COMPL); + if (dx > 0) + flips ^= (1 << CR_FLIPBIT_FLIP_X) ^ (1 << CR_FLIPBIT_FLIP_Y); + if (::abs(dx) < ::abs(dy)) + flips ^= (1 << CR_FLIPBIT_SWAP_XY) ^ (1 << CR_FLIPBIT_FLIP_Y); + return flips; +} + +//------------------------------------------------------------------------ + +__device__ __inline__ bool prepareTriangle( + const CRParams& p, + int2 p0, int2 p1, int2 p2, int2 lo, int2 hi, + int2& d1, int2& d2, S32& area) +{ + // Backfacing or degenerate => cull. + + d1 = make_int2(p1.x - p0.x, p1.y - p0.y); + d2 = make_int2(p2.x - p0.x, p2.y - p0.y); + area = d1.x * d2.y - d1.y * d2.x; + + if (area == 0) + return false; // Degenerate. + + if (area < 0 && (p.renderModeFlags & CudaRaster::RenderModeFlag_EnableBackfaceCulling) != 0) + return false; // Backfacing. + + // AABB falls between samples => cull. + + int sampleSize = 1 << CR_SUBPIXEL_LOG2; + int biasX = (p.widthPixelsVp << (CR_SUBPIXEL_LOG2 - 1)) - (sampleSize >> 1); + int biasY = (p.heightPixelsVp << (CR_SUBPIXEL_LOG2 - 1)) - (sampleSize >> 1); + int lox = (int)add_add(lo.x, sampleSize - 1, biasX) & -sampleSize; + int loy = (int)add_add(lo.y, sampleSize - 1, biasY) & -sampleSize; + int hix = (hi.x + biasX) & -sampleSize; + int hiy = (hi.y + biasY) & -sampleSize; + + if (lox > hix || loy > hiy) + return false; // Between pixels. + + // AABB covers 1 or 2 samples => cull if they are not covered. + + int diff = add_sub(hix, hiy, lox) - loy; + if (diff <= sampleSize) + { + int2 t0 = make_int2(add_sub(p0.x, biasX, lox), add_sub(p0.y, biasY, loy)); + int2 t1 = make_int2(add_sub(p1.x, biasX, lox), add_sub(p1.y, biasY, loy)); + int2 t2 = make_int2(add_sub(p2.x, biasX, lox), add_sub(p2.y, biasY, loy)); + S32 e0 = t0.x * t1.y - t0.y * t1.x; + S32 e1 = t1.x * t2.y - t1.y * t2.x; + S32 e2 = t2.x * t0.y - t2.y * t0.x; + if (area < 0) + { + e0 = -e0; + e1 = -e1; + e2 = -e2; + } + + if (e0 < 0 || e1 < 0 || e2 < 0) + { + if (diff == 0) + return false; // Between pixels. + + t0 = make_int2(add_sub(p0.x, biasX, hix), add_sub(p0.y, biasY, hiy)); + t1 = make_int2(add_sub(p1.x, biasX, hix), add_sub(p1.y, biasY, hiy)); + t2 = make_int2(add_sub(p2.x, biasX, hix), add_sub(p2.y, biasY, hiy)); + e0 = t0.x * t1.y - t0.y * t1.x; + e1 = t1.x * t2.y - t1.y * t2.x; + e2 = t2.x * t0.y - t2.y * t0.x; + if (area < 0) + { + e0 = -e0; + e1 = -e1; + e2 = -e2; + } + + if (e0 < 0 || e1 < 0 || e2 < 0) + return false; // Between pixels. + } + } + + // Otherwise => proceed to output the triangle. + + return true; // Visible. +} + +//------------------------------------------------------------------------ + +__device__ __inline__ void setupTriangle( + const CRParams& p, + CRTriangleHeader* th, CRTriangleData* td, int triId, + float v0z, float v1z, float v2z, + int2 p0, int2 p1, int2 p2, float3 rcpW, + int2 d1, int2 d2, S32 area) +{ + // Swap vertices 1 and 2 if area is negative. Only executed if backface culling is + // disabled (if it is enabled, we never come here with area < 0). + + if (area < 0) + { + swap(d1, d2); + swap(p1, p2); + swap(v1z, v2z); + swap(rcpW.y, rcpW.z); + area = -area; + } + + int2 wv0; + wv0.x = p0.x + (p.widthPixelsVp << (CR_SUBPIXEL_LOG2 - 1)); + wv0.y = p0.y + (p.heightPixelsVp << (CR_SUBPIXEL_LOG2 - 1)); + + // Setup depth plane equation. + + F32 zcoef = (F32)(CR_DEPTH_MAX - CR_DEPTH_MIN) * 0.5f; + F32 zbias = (F32)(CR_DEPTH_MAX + CR_DEPTH_MIN) * 0.5f; + float3 zvert = make_float3( + (v0z * zcoef) * rcpW.x + zbias, + (v1z * zcoef) * rcpW.y + zbias, + (v2z * zcoef) * rcpW.z + zbias + ); + int2 zv0 = make_int2( + wv0.x - (1 << (CR_SUBPIXEL_LOG2 - 1)), + wv0.y - (1 << (CR_SUBPIXEL_LOG2 - 1)) + ); + uint3 zpleq = setupPleq(zvert, zv0, d1, d2, 1.0f / (F32)area); + + U32 zmin = f32_to_u32_sat(fminf(fminf(zvert.x, zvert.y), zvert.z) - (F32)CR_LERP_ERROR(0)); + + // Write CRTriangleData. + + *(uint4*)td = make_uint4(zpleq.x, zpleq.y, zpleq.z, triId); + + // Determine flipbits. + + U32 f01 = cover8x8_selectFlips(d1.x, d1.y); + U32 f12 = cover8x8_selectFlips(d2.x - d1.x, d2.y - d1.y); + U32 f20 = cover8x8_selectFlips(-d2.x, -d2.y); + + // Write CRTriangleHeader. + + *(uint4*)th = make_uint4( + prmt(p0.x, p0.y, 0x5410), + prmt(p1.x, p1.y, 0x5410), + prmt(p2.x, p2.y, 0x5410), + (zmin & 0xfffff000u) | (f01 << 6) | (f12 << 2) | (f20 >> 2)); +} + +//------------------------------------------------------------------------ + +__device__ __inline__ void triangleSetupImpl(const CRParams p) +{ + __shared__ F32 s_bary[CR_SETUP_WARPS * 32][18]; + F32* bary = s_bary[threadIdx.x + threadIdx.y * 32]; + + // Compute task and image indices. + + int taskIdx = threadIdx.x + 32 * (threadIdx.y + CR_SETUP_WARPS * blockIdx.x); + int imageIdx = 0; + if (p.instanceMode) + { + imageIdx = blockIdx.z; + if (taskIdx >= p.numTriangles) + return; + } + else + { + while (imageIdx < p.numImages) + { + int count = getImageParams(p, imageIdx).triCount; + if (taskIdx < count) + break; + taskIdx -= count; + imageIdx += 1; + } + if (imageIdx == p.numImages) + return; + } + + // Per-image data structures. + + const CRImageParams& ip = getImageParams(p, imageIdx); + CRAtomics& atomics = p.atomics[imageIdx]; + + const int* indexBuffer = (const int*)p.indexBuffer; + U8* triSubtris = (U8*)p.triSubtris + imageIdx * p.maxSubtris; + CRTriangleHeader* triHeader = (CRTriangleHeader*)p.triHeader + imageIdx * p.maxSubtris; + CRTriangleData* triData = (CRTriangleData*)p.triData + imageIdx * p.maxSubtris; + + // Determine triangle index. + + int triIdx = taskIdx; + if (!p.instanceMode) + triIdx += ip.triOffset; + + // Read vertex indices. + + if ((U32)triIdx >= (U32)p.numTriangles) + { + // Bad triangle index. + triSubtris[taskIdx] = 0; + return; + } + + uint4 vidx; + vidx.x = indexBuffer[triIdx * 3 + 0]; + vidx.y = indexBuffer[triIdx * 3 + 1]; + vidx.z = indexBuffer[triIdx * 3 + 2]; + vidx.w = triIdx + 1; // Triangle index. + + if (vidx.x >= (U32)p.numVertices || + vidx.y >= (U32)p.numVertices || + vidx.z >= (U32)p.numVertices) + { + // Bad vertex index. + triSubtris[taskIdx] = 0; + return; + } + + // Read vertex positions. + + const float4* vertexBuffer = (const float4*)p.vertexBuffer; + if (p.instanceMode) + vertexBuffer += p.numVertices * imageIdx; // Instance offset. + + float4 v0 = vertexBuffer[vidx.x]; + float4 v1 = vertexBuffer[vidx.y]; + float4 v2 = vertexBuffer[vidx.z]; + + // Adjust vertex positions according to current viewport size and offset. + + v0.x = v0.x * p.xs + v0.w * p.xo; + v0.y = v0.y * p.ys + v0.w * p.yo; + v1.x = v1.x * p.xs + v1.w * p.xo; + v1.y = v1.y * p.ys + v1.w * p.yo; + v2.x = v2.x * p.xs + v2.w * p.xo; + v2.y = v2.y * p.ys + v2.w * p.yo; + + // Outside view frustum => cull. + + if (v0.w < fabsf(v0.x) | v0.w < fabsf(v0.y) | v0.w < fabsf(v0.z)) + { + if ((v0.w < +v0.x & v1.w < +v1.x & v2.w < +v2.x) | + (v0.w < -v0.x & v1.w < -v1.x & v2.w < -v2.x) | + (v0.w < +v0.y & v1.w < +v1.y & v2.w < +v2.y) | + (v0.w < -v0.y & v1.w < -v1.y & v2.w < -v2.y) | + (v0.w < +v0.z & v1.w < +v1.z & v2.w < +v2.z) | + (v0.w < -v0.z & v1.w < -v1.z & v2.w < -v2.z)) + { + triSubtris[taskIdx] = 0; + return; + } + } + + // Inside depth range => try to snap vertices. + + if (v0.w >= fabsf(v0.z) & v1.w >= fabsf(v1.z) & v2.w >= fabsf(v2.z)) + { + // Inside S16 range and small enough => fast path. + // Note: aabbLimit comes from the fact that cover8x8 + // does not support guardband with maximal viewport. + + int2 p0, p1, p2, lo, hi; + float3 rcpW; + + snapTriangle(p, v0, v1, v2, p0, p1, p2, rcpW, lo, hi); + S32 loxy = ::min(lo.x, lo.y); + S32 hixy = ::max(hi.x, hi.y); + S32 aabbLimit = (1 << (CR_MAXVIEWPORT_LOG2 + CR_SUBPIXEL_LOG2)) - 1; + + if (loxy >= -32768 && hixy <= 32767 && hixy - loxy <= aabbLimit) + { + int2 d1, d2; + S32 area; + bool res = prepareTriangle(p, p0, p1, p2, lo, hi, d1, d2, area); + triSubtris[taskIdx] = res ? 1 : 0; + + if (res) + setupTriangle( + p, + &triHeader[taskIdx], &triData[taskIdx], vidx.w, + v0.z, v1.z, v2.z, + p0, p1, p2, rcpW, + d1, d2, area); + + return; + } + } + + // Clip to view frustum. + + float4 ov0 = v0; + float4 od1 = make_float4(v1.x - v0.x, v1.y - v0.y, v1.z - v0.z, v1.w - v0.w); + float4 od2 = make_float4(v2.x - v0.x, v2.y - v0.y, v2.z - v0.z, v2.w - v0.w); + int numVerts = clipTriangleWithFrustum(bary, &ov0.x, &v1.x, &v2.x, &od1.x, &od2.x); + + // Count non-culled subtriangles. + + v0.x = ov0.x + od1.x * bary[0] + od2.x * bary[1]; + v0.y = ov0.y + od1.y * bary[0] + od2.y * bary[1]; + v0.z = ov0.z + od1.z * bary[0] + od2.z * bary[1]; + v0.w = ov0.w + od1.w * bary[0] + od2.w * bary[1]; + v1.x = ov0.x + od1.x * bary[2] + od2.x * bary[3]; + v1.y = ov0.y + od1.y * bary[2] + od2.y * bary[3]; + v1.z = ov0.z + od1.z * bary[2] + od2.z * bary[3]; + v1.w = ov0.w + od1.w * bary[2] + od2.w * bary[3]; + float4 tv1 = v1; + + int numSubtris = 0; + for (int i = 2; i < numVerts; i++) + { + v2.x = ov0.x + od1.x * bary[i * 2 + 0] + od2.x * bary[i * 2 + 1]; + v2.y = ov0.y + od1.y * bary[i * 2 + 0] + od2.y * bary[i * 2 + 1]; + v2.z = ov0.z + od1.z * bary[i * 2 + 0] + od2.z * bary[i * 2 + 1]; + v2.w = ov0.w + od1.w * bary[i * 2 + 0] + od2.w * bary[i * 2 + 1]; + + int2 p0, p1, p2, lo, hi, d1, d2; + float3 rcpW; + S32 area; + + snapTriangle(p, v0, v1, v2, p0, p1, p2, rcpW, lo, hi); + if (prepareTriangle(p, p0, p1, p2, lo, hi, d1, d2, area)) + numSubtris++; + + v1 = v2; + } + + triSubtris[taskIdx] = numSubtris; + + // Multiple subtriangles => allocate. + + int subtriBase = taskIdx; + if (numSubtris > 1) + { + subtriBase = atomicAdd(&atomics.numSubtris, numSubtris); + triHeader[taskIdx].misc = subtriBase; + if (subtriBase + numSubtris > p.maxSubtris) + numVerts = 0; + } + + // Setup subtriangles. + + v1 = tv1; + for (int i = 2; i < numVerts; i++) + { + v2.x = ov0.x + od1.x * bary[i * 2 + 0] + od2.x * bary[i * 2 + 1]; + v2.y = ov0.y + od1.y * bary[i * 2 + 0] + od2.y * bary[i * 2 + 1]; + v2.z = ov0.z + od1.z * bary[i * 2 + 0] + od2.z * bary[i * 2 + 1]; + v2.w = ov0.w + od1.w * bary[i * 2 + 0] + od2.w * bary[i * 2 + 1]; + + int2 p0, p1, p2, lo, hi, d1, d2; + float3 rcpW; + S32 area; + + snapTriangle(p, v0, v1, v2, p0, p1, p2, rcpW, lo, hi); + if (prepareTriangle(p, p0, p1, p2, lo, hi, d1, d2, area)) + { + setupTriangle( + p, + &triHeader[subtriBase], &triData[subtriBase], vidx.w, + v0.z, v1.z, v2.z, + p0, p1, p2, rcpW, + d1, d2, area); + + subtriBase++; + } + + v1 = v2; + } +} + +//------------------------------------------------------------------------ diff --git a/extensions/nvdiffrast/nvdiffrast/common/cudaraster/impl/Util.inl b/extensions/nvdiffrast/nvdiffrast/common/cudaraster/impl/Util.inl new file mode 100644 index 0000000000000000000000000000000000000000..f8faeba7ba2d0634a80d92869b286d48d3071722 --- /dev/null +++ b/extensions/nvdiffrast/nvdiffrast/common/cudaraster/impl/Util.inl @@ -0,0 +1,452 @@ +// Copyright (c) 2009-2022, NVIDIA CORPORATION. All rights reserved. +// +// NVIDIA CORPORATION and its licensors retain all intellectual property +// and proprietary rights in and to this software, related documentation +// and any modifications thereto. Any use, reproduction, disclosure or +// distribution of this software and related documentation without an express +// license agreement from NVIDIA CORPORATION is strictly prohibited. + +#include "PrivateDefs.hpp" + +namespace CR +{ +//------------------------------------------------------------------------ + +template __device__ __inline__ void swap(T& a, T& b) { T t = a; a = b; b = t; } + +__device__ __inline__ U32 getLo (U64 a) { return __double2loint(__longlong_as_double(a)); } +__device__ __inline__ S32 getLo (S64 a) { return __double2loint(__longlong_as_double(a)); } +__device__ __inline__ U32 getHi (U64 a) { return __double2hiint(__longlong_as_double(a)); } +__device__ __inline__ S32 getHi (S64 a) { return __double2hiint(__longlong_as_double(a)); } +__device__ __inline__ U64 combineLoHi (U32 lo, U32 hi) { return __double_as_longlong(__hiloint2double(hi, lo)); } +__device__ __inline__ S64 combineLoHi (S32 lo, S32 hi) { return __double_as_longlong(__hiloint2double(hi, lo)); } +__device__ __inline__ U32 getLaneMaskLt (void) { U32 r; asm("mov.u32 %0, %lanemask_lt;" : "=r"(r)); return r; } +__device__ __inline__ U32 getLaneMaskLe (void) { U32 r; asm("mov.u32 %0, %lanemask_le;" : "=r"(r)); return r; } +__device__ __inline__ U32 getLaneMaskGt (void) { U32 r; asm("mov.u32 %0, %lanemask_gt;" : "=r"(r)); return r; } +__device__ __inline__ U32 getLaneMaskGe (void) { U32 r; asm("mov.u32 %0, %lanemask_ge;" : "=r"(r)); return r; } +__device__ __inline__ int findLeadingOne (U32 v) { U32 r; asm("bfind.u32 %0, %1;" : "=r"(r) : "r"(v)); return r; } +__device__ __inline__ bool singleLane (void) { return ((::__ballot_sync(~0u, true) & getLaneMaskLt()) == 0); } + +__device__ __inline__ void add_add_carry (U32& rlo, U32 alo, U32 blo, U32& rhi, U32 ahi, U32 bhi) { U64 r = combineLoHi(alo, ahi) + combineLoHi(blo, bhi); rlo = getLo(r); rhi = getHi(r); } +__device__ __inline__ S32 f32_to_s32_sat (F32 a) { S32 v; asm("cvt.rni.sat.s32.f32 %0, %1;" : "=r"(v) : "f"(a)); return v; } +__device__ __inline__ U32 f32_to_u32_sat (F32 a) { U32 v; asm("cvt.rni.sat.u32.f32 %0, %1;" : "=r"(v) : "f"(a)); return v; } +__device__ __inline__ U32 f32_to_u32_sat_rmi (F32 a) { U32 v; asm("cvt.rmi.sat.u32.f32 %0, %1;" : "=r"(v) : "f"(a)); return v; } +__device__ __inline__ U32 f32_to_u8_sat (F32 a) { U32 v; asm("cvt.rni.sat.u8.f32 %0, %1;" : "=r"(v) : "f"(a)); return v; } +__device__ __inline__ S64 f32_to_s64 (F32 a) { S64 v; asm("cvt.rni.s64.f32 %0, %1;" : "=l"(v) : "f"(a)); return v; } +__device__ __inline__ S32 add_s16lo_s16lo (S32 a, S32 b) { S32 v; asm("vadd.s32.s32.s32 %0, %1.h0, %2.h0;" : "=r"(v) : "r"(a), "r"(b)); return v; } +__device__ __inline__ S32 add_s16hi_s16lo (S32 a, S32 b) { S32 v; asm("vadd.s32.s32.s32 %0, %1.h1, %2.h0;" : "=r"(v) : "r"(a), "r"(b)); return v; } +__device__ __inline__ S32 add_s16lo_s16hi (S32 a, S32 b) { S32 v; asm("vadd.s32.s32.s32 %0, %1.h0, %2.h1;" : "=r"(v) : "r"(a), "r"(b)); return v; } +__device__ __inline__ S32 add_s16hi_s16hi (S32 a, S32 b) { S32 v; asm("vadd.s32.s32.s32 %0, %1.h1, %2.h1;" : "=r"(v) : "r"(a), "r"(b)); return v; } +__device__ __inline__ S32 sub_s16lo_s16lo (S32 a, S32 b) { S32 v; asm("vsub.s32.s32.s32 %0, %1.h0, %2.h0;" : "=r"(v) : "r"(a), "r"(b)); return v; } +__device__ __inline__ S32 sub_s16hi_s16lo (S32 a, S32 b) { S32 v; asm("vsub.s32.s32.s32 %0, %1.h1, %2.h0;" : "=r"(v) : "r"(a), "r"(b)); return v; } +__device__ __inline__ S32 sub_s16lo_s16hi (S32 a, S32 b) { S32 v; asm("vsub.s32.s32.s32 %0, %1.h0, %2.h1;" : "=r"(v) : "r"(a), "r"(b)); return v; } +__device__ __inline__ S32 sub_s16hi_s16hi (S32 a, S32 b) { S32 v; asm("vsub.s32.s32.s32 %0, %1.h1, %2.h1;" : "=r"(v) : "r"(a), "r"(b)); return v; } +__device__ __inline__ S32 sub_u16lo_u16lo (U32 a, U32 b) { S32 v; asm("vsub.s32.u32.u32 %0, %1.h0, %2.h0;" : "=r"(v) : "r"(a), "r"(b)); return v; } +__device__ __inline__ S32 sub_u16hi_u16lo (U32 a, U32 b) { S32 v; asm("vsub.s32.u32.u32 %0, %1.h1, %2.h0;" : "=r"(v) : "r"(a), "r"(b)); return v; } +__device__ __inline__ S32 sub_u16lo_u16hi (U32 a, U32 b) { S32 v; asm("vsub.s32.u32.u32 %0, %1.h0, %2.h1;" : "=r"(v) : "r"(a), "r"(b)); return v; } +__device__ __inline__ S32 sub_u16hi_u16hi (U32 a, U32 b) { S32 v; asm("vsub.s32.u32.u32 %0, %1.h1, %2.h1;" : "=r"(v) : "r"(a), "r"(b)); return v; } +__device__ __inline__ U32 add_b0 (U32 a, U32 b) { U32 v; asm("vadd.u32.u32.u32 %0, %1.b0, %2;" : "=r"(v) : "r"(a), "r"(b)); return v; } +__device__ __inline__ U32 add_b1 (U32 a, U32 b) { U32 v; asm("vadd.u32.u32.u32 %0, %1.b1, %2;" : "=r"(v) : "r"(a), "r"(b)); return v; } +__device__ __inline__ U32 add_b2 (U32 a, U32 b) { U32 v; asm("vadd.u32.u32.u32 %0, %1.b2, %2;" : "=r"(v) : "r"(a), "r"(b)); return v; } +__device__ __inline__ U32 add_b3 (U32 a, U32 b) { U32 v; asm("vadd.u32.u32.u32 %0, %1.b3, %2;" : "=r"(v) : "r"(a), "r"(b)); return v; } +__device__ __inline__ U32 vmad_b0 (U32 a, U32 b, U32 c) { U32 v; asm("vmad.u32.u32.u32 %0, %1.b0, %2, %3;" : "=r"(v) : "r"(a), "r"(b), "r"(c)); return v; } +__device__ __inline__ U32 vmad_b1 (U32 a, U32 b, U32 c) { U32 v; asm("vmad.u32.u32.u32 %0, %1.b1, %2, %3;" : "=r"(v) : "r"(a), "r"(b), "r"(c)); return v; } +__device__ __inline__ U32 vmad_b2 (U32 a, U32 b, U32 c) { U32 v; asm("vmad.u32.u32.u32 %0, %1.b2, %2, %3;" : "=r"(v) : "r"(a), "r"(b), "r"(c)); return v; } +__device__ __inline__ U32 vmad_b3 (U32 a, U32 b, U32 c) { U32 v; asm("vmad.u32.u32.u32 %0, %1.b3, %2, %3;" : "=r"(v) : "r"(a), "r"(b), "r"(c)); return v; } +__device__ __inline__ U32 vmad_b0_b3 (U32 a, U32 b, U32 c) { U32 v; asm("vmad.u32.u32.u32 %0, %1.b0, %2.b3, %3;" : "=r"(v) : "r"(a), "r"(b), "r"(c)); return v; } +__device__ __inline__ U32 vmad_b1_b3 (U32 a, U32 b, U32 c) { U32 v; asm("vmad.u32.u32.u32 %0, %1.b1, %2.b3, %3;" : "=r"(v) : "r"(a), "r"(b), "r"(c)); return v; } +__device__ __inline__ U32 vmad_b2_b3 (U32 a, U32 b, U32 c) { U32 v; asm("vmad.u32.u32.u32 %0, %1.b2, %2.b3, %3;" : "=r"(v) : "r"(a), "r"(b), "r"(c)); return v; } +__device__ __inline__ U32 vmad_b3_b3 (U32 a, U32 b, U32 c) { U32 v; asm("vmad.u32.u32.u32 %0, %1.b3, %2.b3, %3;" : "=r"(v) : "r"(a), "r"(b), "r"(c)); return v; } +__device__ __inline__ U32 add_mask8 (U32 a, U32 b) { U32 v; U32 z=0; asm("vadd.u32.u32.u32 %0.b0, %1, %2, %3;" : "=r"(v) : "r"(a), "r"(b), "r"(z)); return v; } +__device__ __inline__ U32 sub_mask8 (U32 a, U32 b) { U32 v; U32 z=0; asm("vsub.u32.u32.u32 %0.b0, %1, %2, %3;" : "=r"(v) : "r"(a), "r"(b), "r"(z)); return v; } +__device__ __inline__ S32 max_max (S32 a, S32 b, S32 c) { S32 v; asm("vmax.s32.s32.s32.max %0, %1, %2, %3;" : "=r"(v) : "r"(a), "r"(b), "r"(c)); return v; } +__device__ __inline__ S32 min_min (S32 a, S32 b, S32 c) { S32 v; asm("vmin.s32.s32.s32.min %0, %1, %2, %3;" : "=r"(v) : "r"(a), "r"(b), "r"(c)); return v; } +__device__ __inline__ S32 max_add (S32 a, S32 b, S32 c) { S32 v; asm("vmax.s32.s32.s32.add %0, %1, %2, %3;" : "=r"(v) : "r"(a), "r"(b), "r"(c)); return v; } +__device__ __inline__ S32 min_add (S32 a, S32 b, S32 c) { S32 v; asm("vmin.s32.s32.s32.add %0, %1, %2, %3;" : "=r"(v) : "r"(a), "r"(b), "r"(c)); return v; } +__device__ __inline__ U32 add_add (U32 a, U32 b, U32 c) { U32 v; asm("vadd.u32.u32.u32.add %0, %1, %2, %3;" : "=r"(v) : "r"(a), "r"(b), "r"(c)); return v; } +__device__ __inline__ U32 sub_add (U32 a, U32 b, U32 c) { U32 v; asm("vsub.u32.u32.u32.add %0, %1, %2, %3;" : "=r"(v) : "r"(a), "r"(b), "r"(c)); return v; } +__device__ __inline__ U32 add_sub (U32 a, U32 b, U32 c) { U32 v; asm("vsub.u32.u32.u32.add %0, %1, %2, %3;" : "=r"(v) : "r"(a), "r"(c), "r"(b)); return v; } +__device__ __inline__ S32 add_clamp_0_x (S32 a, S32 b, S32 c) { S32 v; asm("vadd.u32.s32.s32.sat.min %0, %1, %2, %3;" : "=r"(v) : "r"(a), "r"(b), "r"(c)); return v; } +__device__ __inline__ S32 add_clamp_b0 (S32 a, S32 b, S32 c) { S32 v; asm("vadd.u32.s32.s32.sat %0.b0, %1, %2, %3;" : "=r"(v) : "r"(a), "r"(b), "r"(c)); return v; } +__device__ __inline__ S32 add_clamp_b2 (S32 a, S32 b, S32 c) { S32 v; asm("vadd.u32.s32.s32.sat %0.b2, %1, %2, %3;" : "=r"(v) : "r"(a), "r"(b), "r"(c)); return v; } +__device__ __inline__ U32 prmt (U32 a, U32 b, U32 c) { U32 v; asm("prmt.b32 %0, %1, %2, %3;" : "=r"(v) : "r"(a), "r"(b), "r"(c)); return v; } +__device__ __inline__ S32 u32lo_sext (U32 a) { U32 v; asm("cvt.s16.u32 %0, %1;" : "=r"(v) : "r"(a)); return v; } +__device__ __inline__ U32 slct (U32 a, U32 b, S32 c) { U32 v; asm("slct.u32.s32 %0, %1, %2, %3;" : "=r"(v) : "r"(a), "r"(b), "r"(c)); return v; } +__device__ __inline__ S32 slct (S32 a, S32 b, S32 c) { S32 v; asm("slct.s32.s32 %0, %1, %2, %3;" : "=r"(v) : "r"(a), "r"(b), "r"(c)); return v; } +__device__ __inline__ F32 slct (F32 a, F32 b, S32 c) { F32 v; asm("slct.f32.s32 %0, %1, %2, %3;" : "=f"(v) : "f"(a), "f"(b), "r"(c)); return v; } +__device__ __inline__ U32 isetge (S32 a, S32 b) { U32 v; asm("set.ge.u32.s32 %0, %1, %2;" : "=r"(v) : "r"(a), "r"(b)); return v; } +__device__ __inline__ F64 rcp_approx (F64 a) { F64 v; asm("rcp.approx.ftz.f64 %0, %1;" : "=d"(v) : "d"(a)); return v; } +__device__ __inline__ F32 fma_rm (F32 a, F32 b, F32 c) { F32 v; asm("fma.rm.f32 %0, %1, %2, %3;" : "=f"(v) : "f"(a), "f"(b), "f"(c)); return v; } +__device__ __inline__ U32 idiv_fast (U32 a, U32 b); + +__device__ __inline__ uint3 setupPleq (float3 values, int2 v0, int2 d1, int2 d2, F32 areaRcp); + +__device__ __inline__ void cover8x8_setupLUT (volatile U64* lut); +__device__ __inline__ U64 cover8x8_exact_fast (S32 ox, S32 oy, S32 dx, S32 dy, U32 flips, volatile const U64* lut); // Assumes viewport <= 2^11, subpixels <= 2^4, no guardband. +__device__ __inline__ U64 cover8x8_lookupMask (S64 yinit, U32 yinc, U32 flips, volatile const U64* lut); + +__device__ __inline__ U64 cover8x8_exact_noLUT (S32 ox, S32 oy, S32 dx, S32 dy); // optimized reference implementation, does not require look-up table +__device__ __inline__ U64 cover8x8_conservative_noLUT (S32 ox, S32 oy, S32 dx, S32 dy); +__device__ __inline__ U64 cover8x8_generateMask_noLUT (S32 curr, S32 dx, S32 dy); + +template __device__ __inline__ void sortShared(T* ptr, int numItems); // Assumes that numItems <= threadsInBlock. Must sync before & after the call. + +__device__ __inline__ const CRImageParams& getImageParams(const CRParams& p, int idx) +{ + return (idx < CR_EMBED_IMAGE_PARAMS) ? p.imageParamsFirst[idx] : p.imageParamsExtra[idx - CR_EMBED_IMAGE_PARAMS]; +} + +//------------------------------------------------------------------------ + +__device__ __inline__ int clipPolygonWithPlane(F32* baryOut, const F32* baryIn, int numIn, F32 v0, F32 v1, F32 v2) +{ + int numOut = 0; + if (numIn >= 3) + { + int ai = (numIn - 1) * 2; + F32 av = v0 + v1 * baryIn[ai + 0] + v2 * baryIn[ai + 1]; + for (int bi = 0; bi < numIn * 2; bi += 2) + { + F32 bv = v0 + v1 * baryIn[bi + 0] + v2 * baryIn[bi + 1]; + if (av * bv < 0.0f) + { + F32 bc = av / (av - bv); + F32 ac = 1.0f - bc; + baryOut[numOut + 0] = baryIn[ai + 0] * ac + baryIn[bi + 0] * bc; + baryOut[numOut + 1] = baryIn[ai + 1] * ac + baryIn[bi + 1] * bc; + numOut += 2; + } + if (bv >= 0.0f) + { + baryOut[numOut + 0] = baryIn[bi + 0]; + baryOut[numOut + 1] = baryIn[bi + 1]; + numOut += 2; + } + ai = bi; + av = bv; + } + } + return (numOut >> 1); +} + +//------------------------------------------------------------------------ + +__device__ __inline__ int clipTriangleWithFrustum(F32* bary, const F32* v0, const F32* v1, const F32* v2, const F32* d1, const F32* d2) +{ + int num = 3; + bary[0] = 0.0f, bary[1] = 0.0f; + bary[2] = 1.0f, bary[3] = 0.0f; + bary[4] = 0.0f, bary[5] = 1.0f; + + if ((v0[3] < fabsf(v0[0])) | (v1[3] < fabsf(v1[0])) | (v2[3] < fabsf(v2[0]))) + { + F32 temp[18]; + num = clipPolygonWithPlane(temp, bary, num, v0[3] + v0[0], d1[3] + d1[0], d2[3] + d2[0]); + num = clipPolygonWithPlane(bary, temp, num, v0[3] - v0[0], d1[3] - d1[0], d2[3] - d2[0]); + } + if ((v0[3] < fabsf(v0[1])) | (v1[3] < fabsf(v1[1])) | (v2[3] < fabsf(v2[1]))) + { + F32 temp[18]; + num = clipPolygonWithPlane(temp, bary, num, v0[3] + v0[1], d1[3] + d1[1], d2[3] + d2[1]); + num = clipPolygonWithPlane(bary, temp, num, v0[3] - v0[1], d1[3] - d1[1], d2[3] - d2[1]); + } + if ((v0[3] < fabsf(v0[2])) | (v1[3] < fabsf(v1[2])) | (v2[3] < fabsf(v2[2]))) + { + F32 temp[18]; + num = clipPolygonWithPlane(temp, bary, num, v0[3] + v0[2], d1[3] + d1[2], d2[3] + d2[2]); + num = clipPolygonWithPlane(bary, temp, num, v0[3] - v0[2], d1[3] - d1[2], d2[3] - d2[2]); + } + return num; +} + +//------------------------------------------------------------------------ + +__device__ __inline__ U32 idiv_fast(U32 a, U32 b) +{ + return f32_to_u32_sat_rmi(((F32)a + 0.5f) / (F32)b); +} + +//------------------------------------------------------------------------ + +__device__ __inline__ U32 toABGR(float4 color) +{ + // 11 instructions: 4*FFMA, 4*F2I, 3*PRMT + U32 x = f32_to_u32_sat_rmi(fma_rm(color.x, (1 << 24) * 255.0f, (1 << 24) * 0.5f)); + U32 y = f32_to_u32_sat_rmi(fma_rm(color.y, (1 << 24) * 255.0f, (1 << 24) * 0.5f)); + U32 z = f32_to_u32_sat_rmi(fma_rm(color.z, (1 << 24) * 255.0f, (1 << 24) * 0.5f)); + U32 w = f32_to_u32_sat_rmi(fma_rm(color.w, (1 << 24) * 255.0f, (1 << 24) * 0.5f)); + return prmt(prmt(x, y, 0x0073), prmt(z, w, 0x0073), 0x5410); +} + +//------------------------------------------------------------------------ +// v0 = subpixels relative to the bottom-left sampling point + +__device__ __inline__ uint3 setupPleq(float3 values, int2 v0, int2 d1, int2 d2, F32 areaRcp) +{ + F32 mx = fmaxf(fmaxf(values.x, values.y), values.z); + int sh = ::min(::max((__float_as_int(mx) >> 23) - (127 + 22), 0), 8); + S32 t0 = (U32)values.x >> sh; + S32 t1 = ((U32)values.y >> sh) - t0; + S32 t2 = ((U32)values.z >> sh) - t0; + + U32 rcpMant = (__float_as_int(areaRcp) & 0x007FFFFF) | 0x00800000; + int rcpShift = (23 + 127) - (__float_as_int(areaRcp) >> 23); + + uint3 pleq; + S64 xc = ((S64)t1 * d2.y - (S64)t2 * d1.y) * rcpMant; + S64 yc = ((S64)t2 * d1.x - (S64)t1 * d2.x) * rcpMant; + pleq.x = (U32)(xc >> (rcpShift - (sh + CR_SUBPIXEL_LOG2))); + pleq.y = (U32)(yc >> (rcpShift - (sh + CR_SUBPIXEL_LOG2))); + + S32 centerX = (v0.x * 2 + min_min(d1.x, d2.x, 0) + max_max(d1.x, d2.x, 0)) >> (CR_SUBPIXEL_LOG2 + 1); + S32 centerY = (v0.y * 2 + min_min(d1.y, d2.y, 0) + max_max(d1.y, d2.y, 0)) >> (CR_SUBPIXEL_LOG2 + 1); + S32 vcx = v0.x - (centerX << CR_SUBPIXEL_LOG2); + S32 vcy = v0.y - (centerY << CR_SUBPIXEL_LOG2); + + pleq.z = t0 << sh; + pleq.z -= (U32)(((xc >> 13) * vcx + (yc >> 13) * vcy) >> (rcpShift - (sh + 13))); + pleq.z -= pleq.x * centerX + pleq.y * centerY; + return pleq; +} + +//------------------------------------------------------------------------ + +__device__ __inline__ void cover8x8_setupLUT(volatile U64* lut) +{ + for (S32 lutIdx = threadIdx.x + blockDim.x * threadIdx.y; lutIdx < CR_COVER8X8_LUT_SIZE; lutIdx += blockDim.x * blockDim.y) + { + int half = (lutIdx < (12 << 5)) ? 0 : 1; + int yint = (lutIdx >> 5) - half * 12 - 3; + U32 shape = ((lutIdx >> 2) & 7) << (31 - 2); + S32 slctSwapXY = lutIdx << (31 - 1); + S32 slctNegX = lutIdx << (31 - 0); + S32 slctCompl = slctSwapXY ^ slctNegX; + + U64 mask = 0; + int xlo = half * 4; + int xhi = xlo + 4; + for (int x = xlo; x < xhi; x++) + { + int ylo = slct(0, ::max(yint, 0), slctCompl); + int yhi = slct(::min(yint, 8), 8, slctCompl); + for (int y = ylo; y < yhi; y++) + { + int xx = slct(x, y, slctSwapXY); + int yy = slct(y, x, slctSwapXY); + xx = slct(xx, 7 - xx, slctNegX); + mask |= (U64)1 << (xx + yy * 8); + } + yint += shape >> 31; + shape <<= 1; + } + lut[lutIdx] = mask; + } +} + +//------------------------------------------------------------------------ + +__device__ __inline__ U64 cover8x8_exact_fast(S32 ox, S32 oy, S32 dx, S32 dy, U32 flips, volatile const U64* lut) // 52 instr +{ + F32 yinitBias = (F32)(1 << (31 - CR_MAXVIEWPORT_LOG2 - CR_SUBPIXEL_LOG2 * 2)); + F32 yinitScale = (F32)(1 << (32 - CR_SUBPIXEL_LOG2)); + F32 yincScale = 65536.0f * 65536.0f; + + S32 slctFlipY = flips << (31 - CR_FLIPBIT_FLIP_Y); + S32 slctFlipX = flips << (31 - CR_FLIPBIT_FLIP_X); + S32 slctSwapXY = flips << (31 - CR_FLIPBIT_SWAP_XY); + + // Evaluate cross product. + + S32 t = ox * dy - oy * dx; + F32 det = (F32)slct(t, t - dy * (7 << CR_SUBPIXEL_LOG2), slctFlipX); + if (flips >= (1 << CR_FLIPBIT_COMPL)) + det = -det; + + // Represent Y as a function of X. + + F32 xrcp = 1.0f / (F32)::abs(slct(dx, dy, slctSwapXY)); + F32 yzero = det * yinitScale * xrcp + yinitBias; + S64 yinit = f32_to_s64(slct(yzero, -yzero, slctFlipY)); + U32 yinc = f32_to_u32_sat((F32)::abs(slct(dy, dx, slctSwapXY)) * xrcp * yincScale); + + // Lookup. + + return cover8x8_lookupMask(yinit, yinc, flips, lut); +} + +//------------------------------------------------------------------------ + +__device__ __inline__ U64 cover8x8_lookupMask(S64 yinit, U32 yinc, U32 flips, volatile const U64* lut) +{ + // First half. + + U32 yfrac = getLo(yinit); + U32 shape = add_clamp_0_x(getHi(yinit) + 4, 0, 11); + add_add_carry(yfrac, yfrac, yinc, shape, shape, shape); + add_add_carry(yfrac, yfrac, yinc, shape, shape, shape); + add_add_carry(yfrac, yfrac, yinc, shape, shape, shape); + int oct = flips & ((1 << CR_FLIPBIT_FLIP_X) | (1 << CR_FLIPBIT_SWAP_XY)); + U64 mask = *(U64*)((U8*)lut + oct + (shape << 5)); + + // Second half. + + add_add_carry(yfrac, yfrac, yinc, shape, shape, shape); + shape = add_clamp_0_x(getHi(yinit) + 4, __popc(shape & 15), 11); + add_add_carry(yfrac, yfrac, yinc, shape, shape, shape); + add_add_carry(yfrac, yfrac, yinc, shape, shape, shape); + add_add_carry(yfrac, yfrac, yinc, shape, shape, shape); + mask |= *(U64*)((U8*)lut + oct + (shape << 5) + (12 << 8)); + return (flips >= (1 << CR_FLIPBIT_COMPL)) ? ~mask : mask; +} + +//------------------------------------------------------------------------ + +__device__ __inline__ U64 cover8x8_exact_noLUT(S32 ox, S32 oy, S32 dx, S32 dy) +{ + S32 curr = ox * dy - oy * dx; + if (dy > 0 || (dy == 0 && dx <= 0)) curr--; // exclusive + return cover8x8_generateMask_noLUT(curr, dx, dy); +} + +//------------------------------------------------------------------------ + +__device__ __inline__ U64 cover8x8_conservative_noLUT(S32 ox, S32 oy, S32 dx, S32 dy) +{ + S32 curr = ox * dy - oy * dx; + if (dy > 0 || (dy == 0 && dx <= 0)) curr--; // exclusive + curr += (::abs(dx) + ::abs(dy)) << (CR_SUBPIXEL_LOG2 - 1); + return cover8x8_generateMask_noLUT(curr, dx, dy); +} + +//------------------------------------------------------------------------ + +__device__ __inline__ U64 cover8x8_generateMask_noLUT(S32 curr, S32 dx, S32 dy) +{ + curr += (dx - dy) * (7 << CR_SUBPIXEL_LOG2); + S32 stepX = dy << (CR_SUBPIXEL_LOG2 + 1); + S32 stepYorig = -dx - dy * 7; + S32 stepY = stepYorig << (CR_SUBPIXEL_LOG2 + 1); + + U32 hi = isetge(curr, 0); + U32 frac = curr + curr; + for (int i = 62; i >= 32; i--) + add_add_carry(frac, frac, ((i & 7) == 7) ? stepY : stepX, hi, hi, hi); + + U32 lo = 0; + for (int i = 31; i >= 0; i--) + add_add_carry(frac, frac, ((i & 7) == 7) ? stepY : stepX, lo, lo, lo); + + lo ^= lo >> 1, hi ^= hi >> 1; + lo ^= lo >> 2, hi ^= hi >> 2; + lo ^= lo >> 4, hi ^= hi >> 4; + lo ^= lo >> 8, hi ^= hi >> 8; + lo ^= lo >> 16, hi ^= hi >> 16; + + if (dy < 0) + { + lo ^= 0x55AA55AA; + hi ^= 0x55AA55AA; + } + if (stepYorig < 0) + { + lo ^= 0xFF00FF00; + hi ^= 0x00FF00FF; + } + if ((hi & 1) != 0) + lo = ~lo; + + return combineLoHi(lo, hi); +} + +//------------------------------------------------------------------------ + +template __device__ __inline__ void sortShared(T* ptr, int numItems) +{ + int thrInBlock = threadIdx.x + threadIdx.y * blockDim.x; + int range = 16; + + // Use transposition sort within each 16-wide subrange. + + int base = thrInBlock * 2; + bool act = (base < numItems - 1); + U32 actMask = __ballot_sync(~0u, act); + if (act) + { + bool tryOdd = (base < numItems - 2 && (~base & (range - 2)) != 0); + T mid = ptr[base + 1]; + + for (int iter = 0; iter < range; iter += 2) + { + // Evens. + + T tmp = ptr[base + 0]; + if (tmp > mid) + { + ptr[base + 0] = mid; + mid = tmp; + } + __syncwarp(actMask); + + // Odds. + + if (tryOdd) + { + tmp = ptr[base + 2]; + if (mid > tmp) + { + ptr[base + 2] = mid; + mid = tmp; + } + } + __syncwarp(actMask); + } + ptr[base + 1] = mid; + } + + // Multiple subranges => Merge hierarchically. + + for (; range < numItems; range <<= 1) + { + // Assuming that we would insert the current item into the other + // subrange, use binary search to find the appropriate slot. + + __syncthreads(); + + T item; + int slot; + if (thrInBlock < numItems) + { + item = ptr[thrInBlock]; + slot = (thrInBlock & -range) ^ range; + if (slot < numItems) + { + T tmp = ptr[slot]; + bool inclusive = ((thrInBlock & range) != 0); + if (tmp < item || (inclusive && tmp == item)) + { + for (int step = (range >> 1); step != 0; step >>= 1) + { + int probe = slot + step; + if (probe < numItems) + { + tmp = ptr[probe]; + if (tmp < item || (inclusive && tmp == item)) + slot = probe; + } + } + slot++; + } + } + } + + // Store the item at an appropriate place. + + __syncthreads(); + + if (thrInBlock < numItems) + ptr[slot + (thrInBlock & (range * 2 - 1)) - range] = item; + } +} + +//------------------------------------------------------------------------ +} diff --git a/extensions/nvdiffrast/nvdiffrast/common/framework.h b/extensions/nvdiffrast/nvdiffrast/common/framework.h new file mode 100644 index 0000000000000000000000000000000000000000..12d803caaf3210c45808dee41217c4c6c6edfe6e --- /dev/null +++ b/extensions/nvdiffrast/nvdiffrast/common/framework.h @@ -0,0 +1,49 @@ +// Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +// +// NVIDIA CORPORATION and its licensors retain all intellectual property +// and proprietary rights in and to this software, related documentation +// and any modifications thereto. Any use, reproduction, disclosure or +// distribution of this software and related documentation without an express +// license agreement from NVIDIA CORPORATION is strictly prohibited. + +#pragma once + +// Framework-specific macros to enable code sharing. + +//------------------------------------------------------------------------ +// Tensorflow. + +#ifdef NVDR_TENSORFLOW +#define EIGEN_USE_GPU +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/shape_inference.h" +#include "tensorflow/core/platform/default/logging.h" +using namespace tensorflow; +using namespace tensorflow::shape_inference; +#define NVDR_CTX_ARGS OpKernelContext* _nvdr_ctx +#define NVDR_CTX_PARAMS _nvdr_ctx +#define NVDR_CHECK(COND, ERR) OP_REQUIRES(_nvdr_ctx, COND, errors::Internal(ERR)) +#define NVDR_CHECK_CUDA_ERROR(CUDA_CALL) OP_CHECK_CUDA_ERROR(_nvdr_ctx, CUDA_CALL) +#define NVDR_CHECK_GL_ERROR(GL_CALL) OP_CHECK_GL_ERROR(_nvdr_ctx, GL_CALL) +#endif + +//------------------------------------------------------------------------ +// PyTorch. + +#ifdef NVDR_TORCH +#ifndef __CUDACC__ +#include +#include +#include +#include +#include +#endif +#define NVDR_CTX_ARGS int _nvdr_ctx_dummy +#define NVDR_CTX_PARAMS 0 +#define NVDR_CHECK(COND, ERR) do { TORCH_CHECK(COND, ERR) } while(0) +#define NVDR_CHECK_CUDA_ERROR(CUDA_CALL) do { cudaError_t err = CUDA_CALL; TORCH_CHECK(!err, "Cuda error: ", cudaGetLastError(), "[", #CUDA_CALL, ";]"); } while(0) +#define NVDR_CHECK_GL_ERROR(GL_CALL) do { GL_CALL; GLenum err = glGetError(); TORCH_CHECK(err == GL_NO_ERROR, "OpenGL error: ", getGLErrorString(err), "[", #GL_CALL, ";]"); } while(0) +#endif + +//------------------------------------------------------------------------ diff --git a/extensions/nvdiffrast/nvdiffrast/common/glutil.cpp b/extensions/nvdiffrast/nvdiffrast/common/glutil.cpp new file mode 100644 index 0000000000000000000000000000000000000000..2af3e931b6808e2575d8a209d5485746499b3374 --- /dev/null +++ b/extensions/nvdiffrast/nvdiffrast/common/glutil.cpp @@ -0,0 +1,403 @@ +// Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +// +// NVIDIA CORPORATION and its licensors retain all intellectual property +// and proprietary rights in and to this software, related documentation +// and any modifications thereto. Any use, reproduction, disclosure or +// distribution of this software and related documentation without an express +// license agreement from NVIDIA CORPORATION is strictly prohibited. + +//------------------------------------------------------------------------ +// Common. +//------------------------------------------------------------------------ + +#include "framework.h" +#include "glutil.h" +#include +#include + +// Create the function pointers. +#define GLUTIL_EXT(return_type, name, ...) return_type (GLAPIENTRY* name)(__VA_ARGS__) = 0; +#include "glutil_extlist.h" +#undef GLUTIL_EXT + +// Track initialization status. +static volatile bool s_glExtInitialized = false; + +// Error strings. +const char* getGLErrorString(GLenum err) +{ + switch(err) + { + case GL_NO_ERROR: return "GL_NO_ERROR"; + case GL_INVALID_ENUM: return "GL_INVALID_ENUM"; + case GL_INVALID_VALUE: return "GL_INVALID_VALUE"; + case GL_INVALID_OPERATION: return "GL_INVALID_OPERATION"; + case GL_STACK_OVERFLOW: return "GL_STACK_OVERFLOW"; + case GL_STACK_UNDERFLOW: return "GL_STACK_UNDERFLOW"; + case GL_OUT_OF_MEMORY: return "GL_OUT_OF_MEMORY"; + case GL_INVALID_FRAMEBUFFER_OPERATION: return "GL_INVALID_FRAMEBUFFER_OPERATION"; + case GL_TABLE_TOO_LARGE: return "GL_TABLE_TOO_LARGE"; + case GL_CONTEXT_LOST: return "GL_CONTEXT_LOST"; + } + return "Unknown error"; +} + +//------------------------------------------------------------------------ +// Windows. +//------------------------------------------------------------------------ + +#ifdef _WIN32 + +static CRITICAL_SECTION getInitializedCriticalSection(void) +{ + CRITICAL_SECTION cs; + InitializeCriticalSection(&cs); + return cs; +} + +static CRITICAL_SECTION s_getProcAddressMutex = getInitializedCriticalSection(); + +static void safeGetProcAddress(const char* name, PROC* pfn) +{ + PROC result = wglGetProcAddress(name); + if (!result) + { + LeaveCriticalSection(&s_getProcAddressMutex); // Prepare for thread exit. + LOG(FATAL) << "wglGetProcAddress() failed for '" << name << "'"; + exit(1); // Should never get here but make sure we exit. + } + *pfn = result; +} + +static void initializeGLExtensions(void) +{ + // Use critical section for thread safety. + EnterCriticalSection(&s_getProcAddressMutex); + + // Only dig function pointers if not done already. + if (!s_glExtInitialized) + { + // Generate code to populate the function pointers. +#define GLUTIL_EXT(return_type, name, ...) safeGetProcAddress(#name, (PROC*)&name); +#include "glutil_extlist.h" +#undef GLUTIL_EXT + + // Mark as initialized. + s_glExtInitialized = true; + } + + // Done. + LeaveCriticalSection(&s_getProcAddressMutex); + return; +} + +void setGLContext(GLContext& glctx) +{ + if (!glctx.hglrc) + LOG(FATAL) << "setGLContext() called with null gltcx"; + if (!wglMakeCurrent(glctx.hdc, glctx.hglrc)) + LOG(FATAL) << "wglMakeCurrent() failed when setting GL context"; + + if (glctx.extInitialized) + return; + initializeGLExtensions(); + glctx.extInitialized = 1; +} + +void releaseGLContext(void) +{ + if (!wglMakeCurrent(NULL, NULL)) + LOG(FATAL) << "wglMakeCurrent() failed when releasing GL context"; +} + +extern "C" int set_gpu(const char*); // In setgpu.lib +GLContext createGLContext(int cudaDeviceIdx) +{ + if (cudaDeviceIdx >= 0) + { + char pciBusId[256] = ""; + LOG(INFO) << "Creating GL context for Cuda device " << cudaDeviceIdx; + if (cudaDeviceGetPCIBusId(pciBusId, 255, cudaDeviceIdx)) + { + LOG(INFO) << "PCI bus id query failed"; + } + else + { + int res = set_gpu(pciBusId); + LOG(INFO) << "Selecting device with PCI bus id " << pciBusId << " - " << (res ? "failed, expect crash or major slowdown" : "success"); + } + } + + HINSTANCE hInstance = GetModuleHandle(NULL); + WNDCLASS wc = {}; + wc.style = CS_OWNDC; + wc.lpfnWndProc = DefWindowProc; + wc.hInstance = hInstance; + wc.lpszClassName = "__DummyGLClassCPP"; + int res = RegisterClass(&wc); + + HWND hwnd = CreateWindow( + "__DummyGLClassCPP", // lpClassName + "__DummyGLWindowCPP", // lpWindowName + WS_OVERLAPPEDWINDOW, // dwStyle + CW_USEDEFAULT, // x + CW_USEDEFAULT, // y + 0, 0, // nWidth, nHeight + NULL, NULL, // hWndParent, hMenu + hInstance, // hInstance + NULL // lpParam + ); + + PIXELFORMATDESCRIPTOR pfd = {}; + pfd.dwFlags = PFD_SUPPORT_OPENGL; + pfd.iPixelType = PFD_TYPE_RGBA; + pfd.iLayerType = PFD_MAIN_PLANE; + pfd.cColorBits = 32; + pfd.cDepthBits = 24; + pfd.cStencilBits = 8; + + HDC hdc = GetDC(hwnd); + int pixelformat = ChoosePixelFormat(hdc, &pfd); + SetPixelFormat(hdc, pixelformat, &pfd); + + HGLRC hglrc = wglCreateContext(hdc); + LOG(INFO) << std::hex << std::setfill('0') + << "WGL OpenGL context created (hdc: 0x" << std::setw(8) << (uint32_t)(uintptr_t)hdc + << ", hglrc: 0x" << std::setw(8) << (uint32_t)(uintptr_t)hglrc << ")"; + + GLContext glctx = {hdc, hglrc, 0}; + return glctx; +} + +void destroyGLContext(GLContext& glctx) +{ + if (!glctx.hglrc) + LOG(FATAL) << "destroyGLContext() called with null gltcx"; + + // If this is the current context, release it. + if (wglGetCurrentContext() == glctx.hglrc) + releaseGLContext(); + + HWND hwnd = WindowFromDC(glctx.hdc); + if (!hwnd) + LOG(FATAL) << "WindowFromDC() failed"; + if (!ReleaseDC(hwnd, glctx.hdc)) + LOG(FATAL) << "ReleaseDC() failed"; + if (!wglDeleteContext(glctx.hglrc)) + LOG(FATAL) << "wglDeleteContext() failed"; + if (!DestroyWindow(hwnd)) + LOG(FATAL) << "DestroyWindow() failed"; + + LOG(INFO) << std::hex << std::setfill('0') + << "WGL OpenGL context destroyed (hdc: 0x" << std::setw(8) << (uint32_t)(uintptr_t)glctx.hdc + << ", hglrc: 0x" << std::setw(8) << (uint32_t)(uintptr_t)glctx.hglrc << ")"; + + memset(&glctx, 0, sizeof(GLContext)); +} + +#endif // _WIN32 + +//------------------------------------------------------------------------ +// Linux. +//------------------------------------------------------------------------ + +#ifdef __linux__ + +static pthread_mutex_t s_getProcAddressMutex; + +typedef void (*PROCFN)(); + +static void safeGetProcAddress(const char* name, PROCFN* pfn) +{ + PROCFN result = eglGetProcAddress(name); + if (!result) + { + pthread_mutex_unlock(&s_getProcAddressMutex); // Prepare for thread exit. + LOG(FATAL) << "wglGetProcAddress() failed for '" << name << "'"; + exit(1); // Should never get here but make sure we exit. + } + *pfn = result; +} + +static void initializeGLExtensions(void) +{ + pthread_mutex_lock(&s_getProcAddressMutex); + + // Only dig function pointers if not done already. + if (!s_glExtInitialized) + { + // Generate code to populate the function pointers. +#define GLUTIL_EXT(return_type, name, ...) safeGetProcAddress(#name, (PROCFN*)&name); +#include "glutil_extlist.h" +#undef GLUTIL_EXT + + // Mark as initialized. + s_glExtInitialized = true; + } + + pthread_mutex_unlock(&s_getProcAddressMutex); + return; +} + +void setGLContext(GLContext& glctx) +{ + if (!glctx.context) + LOG(FATAL) << "setGLContext() called with null gltcx"; + + if (!eglMakeCurrent(glctx.display, EGL_NO_SURFACE, EGL_NO_SURFACE, glctx.context)) + LOG(ERROR) << "eglMakeCurrent() failed when setting GL context"; + + if (glctx.extInitialized) + return; + initializeGLExtensions(); + glctx.extInitialized = 1; +} + +void releaseGLContext(void) +{ + EGLDisplay display = eglGetCurrentDisplay(); + if (display == EGL_NO_DISPLAY) + LOG(WARNING) << "releaseGLContext() called with no active display"; + if (!eglMakeCurrent(display, EGL_NO_SURFACE, EGL_NO_SURFACE, EGL_NO_CONTEXT)) + LOG(FATAL) << "eglMakeCurrent() failed when releasing GL context"; +} + +static EGLDisplay getCudaDisplay(int cudaDeviceIdx) +{ + typedef EGLBoolean (*eglQueryDevicesEXT_t)(EGLint, EGLDeviceEXT, EGLint*); + typedef EGLBoolean (*eglQueryDeviceAttribEXT_t)(EGLDeviceEXT, EGLint, EGLAttrib*); + typedef EGLDisplay (*eglGetPlatformDisplayEXT_t)(EGLenum, void*, const EGLint*); + + eglQueryDevicesEXT_t eglQueryDevicesEXT = (eglQueryDevicesEXT_t)eglGetProcAddress("eglQueryDevicesEXT"); + if (!eglQueryDevicesEXT) + { + LOG(INFO) << "eglGetProcAddress(\"eglQueryDevicesEXT\") failed"; + return 0; + } + + eglQueryDeviceAttribEXT_t eglQueryDeviceAttribEXT = (eglQueryDeviceAttribEXT_t)eglGetProcAddress("eglQueryDeviceAttribEXT"); + if (!eglQueryDeviceAttribEXT) + { + LOG(INFO) << "eglGetProcAddress(\"eglQueryDeviceAttribEXT\") failed"; + return 0; + } + + eglGetPlatformDisplayEXT_t eglGetPlatformDisplayEXT = (eglGetPlatformDisplayEXT_t)eglGetProcAddress("eglGetPlatformDisplayEXT"); + if (!eglGetPlatformDisplayEXT) + { + LOG(INFO) << "eglGetProcAddress(\"eglGetPlatformDisplayEXT\") failed"; + return 0; + } + + int num_devices = 0; + eglQueryDevicesEXT(0, 0, &num_devices); + if (!num_devices) + return 0; + + EGLDisplay display = 0; + EGLDeviceEXT* devices = (EGLDeviceEXT*)malloc(num_devices * sizeof(void*)); + eglQueryDevicesEXT(num_devices, devices, &num_devices); + for (int i=0; i < num_devices; i++) + { + EGLDeviceEXT device = devices[i]; + intptr_t value = -1; + if (eglQueryDeviceAttribEXT(device, EGL_CUDA_DEVICE_NV, &value) && value == cudaDeviceIdx) + { + display = eglGetPlatformDisplayEXT(EGL_PLATFORM_DEVICE_EXT, device, 0); + break; + } + } + + free(devices); + return display; +} + +GLContext createGLContext(int cudaDeviceIdx) +{ + EGLDisplay display = 0; + + if (cudaDeviceIdx >= 0) + { + char pciBusId[256] = ""; + LOG(INFO) << "Creating GL context for Cuda device " << cudaDeviceIdx; + display = getCudaDisplay(cudaDeviceIdx); + if (!display) + LOG(INFO) << "Failed, falling back to default display"; + } + + if (!display) + { + display = eglGetDisplay(EGL_DEFAULT_DISPLAY); + if (display == EGL_NO_DISPLAY) + LOG(FATAL) << "eglGetDisplay() failed"; + } + + EGLint major; + EGLint minor; + if (!eglInitialize(display, &major, &minor)) + LOG(FATAL) << "eglInitialize() failed"; + + // Choose configuration. + + const EGLint context_attribs[] = { + EGL_RED_SIZE, 8, + EGL_GREEN_SIZE, 8, + EGL_BLUE_SIZE, 8, + EGL_ALPHA_SIZE, 8, + EGL_DEPTH_SIZE, 24, + EGL_STENCIL_SIZE, 8, + EGL_RENDERABLE_TYPE, EGL_OPENGL_BIT, + EGL_SURFACE_TYPE, EGL_PBUFFER_BIT, + EGL_NONE + }; + + EGLConfig config; + EGLint num_config; + if (!eglChooseConfig(display, context_attribs, &config, 1, &num_config)) + LOG(FATAL) << "eglChooseConfig() failed"; + + // Create GL context. + + if (!eglBindAPI(EGL_OPENGL_API)) + LOG(FATAL) << "eglBindAPI() failed"; + + EGLContext context = eglCreateContext(display, config, EGL_NO_CONTEXT, NULL); + if (context == EGL_NO_CONTEXT) + LOG(FATAL) << "eglCreateContext() failed"; + + // Done. + + LOG(INFO) << "EGL " << (int)minor << "." << (int)major << " OpenGL context created (disp: 0x" + << std::hex << std::setfill('0') + << std::setw(16) << (uintptr_t)display + << ", ctx: 0x" << std::setw(16) << (uintptr_t)context << ")"; + + GLContext glctx = {display, context, 0}; + return glctx; +} + +void destroyGLContext(GLContext& glctx) +{ + if (!glctx.context) + LOG(FATAL) << "destroyGLContext() called with null gltcx"; + + // If this is the current context, release it. + if (eglGetCurrentContext() == glctx.context) + releaseGLContext(); + + if (!eglDestroyContext(glctx.display, glctx.context)) + LOG(ERROR) << "eglDestroyContext() failed"; + + LOG(INFO) << "EGL OpenGL context destroyed (disp: 0x" + << std::hex << std::setfill('0') + << std::setw(16) << (uintptr_t)glctx.display + << ", ctx: 0x" << std::setw(16) << (uintptr_t)glctx.context << ")"; + + memset(&glctx, 0, sizeof(GLContext)); +} + +//------------------------------------------------------------------------ + +#endif // __linux__ + +//------------------------------------------------------------------------ diff --git a/extensions/nvdiffrast/nvdiffrast/common/glutil.h b/extensions/nvdiffrast/nvdiffrast/common/glutil.h new file mode 100644 index 0000000000000000000000000000000000000000..e9a3a7d95a5af4a808a25097cc055b699024409e --- /dev/null +++ b/extensions/nvdiffrast/nvdiffrast/common/glutil.h @@ -0,0 +1,113 @@ +// Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +// +// NVIDIA CORPORATION and its licensors retain all intellectual property +// and proprietary rights in and to this software, related documentation +// and any modifications thereto. Any use, reproduction, disclosure or +// distribution of this software and related documentation without an express +// license agreement from NVIDIA CORPORATION is strictly prohibited. + +#pragma once + +//------------------------------------------------------------------------ +// Windows-specific headers and types. +//------------------------------------------------------------------------ + +#ifdef _WIN32 +#define NOMINMAX +#include // Required by gl.h in Windows. +#define GLAPIENTRY APIENTRY + +struct GLContext +{ + HDC hdc; + HGLRC hglrc; + int extInitialized; +}; + +#endif // _WIN32 + +//------------------------------------------------------------------------ +// Linux-specific headers and types. +//------------------------------------------------------------------------ + +#ifdef __linux__ +#define EGL_NO_X11 // X11/Xlib.h has "#define Status int" which breaks Tensorflow. Avoid it. +#define MESA_EGL_NO_X11_HEADERS +#include +#include +#define GLAPIENTRY + +struct GLContext +{ + EGLDisplay display; + EGLContext context; + int extInitialized; +}; + +#endif // __linux__ + +//------------------------------------------------------------------------ +// OpenGL, CUDA interop, GL extensions. +//------------------------------------------------------------------------ +#define GL_GLEXT_LEGACY +#include +#include + +// Constants. +#ifndef GL_VERSION_1_2 +#define GL_CLAMP_TO_EDGE 0x812F +#define GL_TEXTURE_3D 0x806F +#endif +#ifndef GL_VERSION_1_5 +#define GL_ARRAY_BUFFER 0x8892 +#define GL_DYNAMIC_DRAW 0x88E8 +#define GL_ELEMENT_ARRAY_BUFFER 0x8893 +#endif +#ifndef GL_VERSION_2_0 +#define GL_FRAGMENT_SHADER 0x8B30 +#define GL_INFO_LOG_LENGTH 0x8B84 +#define GL_LINK_STATUS 0x8B82 +#define GL_VERTEX_SHADER 0x8B31 +#endif +#ifndef GL_VERSION_3_0 +#define GL_MAJOR_VERSION 0x821B +#define GL_MINOR_VERSION 0x821C +#define GL_RGBA32F 0x8814 +#define GL_TEXTURE_2D_ARRAY 0x8C1A +#endif +#ifndef GL_VERSION_3_2 +#define GL_GEOMETRY_SHADER 0x8DD9 +#endif +#ifndef GL_ARB_framebuffer_object +#define GL_COLOR_ATTACHMENT0 0x8CE0 +#define GL_COLOR_ATTACHMENT1 0x8CE1 +#define GL_DEPTH_STENCIL 0x84F9 +#define GL_DEPTH_STENCIL_ATTACHMENT 0x821A +#define GL_DEPTH24_STENCIL8 0x88F0 +#define GL_FRAMEBUFFER 0x8D40 +#define GL_INVALID_FRAMEBUFFER_OPERATION 0x0506 +#define GL_UNSIGNED_INT_24_8 0x84FA +#endif +#ifndef GL_ARB_imaging +#define GL_TABLE_TOO_LARGE 0x8031 +#endif +#ifndef GL_KHR_robustness +#define GL_CONTEXT_LOST 0x0507 +#endif + +// Declare function pointers to OpenGL extension functions. +#define GLUTIL_EXT(return_type, name, ...) extern return_type (GLAPIENTRY* name)(__VA_ARGS__); +#include "glutil_extlist.h" +#undef GLUTIL_EXT + +//------------------------------------------------------------------------ +// Common functions. +//------------------------------------------------------------------------ + +void setGLContext (GLContext& glctx); +void releaseGLContext (void); +GLContext createGLContext (int cudaDeviceIdx); +void destroyGLContext (GLContext& glctx); +const char* getGLErrorString (GLenum err); + +//------------------------------------------------------------------------ diff --git a/extensions/nvdiffrast/nvdiffrast/common/glutil_extlist.h b/extensions/nvdiffrast/nvdiffrast/common/glutil_extlist.h new file mode 100644 index 0000000000000000000000000000000000000000..afa08f399ad59e635b055548aec04cc661e28485 --- /dev/null +++ b/extensions/nvdiffrast/nvdiffrast/common/glutil_extlist.h @@ -0,0 +1,48 @@ +// Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +// +// NVIDIA CORPORATION and its licensors retain all intellectual property +// and proprietary rights in and to this software, related documentation +// and any modifications thereto. Any use, reproduction, disclosure or +// distribution of this software and related documentation without an express +// license agreement from NVIDIA CORPORATION is strictly prohibited. + +#ifndef GL_VERSION_1_2 +GLUTIL_EXT(void, glTexImage3D, GLenum target, GLint level, GLint internalFormat, GLsizei width, GLsizei height, GLsizei depth, GLint border, GLenum format, GLenum type, const void *pixels); +#endif +#ifndef GL_VERSION_1_5 +GLUTIL_EXT(void, glBindBuffer, GLenum target, GLuint buffer); +GLUTIL_EXT(void, glBufferData, GLenum target, ptrdiff_t size, const void* data, GLenum usage); +GLUTIL_EXT(void, glGenBuffers, GLsizei n, GLuint* buffers); +#endif +#ifndef GL_VERSION_2_0 +GLUTIL_EXT(void, glAttachShader, GLuint program, GLuint shader); +GLUTIL_EXT(void, glCompileShader, GLuint shader); +GLUTIL_EXT(GLuint, glCreateProgram, void); +GLUTIL_EXT(GLuint, glCreateShader, GLenum type); +GLUTIL_EXT(void, glDrawBuffers, GLsizei n, const GLenum* bufs); +GLUTIL_EXT(void, glEnableVertexAttribArray, GLuint index); +GLUTIL_EXT(void, glGetProgramInfoLog, GLuint program, GLsizei bufSize, GLsizei* length, char* infoLog); +GLUTIL_EXT(void, glGetProgramiv, GLuint program, GLenum pname, GLint* param); +GLUTIL_EXT(void, glLinkProgram, GLuint program); +GLUTIL_EXT(void, glShaderSource, GLuint shader, GLsizei count, const char *const* string, const GLint* length); +GLUTIL_EXT(void, glUniform1f, GLint location, GLfloat v0); +GLUTIL_EXT(void, glUniform2f, GLint location, GLfloat v0, GLfloat v1); +GLUTIL_EXT(void, glUseProgram, GLuint program); +GLUTIL_EXT(void, glVertexAttribPointer, GLuint index, GLint size, GLenum type, GLboolean normalized, GLsizei stride, const void* pointer); +#endif +#ifndef GL_VERSION_3_2 +GLUTIL_EXT(void, glFramebufferTexture, GLenum target, GLenum attachment, GLuint texture, GLint level); +#endif +#ifndef GL_ARB_framebuffer_object +GLUTIL_EXT(void, glBindFramebuffer, GLenum target, GLuint framebuffer); +GLUTIL_EXT(void, glGenFramebuffers, GLsizei n, GLuint* framebuffers); +#endif +#ifndef GL_ARB_vertex_array_object +GLUTIL_EXT(void, glBindVertexArray, GLuint array); +GLUTIL_EXT(void, glGenVertexArrays, GLsizei n, GLuint* arrays); +#endif +#ifndef GL_ARB_multi_draw_indirect +GLUTIL_EXT(void, glMultiDrawElementsIndirect, GLenum mode, GLenum type, const void *indirect, GLsizei primcount, GLsizei stride); +#endif + +//------------------------------------------------------------------------ diff --git a/extensions/nvdiffrast/nvdiffrast/common/interpolate.cu b/extensions/nvdiffrast/nvdiffrast/common/interpolate.cu new file mode 100644 index 0000000000000000000000000000000000000000..3bd2a7a7ab3111ae12f6cdce73906eeb9bbf6935 --- /dev/null +++ b/extensions/nvdiffrast/nvdiffrast/common/interpolate.cu @@ -0,0 +1,276 @@ +// Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +// +// NVIDIA CORPORATION and its licensors retain all intellectual property +// and proprietary rights in and to this software, related documentation +// and any modifications thereto. Any use, reproduction, disclosure or +// distribution of this software and related documentation without an express +// license agreement from NVIDIA CORPORATION is strictly prohibited. + +#include "common.h" +#include "interpolate.h" + +//------------------------------------------------------------------------ +// Forward kernel. + +template +static __forceinline__ __device__ void InterpolateFwdKernelTemplate(const InterpolateKernelParams p) +{ + // Calculate pixel position. + int px = blockIdx.x * blockDim.x + threadIdx.x; + int py = blockIdx.y * blockDim.y + threadIdx.y; + int pz = blockIdx.z; + if (px >= p.width || py >= p.height || pz >= p.depth) + return; + + // Pixel index. + int pidx = px + p.width * (py + p.height * pz); + + // Output ptrs. + float* out = p.out + pidx * p.numAttr; + float2* outDA = ENABLE_DA ? (((float2*)p.outDA) + pidx * p.numDiffAttr) : 0; + + // Fetch rasterizer output. + float4 r = ((float4*)p.rast)[pidx]; + int triIdx = float_to_triidx(r.w) - 1; + bool triValid = (triIdx >= 0 && triIdx < p.numTriangles); + + // If no geometry in entire warp, zero the output and exit. + // Otherwise force barys to zero and output with live threads. + if (__all_sync(0xffffffffu, !triValid)) + { + for (int i=0; i < p.numAttr; i++) + out[i] = 0.f; + if (ENABLE_DA) + for (int i=0; i < p.numDiffAttr; i++) + outDA[i] = make_float2(0.f, 0.f); + return; + } + + // Fetch vertex indices. + int vi0 = triValid ? p.tri[triIdx * 3 + 0] : 0; + int vi1 = triValid ? p.tri[triIdx * 3 + 1] : 0; + int vi2 = triValid ? p.tri[triIdx * 3 + 2] : 0; + + // Bail out if corrupt indices. + if (vi0 < 0 || vi0 >= p.numVertices || + vi1 < 0 || vi1 >= p.numVertices || + vi2 < 0 || vi2 >= p.numVertices) + return; + + // In instance mode, adjust vertex indices by minibatch index unless broadcasting. + if (p.instance_mode && !p.attrBC) + { + vi0 += pz * p.numVertices; + vi1 += pz * p.numVertices; + vi2 += pz * p.numVertices; + } + + // Pointers to attributes. + const float* a0 = p.attr + vi0 * p.numAttr; + const float* a1 = p.attr + vi1 * p.numAttr; + const float* a2 = p.attr + vi2 * p.numAttr; + + // Barys. If no triangle, force all to zero -> output is zero. + float b0 = triValid ? r.x : 0.f; + float b1 = triValid ? r.y : 0.f; + float b2 = triValid ? (1.f - r.x - r.y) : 0.f; + + // Interpolate and write attributes. + for (int i=0; i < p.numAttr; i++) + out[i] = b0*a0[i] + b1*a1[i] + b2*a2[i]; + + // No diff attrs? Exit. + if (!ENABLE_DA) + return; + + // Read bary pixel differentials if we have a triangle. + float4 db = make_float4(0.f, 0.f, 0.f, 0.f); + if (triValid) + db = ((float4*)p.rastDB)[pidx]; + + // Unpack a bit. + float dudx = db.x; + float dudy = db.y; + float dvdx = db.z; + float dvdy = db.w; + + // Calculate the pixel differentials of chosen attributes. + for (int i=0; i < p.numDiffAttr; i++) + { + // Input attribute index. + int j = p.diff_attrs_all ? i : p.diffAttrs[i]; + if (j < 0) + j += p.numAttr; // Python-style negative indices. + + // Zero output if invalid index. + float dsdx = 0.f; + float dsdy = 0.f; + if (j >= 0 && j < p.numAttr) + { + float s0 = a0[j]; + float s1 = a1[j]; + float s2 = a2[j]; + float dsdu = s0 - s2; + float dsdv = s1 - s2; + dsdx = dudx*dsdu + dvdx*dsdv; + dsdy = dudy*dsdu + dvdy*dsdv; + } + + // Write. + outDA[i] = make_float2(dsdx, dsdy); + } +} + +// Template specializations. +__global__ void InterpolateFwdKernel (const InterpolateKernelParams p) { InterpolateFwdKernelTemplate(p); } +__global__ void InterpolateFwdKernelDa(const InterpolateKernelParams p) { InterpolateFwdKernelTemplate(p); } + +//------------------------------------------------------------------------ +// Gradient kernel. + +template +static __forceinline__ __device__ void InterpolateGradKernelTemplate(const InterpolateKernelParams p) +{ + // Temporary space for coalesced atomics. + CA_DECLARE_TEMP(IP_GRAD_MAX_KERNEL_BLOCK_WIDTH * IP_GRAD_MAX_KERNEL_BLOCK_HEIGHT); + + // Calculate pixel position. + int px = blockIdx.x * blockDim.x + threadIdx.x; + int py = blockIdx.y * blockDim.y + threadIdx.y; + int pz = blockIdx.z; + if (px >= p.width || py >= p.height || pz >= p.depth) + return; + + // Pixel index. + int pidx = px + p.width * (py + p.height * pz); + + // Fetch triangle ID. If none, output zero bary/db gradients and exit. + float4 r = ((float4*)p.rast)[pidx]; + int triIdx = float_to_triidx(r.w) - 1; + if (triIdx < 0 || triIdx >= p.numTriangles) + { + ((float4*)p.gradRaster)[pidx] = make_float4(0.f, 0.f, 0.f, 0.f); + if (ENABLE_DA) + ((float4*)p.gradRasterDB)[pidx] = make_float4(0.f, 0.f, 0.f, 0.f); + return; + } + + // Fetch vertex indices. + int vi0 = p.tri[triIdx * 3 + 0]; + int vi1 = p.tri[triIdx * 3 + 1]; + int vi2 = p.tri[triIdx * 3 + 2]; + + // Bail out if corrupt indices. + if (vi0 < 0 || vi0 >= p.numVertices || + vi1 < 0 || vi1 >= p.numVertices || + vi2 < 0 || vi2 >= p.numVertices) + return; + + // In instance mode, adjust vertex indices by minibatch index unless broadcasting. + if (p.instance_mode && !p.attrBC) + { + vi0 += pz * p.numVertices; + vi1 += pz * p.numVertices; + vi2 += pz * p.numVertices; + } + + // Initialize coalesced atomics. + CA_SET_GROUP(triIdx); + + // Pointers to inputs. + const float* a0 = p.attr + vi0 * p.numAttr; + const float* a1 = p.attr + vi1 * p.numAttr; + const float* a2 = p.attr + vi2 * p.numAttr; + const float* pdy = p.dy + pidx * p.numAttr; + + // Pointers to outputs. + float* ga0 = p.gradAttr + vi0 * p.numAttr; + float* ga1 = p.gradAttr + vi1 * p.numAttr; + float* ga2 = p.gradAttr + vi2 * p.numAttr; + + // Barys and bary gradient accumulators. + float b0 = r.x; + float b1 = r.y; + float b2 = 1.f - r.x - r.y; + float gb0 = 0.f; + float gb1 = 0.f; + + // Loop over attributes and accumulate attribute gradients. + for (int i=0; i < p.numAttr; i++) + { + float y = pdy[i]; + float s0 = a0[i]; + float s1 = a1[i]; + float s2 = a2[i]; + gb0 += y * (s0 - s2); + gb1 += y * (s1 - s2); + caAtomicAdd(ga0 + i, b0 * y); + caAtomicAdd(ga1 + i, b1 * y); + caAtomicAdd(ga2 + i, b2 * y); + } + + // Write the bary gradients. + ((float4*)p.gradRaster)[pidx] = make_float4(gb0, gb1, 0.f, 0.f); + + // If pixel differentials disabled, we're done. + if (!ENABLE_DA) + return; + + // Calculate gradients based on attribute pixel differentials. + const float2* dda = ((float2*)p.dda) + pidx * p.numDiffAttr; + float gdudx = 0.f; + float gdudy = 0.f; + float gdvdx = 0.f; + float gdvdy = 0.f; + + // Read bary pixel differentials. + float4 db = ((float4*)p.rastDB)[pidx]; + float dudx = db.x; + float dudy = db.y; + float dvdx = db.z; + float dvdy = db.w; + + for (int i=0; i < p.numDiffAttr; i++) + { + // Input attribute index. + int j = p.diff_attrs_all ? i : p.diffAttrs[i]; + if (j < 0) + j += p.numAttr; // Python-style negative indices. + + // Check that index is valid. + if (j >= 0 && j < p.numAttr) + { + float2 dsdxy = dda[i]; + float dsdx = dsdxy.x; + float dsdy = dsdxy.y; + + float s0 = a0[j]; + float s1 = a1[j]; + float s2 = a2[j]; + + // Gradients of db. + float dsdu = s0 - s2; + float dsdv = s1 - s2; + gdudx += dsdu * dsdx; + gdudy += dsdu * dsdy; + gdvdx += dsdv * dsdx; + gdvdy += dsdv * dsdy; + + // Gradients of attributes. + float du = dsdx*dudx + dsdy*dudy; + float dv = dsdx*dvdx + dsdy*dvdy; + caAtomicAdd(ga0 + j, du); + caAtomicAdd(ga1 + j, dv); + caAtomicAdd(ga2 + j, -du - dv); + } + } + + // Write. + ((float4*)p.gradRasterDB)[pidx] = make_float4(gdudx, gdudy, gdvdx, gdvdy); +} + +// Template specializations. +__global__ void InterpolateGradKernel (const InterpolateKernelParams p) { InterpolateGradKernelTemplate(p); } +__global__ void InterpolateGradKernelDa(const InterpolateKernelParams p) { InterpolateGradKernelTemplate(p); } + +//------------------------------------------------------------------------ diff --git a/extensions/nvdiffrast/nvdiffrast/common/interpolate.h b/extensions/nvdiffrast/nvdiffrast/common/interpolate.h new file mode 100644 index 0000000000000000000000000000000000000000..d35d8388240e97c255c837446609d8ae00cd78d9 --- /dev/null +++ b/extensions/nvdiffrast/nvdiffrast/common/interpolate.h @@ -0,0 +1,49 @@ +// Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +// +// NVIDIA CORPORATION and its licensors retain all intellectual property +// and proprietary rights in and to this software, related documentation +// and any modifications thereto. Any use, reproduction, disclosure or +// distribution of this software and related documentation without an express +// license agreement from NVIDIA CORPORATION is strictly prohibited. + +#pragma once + +//------------------------------------------------------------------------ +// Constants and helpers. + +#define IP_FWD_MAX_KERNEL_BLOCK_WIDTH 8 +#define IP_FWD_MAX_KERNEL_BLOCK_HEIGHT 8 +#define IP_GRAD_MAX_KERNEL_BLOCK_WIDTH 8 +#define IP_GRAD_MAX_KERNEL_BLOCK_HEIGHT 8 +#define IP_MAX_DIFF_ATTRS 32 + +//------------------------------------------------------------------------ +// CUDA kernel params. + +struct InterpolateKernelParams +{ + const int* tri; // Incoming triangle buffer. + const float* attr; // Incoming attribute buffer. + const float* rast; // Incoming rasterizer output buffer. + const float* rastDB; // Incoming rasterizer output buffer for bary derivatives. + const float* dy; // Incoming attribute gradients. + const float* dda; // Incoming attr diff gradients. + float* out; // Outgoing interpolated attributes. + float* outDA; // Outgoing texcoord major axis lengths. + float* gradAttr; // Outgoing attribute gradients. + float* gradRaster; // Outgoing rasterizer gradients. + float* gradRasterDB; // Outgoing rasterizer bary diff gradients. + int numTriangles; // Number of triangles. + int numVertices; // Number of vertices. + int numAttr; // Number of total vertex attributes. + int numDiffAttr; // Number of attributes to differentiate. + int width; // Image width. + int height; // Image height. + int depth; // Minibatch size. + int attrBC; // 0=normal, 1=attr is broadcast. + int instance_mode; // 0=normal, 1=instance mode. + int diff_attrs_all; // 0=normal, 1=produce pixel differentials for all attributes. + int diffAttrs[IP_MAX_DIFF_ATTRS]; // List of attributes to differentiate. +}; + +//------------------------------------------------------------------------ diff --git a/extensions/nvdiffrast/nvdiffrast/common/rasterize.cu b/extensions/nvdiffrast/nvdiffrast/common/rasterize.cu new file mode 100644 index 0000000000000000000000000000000000000000..455aca3e09064d1fbe25b406ff958ad8efb4dffe --- /dev/null +++ b/extensions/nvdiffrast/nvdiffrast/common/rasterize.cu @@ -0,0 +1,276 @@ +// Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +// +// NVIDIA CORPORATION and its licensors retain all intellectual property +// and proprietary rights in and to this software, related documentation +// and any modifications thereto. Any use, reproduction, disclosure or +// distribution of this software and related documentation without an express +// license agreement from NVIDIA CORPORATION is strictly prohibited. + +#include "common.h" +#include "rasterize.h" + +//------------------------------------------------------------------------ +// Cuda forward rasterizer pixel shader kernel. + +__global__ void RasterizeCudaFwdShaderKernel(const RasterizeCudaFwdShaderParams p) +{ + // Calculate pixel position. + int px = blockIdx.x * blockDim.x + threadIdx.x; + int py = blockIdx.y * blockDim.y + threadIdx.y; + int pz = blockIdx.z; + if (px >= p.width_out || py >= p.height_out || pz >= p.depth) + return; + + // Pixel indices. + int pidx_in = px + p.width_in * (py + p.height_in * pz); + int pidx_out = px + p.width_out * (py + p.height_out * pz); + + // Fetch triangle idx. + int triIdx = p.in_idx[pidx_in] - 1; + if (triIdx < 0 || triIdx >= p.numTriangles) + { + // No or corrupt triangle. + ((float4*)p.out)[pidx_out] = make_float4(0.0, 0.0, 0.0, 0.0); // Clear out. + ((float4*)p.out_db)[pidx_out] = make_float4(0.0, 0.0, 0.0, 0.0); // Clear out_db. + return; + } + + // Fetch vertex indices. + int vi0 = p.tri[triIdx * 3 + 0]; + int vi1 = p.tri[triIdx * 3 + 1]; + int vi2 = p.tri[triIdx * 3 + 2]; + + // Bail out if vertex indices are corrupt. + if (vi0 < 0 || vi0 >= p.numVertices || + vi1 < 0 || vi1 >= p.numVertices || + vi2 < 0 || vi2 >= p.numVertices) + return; + + // In instance mode, adjust vertex indices by minibatch index. + if (p.instance_mode) + { + vi0 += pz * p.numVertices; + vi1 += pz * p.numVertices; + vi2 += pz * p.numVertices; + } + + // Fetch vertex positions. + float4 p0 = ((float4*)p.pos)[vi0]; + float4 p1 = ((float4*)p.pos)[vi1]; + float4 p2 = ((float4*)p.pos)[vi2]; + + // Evaluate edge functions. + float fx = p.xs * (float)px + p.xo; + float fy = p.ys * (float)py + p.yo; + float p0x = p0.x - fx * p0.w; + float p0y = p0.y - fy * p0.w; + float p1x = p1.x - fx * p1.w; + float p1y = p1.y - fy * p1.w; + float p2x = p2.x - fx * p2.w; + float p2y = p2.y - fy * p2.w; + float a0 = p1x*p2y - p1y*p2x; + float a1 = p2x*p0y - p2y*p0x; + float a2 = p0x*p1y - p0y*p1x; + + // Perspective correct, normalized barycentrics. + float iw = 1.f / (a0 + a1 + a2); + float b0 = a0 * iw; + float b1 = a1 * iw; + + // Compute z/w for depth buffer. + float z = p0.z * a0 + p1.z * a1 + p2.z * a2; + float w = p0.w * a0 + p1.w * a1 + p2.w * a2; + float zw = z / w; + + // Clamps to avoid NaNs. + b0 = __saturatef(b0); // Clamp to [+0.0, 1.0]. + b1 = __saturatef(b1); // Clamp to [+0.0, 1.0]. + zw = fmaxf(fminf(zw, 1.f), -1.f); + + // Emit output. + ((float4*)p.out)[pidx_out] = make_float4(b0, b1, zw, triidx_to_float(triIdx + 1)); + + // Calculate bary pixel differentials. + float dfxdx = p.xs * iw; + float dfydy = p.ys * iw; + float da0dx = p2.y*p1.w - p1.y*p2.w; + float da0dy = p1.x*p2.w - p2.x*p1.w; + float da1dx = p0.y*p2.w - p2.y*p0.w; + float da1dy = p2.x*p0.w - p0.x*p2.w; + float da2dx = p1.y*p0.w - p0.y*p1.w; + float da2dy = p0.x*p1.w - p1.x*p0.w; + float datdx = da0dx + da1dx + da2dx; + float datdy = da0dy + da1dy + da2dy; + float dudx = dfxdx * (b0 * datdx - da0dx); + float dudy = dfydy * (b0 * datdy - da0dy); + float dvdx = dfxdx * (b1 * datdx - da1dx); + float dvdy = dfydy * (b1 * datdy - da1dy); + + // Emit bary pixel differentials. + ((float4*)p.out_db)[pidx_out] = make_float4(dudx, dudy, dvdx, dvdy); +} + +//------------------------------------------------------------------------ +// Gradient Cuda kernel. + +template +static __forceinline__ __device__ void RasterizeGradKernelTemplate(const RasterizeGradParams p) +{ + // Temporary space for coalesced atomics. + CA_DECLARE_TEMP(RAST_GRAD_MAX_KERNEL_BLOCK_WIDTH * RAST_GRAD_MAX_KERNEL_BLOCK_HEIGHT); + + // Calculate pixel position. + int px = blockIdx.x * blockDim.x + threadIdx.x; + int py = blockIdx.y * blockDim.y + threadIdx.y; + int pz = blockIdx.z; + if (px >= p.width || py >= p.height || pz >= p.depth) + return; + + // Pixel index. + int pidx = px + p.width * (py + p.height * pz); + + // Read triangle idx and dy. + float2 dy = ((float2*)p.dy)[pidx * 2]; + float4 ddb = ENABLE_DB ? ((float4*)p.ddb)[pidx] : make_float4(0.f, 0.f, 0.f, 0.f); + int triIdx = float_to_triidx(((float*)p.out)[pidx * 4 + 3]) - 1; + + // Exit if nothing to do. + if (triIdx < 0 || triIdx >= p.numTriangles) + return; // No or corrupt triangle. + int grad_all_dy = __float_as_int(dy.x) | __float_as_int(dy.y); // Bitwise OR of all incoming gradients. + int grad_all_ddb = 0; + if (ENABLE_DB) + grad_all_ddb = __float_as_int(ddb.x) | __float_as_int(ddb.y) | __float_as_int(ddb.z) | __float_as_int(ddb.w); + if (((grad_all_dy | grad_all_ddb) << 1) == 0) + return; // All incoming gradients are +0/-0. + + // Fetch vertex indices. + int vi0 = p.tri[triIdx * 3 + 0]; + int vi1 = p.tri[triIdx * 3 + 1]; + int vi2 = p.tri[triIdx * 3 + 2]; + + // Bail out if vertex indices are corrupt. + if (vi0 < 0 || vi0 >= p.numVertices || + vi1 < 0 || vi1 >= p.numVertices || + vi2 < 0 || vi2 >= p.numVertices) + return; + + // In instance mode, adjust vertex indices by minibatch index. + if (p.instance_mode) + { + vi0 += pz * p.numVertices; + vi1 += pz * p.numVertices; + vi2 += pz * p.numVertices; + } + + // Initialize coalesced atomics. + CA_SET_GROUP(triIdx); + + // Fetch vertex positions. + float4 p0 = ((float4*)p.pos)[vi0]; + float4 p1 = ((float4*)p.pos)[vi1]; + float4 p2 = ((float4*)p.pos)[vi2]; + + // Evaluate edge functions. + float fx = p.xs * (float)px + p.xo; + float fy = p.ys * (float)py + p.yo; + float p0x = p0.x - fx * p0.w; + float p0y = p0.y - fy * p0.w; + float p1x = p1.x - fx * p1.w; + float p1y = p1.y - fy * p1.w; + float p2x = p2.x - fx * p2.w; + float p2y = p2.y - fy * p2.w; + float a0 = p1x*p2y - p1y*p2x; + float a1 = p2x*p0y - p2y*p0x; + float a2 = p0x*p1y - p0y*p1x; + + // Compute inverse area with epsilon. + float at = a0 + a1 + a2; + float ep = copysignf(1e-6f, at); // ~1 pixel in 1k x 1k image. + float iw = 1.f / (at + ep); + + // Perspective correct, normalized barycentrics. + float b0 = a0 * iw; + float b1 = a1 * iw; + + // Position gradients. + float gb0 = dy.x * iw; + float gb1 = dy.y * iw; + float gbb = gb0 * b0 + gb1 * b1; + float gp0x = gbb * (p2y - p1y) - gb1 * p2y; + float gp1x = gbb * (p0y - p2y) + gb0 * p2y; + float gp2x = gbb * (p1y - p0y) - gb0 * p1y + gb1 * p0y; + float gp0y = gbb * (p1x - p2x) + gb1 * p2x; + float gp1y = gbb * (p2x - p0x) - gb0 * p2x; + float gp2y = gbb * (p0x - p1x) + gb0 * p1x - gb1 * p0x; + float gp0w = -fx * gp0x - fy * gp0y; + float gp1w = -fx * gp1x - fy * gp1y; + float gp2w = -fx * gp2x - fy * gp2y; + + // Bary differential gradients. + if (ENABLE_DB && ((grad_all_ddb) << 1) != 0) + { + float dfxdX = p.xs * iw; + float dfydY = p.ys * iw; + ddb.x *= dfxdX; + ddb.y *= dfydY; + ddb.z *= dfxdX; + ddb.w *= dfydY; + + float da0dX = p1.y * p2.w - p2.y * p1.w; + float da1dX = p2.y * p0.w - p0.y * p2.w; + float da2dX = p0.y * p1.w - p1.y * p0.w; + float da0dY = p2.x * p1.w - p1.x * p2.w; + float da1dY = p0.x * p2.w - p2.x * p0.w; + float da2dY = p1.x * p0.w - p0.x * p1.w; + float datdX = da0dX + da1dX + da2dX; + float datdY = da0dY + da1dY + da2dY; + + float x01 = p0.x - p1.x; + float x12 = p1.x - p2.x; + float x20 = p2.x - p0.x; + float y01 = p0.y - p1.y; + float y12 = p1.y - p2.y; + float y20 = p2.y - p0.y; + float w01 = p0.w - p1.w; + float w12 = p1.w - p2.w; + float w20 = p2.w - p0.w; + + float a0p1 = fy * p2.x - fx * p2.y; + float a0p2 = fx * p1.y - fy * p1.x; + float a1p0 = fx * p2.y - fy * p2.x; + float a1p2 = fy * p0.x - fx * p0.y; + + float wdudX = 2.f * b0 * datdX - da0dX; + float wdudY = 2.f * b0 * datdY - da0dY; + float wdvdX = 2.f * b1 * datdX - da1dX; + float wdvdY = 2.f * b1 * datdY - da1dY; + + float c0 = iw * (ddb.x * wdudX + ddb.y * wdudY + ddb.z * wdvdX + ddb.w * wdvdY); + float cx = c0 * fx - ddb.x * b0 - ddb.z * b1; + float cy = c0 * fy - ddb.y * b0 - ddb.w * b1; + float cxy = iw * (ddb.x * datdX + ddb.y * datdY); + float czw = iw * (ddb.z * datdX + ddb.w * datdY); + + gp0x += c0 * y12 - cy * w12 + czw * p2y + ddb.w * p2.w; + gp1x += c0 * y20 - cy * w20 - cxy * p2y - ddb.y * p2.w; + gp2x += c0 * y01 - cy * w01 + cxy * p1y - czw * p0y + ddb.y * p1.w - ddb.w * p0.w; + gp0y += cx * w12 - c0 * x12 - czw * p2x - ddb.z * p2.w; + gp1y += cx * w20 - c0 * x20 + cxy * p2x + ddb.x * p2.w; + gp2y += cx * w01 - c0 * x01 - cxy * p1x + czw * p0x - ddb.x * p1.w + ddb.z * p0.w; + gp0w += cy * x12 - cx * y12 - czw * a1p0 + ddb.z * p2.y - ddb.w * p2.x; + gp1w += cy * x20 - cx * y20 - cxy * a0p1 - ddb.x * p2.y + ddb.y * p2.x; + gp2w += cy * x01 - cx * y01 - cxy * a0p2 - czw * a1p2 + ddb.x * p1.y - ddb.y * p1.x - ddb.z * p0.y + ddb.w * p0.x; + } + + // Accumulate using coalesced atomics. + caAtomicAdd3_xyw(p.grad + 4 * vi0, gp0x, gp0y, gp0w); + caAtomicAdd3_xyw(p.grad + 4 * vi1, gp1x, gp1y, gp1w); + caAtomicAdd3_xyw(p.grad + 4 * vi2, gp2x, gp2y, gp2w); +} + +// Template specializations. +__global__ void RasterizeGradKernel (const RasterizeGradParams p) { RasterizeGradKernelTemplate(p); } +__global__ void RasterizeGradKernelDb(const RasterizeGradParams p) { RasterizeGradKernelTemplate(p); } + +//------------------------------------------------------------------------ diff --git a/extensions/nvdiffrast/nvdiffrast/common/rasterize.h b/extensions/nvdiffrast/nvdiffrast/common/rasterize.h new file mode 100644 index 0000000000000000000000000000000000000000..cb3104fae0e533e6da134e01c6020f70effb4964 --- /dev/null +++ b/extensions/nvdiffrast/nvdiffrast/common/rasterize.h @@ -0,0 +1,60 @@ +// Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +// +// NVIDIA CORPORATION and its licensors retain all intellectual property +// and proprietary rights in and to this software, related documentation +// and any modifications thereto. Any use, reproduction, disclosure or +// distribution of this software and related documentation without an express +// license agreement from NVIDIA CORPORATION is strictly prohibited. + +#pragma once + +//------------------------------------------------------------------------ +// Constants and helpers. + +#define RAST_CUDA_FWD_SHADER_KERNEL_BLOCK_WIDTH 8 +#define RAST_CUDA_FWD_SHADER_KERNEL_BLOCK_HEIGHT 8 +#define RAST_GRAD_MAX_KERNEL_BLOCK_WIDTH 8 +#define RAST_GRAD_MAX_KERNEL_BLOCK_HEIGHT 8 + +//------------------------------------------------------------------------ +// CUDA forward rasterizer shader kernel params. + +struct RasterizeCudaFwdShaderParams +{ + const float* pos; // Vertex positions. + const int* tri; // Triangle indices. + const int* in_idx; // Triangle idx buffer from rasterizer. + float* out; // Main output buffer. + float* out_db; // Bary pixel gradient output buffer. + int numTriangles; // Number of triangles. + int numVertices; // Number of vertices. + int width_in; // Input image width. + int height_in; // Input image height. + int width_out; // Output image width. + int height_out; // Output image height. + int depth; // Size of minibatch. + int instance_mode; // 1 if in instance rendering mode. + float xs, xo, ys, yo; // Pixel position to clip-space x, y transform. +}; + +//------------------------------------------------------------------------ +// Gradient CUDA kernel params. + +struct RasterizeGradParams +{ + const float* pos; // Incoming position buffer. + const int* tri; // Incoming triangle buffer. + const float* out; // Rasterizer output buffer. + const float* dy; // Incoming gradients of rasterizer output buffer. + const float* ddb; // Incoming gradients of bary diff output buffer. + float* grad; // Outgoing position gradients. + int numTriangles; // Number of triangles. + int numVertices; // Number of vertices. + int width; // Image width. + int height; // Image height. + int depth; // Size of minibatch. + int instance_mode; // 1 if in instance rendering mode. + float xs, xo, ys, yo; // Pixel position to clip-space x, y transform. +}; + +//------------------------------------------------------------------------ diff --git a/extensions/nvdiffrast/nvdiffrast/common/rasterize_gl.cpp b/extensions/nvdiffrast/nvdiffrast/common/rasterize_gl.cpp new file mode 100644 index 0000000000000000000000000000000000000000..ac71ccd8eb91740c4c8cacc21cb9fb00f452403c --- /dev/null +++ b/extensions/nvdiffrast/nvdiffrast/common/rasterize_gl.cpp @@ -0,0 +1,644 @@ +// Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +// +// NVIDIA CORPORATION and its licensors retain all intellectual property +// and proprietary rights in and to this software, related documentation +// and any modifications thereto. Any use, reproduction, disclosure or +// distribution of this software and related documentation without an express +// license agreement from NVIDIA CORPORATION is strictly prohibited. + +#include "rasterize_gl.h" +#include "glutil.h" +#include +#define STRINGIFY_SHADER_SOURCE(x) #x + +//------------------------------------------------------------------------ +// Helpers. + +#define ROUND_UP(x, y) ((((x) + ((y) - 1)) / (y)) * (y)) +static int ROUND_UP_BITS(uint32_t x, uint32_t y) +{ + // Round x up so that it has at most y bits of mantissa. + if (x < (1u << y)) + return x; + uint32_t m = 0; + while (x & ~m) + m = (m << 1) | 1u; + m >>= y; + if (!(x & m)) + return x; + return (x | m) + 1u; +} + +//------------------------------------------------------------------------ +// Draw command struct used by rasterizer. + +struct GLDrawCmd +{ + uint32_t count; + uint32_t instanceCount; + uint32_t firstIndex; + uint32_t baseVertex; + uint32_t baseInstance; +}; + +//------------------------------------------------------------------------ +// GL helpers. + +static void compileGLShader(NVDR_CTX_ARGS, const RasterizeGLState& s, GLuint* pShader, GLenum shaderType, const char* src_buf) +{ + std::string src(src_buf); + + // Set preprocessor directives. + int n = src.find('\n') + 1; // After first line containing #version directive. + if (s.enableZModify) + src.insert(n, "#define IF_ZMODIFY(x) x\n"); + else + src.insert(n, "#define IF_ZMODIFY(x)\n"); + + const char *cstr = src.c_str(); + *pShader = 0; + NVDR_CHECK_GL_ERROR(*pShader = glCreateShader(shaderType)); + NVDR_CHECK_GL_ERROR(glShaderSource(*pShader, 1, &cstr, 0)); + NVDR_CHECK_GL_ERROR(glCompileShader(*pShader)); +} + +static void constructGLProgram(NVDR_CTX_ARGS, GLuint* pProgram, GLuint glVertexShader, GLuint glGeometryShader, GLuint glFragmentShader) +{ + *pProgram = 0; + + GLuint glProgram = 0; + NVDR_CHECK_GL_ERROR(glProgram = glCreateProgram()); + NVDR_CHECK_GL_ERROR(glAttachShader(glProgram, glVertexShader)); + NVDR_CHECK_GL_ERROR(glAttachShader(glProgram, glGeometryShader)); + NVDR_CHECK_GL_ERROR(glAttachShader(glProgram, glFragmentShader)); + NVDR_CHECK_GL_ERROR(glLinkProgram(glProgram)); + + GLint linkStatus = 0; + NVDR_CHECK_GL_ERROR(glGetProgramiv(glProgram, GL_LINK_STATUS, &linkStatus)); + if (!linkStatus) + { + GLint infoLen = 0; + NVDR_CHECK_GL_ERROR(glGetProgramiv(glProgram, GL_INFO_LOG_LENGTH, &infoLen)); + if (infoLen) + { + const char* hdr = "glLinkProgram() failed:\n"; + std::vector info(strlen(hdr) + infoLen); + strcpy(&info[0], hdr); + NVDR_CHECK_GL_ERROR(glGetProgramInfoLog(glProgram, infoLen, &infoLen, &info[strlen(hdr)])); + NVDR_CHECK(0, &info[0]); + } + NVDR_CHECK(0, "glLinkProgram() failed"); + } + + *pProgram = glProgram; +} + +//------------------------------------------------------------------------ +// Shared C++ functions. + +void rasterizeInitGLContext(NVDR_CTX_ARGS, RasterizeGLState& s, int cudaDeviceIdx) +{ + // Create GL context and set it current. + s.glctx = createGLContext(cudaDeviceIdx); + setGLContext(s.glctx); + + // Version check. + GLint vMajor = 0; + GLint vMinor = 0; + glGetIntegerv(GL_MAJOR_VERSION, &vMajor); + glGetIntegerv(GL_MINOR_VERSION, &vMinor); + glGetError(); // Clear possible GL_INVALID_ENUM error in version query. + LOG(INFO) << "OpenGL version reported as " << vMajor << "." << vMinor; + NVDR_CHECK((vMajor == 4 && vMinor >= 4) || vMajor > 4, "OpenGL 4.4 or later is required"); + + // Enable depth modification workaround on A100 and later. + int capMajor = 0; + NVDR_CHECK_CUDA_ERROR(cudaDeviceGetAttribute(&capMajor, cudaDevAttrComputeCapabilityMajor, cudaDeviceIdx)); + s.enableZModify = (capMajor >= 8); + + // Number of output buffers. + int num_outputs = s.enableDB ? 2 : 1; + + // Set up vertex shader. + compileGLShader(NVDR_CTX_PARAMS, s, &s.glVertexShader, GL_VERTEX_SHADER, + "#version 330\n" + "#extension GL_ARB_shader_draw_parameters : enable\n" + STRINGIFY_SHADER_SOURCE( + layout(location = 0) in vec4 in_pos; + out int v_layer; + out int v_offset; + void main() + { + int layer = gl_DrawIDARB; + gl_Position = in_pos; + v_layer = layer; + v_offset = gl_BaseInstanceARB; // Sneak in TriID offset here. + } + ) + ); + + // Geometry and fragment shaders depend on if bary differential output is enabled or not. + if (s.enableDB) + { + // Set up geometry shader. Calculation of per-pixel bary differentials is based on: + // u = (u/w) / (1/w) + // --> du/dX = d((u/w) / (1/w))/dX + // --> du/dX = [d(u/w)/dX - u*d(1/w)/dX] * w + // and we know both d(u/w)/dX and d(1/w)/dX are constant over triangle. + compileGLShader(NVDR_CTX_PARAMS, s, &s.glGeometryShader, GL_GEOMETRY_SHADER, + "#version 430\n" + STRINGIFY_SHADER_SOURCE( + layout(triangles) in; + layout(triangle_strip, max_vertices=3) out; + layout(location = 0) uniform vec2 vp_scale; + in int v_layer[]; + in int v_offset[]; + out vec4 var_uvzw; + out vec4 var_db; + void main() + { + // Plane equations for bary differentials. + float w0 = gl_in[0].gl_Position.w; + float w1 = gl_in[1].gl_Position.w; + float w2 = gl_in[2].gl_Position.w; + vec2 p0 = gl_in[0].gl_Position.xy; + vec2 p1 = gl_in[1].gl_Position.xy; + vec2 p2 = gl_in[2].gl_Position.xy; + vec2 e0 = p0*w2 - p2*w0; + vec2 e1 = p1*w2 - p2*w1; + float a = e0.x*e1.y - e0.y*e1.x; + + // Clamp area to an epsilon to avoid arbitrarily high bary differentials. + float eps = 1e-6f; // ~1 pixel in 1k x 1k image. + float ca = (abs(a) >= eps) ? a : (a < 0.f) ? -eps : eps; // Clamp with sign. + float ia = 1.f / ca; // Inverse area. + + vec2 ascl = ia * vp_scale; + float dudx = e1.y * ascl.x; + float dudy = -e1.x * ascl.y; + float dvdx = -e0.y * ascl.x; + float dvdy = e0.x * ascl.y; + + float duwdx = w2 * dudx; + float dvwdx = w2 * dvdx; + float duvdx = w0 * dudx + w1 * dvdx; + float duwdy = w2 * dudy; + float dvwdy = w2 * dvdy; + float duvdy = w0 * dudy + w1 * dvdy; + + vec4 db0 = vec4(duvdx - dvwdx, duvdy - dvwdy, dvwdx, dvwdy); + vec4 db1 = vec4(duwdx, duwdy, duvdx - duwdx, duvdy - duwdy); + vec4 db2 = vec4(duwdx, duwdy, dvwdx, dvwdy); + + int layer_id = v_layer[0]; + int prim_id = gl_PrimitiveIDIn + v_offset[0]; + + gl_Layer = layer_id; gl_PrimitiveID = prim_id; gl_Position = vec4(gl_in[0].gl_Position.x, gl_in[0].gl_Position.y, gl_in[0].gl_Position.z, gl_in[0].gl_Position.w); var_uvzw = vec4(1.f, 0.f, gl_in[0].gl_Position.z, gl_in[0].gl_Position.w); var_db = db0; EmitVertex(); + gl_Layer = layer_id; gl_PrimitiveID = prim_id; gl_Position = vec4(gl_in[1].gl_Position.x, gl_in[1].gl_Position.y, gl_in[1].gl_Position.z, gl_in[1].gl_Position.w); var_uvzw = vec4(0.f, 1.f, gl_in[1].gl_Position.z, gl_in[1].gl_Position.w); var_db = db1; EmitVertex(); + gl_Layer = layer_id; gl_PrimitiveID = prim_id; gl_Position = vec4(gl_in[2].gl_Position.x, gl_in[2].gl_Position.y, gl_in[2].gl_Position.z, gl_in[2].gl_Position.w); var_uvzw = vec4(0.f, 0.f, gl_in[2].gl_Position.z, gl_in[2].gl_Position.w); var_db = db2; EmitVertex(); + } + ) + ); + + // Set up fragment shader. + compileGLShader(NVDR_CTX_PARAMS, s, &s.glFragmentShader, GL_FRAGMENT_SHADER, + "#version 430\n" + STRINGIFY_SHADER_SOURCE( + in vec4 var_uvzw; + in vec4 var_db; + layout(location = 0) out vec4 out_raster; + layout(location = 1) out vec4 out_db; + IF_ZMODIFY( + layout(location = 1) uniform float in_dummy; + ) + void main() + { + int id_int = gl_PrimitiveID + 1; + float id_float = (id_int <= 0x01000000) ? float(id_int) : intBitsToFloat(0x4a800000 + id_int); + + out_raster = vec4(var_uvzw.x, var_uvzw.y, var_uvzw.z / var_uvzw.w, id_float); + out_db = var_db * var_uvzw.w; + IF_ZMODIFY(gl_FragDepth = gl_FragCoord.z + in_dummy;) + } + ) + ); + + // Set up fragment shader for depth peeling. + compileGLShader(NVDR_CTX_PARAMS, s, &s.glFragmentShaderDP, GL_FRAGMENT_SHADER, + "#version 430\n" + STRINGIFY_SHADER_SOURCE( + in vec4 var_uvzw; + in vec4 var_db; + layout(binding = 0) uniform sampler2DArray out_prev; + layout(location = 0) out vec4 out_raster; + layout(location = 1) out vec4 out_db; + IF_ZMODIFY( + layout(location = 1) uniform float in_dummy; + ) + void main() + { + int id_int = gl_PrimitiveID + 1; + float id_float = (id_int <= 0x01000000) ? float(id_int) : intBitsToFloat(0x4a800000 + id_int); + + vec4 prev = texelFetch(out_prev, ivec3(gl_FragCoord.x, gl_FragCoord.y, gl_Layer), 0); + float depth_new = var_uvzw.z / var_uvzw.w; + if (prev.w == 0 || depth_new <= prev.z) + discard; + out_raster = vec4(var_uvzw.x, var_uvzw.y, depth_new, id_float); + out_db = var_db * var_uvzw.w; + IF_ZMODIFY(gl_FragDepth = gl_FragCoord.z + in_dummy;) + } + ) + ); + } + else + { + // Geometry shader without bary differential output. + compileGLShader(NVDR_CTX_PARAMS, s, &s.glGeometryShader, GL_GEOMETRY_SHADER, + "#version 330\n" + STRINGIFY_SHADER_SOURCE( + layout(triangles) in; + layout(triangle_strip, max_vertices=3) out; + in int v_layer[]; + in int v_offset[]; + out vec4 var_uvzw; + void main() + { + int layer_id = v_layer[0]; + int prim_id = gl_PrimitiveIDIn + v_offset[0]; + + gl_Layer = layer_id; gl_PrimitiveID = prim_id; gl_Position = vec4(gl_in[0].gl_Position.x, gl_in[0].gl_Position.y, gl_in[0].gl_Position.z, gl_in[0].gl_Position.w); var_uvzw = vec4(1.f, 0.f, gl_in[0].gl_Position.z, gl_in[0].gl_Position.w); EmitVertex(); + gl_Layer = layer_id; gl_PrimitiveID = prim_id; gl_Position = vec4(gl_in[1].gl_Position.x, gl_in[1].gl_Position.y, gl_in[1].gl_Position.z, gl_in[1].gl_Position.w); var_uvzw = vec4(0.f, 1.f, gl_in[1].gl_Position.z, gl_in[1].gl_Position.w); EmitVertex(); + gl_Layer = layer_id; gl_PrimitiveID = prim_id; gl_Position = vec4(gl_in[2].gl_Position.x, gl_in[2].gl_Position.y, gl_in[2].gl_Position.z, gl_in[2].gl_Position.w); var_uvzw = vec4(0.f, 0.f, gl_in[2].gl_Position.z, gl_in[2].gl_Position.w); EmitVertex(); + } + ) + ); + + // Fragment shader without bary differential output. + compileGLShader(NVDR_CTX_PARAMS, s, &s.glFragmentShader, GL_FRAGMENT_SHADER, + "#version 430\n" + STRINGIFY_SHADER_SOURCE( + in vec4 var_uvzw; + layout(location = 0) out vec4 out_raster; + IF_ZMODIFY( + layout(location = 1) uniform float in_dummy; + ) + void main() + { + int id_int = gl_PrimitiveID + 1; + float id_float = (id_int <= 0x01000000) ? float(id_int) : intBitsToFloat(0x4a800000 + id_int); + + out_raster = vec4(var_uvzw.x, var_uvzw.y, var_uvzw.z / var_uvzw.w, id_float); + IF_ZMODIFY(gl_FragDepth = gl_FragCoord.z + in_dummy;) + } + ) + ); + + // Depth peeling variant of fragment shader. + compileGLShader(NVDR_CTX_PARAMS, s, &s.glFragmentShaderDP, GL_FRAGMENT_SHADER, + "#version 430\n" + STRINGIFY_SHADER_SOURCE( + in vec4 var_uvzw; + layout(binding = 0) uniform sampler2DArray out_prev; + layout(location = 0) out vec4 out_raster; + IF_ZMODIFY( + layout(location = 1) uniform float in_dummy; + ) + void main() + { + int id_int = gl_PrimitiveID + 1; + float id_float = (id_int <= 0x01000000) ? float(id_int) : intBitsToFloat(0x4a800000 + id_int); + + vec4 prev = texelFetch(out_prev, ivec3(gl_FragCoord.x, gl_FragCoord.y, gl_Layer), 0); + float depth_new = var_uvzw.z / var_uvzw.w; + if (prev.w == 0 || depth_new <= prev.z) + discard; + out_raster = vec4(var_uvzw.x, var_uvzw.y, var_uvzw.z / var_uvzw.w, id_float); + IF_ZMODIFY(gl_FragDepth = gl_FragCoord.z + in_dummy;) + } + ) + ); + } + + // Finalize programs. + constructGLProgram(NVDR_CTX_PARAMS, &s.glProgram, s.glVertexShader, s.glGeometryShader, s.glFragmentShader); + constructGLProgram(NVDR_CTX_PARAMS, &s.glProgramDP, s.glVertexShader, s.glGeometryShader, s.glFragmentShaderDP); + + // Construct main fbo and bind permanently. + NVDR_CHECK_GL_ERROR(glGenFramebuffers(1, &s.glFBO)); + NVDR_CHECK_GL_ERROR(glBindFramebuffer(GL_FRAMEBUFFER, s.glFBO)); + + // Enable two color attachments. + GLenum draw_buffers[2] = { GL_COLOR_ATTACHMENT0, GL_COLOR_ATTACHMENT1 }; + NVDR_CHECK_GL_ERROR(glDrawBuffers(num_outputs, draw_buffers)); + + // Construct vertex array object. + NVDR_CHECK_GL_ERROR(glGenVertexArrays(1, &s.glVAO)); + NVDR_CHECK_GL_ERROR(glBindVertexArray(s.glVAO)); + + // Construct position buffer, bind permanently, enable, set ptr. + NVDR_CHECK_GL_ERROR(glGenBuffers(1, &s.glPosBuffer)); + NVDR_CHECK_GL_ERROR(glBindBuffer(GL_ARRAY_BUFFER, s.glPosBuffer)); + NVDR_CHECK_GL_ERROR(glEnableVertexAttribArray(0)); + NVDR_CHECK_GL_ERROR(glVertexAttribPointer(0, 4, GL_FLOAT, GL_FALSE, 0, 0)); + + // Construct index buffer and bind permanently. + NVDR_CHECK_GL_ERROR(glGenBuffers(1, &s.glTriBuffer)); + NVDR_CHECK_GL_ERROR(glBindBuffer(GL_ELEMENT_ARRAY_BUFFER, s.glTriBuffer)); + + // Set up depth test. + NVDR_CHECK_GL_ERROR(glEnable(GL_DEPTH_TEST)); + NVDR_CHECK_GL_ERROR(glDepthFunc(GL_LESS)); + NVDR_CHECK_GL_ERROR(glClearDepth(1.0)); + + // Create and bind output buffers. Storage is allocated later. + NVDR_CHECK_GL_ERROR(glGenTextures(num_outputs, s.glColorBuffer)); + for (int i=0; i < num_outputs; i++) + { + NVDR_CHECK_GL_ERROR(glBindTexture(GL_TEXTURE_2D_ARRAY, s.glColorBuffer[i])); + NVDR_CHECK_GL_ERROR(glFramebufferTexture(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT0 + i, s.glColorBuffer[i], 0)); + } + + // Create and bind depth/stencil buffer. Storage is allocated later. + NVDR_CHECK_GL_ERROR(glGenTextures(1, &s.glDepthStencilBuffer)); + NVDR_CHECK_GL_ERROR(glBindTexture(GL_TEXTURE_2D_ARRAY, s.glDepthStencilBuffer)); + NVDR_CHECK_GL_ERROR(glFramebufferTexture(GL_FRAMEBUFFER, GL_DEPTH_STENCIL_ATTACHMENT, s.glDepthStencilBuffer, 0)); + + // Create texture name for previous output buffer (depth peeling). + NVDR_CHECK_GL_ERROR(glGenTextures(1, &s.glPrevOutBuffer)); +} + +void rasterizeResizeBuffers(NVDR_CTX_ARGS, RasterizeGLState& s, bool& changes, int posCount, int triCount, int width, int height, int depth) +{ + changes = false; + + // Resize vertex buffer? + if (posCount > s.posCount) + { + if (s.cudaPosBuffer) + NVDR_CHECK_CUDA_ERROR(cudaGraphicsUnregisterResource(s.cudaPosBuffer)); + s.posCount = (posCount > 64) ? ROUND_UP_BITS(posCount, 2) : 64; + LOG(INFO) << "Increasing position buffer size to " << s.posCount << " float32"; + NVDR_CHECK_GL_ERROR(glBufferData(GL_ARRAY_BUFFER, s.posCount * sizeof(float), NULL, GL_DYNAMIC_DRAW)); + NVDR_CHECK_CUDA_ERROR(cudaGraphicsGLRegisterBuffer(&s.cudaPosBuffer, s.glPosBuffer, cudaGraphicsRegisterFlagsWriteDiscard)); + changes = true; + } + + // Resize triangle buffer? + if (triCount > s.triCount) + { + if (s.cudaTriBuffer) + NVDR_CHECK_CUDA_ERROR(cudaGraphicsUnregisterResource(s.cudaTriBuffer)); + s.triCount = (triCount > 64) ? ROUND_UP_BITS(triCount, 2) : 64; + LOG(INFO) << "Increasing triangle buffer size to " << s.triCount << " int32"; + NVDR_CHECK_GL_ERROR(glBufferData(GL_ELEMENT_ARRAY_BUFFER, s.triCount * sizeof(int32_t), NULL, GL_DYNAMIC_DRAW)); + NVDR_CHECK_CUDA_ERROR(cudaGraphicsGLRegisterBuffer(&s.cudaTriBuffer, s.glTriBuffer, cudaGraphicsRegisterFlagsWriteDiscard)); + changes = true; + } + + // Resize framebuffer? + if (width > s.width || height > s.height || depth > s.depth) + { + int num_outputs = s.enableDB ? 2 : 1; + if (s.cudaColorBuffer[0]) + for (int i=0; i < num_outputs; i++) + NVDR_CHECK_CUDA_ERROR(cudaGraphicsUnregisterResource(s.cudaColorBuffer[i])); + + if (s.cudaPrevOutBuffer) + { + NVDR_CHECK_CUDA_ERROR(cudaGraphicsUnregisterResource(s.cudaPrevOutBuffer)); + s.cudaPrevOutBuffer = 0; + } + + // New framebuffer size. + s.width = (width > s.width) ? width : s.width; + s.height = (height > s.height) ? height : s.height; + s.depth = (depth > s.depth) ? depth : s.depth; + s.width = ROUND_UP(s.width, 32); + s.height = ROUND_UP(s.height, 32); + LOG(INFO) << "Increasing frame buffer size to (width, height, depth) = (" << s.width << ", " << s.height << ", " << s.depth << ")"; + + // Allocate color buffers. + for (int i=0; i < num_outputs; i++) + { + NVDR_CHECK_GL_ERROR(glBindTexture(GL_TEXTURE_2D_ARRAY, s.glColorBuffer[i])); + NVDR_CHECK_GL_ERROR(glTexImage3D(GL_TEXTURE_2D_ARRAY, 0, GL_RGBA32F, s.width, s.height, s.depth, 0, GL_RGBA, GL_UNSIGNED_BYTE, 0)); + NVDR_CHECK_GL_ERROR(glTexParameteri(GL_TEXTURE_2D_ARRAY, GL_TEXTURE_MAG_FILTER, GL_NEAREST)); + NVDR_CHECK_GL_ERROR(glTexParameteri(GL_TEXTURE_2D_ARRAY, GL_TEXTURE_MIN_FILTER, GL_NEAREST)); + NVDR_CHECK_GL_ERROR(glTexParameteri(GL_TEXTURE_2D_ARRAY, GL_TEXTURE_WRAP_S, GL_CLAMP_TO_EDGE)); + NVDR_CHECK_GL_ERROR(glTexParameteri(GL_TEXTURE_2D_ARRAY, GL_TEXTURE_WRAP_T, GL_CLAMP_TO_EDGE)); + } + + // Allocate depth/stencil buffer. + NVDR_CHECK_GL_ERROR(glBindTexture(GL_TEXTURE_2D_ARRAY, s.glDepthStencilBuffer)); + NVDR_CHECK_GL_ERROR(glTexImage3D(GL_TEXTURE_2D_ARRAY, 0, GL_DEPTH24_STENCIL8, s.width, s.height, s.depth, 0, GL_DEPTH_STENCIL, GL_UNSIGNED_INT_24_8, 0)); + + // (Re-)register all GL buffers into Cuda. + for (int i=0; i < num_outputs; i++) + NVDR_CHECK_CUDA_ERROR(cudaGraphicsGLRegisterImage(&s.cudaColorBuffer[i], s.glColorBuffer[i], GL_TEXTURE_3D, cudaGraphicsRegisterFlagsReadOnly)); + + changes = true; + } +} + +void rasterizeRender(NVDR_CTX_ARGS, RasterizeGLState& s, cudaStream_t stream, const float* posPtr, int posCount, int vtxPerInstance, const int32_t* triPtr, int triCount, const int32_t* rangesPtr, int width, int height, int depth, int peeling_idx) +{ + // Only copy inputs if we are on first iteration of depth peeling or not doing it at all. + if (peeling_idx < 1) + { + if (triPtr) + { + // Copy both position and triangle buffers. + void* glPosPtr = NULL; + void* glTriPtr = NULL; + size_t posBytes = 0; + size_t triBytes = 0; + NVDR_CHECK_CUDA_ERROR(cudaGraphicsMapResources(2, &s.cudaPosBuffer, stream)); + NVDR_CHECK_CUDA_ERROR(cudaGraphicsResourceGetMappedPointer(&glPosPtr, &posBytes, s.cudaPosBuffer)); + NVDR_CHECK_CUDA_ERROR(cudaGraphicsResourceGetMappedPointer(&glTriPtr, &triBytes, s.cudaTriBuffer)); + NVDR_CHECK(posBytes >= posCount * sizeof(float), "mapped GL position buffer size mismatch"); + NVDR_CHECK(triBytes >= triCount * sizeof(int32_t), "mapped GL triangle buffer size mismatch"); + NVDR_CHECK_CUDA_ERROR(cudaMemcpyAsync(glPosPtr, posPtr, posCount * sizeof(float), cudaMemcpyDeviceToDevice, stream)); + NVDR_CHECK_CUDA_ERROR(cudaMemcpyAsync(glTriPtr, triPtr, triCount * sizeof(int32_t), cudaMemcpyDeviceToDevice, stream)); + NVDR_CHECK_CUDA_ERROR(cudaGraphicsUnmapResources(2, &s.cudaPosBuffer, stream)); + } + else + { + // Copy position buffer only. Triangles are already copied and known to be constant. + void* glPosPtr = NULL; + size_t posBytes = 0; + NVDR_CHECK_CUDA_ERROR(cudaGraphicsMapResources(1, &s.cudaPosBuffer, stream)); + NVDR_CHECK_CUDA_ERROR(cudaGraphicsResourceGetMappedPointer(&glPosPtr, &posBytes, s.cudaPosBuffer)); + NVDR_CHECK(posBytes >= posCount * sizeof(float), "mapped GL position buffer size mismatch"); + NVDR_CHECK_CUDA_ERROR(cudaMemcpyAsync(glPosPtr, posPtr, posCount * sizeof(float), cudaMemcpyDeviceToDevice, stream)); + NVDR_CHECK_CUDA_ERROR(cudaGraphicsUnmapResources(1, &s.cudaPosBuffer, stream)); + } + } + + // Select program based on whether we have a depth peeling input or not. + if (peeling_idx < 1) + { + // Normal case: No peeling, or peeling disabled. + NVDR_CHECK_GL_ERROR(glUseProgram(s.glProgram)); + } + else + { + // If we don't have a third buffer yet, create one. + if (!s.cudaPrevOutBuffer) + { + NVDR_CHECK_GL_ERROR(glBindTexture(GL_TEXTURE_2D_ARRAY, s.glPrevOutBuffer)); + NVDR_CHECK_GL_ERROR(glTexImage3D(GL_TEXTURE_2D_ARRAY, 0, GL_RGBA32F, s.width, s.height, s.depth, 0, GL_RGBA, GL_UNSIGNED_BYTE, 0)); + NVDR_CHECK_GL_ERROR(glTexParameteri(GL_TEXTURE_2D_ARRAY, GL_TEXTURE_MAG_FILTER, GL_NEAREST)); + NVDR_CHECK_GL_ERROR(glTexParameteri(GL_TEXTURE_2D_ARRAY, GL_TEXTURE_MIN_FILTER, GL_NEAREST)); + NVDR_CHECK_GL_ERROR(glTexParameteri(GL_TEXTURE_2D_ARRAY, GL_TEXTURE_WRAP_S, GL_CLAMP_TO_EDGE)); + NVDR_CHECK_GL_ERROR(glTexParameteri(GL_TEXTURE_2D_ARRAY, GL_TEXTURE_WRAP_T, GL_CLAMP_TO_EDGE)); + NVDR_CHECK_CUDA_ERROR(cudaGraphicsGLRegisterImage(&s.cudaPrevOutBuffer, s.glPrevOutBuffer, GL_TEXTURE_3D, cudaGraphicsRegisterFlagsReadOnly)); + } + + // Swap the GL buffers. + GLuint glTempBuffer = s.glPrevOutBuffer; + s.glPrevOutBuffer = s.glColorBuffer[0]; + s.glColorBuffer[0] = glTempBuffer; + + // Swap the Cuda buffers. + cudaGraphicsResource_t cudaTempBuffer = s.cudaPrevOutBuffer; + s.cudaPrevOutBuffer = s.cudaColorBuffer[0]; + s.cudaColorBuffer[0] = cudaTempBuffer; + + // Bind the new output buffer. + NVDR_CHECK_GL_ERROR(glBindTexture(GL_TEXTURE_2D_ARRAY, s.glColorBuffer[0])); + NVDR_CHECK_GL_ERROR(glFramebufferTexture(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT0, s.glColorBuffer[0], 0)); + + // Bind old buffer as the input texture. + NVDR_CHECK_GL_ERROR(glBindTexture(GL_TEXTURE_2D_ARRAY, s.glPrevOutBuffer)); + + // Activate the correct program. + NVDR_CHECK_GL_ERROR(glUseProgram(s.glProgramDP)); + } + + // Set viewport, clear color buffer(s) and depth/stencil buffer. + NVDR_CHECK_GL_ERROR(glViewport(0, 0, width, height)); + NVDR_CHECK_GL_ERROR(glClear(GL_COLOR_BUFFER_BIT | GL_DEPTH_BUFFER_BIT | GL_STENCIL_BUFFER_BIT)); + + // If outputting bary differentials, set resolution uniform + if (s.enableDB) + NVDR_CHECK_GL_ERROR(glUniform2f(0, 2.f / (float)width, 2.f / (float)height)); + + // Set the dummy uniform if depth modification workaround is active. + if (s.enableZModify) + NVDR_CHECK_GL_ERROR(glUniform1f(1, 0.f)); + + // Render the meshes. + if (depth == 1 && !rangesPtr) + { + // Trivial case. + NVDR_CHECK_GL_ERROR(glDrawElements(GL_TRIANGLES, triCount, GL_UNSIGNED_INT, 0)); + } + else + { + // Populate a buffer for draw commands and execute it. + std::vector drawCmdBuffer(depth); + + if (!rangesPtr) + { + // Fill in range array to instantiate the same triangles for each output layer. + // Triangle IDs starts at zero (i.e., one) for each layer, so they correspond to + // the first dimension in addressing the triangle array. + for (int i=0; i < depth; i++) + { + GLDrawCmd& cmd = drawCmdBuffer[i]; + cmd.firstIndex = 0; + cmd.count = triCount; + cmd.baseVertex = vtxPerInstance * i; + cmd.baseInstance = 0; + cmd.instanceCount = 1; + } + } + else + { + // Fill in the range array according to user-given ranges. Triangle IDs point + // to the input triangle array, NOT index within range, so they correspond to + // the first dimension in addressing the triangle array. + for (int i=0, j=0; i < depth; i++) + { + GLDrawCmd& cmd = drawCmdBuffer[i]; + int first = rangesPtr[j++]; + int count = rangesPtr[j++]; + NVDR_CHECK(first >= 0 && count >= 0, "range contains negative values"); + NVDR_CHECK((first + count) * 3 <= triCount, "range extends beyond end of triangle buffer"); + cmd.firstIndex = first * 3; + cmd.count = count * 3; + cmd.baseVertex = 0; + cmd.baseInstance = first; + cmd.instanceCount = 1; + } + } + + // Draw! + NVDR_CHECK_GL_ERROR(glMultiDrawElementsIndirect(GL_TRIANGLES, GL_UNSIGNED_INT, &drawCmdBuffer[0], depth, sizeof(GLDrawCmd))); + } +} + +void rasterizeCopyResults(NVDR_CTX_ARGS, RasterizeGLState& s, cudaStream_t stream, float** outputPtr, int width, int height, int depth) +{ + // Copy color buffers to output tensors. + cudaArray_t array = 0; + cudaChannelFormatDesc arrayDesc = {}; // For error checking. + cudaExtent arrayExt = {}; // For error checking. + int num_outputs = s.enableDB ? 2 : 1; + NVDR_CHECK_CUDA_ERROR(cudaGraphicsMapResources(num_outputs, s.cudaColorBuffer, stream)); + for (int i=0; i < num_outputs; i++) + { + NVDR_CHECK_CUDA_ERROR(cudaGraphicsSubResourceGetMappedArray(&array, s.cudaColorBuffer[i], 0, 0)); + NVDR_CHECK_CUDA_ERROR(cudaArrayGetInfo(&arrayDesc, &arrayExt, NULL, array)); + NVDR_CHECK(arrayDesc.f == cudaChannelFormatKindFloat, "CUDA mapped array data kind mismatch"); + NVDR_CHECK(arrayDesc.x == 32 && arrayDesc.y == 32 && arrayDesc.z == 32 && arrayDesc.w == 32, "CUDA mapped array data width mismatch"); + NVDR_CHECK(arrayExt.width >= width && arrayExt.height >= height && arrayExt.depth >= depth, "CUDA mapped array extent mismatch"); + cudaMemcpy3DParms p = {0}; + p.srcArray = array; + p.dstPtr.ptr = outputPtr[i]; + p.dstPtr.pitch = width * 4 * sizeof(float); + p.dstPtr.xsize = width; + p.dstPtr.ysize = height; + p.extent.width = width; + p.extent.height = height; + p.extent.depth = depth; + p.kind = cudaMemcpyDeviceToDevice; + NVDR_CHECK_CUDA_ERROR(cudaMemcpy3DAsync(&p, stream)); + } + NVDR_CHECK_CUDA_ERROR(cudaGraphicsUnmapResources(num_outputs, s.cudaColorBuffer, stream)); +} + +void rasterizeReleaseBuffers(NVDR_CTX_ARGS, RasterizeGLState& s) +{ + int num_outputs = s.enableDB ? 2 : 1; + + if (s.cudaPosBuffer) + { + NVDR_CHECK_CUDA_ERROR(cudaGraphicsUnregisterResource(s.cudaPosBuffer)); + s.cudaPosBuffer = 0; + } + + if (s.cudaTriBuffer) + { + NVDR_CHECK_CUDA_ERROR(cudaGraphicsUnregisterResource(s.cudaTriBuffer)); + s.cudaTriBuffer = 0; + } + + for (int i=0; i < num_outputs; i++) + { + if (s.cudaColorBuffer[i]) + { + NVDR_CHECK_CUDA_ERROR(cudaGraphicsUnregisterResource(s.cudaColorBuffer[i])); + s.cudaColorBuffer[i] = 0; + } + } + + if (s.cudaPrevOutBuffer) + { + NVDR_CHECK_CUDA_ERROR(cudaGraphicsUnregisterResource(s.cudaPrevOutBuffer)); + s.cudaPrevOutBuffer = 0; + } +} + +//------------------------------------------------------------------------ diff --git a/extensions/nvdiffrast/nvdiffrast/common/rasterize_gl.h b/extensions/nvdiffrast/nvdiffrast/common/rasterize_gl.h new file mode 100644 index 0000000000000000000000000000000000000000..27537c5624286af9c2cba9dc908f400abc9ddfdf --- /dev/null +++ b/extensions/nvdiffrast/nvdiffrast/common/rasterize_gl.h @@ -0,0 +1,60 @@ +// Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +// +// NVIDIA CORPORATION and its licensors retain all intellectual property +// and proprietary rights in and to this software, related documentation +// and any modifications thereto. Any use, reproduction, disclosure or +// distribution of this software and related documentation without an express +// license agreement from NVIDIA CORPORATION is strictly prohibited. + +#pragma once + +//------------------------------------------------------------------------ +// Do not try to include OpenGL stuff when compiling CUDA kernels for torch. + +#if !(defined(NVDR_TORCH) && defined(__CUDACC__)) +#include "framework.h" +#include "glutil.h" + +//------------------------------------------------------------------------ +// OpenGL-related persistent state for forward op. + +struct RasterizeGLState // Must be initializable by memset to zero. +{ + int width; // Allocated frame buffer width. + int height; // Allocated frame buffer height. + int depth; // Allocated frame buffer depth. + int posCount; // Allocated position buffer in floats. + int triCount; // Allocated triangle buffer in ints. + GLContext glctx; + GLuint glFBO; + GLuint glColorBuffer[2]; + GLuint glPrevOutBuffer; + GLuint glDepthStencilBuffer; + GLuint glVAO; + GLuint glTriBuffer; + GLuint glPosBuffer; + GLuint glProgram; + GLuint glProgramDP; + GLuint glVertexShader; + GLuint glGeometryShader; + GLuint glFragmentShader; + GLuint glFragmentShaderDP; + cudaGraphicsResource_t cudaColorBuffer[2]; + cudaGraphicsResource_t cudaPrevOutBuffer; + cudaGraphicsResource_t cudaPosBuffer; + cudaGraphicsResource_t cudaTriBuffer; + int enableDB; + int enableZModify; // Modify depth in shader, workaround for a rasterization issue on A100. +}; + +//------------------------------------------------------------------------ +// Shared C++ code prototypes. + +void rasterizeInitGLContext(NVDR_CTX_ARGS, RasterizeGLState& s, int cudaDeviceIdx); +void rasterizeResizeBuffers(NVDR_CTX_ARGS, RasterizeGLState& s, bool& changes, int posCount, int triCount, int width, int height, int depth); +void rasterizeRender(NVDR_CTX_ARGS, RasterizeGLState& s, cudaStream_t stream, const float* posPtr, int posCount, int vtxPerInstance, const int32_t* triPtr, int triCount, const int32_t* rangesPtr, int width, int height, int depth, int peeling_idx); +void rasterizeCopyResults(NVDR_CTX_ARGS, RasterizeGLState& s, cudaStream_t stream, float** outputPtr, int width, int height, int depth); +void rasterizeReleaseBuffers(NVDR_CTX_ARGS, RasterizeGLState& s); + +//------------------------------------------------------------------------ +#endif // !(defined(NVDR_TORCH) && defined(__CUDACC__)) diff --git a/extensions/nvdiffrast/nvdiffrast/common/texture.cpp b/extensions/nvdiffrast/nvdiffrast/common/texture.cpp new file mode 100644 index 0000000000000000000000000000000000000000..51633e10120b4dc465e5283241a38c95db31f8dc --- /dev/null +++ b/extensions/nvdiffrast/nvdiffrast/common/texture.cpp @@ -0,0 +1,104 @@ +// Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +// +// NVIDIA CORPORATION and its licensors retain all intellectual property +// and proprietary rights in and to this software, related documentation +// and any modifications thereto. Any use, reproduction, disclosure or +// distribution of this software and related documentation without an express +// license agreement from NVIDIA CORPORATION is strictly prohibited. + +#include "framework.h" +#include "texture.h" + +//------------------------------------------------------------------------ +// Mip stack construction and access helpers. + +void raiseMipSizeError(NVDR_CTX_ARGS, const TextureKernelParams& p) +{ + char buf[1024]; + int bufsz = 1024; + + std::string msg = "Mip-map size error - cannot downsample an odd extent greater than 1. Resize the texture so that both spatial extents are powers of two, or limit the number of mip maps using max_mip_level argument.\n"; + + int w = p.texWidth; + int h = p.texHeight; + bool ew = false; + bool eh = false; + + msg += "Attempted mip stack construction:\n"; + msg += "level width height\n"; + msg += "----- ----- ------\n"; + snprintf(buf, bufsz, "base %5d %5d\n", w, h); + msg += buf; + + int mipTotal = 0; + int level = 0; + while ((w|h) > 1 && !(ew || eh)) // Stop at first impossible size. + { + // Current level. + level += 1; + + // Determine if downsampling fails. + ew = ew || (w > 1 && (w & 1)); + eh = eh || (h > 1 && (h & 1)); + + // Downsample. + if (w > 1) w >>= 1; + if (h > 1) h >>= 1; + + // Append level size to error message. + snprintf(buf, bufsz, "mip %-2d ", level); + msg += buf; + if (ew) snprintf(buf, bufsz, " err "); + else snprintf(buf, bufsz, "%5d ", w); + msg += buf; + if (eh) snprintf(buf, bufsz, " err\n"); + else snprintf(buf, bufsz, "%5d\n", h); + msg += buf; + } + + NVDR_CHECK(0, msg); +} + +int calculateMipInfo(NVDR_CTX_ARGS, TextureKernelParams& p, int* mipOffsets) +{ + // No levels at all? + if (p.mipLevelLimit == 0) + { + p.mipLevelMax = 0; + return 0; + } + + // Current level size. + int w = p.texWidth; + int h = p.texHeight; + + int mipTotal = 0; + int level = 0; + int c = (p.boundaryMode == TEX_BOUNDARY_MODE_CUBE) ? (p.channels * 6) : p.channels; + mipOffsets[0] = 0; + while ((w|h) > 1) + { + // Current level. + level += 1; + + // Quit if cannot downsample. + if ((w > 1 && (w & 1)) || (h > 1 && (h & 1))) + raiseMipSizeError(NVDR_CTX_PARAMS, p); + + // Downsample. + if (w > 1) w >>= 1; + if (h > 1) h >>= 1; + + mipOffsets[level] = mipTotal; // Store the mip offset (#floats). + mipTotal += w * h * p.texDepth * c; + + // Hit the level limit? + if (p.mipLevelLimit >= 0 && level == p.mipLevelLimit) + break; + } + + p.mipLevelMax = level; + return mipTotal; +} + +//------------------------------------------------------------------------ diff --git a/extensions/nvdiffrast/nvdiffrast/common/texture.h b/extensions/nvdiffrast/nvdiffrast/common/texture.h new file mode 100644 index 0000000000000000000000000000000000000000..f79b600fff0256cdadd38e265b49366549434ef8 --- /dev/null +++ b/extensions/nvdiffrast/nvdiffrast/common/texture.h @@ -0,0 +1,78 @@ +// Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +// +// NVIDIA CORPORATION and its licensors retain all intellectual property +// and proprietary rights in and to this software, related documentation +// and any modifications thereto. Any use, reproduction, disclosure or +// distribution of this software and related documentation without an express +// license agreement from NVIDIA CORPORATION is strictly prohibited. + +#pragma once +#include "framework.h" + +//------------------------------------------------------------------------ +// Constants. + +#define TEX_DEBUG_MIP_RETAIN_VARIANCE 0 // For debugging +#define TEX_FWD_MAX_KERNEL_BLOCK_WIDTH 8 +#define TEX_FWD_MAX_KERNEL_BLOCK_HEIGHT 8 +#define TEX_FWD_MAX_MIP_KERNEL_BLOCK_WIDTH 8 +#define TEX_FWD_MAX_MIP_KERNEL_BLOCK_HEIGHT 8 +#define TEX_GRAD_MAX_KERNEL_BLOCK_WIDTH 8 +#define TEX_GRAD_MAX_KERNEL_BLOCK_HEIGHT 8 +#define TEX_GRAD_MAX_MIP_KERNEL_BLOCK_WIDTH 8 +#define TEX_GRAD_MAX_MIP_KERNEL_BLOCK_HEIGHT 8 +#define TEX_MAX_MIP_LEVEL 16 // Currently a texture cannot be larger than 2 GB because we use 32-bit indices everywhere. +#define TEX_MODE_NEAREST 0 // Nearest on base level. +#define TEX_MODE_LINEAR 1 // Bilinear on base level. +#define TEX_MODE_LINEAR_MIPMAP_NEAREST 2 // Bilinear on nearest mip level. +#define TEX_MODE_LINEAR_MIPMAP_LINEAR 3 // Trilinear. +#define TEX_MODE_COUNT 4 +#define TEX_BOUNDARY_MODE_CUBE 0 // Cube map mode. +#define TEX_BOUNDARY_MODE_WRAP 1 // Wrap (u, v). +#define TEX_BOUNDARY_MODE_CLAMP 2 // Clamp (u, v). +#define TEX_BOUNDARY_MODE_ZERO 3 // Pad with zeros. +#define TEX_BOUNDARY_MODE_COUNT 4 + +//------------------------------------------------------------------------ +// CUDA kernel params. + +struct TextureKernelParams +{ + const float* tex[TEX_MAX_MIP_LEVEL]; // Incoming texture buffer with mip levels. + const float* uv; // Incoming texcoord buffer. + const float* uvDA; // Incoming uv pixel diffs or NULL. + const float* mipLevelBias; // Incoming mip level bias or NULL. + const float* dy; // Incoming output gradient. + float* out; // Outgoing texture data. + float* gradTex[TEX_MAX_MIP_LEVEL]; // Outgoing texture gradients with mip levels. + float* gradUV; // Outgoing texcoord gradient. + float* gradUVDA; // Outgoing texcoord pixel differential gradient. + float* gradMipLevelBias; // Outgoing mip level bias gradient. + int enableMip; // If true, we have uv_da and/or mip_level_bias input(s), and a mip tensor. + int filterMode; // One of the TEX_MODE_ constants. + int boundaryMode; // One of the TEX_BOUNDARY_MODE_ contants. + int texConst; // If true, texture is known to be constant. + int mipLevelLimit; // Mip level limit coming from the op. + int channels; // Number of texture channels. + int imgWidth; // Image width. + int imgHeight; // Image height. + int texWidth; // Texture width. + int texHeight; // Texture height. + int texDepth; // Texture depth. + int n; // Minibatch size. + int mipLevelMax; // Maximum mip level index. Zero if mips disabled. + int mipLevelOut; // Mip level being calculated in builder kernel. +}; + +//------------------------------------------------------------------------ +// C++ helper function prototypes. + +void raiseMipSizeError(NVDR_CTX_ARGS, const TextureKernelParams& p); +int calculateMipInfo(NVDR_CTX_ARGS, TextureKernelParams& p, int* mipOffsets); + +//------------------------------------------------------------------------ +// Macros. + +#define mipLevelSize(p, i) make_int2(((p).texWidth >> (i)) > 1 ? ((p).texWidth >> (i)) : 1, ((p).texHeight >> (i)) > 1 ? ((p).texHeight >> (i)) : 1) + +//------------------------------------------------------------------------ diff --git a/extensions/nvdiffrast/nvdiffrast/common/texture_.cu b/extensions/nvdiffrast/nvdiffrast/common/texture_.cu new file mode 100644 index 0000000000000000000000000000000000000000..490b8d68dd62398e05086843f138bd7f3510f449 --- /dev/null +++ b/extensions/nvdiffrast/nvdiffrast/common/texture_.cu @@ -0,0 +1,1156 @@ +// Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +// +// NVIDIA CORPORATION and its licensors retain all intellectual property +// and proprietary rights in and to this software, related documentation +// and any modifications thereto. Any use, reproduction, disclosure or +// distribution of this software and related documentation without an express +// license agreement from NVIDIA CORPORATION is strictly prohibited. + +#include "common.h" +#include "texture.h" + +//------------------------------------------------------------------------ +// Memory access and math helpers. + +static __device__ __forceinline__ void accum_from_mem(float* a, int s, float b, float c) { a[0] += b * c; } +static __device__ __forceinline__ void accum_from_mem(float* a, int s, float2 b, float c) { a[0] += b.x * c; a[s] += b.y * c; } +static __device__ __forceinline__ void accum_from_mem(float* a, int s, float4 b, float c) { a[0] += b.x * c; a[s] += b.y * c; a[2*s] += b.z * c; a[3*s] += b.w * c; } +static __device__ __forceinline__ void accum_to_mem(float& a, float* b, int s) { a += b[0]; } +static __device__ __forceinline__ void accum_to_mem(float2& a, float* b, int s) { float2 v = a; v.x += b[0]; v.y += b[s]; a = v; } +static __device__ __forceinline__ void accum_to_mem(float4& a, float* b, int s) { float4 v = a; v.x += b[0]; v.y += b[s]; v.z += b[2*s]; v.w += b[3*s]; a = v; } +static __device__ __forceinline__ bool isfinite_vec3(const float3& a) { return isfinite(a.x) && isfinite(a.y) && isfinite(a.z); } +static __device__ __forceinline__ bool isfinite_vec4(const float4& a) { return isfinite(a.x) && isfinite(a.y) && isfinite(a.z) && isfinite(a.w); } +template static __device__ __forceinline__ T lerp (const T& a, const T& b, float c) { return a + c * (b - a); } +template static __device__ __forceinline__ T bilerp(const T& a, const T& b, const T& c, const T& d, const float2& e) { return lerp(lerp(a, b, e.x), lerp(c, d, e.x), e.y); } + +//------------------------------------------------------------------------ +// Cube map wrapping for smooth filtering across edges and corners. At corners, +// one of the texture coordinates will be negative. For correct interpolation, +// the missing texel must take the average color of the other three. + +static __constant__ uint32_t c_cubeWrapMask1[48] = +{ + 0x1530a440, 0x1133a550, 0x6103a110, 0x1515aa44, 0x6161aa11, 0x40154a04, 0x44115a05, 0x04611a01, + 0x2630a440, 0x2233a550, 0x5203a110, 0x2626aa44, 0x5252aa11, 0x40264a04, 0x44225a05, 0x04521a01, + 0x32608064, 0x3366a055, 0x13062091, 0x32328866, 0x13132299, 0x50320846, 0x55330a55, 0x05130219, + 0x42508064, 0x4455a055, 0x14052091, 0x42428866, 0x14142299, 0x60420846, 0x66440a55, 0x06140219, + 0x5230a044, 0x5533a055, 0x1503a011, 0x5252aa44, 0x1515aa11, 0x40520a44, 0x44550a55, 0x04150a11, + 0x6130a044, 0x6633a055, 0x2603a011, 0x6161aa44, 0x2626aa11, 0x40610a44, 0x44660a55, 0x04260a11, +}; + +static __constant__ uint8_t c_cubeWrapMask2[48] = +{ + 0x26, 0x33, 0x11, 0x05, 0x00, 0x09, 0x0c, 0x04, 0x04, 0x00, 0x00, 0x05, 0x00, 0x81, 0xc0, 0x40, + 0x02, 0x03, 0x09, 0x00, 0x0a, 0x00, 0x00, 0x02, 0x64, 0x30, 0x90, 0x55, 0xa0, 0x99, 0xcc, 0x64, + 0x24, 0x30, 0x10, 0x05, 0x00, 0x01, 0x00, 0x00, 0x06, 0x03, 0x01, 0x05, 0x00, 0x89, 0xcc, 0x44, +}; + +static __device__ __forceinline__ int4 wrapCubeMap(int face, int ix0, int ix1, int iy0, int iy1, int w) +{ + // Calculate case number. + int cx = (ix0 < 0) ? 0 : (ix1 >= w) ? 2 : 1; + int cy = (iy0 < 0) ? 0 : (iy1 >= w) ? 6 : 3; + int c = cx + cy; + if (c >= 5) + c--; + c = (face << 3) + c; + + // Compute coordinates and faces. + unsigned int m = c_cubeWrapMask1[c]; + int x0 = (m >> 0) & 3; x0 = (x0 == 0) ? 0 : (x0 == 1) ? ix0 : iy0; + int x1 = (m >> 2) & 3; x1 = (x1 == 0) ? 0 : (x1 == 1) ? ix1 : iy0; + int x2 = (m >> 4) & 3; x2 = (x2 == 0) ? 0 : (x2 == 1) ? ix0 : iy1; + int x3 = (m >> 6) & 3; x3 = (x3 == 0) ? 0 : (x3 == 1) ? ix1 : iy1; + int y0 = (m >> 8) & 3; y0 = (y0 == 0) ? 0 : (y0 == 1) ? ix0 : iy0; + int y1 = (m >> 10) & 3; y1 = (y1 == 0) ? 0 : (y1 == 1) ? ix1 : iy0; + int y2 = (m >> 12) & 3; y2 = (y2 == 0) ? 0 : (y2 == 1) ? ix0 : iy1; + int y3 = (m >> 14) & 3; y3 = (y3 == 0) ? 0 : (y3 == 1) ? ix1 : iy1; + int f0 = ((m >> 16) & 15) - 1; + int f1 = ((m >> 20) & 15) - 1; + int f2 = ((m >> 24) & 15) - 1; + int f3 = ((m >> 28) ) - 1; + + // Flips. + unsigned int f = c_cubeWrapMask2[c]; + int w1 = w - 1; + if (f & 0x01) x0 = w1 - x0; + if (f & 0x02) x1 = w1 - x1; + if (f & 0x04) x2 = w1 - x2; + if (f & 0x08) x3 = w1 - x3; + if (f & 0x10) y0 = w1 - y0; + if (f & 0x20) y1 = w1 - y1; + if (f & 0x40) y2 = w1 - y2; + if (f & 0x80) y3 = w1 - y3; + + // Done. + int4 tcOut; + tcOut.x = x0 + (y0 + f0 * w) * w; + tcOut.y = x1 + (y1 + f1 * w) * w; + tcOut.z = x2 + (y2 + f2 * w) * w; + tcOut.w = x3 + (y3 + f3 * w) * w; + return tcOut; +} + +//------------------------------------------------------------------------ +// Cube map indexing and gradient functions. + +// Map a 3D lookup vector into an (s,t) face coordinates (returned in first . +// two parameters) and face index. +static __device__ __forceinline__ int indexCubeMap(float& x, float& y, float z) +{ + float ax = fabsf(x); + float ay = fabsf(y); + float az = fabsf(z); + int idx; + float c; + if (az > fmaxf(ax, ay)) { idx = 4; c = z; } + else if (ay > ax) { idx = 2; c = y; y = z; } + else { idx = 0; c = x; x = z; } + if (c < 0.f) idx += 1; + float m = __frcp_rz(fabsf(c)) * .5; + float m0 = __uint_as_float(__float_as_uint(m) ^ ((0x21u >> idx) << 31)); + float m1 = (idx != 2) ? -m : m; + x = x * m0 + .5; + y = y * m1 + .5; + if (!isfinite(x) || !isfinite(y)) + return -1; // Invalid uv. + x = fminf(fmaxf(x, 0.f), 1.f); + y = fminf(fmaxf(y, 0.f), 1.f); + return idx; +} + +// Based on dA/d{s,t}, compute dA/d{x,y,z} at a given 3D lookup vector. +static __device__ __forceinline__ float3 indexCubeMapGrad(float3 uv, float gu, float gv) +{ + float ax = fabsf(uv.x); + float ay = fabsf(uv.y); + float az = fabsf(uv.z); + int idx; + float c; + float c0 = gu; + float c1 = gv; + if (az > fmaxf(ax, ay)) { idx = 0x10; c = uv.z; c0 *= uv.x; c1 *= uv.y; } + else if (ay > ax) { idx = 0x04; c = uv.y; c0 *= uv.x; c1 *= uv.z; } + else { idx = 0x01; c = uv.x; c0 *= uv.z; c1 *= uv.y; } + if (c < 0.f) idx += idx; + float m = __frcp_rz(fabsf(c)); + c0 = (idx & 0x34) ? -c0 : c0; + c1 = (idx & 0x2e) ? -c1 : c1; + float gl = (c0 + c1) * m; + float gx = (idx & 0x03) ? gl : (idx & 0x20) ? -gu : gu; + float gy = (idx & 0x0c) ? gl : -gv; + float gz = (idx & 0x30) ? gl : (idx & 0x03) ? gu : gv; + gz = (idx & 0x09) ? -gz : gz; + float3 res = make_float3(gx, gy, gz) * (m * .5f); + if (!isfinite_vec3(res)) + return make_float3(0.f, 0.f, 0.f); // Invalid uv. + return res; +} + +// Based on dL/d(d{s,t}/s{X,Y}), compute dL/d(d{x,y,z}/d{X,Y}). This is just two +// indexCubeMapGrad() functions rolled together. +static __device__ __forceinline__ void indexCubeMapGrad4(float3 uv, float4 dw, float3& g0, float3& g1) +{ + float ax = fabsf(uv.x); + float ay = fabsf(uv.y); + float az = fabsf(uv.z); + int idx; + float c, c0, c1; + if (az > fmaxf(ax, ay)) { idx = 0x10; c = uv.z; c0 = uv.x; c1 = uv.y; } + else if (ay > ax) { idx = 0x04; c = uv.y; c0 = uv.x; c1 = uv.z; } + else { idx = 0x01; c = uv.x; c0 = uv.z; c1 = uv.y; } + if (c < 0.f) idx += idx; + float m = __frcp_rz(fabsf(c)); + c0 = (idx & 0x34) ? -c0 : c0; + c1 = (idx & 0x2e) ? -c1 : c1; + float gl0 = (dw.x * c0 + dw.z * c1) * m; + float gl1 = (dw.y * c0 + dw.w * c1) * m; + float gx0 = (idx & 0x03) ? gl0 : (idx & 0x20) ? -dw.x : dw.x; + float gx1 = (idx & 0x03) ? gl1 : (idx & 0x20) ? -dw.y : dw.y; + float gy0 = (idx & 0x0c) ? gl0 : -dw.z; + float gy1 = (idx & 0x0c) ? gl1 : -dw.w; + float gz0 = (idx & 0x30) ? gl0 : (idx & 0x03) ? dw.x : dw.z; + float gz1 = (idx & 0x30) ? gl1 : (idx & 0x03) ? dw.y : dw.w; + if (idx & 0x09) + { + gz0 = -gz0; + gz1 = -gz1; + } + g0 = make_float3(gx0, gy0, gz0) * (m * .5f); + g1 = make_float3(gx1, gy1, gz1) * (m * .5f); + if (!isfinite_vec3(g0) || !isfinite_vec3(g1)) + { + g0 = make_float3(0.f, 0.f, 0.f); // Invalid uv. + g1 = make_float3(0.f, 0.f, 0.f); + } +} + +// Compute d{s,t}/d{X,Y} based on d{x,y,z}/d{X,Y} at a given 3D lookup vector. +// Result is (ds/dX, ds/dY, dt/dX, dt/dY). +static __device__ __forceinline__ float4 indexCubeMapGradST(float3 uv, float3 dvdX, float3 dvdY) +{ + float ax = fabsf(uv.x); + float ay = fabsf(uv.y); + float az = fabsf(uv.z); + int idx; + float c, gu, gv; + if (az > fmaxf(ax, ay)) { idx = 0x10; c = uv.z; gu = uv.x; gv = uv.y; } + else if (ay > ax) { idx = 0x04; c = uv.y; gu = uv.x; gv = uv.z; } + else { idx = 0x01; c = uv.x; gu = uv.z; gv = uv.y; } + if (c < 0.f) idx += idx; + if (idx & 0x09) + { + dvdX.z = -dvdX.z; + dvdY.z = -dvdY.z; + } + float m = __frcp_rz(fabsf(c)); + float dm = m * .5f; + float mm = m * dm; + gu *= (idx & 0x34) ? -mm : mm; + gv *= (idx & 0x2e) ? -mm : mm; + + float4 res; + if (idx & 0x03) + { + res = make_float4(gu * dvdX.x + dm * dvdX.z, + gu * dvdY.x + dm * dvdY.z, + gv * dvdX.x - dm * dvdX.y, + gv * dvdY.x - dm * dvdY.y); + } + else if (idx & 0x0c) + { + res = make_float4(gu * dvdX.y + dm * dvdX.x, + gu * dvdY.y + dm * dvdY.x, + gv * dvdX.y + dm * dvdX.z, + gv * dvdY.y + dm * dvdY.z); + } + else // (idx & 0x30) + { + res = make_float4(gu * dvdX.z + copysignf(dm, c) * dvdX.x, + gu * dvdY.z + copysignf(dm, c) * dvdY.x, + gv * dvdX.z - dm * dvdX.y, + gv * dvdY.z - dm * dvdY.y); + } + + if (!isfinite_vec4(res)) + return make_float4(0.f, 0.f, 0.f, 0.f); + + return res; +} + +// Compute d(d{s,t}/d{X,Y})/d{x,y,z}, i.e., how the pixel derivatives of 2D face +// coordinates change w.r.t. 3D texture coordinate vector, returned as follows: +// | d(ds/dX)/dx d(ds/dY)/dx d(dt/dX)/dx d(dt/dY)/dx | +// | d(ds/dX)/dy d(ds/dY)/dy d(dt/dX)/dy d(dt/dY)/dy | +// | d(ds/dX)/dz d(ds/dY)/dz d(dt/dX)/dz d(dt/dY)/dz | +static __device__ __forceinline__ void indexCubeMapGrad2(float3 uv, float3 dvdX, float3 dvdY, float4& dx, float4& dy, float4& dz) +{ + float ax = fabsf(uv.x); + float ay = fabsf(uv.y); + float az = fabsf(uv.z); + int idx; + float c, gu, gv; + if (az > fmaxf(ax, ay)) { idx = 0x10; c = uv.z; gu = uv.x; gv = uv.y; } + else if (ay > ax) { idx = 0x04; c = uv.y; gu = uv.x; gv = uv.z; } + else { idx = 0x01; c = uv.x; gu = uv.z; gv = uv.y; } + if (c < 0.f) idx += idx; + + if (idx & 0x09) + { + dvdX.z = -dvdX.z; + dvdY.z = -dvdY.z; + } + + float m = __frcp_rz(c); + float dm = -m * fabsf(m) * .5; + float mm = m * m * .5; + float mu = (idx & 0x34) ? -mm : mm; + float mv = (idx & 0x2e) ? -mm : mm; + gu *= -2.0 * m * mu; + gv *= -2.0 * m * mv; + + if (idx & 0x03) + { + dx.x = gu * dvdX.x + dm * dvdX.z; + dx.y = gu * dvdY.x + dm * dvdY.z; + dx.z = gv * dvdX.x - dm * dvdX.y; + dx.w = gv * dvdY.x - dm * dvdY.y; + dy.x = 0.f; + dy.y = 0.f; + dy.z = mv * dvdX.x; + dy.w = mv * dvdY.x; + dz.x = mu * dvdX.x; + dz.y = mu * dvdY.x; + dz.z = 0.f; + dz.w = 0.f; + } + else if (idx & 0x0c) + { + dx.x = mu * dvdX.y; + dx.y = mu * dvdY.y; + dx.z = 0.f; + dx.w = 0.f; + dy.x = gu * dvdX.y + dm * dvdX.x; + dy.y = gu * dvdY.y + dm * dvdY.x; + dy.z = gv * dvdX.y + dm * dvdX.z; + dy.w = gv * dvdY.y + dm * dvdY.z; + dz.x = 0.f; + dz.y = 0.f; + dz.z = mv * dvdX.y; + dz.w = mv * dvdY.y; + } + else // (idx & 0x30) + { + dx.x = mu * dvdX.z; + dx.y = mu * dvdY.z; + dx.z = 0.f; + dx.w = 0.f; + dy.x = 0.f; + dy.y = 0.f; + dy.z = mv * dvdX.z; + dy.w = mv * dvdY.z; + dz.x = gu * dvdX.z - fabsf(dm) * dvdX.x; + dz.y = gu * dvdY.z - fabsf(dm) * dvdY.x; + dz.z = gv * dvdX.z - dm * dvdX.y; + dz.w = gv * dvdY.z - dm * dvdY.y; + } +} + +//------------------------------------------------------------------------ +// General texture indexing. + +template +static __device__ __forceinline__ int indexTextureNearest(const TextureKernelParams& p, float3 uv, int tz) +{ + int w = p.texWidth; + int h = p.texHeight; + float u = uv.x; + float v = uv.y; + + // Cube map indexing. + if (CUBE_MODE) + { + // No wrap. Fold face index into tz right away. + int idx = indexCubeMap(u, v, uv.z); // Rewrites u, v. + if (idx < 0) + return -1; // Invalid uv. + tz = 6 * tz + idx; + } + else + { + // Handle boundary. + if (p.boundaryMode == TEX_BOUNDARY_MODE_WRAP) + { + u = u - (float)__float2int_rd(u); + v = v - (float)__float2int_rd(v); + } + } + + u = u * (float)w; + v = v * (float)h; + + int iu = __float2int_rd(u); + int iv = __float2int_rd(v); + + // In zero boundary mode, return texture address -1. + if (!CUBE_MODE && p.boundaryMode == TEX_BOUNDARY_MODE_ZERO) + { + if (iu < 0 || iu >= w || iv < 0 || iv >= h) + return -1; + } + + // Otherwise clamp and calculate the coordinate properly. + iu = min(max(iu, 0), w-1); + iv = min(max(iv, 0), h-1); + return iu + w * (iv + tz * h); +} + +template +static __device__ __forceinline__ float2 indexTextureLinear(const TextureKernelParams& p, float3 uv, int tz, int4& tcOut, int level) +{ + // Mip level size. + int2 sz = mipLevelSize(p, level); + int w = sz.x; + int h = sz.y; + + // Compute texture-space u, v. + float u = uv.x; + float v = uv.y; + bool clampU = false; + bool clampV = false; + + // Cube map indexing. + int face = 0; + if (CUBE_MODE) + { + // Neither clamp or wrap. + face = indexCubeMap(u, v, uv.z); // Rewrites u, v. + if (face < 0) + { + tcOut.x = tcOut.y = tcOut.z = tcOut.w = -1; // Invalid uv. + return make_float2(0.f, 0.f); + } + u = u * (float)w - 0.5f; + v = v * (float)h - 0.5f; + } + else + { + if (p.boundaryMode == TEX_BOUNDARY_MODE_WRAP) + { + // Wrap. + u = u - (float)__float2int_rd(u); + v = v - (float)__float2int_rd(v); + } + + // Move to texel space. + u = u * (float)w - 0.5f; + v = v * (float)h - 0.5f; + + if (p.boundaryMode == TEX_BOUNDARY_MODE_CLAMP) + { + // Clamp to center of edge texels. + u = fminf(fmaxf(u, 0.f), w - 1.f); + v = fminf(fmaxf(v, 0.f), h - 1.f); + clampU = (u == 0.f || u == w - 1.f); + clampV = (v == 0.f || v == h - 1.f); + } + } + + // Compute texel coordinates and weights. + int iu0 = __float2int_rd(u); + int iv0 = __float2int_rd(v); + int iu1 = iu0 + (clampU ? 0 : 1); // Ensure zero u/v gradients with clamped. + int iv1 = iv0 + (clampV ? 0 : 1); + u -= (float)iu0; + v -= (float)iv0; + + // Cube map wrapping. + bool cubeWrap = CUBE_MODE && (iu0 < 0 || iv0 < 0 || iu1 >= w || iv1 >= h); + if (cubeWrap) + { + tcOut = wrapCubeMap(face, iu0, iu1, iv0, iv1, w); + tcOut += 6 * tz * w * h; // Bring in tz. + return make_float2(u, v); // Done. + } + + // Fold cube map face into tz. + if (CUBE_MODE) + tz = 6 * tz + face; + + // Wrap overflowing texel indices. + if (!CUBE_MODE && p.boundaryMode == TEX_BOUNDARY_MODE_WRAP) + { + if (iu0 < 0) iu0 += w; + if (iv0 < 0) iv0 += h; + if (iu1 >= w) iu1 -= w; + if (iv1 >= h) iv1 -= h; + } + + // Coordinates with tz folded in. + int iu0z = iu0 + tz * w * h; + int iu1z = iu1 + tz * w * h; + tcOut.x = iu0z + w * iv0; + tcOut.y = iu1z + w * iv0; + tcOut.z = iu0z + w * iv1; + tcOut.w = iu1z + w * iv1; + + // Invalidate texture addresses outside unit square if we are in zero mode. + if (!CUBE_MODE && p.boundaryMode == TEX_BOUNDARY_MODE_ZERO) + { + bool iu0_out = (iu0 < 0 || iu0 >= w); + bool iu1_out = (iu1 < 0 || iu1 >= w); + bool iv0_out = (iv0 < 0 || iv0 >= h); + bool iv1_out = (iv1 < 0 || iv1 >= h); + if (iu0_out || iv0_out) tcOut.x = -1; + if (iu1_out || iv0_out) tcOut.y = -1; + if (iu0_out || iv1_out) tcOut.z = -1; + if (iu1_out || iv1_out) tcOut.w = -1; + } + + // All done. + return make_float2(u, v); +} + +//------------------------------------------------------------------------ +// Mip level calculation. + +template +static __device__ __forceinline__ void calculateMipLevel(int& level0, int& level1, float& flevel, const TextureKernelParams& p, int pidx, float3 uv, float4* pdw, float3* pdfdv) +{ + // Do nothing if mips not in use. + if (FILTER_MODE == TEX_MODE_NEAREST || FILTER_MODE == TEX_MODE_LINEAR) + return; + + // Determine mip level based on UV pixel derivatives. If no derivatives are given (mip level bias only), leave as zero. + if (!BIAS_ONLY) + { + // Get pixel derivatives of texture coordinates. + float4 uvDA; + float3 dvdX, dvdY; // Gradients use these later. + if (CUBE_MODE) + { + // Fetch. + float2 d0 = ((const float2*)p.uvDA)[3 * pidx + 0]; + float2 d1 = ((const float2*)p.uvDA)[3 * pidx + 1]; + float2 d2 = ((const float2*)p.uvDA)[3 * pidx + 2]; + + // Map d{x,y,z}/d{X,Y} into d{s,t}/d{X,Y}. + dvdX = make_float3(d0.x, d1.x, d2.x); // d{x,y,z}/dX + dvdY = make_float3(d0.y, d1.y, d2.y); // d{x,y,z}/dY + uvDA = indexCubeMapGradST(uv, dvdX, dvdY); // d{s,t}/d{X,Y} + } + else + { + // Fetch. + uvDA = ((const float4*)p.uvDA)[pidx]; + } + + // Scaling factors. + float uscl = p.texWidth; + float vscl = p.texHeight; + + // d[s,t]/d[X,Y]. + float dsdx = uvDA.x * uscl; + float dsdy = uvDA.y * uscl; + float dtdx = uvDA.z * vscl; + float dtdy = uvDA.w * vscl; + + // Calculate footprint axis lengths. + float A = dsdx*dsdx + dtdx*dtdx; + float B = dsdy*dsdy + dtdy*dtdy; + float C = dsdx*dsdy + dtdx*dtdy; + float l2b = 0.5 * (A + B); + float l2n = 0.25 * (A-B)*(A-B) + C*C; + float l2a = sqrt(l2n); + float lenMinorSqr = fmaxf(0.0, l2b - l2a); + float lenMajorSqr = l2b + l2a; + + // Footprint vs. mip level gradient. + if (pdw && FILTER_MODE == TEX_MODE_LINEAR_MIPMAP_LINEAR) + { + float dw = 0.72134752f / (l2n + l2a * l2b); // Constant is 0.5/ln(2). + float AB = dw * .5f * (A - B); + float Cw = dw * C; + float l2aw = dw * l2a; + float d_f_ddsdX = uscl * (dsdx * (l2aw + AB) + dsdy * Cw); + float d_f_ddsdY = uscl * (dsdy * (l2aw - AB) + dsdx * Cw); + float d_f_ddtdX = vscl * (dtdx * (l2aw + AB) + dtdy * Cw); + float d_f_ddtdY = vscl * (dtdy * (l2aw - AB) + dtdx * Cw); + + float4 d_f_dw = make_float4(d_f_ddsdX, d_f_ddsdY, d_f_ddtdX, d_f_ddtdY); + if (!CUBE_MODE) + *pdw = isfinite_vec4(d_f_dw) ? d_f_dw : make_float4(0.f, 0.f, 0.f, 0.f); + + // In cube maps, there is also a texture coordinate vs. mip level gradient. + // Only output nonzero vectors if both are free of inf/Nan garbage. + if (CUBE_MODE) + { + float4 dx, dy, dz; + indexCubeMapGrad2(uv, dvdX, dvdY, dx, dy, dz); + float3 d_dsdX_dv = make_float3(dx.x, dy.x, dz.x); + float3 d_dsdY_dv = make_float3(dx.y, dy.y, dz.y); + float3 d_dtdX_dv = make_float3(dx.z, dy.z, dz.z); + float3 d_dtdY_dv = make_float3(dx.w, dy.w, dz.w); + + float3 d_f_dv = make_float3(0.f, 0.f, 0.f); + d_f_dv += d_dsdX_dv * d_f_ddsdX; + d_f_dv += d_dsdY_dv * d_f_ddsdY; + d_f_dv += d_dtdX_dv * d_f_ddtdX; + d_f_dv += d_dtdY_dv * d_f_ddtdY; + + bool finite = isfinite_vec4(d_f_dw) && isfinite_vec3(d_f_dv); + *pdw = finite ? d_f_dw : make_float4(0.f, 0.f, 0.f, 0.f); + *pdfdv = finite ? d_f_dv : make_float3(0.f, 0.f, 0.f); + } + } + + // Finally, calculate mip level. + flevel = .5f * __log2f(lenMajorSqr); // May be inf/NaN, but clamp fixes it. + } + + // Bias the mip level and clamp. + if (p.mipLevelBias) + flevel += p.mipLevelBias[pidx]; + flevel = fminf(fmaxf(flevel, 0.f), (float)p.mipLevelMax); + + // Calculate levels depending on filter mode. + level0 = __float2int_rd(flevel); + + // Leave everything else at zero if flevel == 0 (magnification) or when in linear-mipmap-nearest mode. + if (FILTER_MODE == TEX_MODE_LINEAR_MIPMAP_LINEAR && flevel > 0.f) + { + level1 = min(level0 + 1, p.mipLevelMax); + flevel -= level0; // Fractional part. Zero if clamped on last level. + } +} + +//------------------------------------------------------------------------ +// Texel fetch and accumulator helpers that understand cube map corners. + +template +static __device__ __forceinline__ void fetchQuad(T& a00, T& a10, T& a01, T& a11, const float* pIn, int4 tc, bool corner) +{ + // For invalid cube map uv, tc will be all negative, and all texel values will be zero. + if (corner) + { + T avg = zero_value(); + if (tc.x >= 0) avg += (a00 = *((const T*)&pIn[tc.x])); + if (tc.y >= 0) avg += (a10 = *((const T*)&pIn[tc.y])); + if (tc.z >= 0) avg += (a01 = *((const T*)&pIn[tc.z])); + if (tc.w >= 0) avg += (a11 = *((const T*)&pIn[tc.w])); + avg *= 0.33333333f; + if (tc.x < 0) a00 = avg; + if (tc.y < 0) a10 = avg; + if (tc.z < 0) a01 = avg; + if (tc.w < 0) a11 = avg; + } + else + { + a00 = (tc.x >= 0) ? *((const T*)&pIn[tc.x]) : zero_value(); + a10 = (tc.y >= 0) ? *((const T*)&pIn[tc.y]) : zero_value(); + a01 = (tc.z >= 0) ? *((const T*)&pIn[tc.z]) : zero_value(); + a11 = (tc.w >= 0) ? *((const T*)&pIn[tc.w]) : zero_value(); + } +} + +static __device__ __forceinline__ void accumQuad(float4 c, float* pOut, int level, int4 tc, bool corner, CA_TEMP_PARAM) +{ + // For invalid cube map uv, tc will be all negative, and no accumulation will take place. + if (corner) + { + float cb; + if (tc.x < 0) cb = c.x; + if (tc.y < 0) cb = c.y; + if (tc.z < 0) cb = c.z; + if (tc.w < 0) cb = c.w; + cb *= 0.33333333f; + if (tc.x >= 0) caAtomicAddTexture(pOut, level, tc.x, c.x + cb); + if (tc.y >= 0) caAtomicAddTexture(pOut, level, tc.y, c.y + cb); + if (tc.z >= 0) caAtomicAddTexture(pOut, level, tc.z, c.z + cb); + if (tc.w >= 0) caAtomicAddTexture(pOut, level, tc.w, c.w + cb); + } + else + { + if (tc.x >= 0) caAtomicAddTexture(pOut, level, tc.x, c.x); + if (tc.y >= 0) caAtomicAddTexture(pOut, level, tc.y, c.y); + if (tc.z >= 0) caAtomicAddTexture(pOut, level, tc.z, c.z); + if (tc.w >= 0) caAtomicAddTexture(pOut, level, tc.w, c.w); + } +} + +//------------------------------------------------------------------------ +// Mip builder kernel. + +template +static __forceinline__ __device__ void MipBuildKernelTemplate(const TextureKernelParams p) +{ + // Sizes. + int2 sz_in = mipLevelSize(p, p.mipLevelOut - 1); + int2 sz_out = mipLevelSize(p, p.mipLevelOut); + + // Calculate pixel position. + int px = blockIdx.x * blockDim.x + threadIdx.x; + int py = blockIdx.y * blockDim.y + threadIdx.y; + int pz = blockIdx.z; + if (px >= sz_out.x || py >= sz_out.y) + return; + + // Pixel indices. + int pidx_in0 = p.channels * (((px + sz_in.x * py) << 1) + (pz * sz_in.x * sz_in.y)); + int pidx_in1 = pidx_in0 + p.channels * sz_in.x; // Next pixel down. + int pidx_out = p.channels * (px + sz_out.x * (py + sz_out.y * pz)); + + // Input and output pointers. + const float* pin = p.tex[p.mipLevelOut - 1]; + float* pout = (float*)p.tex[p.mipLevelOut]; + + // Special case: Input texture height or width is 1. + if (sz_in.x == 1 || sz_in.y == 1) + { + if (sz_in.y == 1) + pidx_in1 = pidx_in0 + p.channels; // Next pixel on the right. + + for (int i=0; i < p.channels; i += C) + { + T v0 = *((const T*)&pin[pidx_in0 + i]); + T v1 = *((const T*)&pin[pidx_in1 + i]); + T avg = .5f * (v0 + v1); +#if TEX_DEBUG_MIP_RETAIN_VARIANCE + avg = (avg - .5f) * 1.41421356f + .5f; +#endif + *((T*)&pout[pidx_out + i]) = avg; + } + + return; + } + + for (int i=0; i < p.channels; i += C) + { + T v0 = *((const T*)&pin[pidx_in0 + i]); + T v1 = *((const T*)&pin[pidx_in0 + i + p.channels]); + T v2 = *((const T*)&pin[pidx_in1 + i]); + T v3 = *((const T*)&pin[pidx_in1 + i + p.channels]); + T avg = .25f * (v0 + v1 + v2 + v3); +#if TEX_DEBUG_MIP_RETAIN_VARIANCE + avg = (avg - .5f) * 2.f + .5f; +#endif + *((T*)&pout[pidx_out + i]) = avg; + } +} + +// Template specializations. +__global__ void MipBuildKernel1(const TextureKernelParams p) { MipBuildKernelTemplate(p); } +__global__ void MipBuildKernel2(const TextureKernelParams p) { MipBuildKernelTemplate(p); } +__global__ void MipBuildKernel4(const TextureKernelParams p) { MipBuildKernelTemplate(p); } + +//------------------------------------------------------------------------ +// Forward kernel. + +template +static __forceinline__ __device__ void TextureFwdKernelTemplate(const TextureKernelParams p) +{ + // Calculate pixel position. + int px = blockIdx.x * blockDim.x + threadIdx.x; + int py = blockIdx.y * blockDim.y + threadIdx.y; + int pz = blockIdx.z; + int tz = (p.texDepth == 1) ? 0 : pz; + if (px >= p.imgWidth || py >= p.imgHeight || pz >= p.n) + return; + + // Pixel index. + int pidx = px + p.imgWidth * (py + p.imgHeight * pz); + + // Output ptr. + float* pOut = p.out + pidx * p.channels; + + // Get UV. + float3 uv; + if (CUBE_MODE) + uv = ((const float3*)p.uv)[pidx]; + else + uv = make_float3(((const float2*)p.uv)[pidx], 0.f); + + // Nearest mode. + if (FILTER_MODE == TEX_MODE_NEAREST) + { + int tc = indexTextureNearest(p, uv, tz); + tc *= p.channels; + const float* pIn = p.tex[0]; + + // Copy if valid tc, otherwise output zero. + for (int i=0; i < p.channels; i += C) + *((T*)&pOut[i]) = (tc >= 0) ? *((const T*)&pIn[tc + i]) : zero_value(); + + return; // Exit. + } + + // Calculate mip level. In 'linear' mode these will all stay zero. + float flevel = 0.f; // Fractional level. + int level0 = 0; // Discrete level 0. + int level1 = 0; // Discrete level 1. + calculateMipLevel(level0, level1, flevel, p, pidx, uv, 0, 0); + + // Get texel indices and pointer for level 0. + int4 tc0 = make_int4(0, 0, 0, 0); + float2 uv0 = indexTextureLinear(p, uv, tz, tc0, level0); + const float* pIn0 = p.tex[level0]; + bool corner0 = CUBE_MODE && ((tc0.x | tc0.y | tc0.z | tc0.w) < 0); + tc0 *= p.channels; + + // Bilinear fetch. + if (FILTER_MODE == TEX_MODE_LINEAR || FILTER_MODE == TEX_MODE_LINEAR_MIPMAP_NEAREST) + { + // Interpolate. + for (int i=0; i < p.channels; i += C, tc0 += C) + { + T a00, a10, a01, a11; + fetchQuad(a00, a10, a01, a11, pIn0, tc0, corner0); + *((T*)&pOut[i]) = bilerp(a00, a10, a01, a11, uv0); + } + return; // Exit. + } + + // Get texel indices and pointer for level 1. + int4 tc1 = make_int4(0, 0, 0, 0); + float2 uv1 = indexTextureLinear(p, uv, tz, tc1, level1); + const float* pIn1 = p.tex[level1]; + bool corner1 = CUBE_MODE && ((tc1.x | tc1.y | tc1.z | tc1.w) < 0); + tc1 *= p.channels; + + // Trilinear fetch. + for (int i=0; i < p.channels; i += C, tc0 += C, tc1 += C) + { + // First level. + T a00, a10, a01, a11; + fetchQuad(a00, a10, a01, a11, pIn0, tc0, corner0); + T a = bilerp(a00, a10, a01, a11, uv0); + + // Second level unless in magnification mode. + if (flevel > 0.f) + { + T b00, b10, b01, b11; + fetchQuad(b00, b10, b01, b11, pIn1, tc1, corner1); + T b = bilerp(b00, b10, b01, b11, uv1); + a = lerp(a, b, flevel); // Interpolate between levels. + } + + // Write. + *((T*)&pOut[i]) = a; + } +} + +// Template specializations. +__global__ void TextureFwdKernelNearest1 (const TextureKernelParams p) { TextureFwdKernelTemplate(p); } +__global__ void TextureFwdKernelNearest2 (const TextureKernelParams p) { TextureFwdKernelTemplate(p); } +__global__ void TextureFwdKernelNearest4 (const TextureKernelParams p) { TextureFwdKernelTemplate(p); } +__global__ void TextureFwdKernelLinear1 (const TextureKernelParams p) { TextureFwdKernelTemplate(p); } +__global__ void TextureFwdKernelLinear2 (const TextureKernelParams p) { TextureFwdKernelTemplate(p); } +__global__ void TextureFwdKernelLinear4 (const TextureKernelParams p) { TextureFwdKernelTemplate(p); } +__global__ void TextureFwdKernelLinearMipmapNearest1 (const TextureKernelParams p) { TextureFwdKernelTemplate(p); } +__global__ void TextureFwdKernelLinearMipmapNearest2 (const TextureKernelParams p) { TextureFwdKernelTemplate(p); } +__global__ void TextureFwdKernelLinearMipmapNearest4 (const TextureKernelParams p) { TextureFwdKernelTemplate(p); } +__global__ void TextureFwdKernelLinearMipmapLinear1 (const TextureKernelParams p) { TextureFwdKernelTemplate(p); } +__global__ void TextureFwdKernelLinearMipmapLinear2 (const TextureKernelParams p) { TextureFwdKernelTemplate(p); } +__global__ void TextureFwdKernelLinearMipmapLinear4 (const TextureKernelParams p) { TextureFwdKernelTemplate(p); } +__global__ void TextureFwdKernelCubeNearest1 (const TextureKernelParams p) { TextureFwdKernelTemplate(p); } +__global__ void TextureFwdKernelCubeNearest2 (const TextureKernelParams p) { TextureFwdKernelTemplate(p); } +__global__ void TextureFwdKernelCubeNearest4 (const TextureKernelParams p) { TextureFwdKernelTemplate(p); } +__global__ void TextureFwdKernelCubeLinear1 (const TextureKernelParams p) { TextureFwdKernelTemplate(p); } +__global__ void TextureFwdKernelCubeLinear2 (const TextureKernelParams p) { TextureFwdKernelTemplate(p); } +__global__ void TextureFwdKernelCubeLinear4 (const TextureKernelParams p) { TextureFwdKernelTemplate(p); } +__global__ void TextureFwdKernelCubeLinearMipmapNearest1 (const TextureKernelParams p) { TextureFwdKernelTemplate(p); } +__global__ void TextureFwdKernelCubeLinearMipmapNearest2 (const TextureKernelParams p) { TextureFwdKernelTemplate(p); } +__global__ void TextureFwdKernelCubeLinearMipmapNearest4 (const TextureKernelParams p) { TextureFwdKernelTemplate(p); } +__global__ void TextureFwdKernelCubeLinearMipmapLinear1 (const TextureKernelParams p) { TextureFwdKernelTemplate(p); } +__global__ void TextureFwdKernelCubeLinearMipmapLinear2 (const TextureKernelParams p) { TextureFwdKernelTemplate(p); } +__global__ void TextureFwdKernelCubeLinearMipmapLinear4 (const TextureKernelParams p) { TextureFwdKernelTemplate(p); } +__global__ void TextureFwdKernelLinearMipmapNearestBO1 (const TextureKernelParams p) { TextureFwdKernelTemplate(p); } +__global__ void TextureFwdKernelLinearMipmapNearestBO2 (const TextureKernelParams p) { TextureFwdKernelTemplate(p); } +__global__ void TextureFwdKernelLinearMipmapNearestBO4 (const TextureKernelParams p) { TextureFwdKernelTemplate(p); } +__global__ void TextureFwdKernelLinearMipmapLinearBO1 (const TextureKernelParams p) { TextureFwdKernelTemplate(p); } +__global__ void TextureFwdKernelLinearMipmapLinearBO2 (const TextureKernelParams p) { TextureFwdKernelTemplate(p); } +__global__ void TextureFwdKernelLinearMipmapLinearBO4 (const TextureKernelParams p) { TextureFwdKernelTemplate(p); } +__global__ void TextureFwdKernelCubeLinearMipmapNearestBO1 (const TextureKernelParams p) { TextureFwdKernelTemplate(p); } +__global__ void TextureFwdKernelCubeLinearMipmapNearestBO2 (const TextureKernelParams p) { TextureFwdKernelTemplate(p); } +__global__ void TextureFwdKernelCubeLinearMipmapNearestBO4 (const TextureKernelParams p) { TextureFwdKernelTemplate(p); } +__global__ void TextureFwdKernelCubeLinearMipmapLinearBO1 (const TextureKernelParams p) { TextureFwdKernelTemplate(p); } +__global__ void TextureFwdKernelCubeLinearMipmapLinearBO2 (const TextureKernelParams p) { TextureFwdKernelTemplate(p); } +__global__ void TextureFwdKernelCubeLinearMipmapLinearBO4 (const TextureKernelParams p) { TextureFwdKernelTemplate(p); } + +//------------------------------------------------------------------------ +// Gradient mip puller kernel. + +template +static __forceinline__ __device__ void MipGradKernelTemplate(const TextureKernelParams p) +{ + // Calculate pixel position. + int px = blockIdx.x * blockDim.x + threadIdx.x; + int py = blockIdx.y * blockDim.y + threadIdx.y; + int pz = blockIdx.z; + if (px >= p.texWidth || py >= p.texHeight) + return; + + // Number of wide elements. + int c = p.channels; + if (C == 2) c >>= 1; + if (C == 4) c >>= 2; + + // Dynamically allocated shared memory for holding a texel. + extern __shared__ float s_texelAccum[]; + int sharedOfs = threadIdx.x + threadIdx.y * blockDim.x; + int sharedStride = blockDim.x * blockDim.y; +# define TEXEL_ACCUM(_i) (s_texelAccum + (sharedOfs + (_i) * sharedStride)) + + // Clear the texel. + for (int i=0; i < p.channels; i++) + *TEXEL_ACCUM(i) = 0.f; + + // Track texel position and accumulation weight over the mip stack. + int x = px; + int y = py; + float w = 1.f; + + // Pull gradients from all levels. + int2 sz = mipLevelSize(p, 0); // Previous level size. + for (int level=1; level <= p.mipLevelMax; level++) + { + // Weight decay depends on previous level size. + if (sz.x > 1) w *= .5f; + if (sz.y > 1) w *= .5f; + + // Current level size and coordinates. + sz = mipLevelSize(p, level); + x >>= 1; + y >>= 1; + + T* pIn = (T*)(p.gradTex[level] + (x + sz.x * (y + sz.y * pz)) * p.channels); + for (int i=0; i < c; i++) + accum_from_mem(TEXEL_ACCUM(i * C), sharedStride, pIn[i], w); + } + + // Add to main texture gradients. + T* pOut = (T*)(p.gradTex[0] + (px + p.texWidth * (py + p.texHeight * pz)) * p.channels); + for (int i=0; i < c; i++) + accum_to_mem(pOut[i], TEXEL_ACCUM(i * C), sharedStride); +} + +// Template specializations. +__global__ void MipGradKernel1(const TextureKernelParams p) { MipGradKernelTemplate(p); } +__global__ void MipGradKernel2(const TextureKernelParams p) { MipGradKernelTemplate(p); } +__global__ void MipGradKernel4(const TextureKernelParams p) { MipGradKernelTemplate(p); } + +//------------------------------------------------------------------------ +// Gradient kernel. + +template +static __forceinline__ __device__ void TextureGradKernelTemplate(const TextureKernelParams p) +{ + // Temporary space for coalesced atomics. + CA_DECLARE_TEMP(TEX_GRAD_MAX_KERNEL_BLOCK_WIDTH * TEX_GRAD_MAX_KERNEL_BLOCK_HEIGHT); + + // Calculate pixel position. + int px = blockIdx.x * blockDim.x + threadIdx.x; + int py = blockIdx.y * blockDim.y + threadIdx.y; + int pz = blockIdx.z; + int tz = (p.texDepth == 1) ? 0 : pz; + if (px >= p.imgWidth || py >= p.imgHeight || pz >= p.n) + return; + + // Pixel index. + int pidx = px + p.imgWidth * (py + p.imgHeight * pz); + + // Early exit if output gradients are zero. + const float* pDy = p.dy + pidx * p.channels; + unsigned int dmax = 0u; + if ((p.channels & 3) == 0) + { + for (int i=0; i < p.channels; i += 4) + { + uint4 dy = *((const uint4*)&pDy[i]); + dmax |= (dy.x | dy.y | dy.z | dy.w); + } + } + else + { + for (int i=0; i < p.channels; i++) + dmax |= __float_as_uint(pDy[i]); + } + + // Store zeros and exit. + if (__uint_as_float(dmax) == 0.f) + { + if (CUBE_MODE) + { + if (FILTER_MODE != TEX_MODE_NEAREST) + ((float3*)p.gradUV)[pidx] = make_float3(0.f, 0.f, 0.f); + if (FILTER_MODE == TEX_MODE_LINEAR_MIPMAP_LINEAR) + { + if (p.gradUVDA) + { + ((float2*)p.gradUVDA)[3 * pidx + 0] = make_float2(0.f, 0.f); + ((float2*)p.gradUVDA)[3 * pidx + 1] = make_float2(0.f, 0.f); + ((float2*)p.gradUVDA)[3 * pidx + 2] = make_float2(0.f, 0.f); + } + if (p.gradMipLevelBias) + p.gradMipLevelBias[pidx] = 0.f; + } + } + else + { + if (FILTER_MODE != TEX_MODE_NEAREST) + ((float2*)p.gradUV)[pidx] = make_float2(0.f, 0.f); + if (FILTER_MODE == TEX_MODE_LINEAR_MIPMAP_LINEAR) + { + if (p.gradUVDA) + ((float4*)p.gradUVDA)[pidx] = make_float4(0.f, 0.f, 0.f, 0.f); + if (p.gradMipLevelBias) + p.gradMipLevelBias[pidx] = 0.f; + } + } + return; + } + + // Get UV. + float3 uv; + if (CUBE_MODE) + uv = ((const float3*)p.uv)[pidx]; + else + uv = make_float3(((const float2*)p.uv)[pidx], 0.f); + + // Nearest mode - texture gradients only. + if (FILTER_MODE == TEX_MODE_NEAREST) + { + int tc = indexTextureNearest(p, uv, tz); + if (tc < 0) + return; // Outside texture. + + tc *= p.channels; + float* pOut = p.gradTex[0]; + + // Accumulate texture gradients. + for (int i=0; i < p.channels; i++) + caAtomicAddTexture(pOut, 0, tc + i, pDy[i]); + + return; // Exit. + } + + // Calculate mip level. In 'linear' mode these will all stay zero. + float4 dw = make_float4(0.f, 0.f, 0.f, 0.f); + float3 dfdv = make_float3(0.f, 0.f, 0.f); + float flevel = 0.f; // Fractional level. + int level0 = 0; // Discrete level 0. + int level1 = 0; // Discrete level 1. + calculateMipLevel(level0, level1, flevel, p, pidx, uv, &dw, &dfdv); + + // UV gradient accumulators. + float gu = 0.f; + float gv = 0.f; + + // Get texel indices and pointers for level 0. + int4 tc0 = make_int4(0, 0, 0, 0); + float2 uv0 = indexTextureLinear(p, uv, tz, tc0, level0); + const float* pIn0 = p.tex[level0]; + float* pOut0 = p.gradTex[level0]; + bool corner0 = CUBE_MODE && ((tc0.x | tc0.y | tc0.z | tc0.w) < 0); + tc0 *= p.channels; + + // Texel weights. + float uv011 = uv0.x * uv0.y; + float uv010 = uv0.x - uv011; + float uv001 = uv0.y - uv011; + float uv000 = 1.f - uv0.x - uv001; + float4 tw0 = make_float4(uv000, uv010, uv001, uv011); + + // Attribute weights. + int2 sz0 = mipLevelSize(p, level0); + float sclu0 = (float)sz0.x; + float sclv0 = (float)sz0.y; + + // Bilinear mode - texture and uv gradients. + if (FILTER_MODE == TEX_MODE_LINEAR || FILTER_MODE == TEX_MODE_LINEAR_MIPMAP_NEAREST) + { + for (int i=0; i < p.channels; i++, tc0 += 1) + { + float dy = pDy[i]; + accumQuad(tw0 * dy, pOut0, level0, tc0, corner0, CA_TEMP); + + float a00, a10, a01, a11; + fetchQuad(a00, a10, a01, a11, pIn0, tc0, corner0); + float ad = (a11 + a00 - a10 - a01); + gu += dy * ((a10 - a00) + uv0.y * ad) * sclu0; + gv += dy * ((a01 - a00) + uv0.x * ad) * sclv0; + } + + // Store UV gradients and exit. + if (CUBE_MODE) + ((float3*)p.gradUV)[pidx] = indexCubeMapGrad(uv, gu, gv); + else + ((float2*)p.gradUV)[pidx] = make_float2(gu, gv); + + return; + } + + // Accumulate fractional mip level gradient. + float df = 0; // dL/df. + + // Get texel indices and pointers for level 1. + int4 tc1 = make_int4(0, 0, 0, 0); + float2 uv1 = indexTextureLinear(p, uv, tz, tc1, level1); + const float* pIn1 = p.tex[level1]; + float* pOut1 = p.gradTex[level1]; + bool corner1 = CUBE_MODE && ((tc1.x | tc1.y | tc1.z | tc1.w) < 0); + tc1 *= p.channels; + + // Texel weights. + float uv111 = uv1.x * uv1.y; + float uv110 = uv1.x - uv111; + float uv101 = uv1.y - uv111; + float uv100 = 1.f - uv1.x - uv101; + float4 tw1 = make_float4(uv100, uv110, uv101, uv111); + + // Attribute weights. + int2 sz1 = mipLevelSize(p, level1); + float sclu1 = (float)sz1.x; + float sclv1 = (float)sz1.y; + + // Trilinear mode. + for (int i=0; i < p.channels; i++, tc0 += 1, tc1 += 1) + { + float dy = pDy[i]; + float dy0 = (1.f - flevel) * dy; + accumQuad(tw0 * dy0, pOut0, level0, tc0, corner0, CA_TEMP); + + // UV gradients for first level. + float a00, a10, a01, a11; + fetchQuad(a00, a10, a01, a11, pIn0, tc0, corner0); + float ad = (a11 + a00 - a10 - a01); + gu += dy0 * ((a10 - a00) + uv0.y * ad) * sclu0; + gv += dy0 * ((a01 - a00) + uv0.x * ad) * sclv0; + + // Second level unless in magnification mode. + if (flevel > 0.f) + { + // Texture gradients for second level. + float dy1 = flevel * dy; + accumQuad(tw1 * dy1, pOut1, level1, tc1, corner1, CA_TEMP); + + // UV gradients for second level. + float b00, b10, b01, b11; + fetchQuad(b00, b10, b01, b11, pIn1, tc1, corner1); + float bd = (b11 + b00 - b10 - b01); + gu += dy1 * ((b10 - b00) + uv1.y * bd) * sclu1; + gv += dy1 * ((b01 - b00) + uv1.x * bd) * sclv1; + + // Mip level gradient. + float a = bilerp(a00, a10, a01, a11, uv0); + float b = bilerp(b00, b10, b01, b11, uv1); + df += (b-a) * dy; + } + } + + // Store UV gradients. + if (CUBE_MODE) + ((float3*)p.gradUV)[pidx] = indexCubeMapGrad(uv, gu, gv) + (dfdv * df); + else + ((float2*)p.gradUV)[pidx] = make_float2(gu, gv); + + // Store mip level bias gradient. + if (p.gradMipLevelBias) + p.gradMipLevelBias[pidx] = df; + + // Store UV pixel differential gradients. + if (!BIAS_ONLY) + { + // Final gradients. + dw *= df; // dL/(d{s,y}/d{X,Y}) = df/(d{s,y}/d{X,Y}) * dL/df. + + // Store them. + if (CUBE_MODE) + { + // Remap from dL/(d{s,t}/s{X,Y}) to dL/(d{x,y,z}/d{X,Y}). + float3 g0, g1; + indexCubeMapGrad4(uv, dw, g0, g1); + ((float2*)p.gradUVDA)[3 * pidx + 0] = make_float2(g0.x, g1.x); + ((float2*)p.gradUVDA)[3 * pidx + 1] = make_float2(g0.y, g1.y); + ((float2*)p.gradUVDA)[3 * pidx + 2] = make_float2(g0.z, g1.z); + } + else + ((float4*)p.gradUVDA)[pidx] = dw; + } +} + +// Template specializations. +__global__ void TextureGradKernelNearest (const TextureKernelParams p) { TextureGradKernelTemplate(p); } +__global__ void TextureGradKernelLinear (const TextureKernelParams p) { TextureGradKernelTemplate(p); } +__global__ void TextureGradKernelLinearMipmapNearest (const TextureKernelParams p) { TextureGradKernelTemplate(p); } +__global__ void TextureGradKernelLinearMipmapLinear (const TextureKernelParams p) { TextureGradKernelTemplate(p); } +__global__ void TextureGradKernelCubeNearest (const TextureKernelParams p) { TextureGradKernelTemplate(p); } +__global__ void TextureGradKernelCubeLinear (const TextureKernelParams p) { TextureGradKernelTemplate(p); } +__global__ void TextureGradKernelCubeLinearMipmapNearest (const TextureKernelParams p) { TextureGradKernelTemplate(p); } +__global__ void TextureGradKernelCubeLinearMipmapLinear (const TextureKernelParams p) { TextureGradKernelTemplate(p); } +__global__ void TextureGradKernelLinearMipmapNearestBO (const TextureKernelParams p) { TextureGradKernelTemplate(p); } +__global__ void TextureGradKernelLinearMipmapLinearBO (const TextureKernelParams p) { TextureGradKernelTemplate(p); } +__global__ void TextureGradKernelCubeLinearMipmapNearestBO (const TextureKernelParams p) { TextureGradKernelTemplate(p); } +__global__ void TextureGradKernelCubeLinearMipmapLinearBO (const TextureKernelParams p) { TextureGradKernelTemplate(p); } + +//------------------------------------------------------------------------ diff --git a/extensions/nvdiffrast/nvdiffrast/lib/setgpu.lib b/extensions/nvdiffrast/nvdiffrast/lib/setgpu.lib new file mode 100644 index 0000000000000000000000000000000000000000..add9a0c4f631cb56dbee31a05ed97339930301e2 Binary files /dev/null and b/extensions/nvdiffrast/nvdiffrast/lib/setgpu.lib differ diff --git a/extensions/nvdiffrast/nvdiffrast/tensorflow/__init__.py b/extensions/nvdiffrast/nvdiffrast/tensorflow/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..cf62df8782d730f072ca5f4e4862a44dc8c3a086 --- /dev/null +++ b/extensions/nvdiffrast/nvdiffrast/tensorflow/__init__.py @@ -0,0 +1,12 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +from .ops import rasterize, interpolate, texture, antialias +from .plugin_loader import set_cache_dir + +__all__ = ["rasterize", "interpolate", "texture", "antialias", "set_cache_dir"] diff --git a/extensions/nvdiffrast/nvdiffrast/tensorflow/ops.py b/extensions/nvdiffrast/nvdiffrast/tensorflow/ops.py new file mode 100644 index 0000000000000000000000000000000000000000..be51deef13e0ecfbd5bfe8bc376af24a18db7224 --- /dev/null +++ b/extensions/nvdiffrast/nvdiffrast/tensorflow/ops.py @@ -0,0 +1,303 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +import tensorflow as tf +import numpy as np +import os +from . import plugin_loader + +#---------------------------------------------------------------------------- +# Helpers. +#---------------------------------------------------------------------------- + +# OpenGL-related linker options depending on platform. +def _get_gl_opts(): + libs = { + 'posix': ['GL', 'EGL'], + 'nt': ['gdi32', 'opengl32', 'user32', 'setgpu'], + } + return ['-l' + x for x in libs[os.name]] + +# Load the cpp plugin. +def _get_plugin(): + fn = os.path.join(os.path.dirname(__file__), 'tf_all.cu') + return plugin_loader.get_plugin(fn, extra_nvcc_options=_get_gl_opts() + ['-DNVDR_TENSORFLOW']) + +# Convert parameter to a numpy array if possible. +def _get_constant(x, dtype): + try: + return np.asarray(x, dtype=dtype) + except (TypeError, ValueError): + return None + +# Tests for a construction-time constantness instead of tf.constant node because +# the latter can be overridden in Session.run() feed_dict at evaluation time. +def _is_constant(x, dtype): + if isinstance(x, np.ndarray): + return np.can_cast(x.dtype, dtype, 'unsafe') + else: + return _get_constant(x, dtype) is not None + +#---------------------------------------------------------------------------- +# Rasterize. +#---------------------------------------------------------------------------- + +def rasterize(pos, tri, resolution, ranges=None, tri_const=False, output_db=True, grad_db=True): + assert tri_const is True or tri_const is False + assert output_db is True or output_db is False + + # Known constant resolution? + resolution_c = _get_constant(resolution, np.int32) + + # Known constant triangles? + tri_const = tri_const or _is_constant(tri, np.int32) + + # Convert all inputs to tensors / base types. + tri_const = 1 if tri_const else 0 + tri = tf.convert_to_tensor(tri, dtype=tf.int32) + pos = tf.convert_to_tensor(pos, dtype=tf.float32) + resolution = tf.convert_to_tensor(resolution, dtype=tf.int32) + if ranges is None: + ranges = tf.convert_to_tensor(np.zeros(shape=[0, 2], dtype=np.int32)) # Empty tensor. + else: + ranges = tf.convert_to_tensor(ranges, dtype=tf.int32) # Convert input to tensor. + + # Infer as much about the output shape as possible. + out_shape = [None, None, None, 4] + if pos.shape.rank == 3: # Instanced mode. + out_shape[0] = pos.shape[0].value + elif pos.shape.rank == 2: # Range mode. + if ranges.shape.rank not in [None, 0]: + out_shape[0] = ranges.shape[0].value + if resolution_c is not None: + assert resolution_c.shape == (2,) + out_shape[1], out_shape[2] = resolution_c + + # Output pixel differentials. + @tf.custom_gradient + def func_db(pos): + out, out_db = _get_plugin().rasterize_fwd(pos, tri, resolution, ranges, 1, tri_const) + out.set_shape(out_shape) + out_db.set_shape(out_shape) + def grad(dy, ddb): + if grad_db: + return _get_plugin().rasterize_grad_db(pos, tri, out, dy, ddb) + else: + return _get_plugin().rasterize_grad(pos, tri, out, dy) + return (out, out_db), grad + + # Do not output pixel differentials. + @tf.custom_gradient + def func(pos): + out, out_db = _get_plugin().rasterize_fwd(pos, tri, resolution, ranges, 0, tri_const) + out.set_shape(out_shape) + out_db.set_shape(out_shape[:-1] + [0]) # Zero channels in out_db. + def grad(dy, _): + return _get_plugin().rasterize_grad(pos, tri, out, dy) + return (out, out_db), grad + + # Choose stub. + if output_db: + return func_db(pos) + else: + return func(pos) + +#---------------------------------------------------------------------------- +# Interpolate. +#---------------------------------------------------------------------------- + +def interpolate(attr, rast, tri, rast_db=None, diff_attrs=None): + # Sanitize the list of pixel differential attributes. + if diff_attrs is None: + diff_attrs = [] + elif diff_attrs != 'all': + diff_attrs = _get_constant(diff_attrs, np.int32) + assert (diff_attrs is not None) and len(diff_attrs.shape) == 1 + diff_attrs = diff_attrs.tolist() + + # Convert all inputs to tensors. + attr = tf.convert_to_tensor(attr, dtype=tf.float32) + rast = tf.convert_to_tensor(rast, dtype=tf.float32) + tri = tf.convert_to_tensor(tri, dtype=tf.int32) + if diff_attrs: + rast_db = tf.convert_to_tensor(rast_db, dtype=tf.float32) + + # Infer output shape. + out_shape = [None, None, None, None] + if rast.shape.rank is not None: + out_shape = [rast.shape[0].value, rast.shape[1].value, rast.shape[2].value, None] + if attr.shape.rank in [2, 3]: + out_shape[3] = attr.shape[-1].value + + # Output pixel differentials for at least some attributes. + @tf.custom_gradient + def func_da(attr, rast, rast_db): + diff_attrs_all = int(diff_attrs == 'all') + diff_attrs_list = [] if diff_attrs_all else diff_attrs + out, out_da = _get_plugin().interpolate_fwd_da(attr, rast, tri, rast_db, diff_attrs_all, diff_attrs_list) + + # Infer number of channels in out_da. + if not diff_attrs_all: + da_channels = 2 * len(diff_attrs) + if (attr.shape.rank in [2, 3]) and (attr.shape[-1].value is not None): + da_channels = 2 * attr.shape[-1].value + else: + da_channels = None + + # Set output shapes. + out.set_shape(out_shape) + out_da.set_shape([out_shape[0], out_shape[1], out_shape[2], da_channels]) + + def grad(dy, dda): + return _get_plugin().interpolate_grad_da(attr, rast, tri, dy, rast_db, dda, diff_attrs_all, diff_attrs_list) + return (out, out_da), grad + + # No pixel differentials for any attribute. + @tf.custom_gradient + def func(attr, rast): + out, out_da = _get_plugin().interpolate_fwd(attr, rast, tri) + out.set_shape(out_shape) + out_da.set_shape(out_shape[:-1] + [0]) # Zero channels in out_da. + def grad(dy, _): + return _get_plugin().interpolate_grad(attr, rast, tri, dy) + return (out, out_da), grad + + # Choose stub. + if diff_attrs: + return func_da(attr, rast, rast_db) + else: + return func(attr, rast) + +#---------------------------------------------------------------------------- +# Texture. +#---------------------------------------------------------------------------- + +def texture(tex, uv, uv_da=None, filter_mode='auto', boundary_mode='wrap', tex_const=False, max_mip_level=None): + assert tex_const is True or tex_const is False + + # Default filter mode. + if filter_mode == 'auto': + filter_mode = 'linear-mipmap-linear' if (uv_da is not None) else 'linear' + + # Known constant texture? + tex_const = tex_const or _is_constant(tex, np.float32) + + # Sanitize inputs. + tex_const = 1 if tex_const else 0 + if max_mip_level is None: + max_mip_level = -1 + else: + max_mip_level = int(max_mip_level) + assert max_mip_level >= 0 + + # Convert inputs to tensors. + tex = tf.convert_to_tensor(tex, dtype=tf.float32) + uv = tf.convert_to_tensor(uv, dtype=tf.float32) + if 'mipmap' in filter_mode: + uv_da = tf.convert_to_tensor(uv_da, dtype=tf.float32) + + # Infer output shape. + out_shape = [None, None, None, None] + if uv.shape.rank is not None: + assert uv.shape.rank == 4 + out_shape = [uv.shape[0].value, uv.shape[1].value, uv.shape[2].value, None] + if tex.shape.rank is not None: + assert tex.shape.rank == (5 if boundary_mode == 'cube' else 4) + out_shape[-1] = tex.shape[-1].value + + # If mipping disabled via max level=0, we may as well use simpler filtering internally. + if max_mip_level == 0 and filter_mode in ['linear-mipmap-nearest', 'linear-mipmap-linear']: + filter_mode = 'linear' + + # Convert filter mode to internal enumeration. + filter_mode_dict = {'nearest': 0, 'linear': 1, 'linear-mipmap-nearest': 2, 'linear-mipmap-linear': 3} + filter_mode_enum = filter_mode_dict[filter_mode] + + # Convert boundary mode to internal enumeration. + boundary_mode_dict = {'cube': 0, 'wrap': 1, 'clamp': 2, 'zero': 3} + boundary_mode_enum = boundary_mode_dict[boundary_mode] + + # Linear-mipmap-linear: Mipmaps enabled, all gradients active. + @tf.custom_gradient + def func_linear_mipmap_linear(tex, uv, uv_da): + out, mip = _get_plugin().texture_fwd_mip(tex, uv, uv_da, filter_mode_enum, boundary_mode_enum, tex_const, max_mip_level) + out.set_shape(out_shape) + def grad(dy): + return _get_plugin().texture_grad_linear_mipmap_linear(tex, uv, dy, uv_da, mip, filter_mode_enum, boundary_mode_enum, max_mip_level) + return out, grad + + # Linear-mipmap-nearest: Mipmaps enabled, no gradients to uv_da. + @tf.custom_gradient + def func_linear_mipmap_nearest(tex, uv): + out, mip = _get_plugin().texture_fwd_mip(tex, uv, uv_da, filter_mode_enum, boundary_mode_enum, tex_const, max_mip_level) + out.set_shape(out_shape) + def grad(dy): + return _get_plugin().texture_grad_linear_mipmap_nearest(tex, uv, dy, uv_da, mip, filter_mode_enum, boundary_mode_enum, max_mip_level) + return out, grad + + # Linear: Mipmaps disabled, no uv_da, no gradients to uv_da. + @tf.custom_gradient + def func_linear(tex, uv): + out = _get_plugin().texture_fwd(tex, uv, filter_mode_enum, boundary_mode_enum) + out.set_shape(out_shape) + def grad(dy): + return _get_plugin().texture_grad_linear(tex, uv, dy, filter_mode_enum, boundary_mode_enum) + return out, grad + + # Nearest: Mipmaps disabled, no uv_da, no gradients to uv_da or uv. + @tf.custom_gradient + def func_nearest(tex): + out = _get_plugin().texture_fwd(tex, uv, filter_mode_enum, boundary_mode_enum) + out.set_shape(out_shape) + def grad(dy): + return _get_plugin().texture_grad_nearest(tex, uv, dy, filter_mode_enum, boundary_mode_enum) + return out, grad + + # Choose stub. + if filter_mode == 'linear-mipmap-linear': + return func_linear_mipmap_linear(tex, uv, uv_da) + elif filter_mode == 'linear-mipmap-nearest': + return func_linear_mipmap_nearest(tex, uv) + elif filter_mode == 'linear': + return func_linear(tex, uv) + elif filter_mode == 'nearest': + return func_nearest(tex) + +#---------------------------------------------------------------------------- +# Antialias. +#---------------------------------------------------------------------------- + +def antialias(color, rast, pos, tri, tri_const=False, pos_gradient_boost=1.0): + assert tri_const is True or tri_const is False + + # Known constant triangles? + tri_const = tri_const or _is_constant(tri, np.int32) + + # Convert inputs to tensors. + color = tf.convert_to_tensor(color, dtype=tf.float32) + rast = tf.convert_to_tensor(rast, dtype=tf.float32) + pos = tf.convert_to_tensor(pos, dtype=tf.float32) + tri = tf.convert_to_tensor(tri, dtype=tf.int32) + + # Sanitize inputs. + tri_const = 1 if tri_const else 0 + + @tf.custom_gradient + def func(color, pos): + color_out, work_buffer = _get_plugin().antialias_fwd(color, rast, pos, tri, tri_const) + color_out.set_shape(color.shape) + def grad(dy): + grad_color, grad_pos = _get_plugin().antialias_grad(color, rast, pos, tri, dy, work_buffer) + if pos_gradient_boost != 1.0: + grad_pos = grad_pos * pos_gradient_boost + return grad_color, grad_pos + return color_out, grad + + return func(color, pos) + +#---------------------------------------------------------------------------- diff --git a/extensions/nvdiffrast/nvdiffrast/tensorflow/plugin_loader.py b/extensions/nvdiffrast/nvdiffrast/tensorflow/plugin_loader.py new file mode 100644 index 0000000000000000000000000000000000000000..3918aecdab6bb4192e8810bd872abf9a1fc30971 --- /dev/null +++ b/extensions/nvdiffrast/nvdiffrast/tensorflow/plugin_loader.py @@ -0,0 +1,219 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +import glob +import os +import re +import uuid +import hashlib +import tempfile +import shutil +import tensorflow as tf +from tensorflow.python.client import device_lib # pylint: disable=no-name-in-module + +#---------------------------------------------------------------------------- +# Global options. + +_nvdiffrast_cache_dir = None + +def set_cache_dir(path: str) -> None: + '''Set CUDA kernel compilation temp dir. + + If `set_cache_dir` is not called, the cache directory will default to + one of the below: + + - Value of NVDIFFRAST_CACHE_DIR env var, if set + - $HOME/.cache/nvdiffrast if HOME env var is set + - $USERPROFILE/.cache/nvdiffrast if USERPROFILE is set. + + Args: + path: Where to save CUDA kernel build temporaries + ''' + global _nvdiffrast_cache_dir + _nvdiffrast_cache_dir = path + +def make_cache_dir_path(*paths: str) -> str: + if _nvdiffrast_cache_dir is not None: + return os.path.join(_nvdiffrast_cache_dir, *paths) + if 'NVDIFFRAST_CACHE_DIR' in os.environ: + return os.path.join(os.environ['NVDIFFRAST_CACHE_DIR'], *paths) + if 'HOME' in os.environ: + return os.path.join(os.environ['HOME'], '.cache', 'nvdiffrast', *paths) + if 'USERPROFILE' in os.environ: + return os.path.join(os.environ['USERPROFILE'], '.cache', 'nvdiffrast', *paths) + return os.path.join(tempfile.gettempdir(), '.cache', 'nvdiffrast', *paths) + +cuda_cache_version_tag = 'v1' +do_not_hash_included_headers = False # Speed up compilation by assuming that headers included by the CUDA code never change. Unsafe! +verbose = True # Print status messages to stdout. + +#---------------------------------------------------------------------------- +# Internal helper funcs. + +def _find_compiler_bindir(): + hostx64_paths = sorted(glob.glob('C:/Program Files/Microsoft Visual Studio/*/Enterprise/VC/Tools/MSVC/*/bin/Hostx64/x64'), reverse=True) + if hostx64_paths != []: + return hostx64_paths[0] + hostx64_paths = sorted(glob.glob('C:/Program Files (x86)/Microsoft Visual Studio/*/Enterprise/VC/Tools/MSVC/*/bin/Hostx64/x64'), reverse=True) + if hostx64_paths != []: + return hostx64_paths[0] + hostx64_paths = sorted(glob.glob('C:/Program Files/Microsoft Visual Studio/*/Professional/VC/Tools/MSVC/*/bin/Hostx64/x64'), reverse=True) + if hostx64_paths != []: + return hostx64_paths[0] + hostx64_paths = sorted(glob.glob('C:/Program Files (x86)/Microsoft Visual Studio/*/Professional/VC/Tools/MSVC/*/bin/Hostx64/x64'), reverse=True) + if hostx64_paths != []: + return hostx64_paths[0] + hostx64_paths = sorted(glob.glob('C:/Program Files/Microsoft Visual Studio/*/BuildTools/VC/Tools/MSVC/*/bin/Hostx64/x64'), reverse=True) + if hostx64_paths != []: + return hostx64_paths[0] + hostx64_paths = sorted(glob.glob('C:/Program Files (x86)/Microsoft Visual Studio/*/BuildTools/VC/Tools/MSVC/*/bin/Hostx64/x64'), reverse=True) + if hostx64_paths != []: + return hostx64_paths[0] + hostx64_paths = sorted(glob.glob('C:/Program Files/Microsoft Visual Studio/*/Community/VC/Tools/MSVC/*/bin/Hostx64/x64'), reverse=True) + if hostx64_paths != []: + return hostx64_paths[0] + hostx64_paths = sorted(glob.glob('C:/Program Files (x86)/Microsoft Visual Studio/*/Community/VC/Tools/MSVC/*/bin/Hostx64/x64'), reverse=True) + if hostx64_paths != []: + return hostx64_paths[0] + vc_bin_dir = 'C:/Program Files (x86)/Microsoft Visual Studio 14.0/vc/bin' + if os.path.isdir(vc_bin_dir): + return vc_bin_dir + return None + +def _get_compute_cap(device): + caps_str = device.physical_device_desc + m = re.search('compute capability: (\\d+).(\\d+)', caps_str) + major = m.group(1) + minor = m.group(2) + return (major, minor) + +def _get_cuda_gpu_arch_string(): + gpus = [x for x in device_lib.list_local_devices() if x.device_type == 'GPU'] + if len(gpus) == 0: + raise RuntimeError('No GPU devices found') + (major, minor) = _get_compute_cap(gpus[0]) + return 'sm_%s%s' % (major, minor) + +def _run_cmd(cmd): + with os.popen(cmd) as pipe: + output = pipe.read() + status = pipe.close() + if status is not None: + raise RuntimeError('NVCC returned an error. See below for full command line and output log:\n\n%s\n\n%s' % (cmd, output)) + +def _prepare_nvcc_cli(opts): + cmd = 'nvcc ' + opts.strip() + cmd += ' --disable-warnings' + cmd += ' --include-path "%s"' % tf.sysconfig.get_include() + cmd += ' --include-path "%s"' % os.path.join(tf.sysconfig.get_include(), 'external', 'protobuf_archive', 'src') + cmd += ' --include-path "%s"' % os.path.join(tf.sysconfig.get_include(), 'external', 'com_google_absl') + cmd += ' --include-path "%s"' % os.path.join(tf.sysconfig.get_include(), 'external', 'eigen_archive') + + compiler_bindir = _find_compiler_bindir() + if compiler_bindir is None: + # Require that _find_compiler_bindir succeeds on Windows. Allow + # nvcc to use whatever is the default on Linux. + if os.name == 'nt': + raise RuntimeError('Could not find MSVC/GCC/CLANG installation on this computer. Check compiler_bindir_search_path list in "%s".' % __file__) + else: + cmd += ' --compiler-bindir "%s"' % compiler_bindir + cmd += ' 2>&1' + return cmd + +#---------------------------------------------------------------------------- +# Main entry point. + +_plugin_cache = dict() + +def get_plugin(cuda_file, extra_nvcc_options=[]): + cuda_file_base = os.path.basename(cuda_file) + cuda_file_name, cuda_file_ext = os.path.splitext(cuda_file_base) + + # Already in cache? + if cuda_file in _plugin_cache: + return _plugin_cache[cuda_file] + + # Setup plugin. + if verbose: + print('Setting up TensorFlow plugin "%s": ' % cuda_file_base, end='', flush=True) + try: + # Hash CUDA source. + md5 = hashlib.md5() + with open(cuda_file, 'rb') as f: + md5.update(f.read()) + md5.update(b'\n') + + # Hash headers included by the CUDA code by running it through the preprocessor. + if not do_not_hash_included_headers: + if verbose: + print('Preprocessing... ', end='', flush=True) + with tempfile.TemporaryDirectory() as tmp_dir: + tmp_file = os.path.join(tmp_dir, cuda_file_name + '_tmp' + cuda_file_ext) + _run_cmd(_prepare_nvcc_cli('"%s" --preprocess -o "%s" --keep --keep-dir "%s"' % (cuda_file, tmp_file, tmp_dir))) + with open(tmp_file, 'rb') as f: + bad_file_str = ('"' + cuda_file.replace('\\', '/') + '"').encode('utf-8') # __FILE__ in error check macros + good_file_str = ('"' + cuda_file_base + '"').encode('utf-8') + for ln in f: + if not ln.startswith(b'# ') and not ln.startswith(b'#line '): # ignore line number pragmas + ln = ln.replace(bad_file_str, good_file_str) + md5.update(ln) + md5.update(b'\n') + + # Select compiler options. + compile_opts = '' + if os.name == 'nt': + compile_opts += '"%s"' % os.path.join(tf.sysconfig.get_lib(), 'python', '_pywrap_tensorflow_internal.lib') + compile_opts += ' --library-path="%s"' % (os.path.dirname(__file__) + r"\..\lib") # Find libraries during compilation. + elif os.name == 'posix': + compile_opts += '"%s"' % os.path.join(tf.sysconfig.get_lib(), 'python', '_pywrap_tensorflow_internal.so') + compile_opts += ' --compiler-options \'-fPIC -D_GLIBCXX_USE_CXX11_ABI=0\'' + else: + assert False # not Windows or Linux, w00t? + compile_opts += ' --gpu-architecture=%s' % _get_cuda_gpu_arch_string() + compile_opts += ' --use_fast_math' + for opt in extra_nvcc_options: + compile_opts += ' ' + opt + nvcc_cmd = _prepare_nvcc_cli(compile_opts) + + # Hash build configuration. + md5.update(('nvcc_cmd: ' + nvcc_cmd).encode('utf-8') + b'\n') + md5.update(('tf.VERSION: ' + tf.VERSION).encode('utf-8') + b'\n') + md5.update(('cuda_cache_version_tag: ' + cuda_cache_version_tag).encode('utf-8') + b'\n') + + # Compile if not already compiled. + bin_file_ext = '.dll' if os.name == 'nt' else '.so' + cuda_cache_path = make_cache_dir_path() + bin_file = os.path.join(make_cache_dir_path(), cuda_file_name + '_' + md5.hexdigest() + bin_file_ext) + if not os.path.isfile(bin_file): + if verbose: + print('Compiling... ', end='', flush=True) + with tempfile.TemporaryDirectory() as tmp_dir: + tmp_file = os.path.join(tmp_dir, cuda_file_name + '_tmp' + bin_file_ext) + _run_cmd(nvcc_cmd + ' "%s" --shared -o "%s" --keep --keep-dir "%s"' % (cuda_file, tmp_file, tmp_dir)) + os.makedirs(cuda_cache_path, exist_ok=True) + intermediate_file = os.path.join(cuda_cache_path, cuda_file_name + '_' + uuid.uuid4().hex + '_tmp' + bin_file_ext) + shutil.copyfile(tmp_file, intermediate_file) + os.rename(intermediate_file, bin_file) # atomic + + # Load. + if verbose: + print('Loading... ', end='', flush=True) + plugin = tf.load_op_library(bin_file) + + # Add to cache. + _plugin_cache[cuda_file] = plugin + if verbose: + print('Done.', flush=True) + return plugin + + except: + if verbose: + print('Failed!', flush=True) + raise + +#---------------------------------------------------------------------------- diff --git a/extensions/nvdiffrast/nvdiffrast/tensorflow/tf_all.cu b/extensions/nvdiffrast/nvdiffrast/tensorflow/tf_all.cu new file mode 100644 index 0000000000000000000000000000000000000000..8eefcfbd35d837b9ec595100f57f0bdb6d072349 --- /dev/null +++ b/extensions/nvdiffrast/nvdiffrast/tensorflow/tf_all.cu @@ -0,0 +1,36 @@ +// Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +// +// NVIDIA CORPORATION and its licensors retain all intellectual property +// and proprietary rights in and to this software, related documentation +// and any modifications thereto. Any use, reproduction, disclosure or +// distribution of this software and related documentation without an express +// license agreement from NVIDIA CORPORATION is strictly prohibited. + +// TF-specific helpers. + +#define OP_CHECK_CUDA_ERROR(CTX, CUDA_CALL) do { cudaError_t err = CUDA_CALL; OP_REQUIRES(CTX, err == cudaSuccess, errors::Internal("Cuda error: ", cudaGetErrorName(err), "[", #CUDA_CALL, ";]")); } while (0) +#define OP_CHECK_GL_ERROR(CTX, GL_CALL) do { GL_CALL; GLenum err = glGetError(); OP_REQUIRES(CTX, err == GL_NO_ERROR, errors::Internal("OpenGL error: ", getGLErrorString(err), "[", #GL_CALL, ";]")); } while (0) + +// Cuda kernels and CPP all together. What an absolute compilation unit. + +#define __CUDA_INCLUDE_COMPILER_INTERNAL_HEADERS__ +#include "../common/framework.h" +#include "../common/glutil.cpp" + +#include "../common/common.h" +#include "../common/common.cpp" + +#include "../common/rasterize.h" +#include "../common/rasterize_gl.cpp" +#include "../common/rasterize.cu" +#include "tf_rasterize.cu" + +#include "../common/interpolate.cu" +#include "tf_interpolate.cu" + +#include "../common/texture.cpp" +#include "../common/texture.cu" +#include "tf_texture.cu" + +#include "../common/antialias.cu" +#include "tf_antialias.cu" diff --git a/extensions/nvdiffrast/nvdiffrast/tensorflow/tf_antialias.cu b/extensions/nvdiffrast/nvdiffrast/tensorflow/tf_antialias.cu new file mode 100644 index 0000000000000000000000000000000000000000..9b14962a8b40e12bfab1ca3a7107d5f5e943a125 --- /dev/null +++ b/extensions/nvdiffrast/nvdiffrast/tensorflow/tf_antialias.cu @@ -0,0 +1,278 @@ +// Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +// +// NVIDIA CORPORATION and its licensors retain all intellectual property +// and proprietary rights in and to this software, related documentation +// and any modifications thereto. Any use, reproduction, disclosure or +// distribution of this software and related documentation without an express +// license agreement from NVIDIA CORPORATION is strictly prohibited. + +//------------------------------------------------------------------------ +// Forward TensorFlow op. + +struct AntialiasFwdOp : public OpKernel +{ + AntialiasKernelParams m_attribs; + + AntialiasFwdOp(OpKernelConstruction* ctx): OpKernel(ctx) + { + memset(&m_attribs, 0, sizeof(m_attribs)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("tri_const", &m_attribs.tri_const)); + } + + void Compute(OpKernelContext* ctx) + { + AntialiasKernelParams& p = m_attribs; + cudaStream_t stream = ctx->eigen_device().stream(); + + // Get input. + const Tensor& color = ctx->input(0); + const Tensor& rasterOut = ctx->input(1); + const Tensor& pos = ctx->input(2); + const Tensor& tri = ctx->input(3); + + // Instance rendering mode? + p.instance_mode = pos.dims() > 2; + + // Extract input dimensions. + if (p.instance_mode) + p.numVertices = (pos.dims() > 1) ? pos.dim_size(1) : 0; + else + p.numVertices = (pos.dims() > 0) ? pos.dim_size(0) : 0; + p.numTriangles = (tri.dims() > 0) ? tri.dim_size(0) : 0; + p.n = (color.dims() > 0) ? color.dim_size(0) : 0; + p.height = (color.dims() > 1) ? color.dim_size(1) : 0; + p.width = (color.dims() > 2) ? color.dim_size(2) : 0; + p.channels = (color.dims() > 3) ? color.dim_size(3) : 0; + + // Sanity checks. + OP_REQUIRES(ctx, color.dims() == 4 && color.dim_size(0) > 0 && color.dim_size(1) > 0 && color.dim_size(2) > 0 && color.dim_size(3) > 0, errors::InvalidArgument("color must have shape[>0, >0, >0, >0]")); + OP_REQUIRES(ctx, rasterOut.dims() == 4 && rasterOut.dim_size(0) > 0 && rasterOut.dim_size(1) > 0 && rasterOut.dim_size(2) > 0 && rasterOut.dim_size(3) == 4, errors::InvalidArgument("raster_out must have shape[>0, >0, >0, 4]")); + OP_REQUIRES(ctx, tri.dims() == 2 && tri.dim_size(0) > 0 && tri.dim_size(1) == 3, errors::InvalidArgument("tri must have shape [>0, 3]")); + OP_REQUIRES(ctx, color.dim_size(1) == rasterOut.dim_size(1) && color.dim_size(2) == rasterOut.dim_size(2), errors::InvalidArgument("color and raster_out inputs must have same spatial dimensions")); + if (p.instance_mode) + { + OP_REQUIRES(ctx, pos.dims() == 3 && pos.dim_size(0) > 0 && pos.dim_size(1) > 0 && pos.dim_size(2) == 4, errors::InvalidArgument("pos must have shape [>0, >0, 4] or [>0, 4]")); + OP_REQUIRES(ctx, rasterOut.dim_size(0) == p.n && pos.dim_size(0) == p.n, errors::InvalidArgument("minibatch size mismatch between inputs color, raster_out, pos")); + } + else + { + OP_REQUIRES(ctx, pos.dims() == 2 && pos.dim_size(0) > 0 && pos.dim_size(1) == 4, errors::InvalidArgument("pos must have shape [>0, >0, 4] or [>0, 4]")); + OP_REQUIRES(ctx, rasterOut.dim_size(0) == p.n, errors::InvalidArgument("minibatch size mismatch between inputs color, raster_out")); + } + + // Get input pointers. + p.color = color.flat().data(); + p.rasterOut = rasterOut.flat().data(); + p.tri = tri.flat().data(); + p.pos = pos.flat().data(); + + // Misc parameters. + p.xh = .5f * (float)p.width; + p.yh = .5f * (float)p.height; + + // Allocate output tensor. + Tensor* outputTensor = NULL; + TensorShape outputShape; + outputShape.AddDim(p.n); + outputShape.AddDim(p.height); + outputShape.AddDim(p.width); + outputShape.AddDim(p.channels); + OP_REQUIRES_OK(ctx, ctx->allocate_output(0, outputShape, &outputTensor)); + p.output = outputTensor->flat().data(); + + // Allocate work buffer. One extra int4 for storing counters. + Tensor* workTensor = NULL; + TensorShape workShape; + workShape.AddDim(p.n * p.width * p.height * 8 + 4); // 8 int for a maximum of two work items per pixel. + OP_REQUIRES_OK(ctx, ctx->allocate_output(1, workShape, &workTensor)); + p.workBuffer = (int4*)(workTensor->flat().data()); + + // Clear the work counters. + OP_CHECK_CUDA_ERROR(ctx, cudaMemsetAsync(p.workBuffer, 0, sizeof(int4), stream)); + + // Verify that buffers are aligned to allow float2/float4 operations. + OP_REQUIRES(ctx, !((uintptr_t)p.pos & 15), errors::Internal("pos input tensor not aligned to float4")); + OP_REQUIRES(ctx, !((uintptr_t)p.rasterOut & 7), errors::Internal("raster_out input tensor not aligned to float2")); + OP_REQUIRES(ctx, !((uintptr_t)p.workBuffer & 15), errors::Internal("work_buffer internal tensor not aligned to int4")); + + // Kernel parameters. + void* args[] = {&p}; + + // (Re-)calculate opposite vertex hash. + if (!p.evHash || !p.tri_const) + { + if (p.allocTriangles < p.numTriangles) + { + p.allocTriangles = max(p.allocTriangles, 64); + while (p.allocTriangles < p.numTriangles) + p.allocTriangles <<= 1; // Must be power of two. + + // (Re-)allocate memory for the hash. + OP_CHECK_CUDA_ERROR(ctx, cudaFree(p.evHash)); + OP_CHECK_CUDA_ERROR(ctx, cudaMalloc(&p.evHash, p.allocTriangles * AA_HASH_ELEMENTS_PER_TRIANGLE(p.allocTriangles) * sizeof(uint4))); + LOG(INFO) << "Increasing topology hash size to accommodate " << p.allocTriangles << " triangles"; + } + + // Clear the hash and launch the mesh kernel to populate it. + OP_CHECK_CUDA_ERROR(ctx, cudaMemsetAsync(p.evHash, 0, p.allocTriangles * AA_HASH_ELEMENTS_PER_TRIANGLE(p.allocTriangles) * sizeof(uint4), stream)); + OP_CHECK_CUDA_ERROR(ctx, cudaLaunchKernel((void*)AntialiasFwdMeshKernel, (p.numTriangles - 1) / AA_MESH_KERNEL_THREADS_PER_BLOCK + 1, AA_MESH_KERNEL_THREADS_PER_BLOCK, args, 0, stream)); + } + + // Copy input to output as a baseline. + OP_CHECK_CUDA_ERROR(ctx, cudaMemcpyAsync(p.output, p.color, p.n * p.height * p.width * p.channels * sizeof(float), cudaMemcpyDeviceToDevice, stream)); + + // Choose launch parameters for the discontinuity finder kernel and launch. + dim3 blockSize(AA_DISCONTINUITY_KERNEL_BLOCK_WIDTH, AA_DISCONTINUITY_KERNEL_BLOCK_HEIGHT, 1); + dim3 gridSize = getLaunchGridSize(blockSize, p.width, p.height, p.n); + OP_CHECK_CUDA_ERROR(ctx, cudaLaunchKernel((void*)AntialiasFwdDiscontinuityKernel, gridSize, blockSize, args, 0, stream)); + + // Determine optimum block size for the persistent analysis kernel. + int device = 0; + int numCTA = 0; + int numSM = 0; + OP_CHECK_CUDA_ERROR(ctx, cudaGetDevice(&device)); + OP_CHECK_CUDA_ERROR(ctx, cudaOccupancyMaxActiveBlocksPerMultiprocessor(&numCTA, (void*)AntialiasFwdAnalysisKernel, AA_ANALYSIS_KERNEL_THREADS_PER_BLOCK, 0)); + OP_CHECK_CUDA_ERROR(ctx, cudaDeviceGetAttribute(&numSM, cudaDevAttrMultiProcessorCount, device)); + + // Launch analysis kernel. + OP_CHECK_CUDA_ERROR(ctx, cudaLaunchKernel((void*)AntialiasFwdAnalysisKernel, numCTA * numSM, AA_ANALYSIS_KERNEL_THREADS_PER_BLOCK, args, 0, stream)); + } +}; + +REGISTER_OP("AntialiasFwd") + .Input ("color: float") + .Input ("raster_out: float") + .Input ("pos: float") + .Input ("tri: int32") + .Output ("output: float") + .Output ("work_buffer: int32") + .Attr ("tri_const: int"); + +REGISTER_KERNEL_BUILDER(Name("AntialiasFwd").Device(DEVICE_GPU), AntialiasFwdOp); + +//------------------------------------------------------------------------ +// Gradient TensorFlow op. + +struct AntialiasGradOp : public OpKernel +{ + AntialiasKernelParams m_attribs; + + AntialiasGradOp(OpKernelConstruction* ctx): OpKernel(ctx) + { + memset(&m_attribs, 0, sizeof(m_attribs)); + } + + void Compute(OpKernelContext* ctx) + { + AntialiasKernelParams& p = m_attribs; + cudaStream_t stream = ctx->eigen_device().stream(); + + // Get input. + const Tensor& color = ctx->input(0); + const Tensor& rasterOut = ctx->input(1); + const Tensor& pos = ctx->input(2); + const Tensor& tri = ctx->input(3); + const Tensor& dy = ctx->input(4); + const Tensor& workBuffer = ctx->input(5); + + // Instance rendering mode? + p.instance_mode = pos.dims() > 2; + + // Extract input dimensions. + if (p.instance_mode) + p.numVertices = (pos.dims() > 1) ? pos.dim_size(1) : 0; + else + p.numVertices = (pos.dims() > 0) ? pos.dim_size(0) : 0; + p.numTriangles = (tri.dims() > 0) ? tri.dim_size(0) : 0; + p.n = (color.dims() > 0) ? color.dim_size(0) : 0; + p.height = (color.dims() > 1) ? color.dim_size(1) : 0; + p.width = (color.dims() > 2) ? color.dim_size(2) : 0; + p.channels = (color.dims() > 3) ? color.dim_size(3) : 0; + + // Sanity checks. + OP_REQUIRES(ctx, dy.dims() == 4 && dy.dim_size(0) > 0 && dy.dim_size(1) > 0 && dy.dim_size(2) > 0 && dy.dim_size(3) > 0, errors::InvalidArgument("dy must have shape[>0, >0, >0, >0]")); + OP_REQUIRES(ctx, color.dims() == 4 && color.dim_size(0) > 0 && color.dim_size(1) > 0 && color.dim_size(2) > 0 && color.dim_size(3) > 0, errors::InvalidArgument("color must have shape[>0, >0, >0, >0]")); + OP_REQUIRES(ctx, rasterOut.dims() == 4 && rasterOut.dim_size(0) > 0 && rasterOut.dim_size(1) > 0 && rasterOut.dim_size(2) > 0 && rasterOut.dim_size(3) == 4, errors::InvalidArgument("raster_out must have shape[>0, >0, >0, 4]")); + OP_REQUIRES(ctx, tri.dims() == 2 && tri.dim_size(0) > 0 && tri.dim_size(1) == 3, errors::InvalidArgument("tri must have shape [>0, 3]")); + OP_REQUIRES(ctx, color.dim_size(1) == rasterOut.dim_size(1) && color.dim_size(2) == rasterOut.dim_size(2), errors::InvalidArgument("color and raster_out inputs must have same spatial dimensions")); + OP_REQUIRES(ctx, color.dim_size(1) == dy.dim_size(1) && color.dim_size(2) == dy.dim_size(2) && color.dim_size(3) == dy.dim_size(3), errors::InvalidArgument("color and dy inputs must have same dimensions")); + if (p.instance_mode) + { + OP_REQUIRES(ctx, pos.dims() == 3 && pos.dim_size(0) > 0 && pos.dim_size(1) > 0 && pos.dim_size(2) == 4, errors::InvalidArgument("pos must have shape [>0, >0, 4] or [>0, 4]")); + OP_REQUIRES(ctx, rasterOut.dim_size(0) == p.n && pos.dim_size(0) == p.n, errors::InvalidArgument("minibatch size mismatch between inputs color, raster_out, pos")); + OP_REQUIRES(ctx, dy.dim_size(0) == p.n && rasterOut.dim_size(0) == p.n && pos.dim_size(0) == p.n, errors::InvalidArgument("minibatch size mismatch between inputs dy, color, raster_out, pos")); + } + else + { + OP_REQUIRES(ctx, pos.dims() == 2 && pos.dim_size(0) > 0 && pos.dim_size(1) == 4, errors::InvalidArgument("pos must have shape [>0, >0, 4] or [>0, 4]")); + OP_REQUIRES(ctx, rasterOut.dim_size(0) == p.n, errors::InvalidArgument("minibatch size mismatch between inputs color, raster_out")); + OP_REQUIRES(ctx, dy.dim_size(0) == p.n && rasterOut.dim_size(0) == p.n, errors::InvalidArgument("minibatch size mismatch between inputs dy, color, raster_out")); + } + + // Get input pointers. + p.dy = dy.flat().data(); + p.color = color.flat().data(); + p.rasterOut = rasterOut.flat().data(); + p.tri = tri.flat().data(); + p.pos = pos.flat().data(); + p.workBuffer = (int4*)(workBuffer.flat().data()); + + // Misc parameters. + p.xh = .5f * (float)p.width; + p.yh = .5f * (float)p.height; + + // Allocate color gradient output tensor. + Tensor* gradColor = NULL; + TensorShape gradColorShape; + gradColorShape.AddDim(p.n); + gradColorShape.AddDim(p.height); + gradColorShape.AddDim(p.width); + gradColorShape.AddDim(p.channels); + OP_REQUIRES_OK(ctx, ctx->allocate_output(0, gradColorShape, &gradColor)); + p.gradColor = gradColor->flat().data(); + + // Allocate position gradient output tensor. + Tensor* gradPos = NULL; + TensorShape gradPosShape; + if (p.instance_mode) + gradPosShape.AddDim(p.n); + gradPosShape.AddDim(p.numVertices); + gradPosShape.AddDim(4); + OP_REQUIRES_OK(ctx, ctx->allocate_output(1, gradPosShape, &gradPos)); + p.gradPos = gradPos->flat().data(); + + // Initialize all the stuff. + OP_CHECK_CUDA_ERROR(ctx, cudaMemsetAsync(&p.workBuffer[0].y, 0, sizeof(int), stream)); // Gradient kernel work counter. + OP_CHECK_CUDA_ERROR(ctx, cudaMemcpyAsync(p.gradColor, p.dy, p.n * p.height * p.width * p.channels * sizeof(float), cudaMemcpyDeviceToDevice, stream)); + OP_CHECK_CUDA_ERROR(ctx, cudaMemsetAsync(p.gradPos, 0, (p.instance_mode ? p.n : 1) * p.numVertices * 4 * sizeof(float), stream)); + + // Verify that buffers are aligned to allow float2/float4 operations. + OP_REQUIRES(ctx, !((uintptr_t)p.pos & 15), errors::Internal("pos input tensor not aligned to float4")); + OP_REQUIRES(ctx, !((uintptr_t)p.workBuffer & 15), errors::Internal("work_buffer internal tensor not aligned to int4")); + + // Launch the gradient kernel. + void* args[] = {&p}; + + int device = 0; + int numCTA = 0; + int numSM = 0; + OP_CHECK_CUDA_ERROR(ctx, cudaGetDevice(&device)); + OP_CHECK_CUDA_ERROR(ctx, cudaOccupancyMaxActiveBlocksPerMultiprocessor(&numCTA, (void*)AntialiasGradKernel, AA_GRAD_KERNEL_THREADS_PER_BLOCK, 0)); + OP_CHECK_CUDA_ERROR(ctx, cudaDeviceGetAttribute(&numSM, cudaDevAttrMultiProcessorCount, device)); + OP_CHECK_CUDA_ERROR(ctx, cudaLaunchKernel((void*)AntialiasGradKernel, numCTA * numSM, AA_GRAD_KERNEL_THREADS_PER_BLOCK, args, 0, stream)); + } +}; + +REGISTER_OP("AntialiasGrad") + .Input ("color: float") + .Input ("raster_out: float") + .Input ("pos: float") + .Input ("tri: int32") + .Input ("dy: float") + .Input ("work_buffer: int32") + .Output ("grad_color: float") + .Output ("grad_pos: float"); + +REGISTER_KERNEL_BUILDER(Name("AntialiasGrad").Device(DEVICE_GPU), AntialiasGradOp); + +//------------------------------------------------------------------------ diff --git a/extensions/nvdiffrast/nvdiffrast/tensorflow/tf_interpolate.cu b/extensions/nvdiffrast/nvdiffrast/tensorflow/tf_interpolate.cu new file mode 100644 index 0000000000000000000000000000000000000000..612ce1afc5ce41a25496523b193725c1edac64de --- /dev/null +++ b/extensions/nvdiffrast/nvdiffrast/tensorflow/tf_interpolate.cu @@ -0,0 +1,301 @@ +// Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +// +// NVIDIA CORPORATION and its licensors retain all intellectual property +// and proprietary rights in and to this software, related documentation +// and any modifications thereto. Any use, reproduction, disclosure or +// distribution of this software and related documentation without an express +// license agreement from NVIDIA CORPORATION is strictly prohibited. + +//------------------------------------------------------------------------ +// Common op attribute parser. + +static __host__ void interpolateParseOpAttributes(OpKernelConstruction* ctx, InterpolateKernelParams& p, bool enableDA) +{ + if (enableDA) + { + OP_REQUIRES_OK(ctx, ctx->GetAttr("diff_attrs_all", &p.diff_attrs_all)); + if (!p.diff_attrs_all) + { + std::vector diff_attrs_vec; + OP_REQUIRES_OK(ctx, ctx->GetAttr("diff_attrs", &diff_attrs_vec)); + OP_REQUIRES(ctx, diff_attrs_vec.size() > 0, errors::InvalidArgument("differentiation enabled with empty diff_attrs list")); + OP_REQUIRES(ctx, diff_attrs_vec.size() <= IP_MAX_DIFF_ATTRS, errors::InvalidArgument("too many entries in diff_attrs list (increase IP_MAX_DIFF_ATTRS)")); + p.numDiffAttr = diff_attrs_vec.size(); + memcpy(p.diffAttrs, &diff_attrs_vec[0], diff_attrs_vec.size()*sizeof(int)); + } + } +} + +//------------------------------------------------------------------------ +// Forward TensorFlow op. + +template +struct InterpolateFwdOp : public OpKernel +{ + InterpolateKernelParams m_attribs; + + InterpolateFwdOp(OpKernelConstruction* ctx): OpKernel(ctx) + { + memset(&m_attribs, 0, sizeof(m_attribs)); + interpolateParseOpAttributes(ctx, m_attribs, ENABLE_DA); + } + + void Compute(OpKernelContext* ctx) + { + InterpolateKernelParams& p = m_attribs; + cudaStream_t stream = ctx->eigen_device().stream(); + + // Get input. + const Tensor& attr = ctx->input(0); + const Tensor& rast = ctx->input(1); + const Tensor& tri = ctx->input(2); + const Tensor& rast_db = ctx->input(ENABLE_DA ? 3 : 2); + + // Instance rendering mode? + p.instance_mode = attr.dims() > 2; + + // Extract input dimensions. + if (p.instance_mode) + { + p.numVertices = (attr.dims() > 1) ? attr.dim_size(1) : 0; + p.numAttr = (attr.dims() > 2) ? attr.dim_size(2) : 0; + } + else + { + p.numVertices = (attr.dims() > 0) ? attr.dim_size(0) : 0; + p.numAttr = (attr.dims() > 1) ? attr.dim_size(1) : 0; + } + p.numTriangles = (tri.dims() > 0) ? tri.dim_size(0) : 0; + p.height = (rast.dims() > 1) ? rast.dim_size(1) : 0; + p.width = (rast.dims() > 2) ? rast.dim_size(2) : 0; + p.depth = (rast.dims() > 0) ? rast.dim_size(0) : 0; + + // Sanity checks. + OP_REQUIRES(ctx, rast.dims() == 4 && rast.dim_size(0) > 0 && rast.dim_size(1) > 0 && rast.dim_size(2) > 0 && rast.dim_size(3) == 4, errors::InvalidArgument("rast must have shape[>0, >0, >0, 4]")); + OP_REQUIRES(ctx, tri.dims() == 2 && tri.dim_size(0) > 0 && tri.dim_size(1) == 3, errors::InvalidArgument("tri must have shape [>0, 3]")); + OP_REQUIRES(ctx, (attr.dims() == 2 || attr.dims() == 3) && attr.dim_size(0) > 0 && attr.dim_size(1) > 0 && (attr.dims() == 2 || attr.dim_size(2) > 0), errors::InvalidArgument("attr must have shape [>0, >0, >0] or [>0, >0]")); + if (p.instance_mode) + OP_REQUIRES(ctx, attr.dim_size(0) == p.depth || attr.dim_size(0) == 1, errors::InvalidArgument("minibatch size mismatch between inputs rast, attr")); + if (ENABLE_DA) + { + OP_REQUIRES(ctx, rast_db.dims() == 4 && rast_db.dim_size(0) > 0 && rast_db.dim_size(1) > 0 && rast_db.dim_size(2) > 0 && rast_db.dim_size(3) == 4, errors::InvalidArgument("rast_db must have shape[>0, >0, >0, 4]")); + OP_REQUIRES(ctx, rast_db.dim_size(1) == rast.dim_size(1) && rast_db.dim_size(2) == rast.dim_size(2), errors::InvalidArgument("spatial size mismatch between inputs rast and rast_db")); + OP_REQUIRES(ctx, rast_db.dim_size(0) == p.depth, errors::InvalidArgument("minibatch size mismatch between inputs rast, rast_db")); + } + + // All diff attrs mode. + if (p.diff_attrs_all) + p.numDiffAttr = p.numAttr; + + // Get input pointers. + p.attr = attr.flat().data(); + p.rast = rast.flat().data(); + p.tri = tri.flat().data(); + p.attrBC = (p.instance_mode && attr.dim_size(0) == 1) ? 1 : 0; + p.rastDB = ENABLE_DA ? rast_db.flat().data() : 0; + + // Allocate main output tensor. + Tensor* out_tensor = NULL; + TensorShape out_shape; + out_shape.AddDim(p.depth); + out_shape.AddDim(p.height); + out_shape.AddDim(p.width); + out_shape.AddDim(p.numAttr); + OP_REQUIRES_OK(ctx, ctx->allocate_output(0, out_shape, &out_tensor)); + p.out = out_tensor->flat().data(); + + // Allocate pixel differential output tensor. + Tensor* out_da_tensor = NULL; + out_shape.set_dim(3, p.numDiffAttr * 2); + OP_REQUIRES_OK(ctx, ctx->allocate_output(1, out_shape, &out_da_tensor)); + p.outDA = ENABLE_DA ? out_da_tensor->flat().data() : 0; + + // Verify that buffers are aligned to allow float2/float4 operations. + OP_REQUIRES(ctx, !((uintptr_t)p.rast & 15), errors::Internal("rast input tensor not aligned to float4")); + OP_REQUIRES(ctx, !((uintptr_t)p.rastDB & 15), errors::Internal("rast_db input tensor not aligned to float4")); + if (ENABLE_DA) + OP_REQUIRES(ctx, !((uintptr_t)p.outDA & 7), errors::Internal("out_da output tensor not aligned to float2")); + + // Choose launch parameters. + dim3 blockSize = getLaunchBlockSize(IP_FWD_MAX_KERNEL_BLOCK_WIDTH, IP_FWD_MAX_KERNEL_BLOCK_HEIGHT, p.width, p.height); + dim3 gridSize = getLaunchGridSize(blockSize, p.width, p.height, p.depth); + + // Launch CUDA kernel. + void* args[] = {&p}; + void* func = ENABLE_DA ? (void*)InterpolateFwdKernelDa : (void*)InterpolateFwdKernel; + OP_CHECK_CUDA_ERROR(ctx, cudaLaunchKernel(func, gridSize, blockSize, args, 0, stream)); + } +}; + +REGISTER_OP("InterpolateFwd") + .Input ("attr: float") + .Input ("rast: float") + .Input ("tri: int32") + .Output ("out: float") + .Output ("out_da: float"); + +REGISTER_OP("InterpolateFwdDa") + .Input ("attr: float") + .Input ("rast: float") + .Input ("tri: int32") + .Input ("rast_db: float") + .Output ("out: float") + .Output ("out_da: float") + .Attr ("diff_attrs_all: int") + .Attr ("diff_attrs: list(int)"); + +REGISTER_KERNEL_BUILDER(Name("InterpolateFwd") .Device(DEVICE_GPU), InterpolateFwdOp); +REGISTER_KERNEL_BUILDER(Name("InterpolateFwdDa").Device(DEVICE_GPU), InterpolateFwdOp); + +//------------------------------------------------------------------------ +// Gradient TensorFlow op. + +template +struct InterpolateGradOp : public OpKernel +{ + InterpolateKernelParams m_attribs; + + InterpolateGradOp(OpKernelConstruction* ctx): OpKernel(ctx) + { + memset(&m_attribs, 0, sizeof(m_attribs)); + interpolateParseOpAttributes(ctx, m_attribs, ENABLE_DA); + } + + void Compute(OpKernelContext* ctx) + { + InterpolateKernelParams& p = m_attribs; + cudaStream_t stream = ctx->eigen_device().stream(); + + // Get input. + const Tensor& attr = ctx->input(0); + const Tensor& rast = ctx->input(1); + const Tensor& tri = ctx->input(2); + const Tensor& dy = ctx->input(3); + const Tensor& rast_db = ctx->input(ENABLE_DA ? 4 : 3); + const Tensor& dda = ctx->input(ENABLE_DA ? 5 : 3); + + // Instance rendering mode? + p.instance_mode = attr.dims() > 2; + + // Extract input dimensions. + if (p.instance_mode) + { + p.numVertices = (attr.dims() > 1) ? attr.dim_size(1) : 0; + p.numAttr = (attr.dims() > 2) ? attr.dim_size(2) : 0; + } + else + { + p.numVertices = (attr.dims() > 0) ? attr.dim_size(0) : 0; + p.numAttr = (attr.dims() > 1) ? attr.dim_size(1) : 0; + } + p.numTriangles = (tri.dims() > 0) ? tri.dim_size(0) : 0; + p.depth = (rast.dims() > 0) ? rast.dim_size(0) : 0; + p.height = (rast.dims() > 1) ? rast.dim_size(1) : 0; + p.width = (rast.dims() > 2) ? rast.dim_size(2) : 0; + int attr_depth = p.instance_mode ? (attr.dims() > 1 ? attr.dim_size(0) : 0) : 1; + + // Sanity checks. + OP_REQUIRES(ctx, rast.dims() == 4 && rast.dim_size(0) > 0 && rast.dim_size(1) > 0 && rast.dim_size(2) > 0 && rast.dim_size(3) == 4, errors::InvalidArgument("rast must have shape[>0, >0, >0, 4]")); + OP_REQUIRES(ctx, tri.dims() == 2 && tri.dim_size(0) > 0 && tri.dim_size(1) == 3, errors::InvalidArgument("tri must have shape [>0, 3]")); + OP_REQUIRES(ctx, (attr.dims() == 2 || attr.dims() == 3) && attr.dim_size(0) > 0 && attr.dim_size(1) > 0 && (attr.dims() == 2 || attr.dim_size(2) > 0), errors::InvalidArgument("attr must have shape [>0, >0, >0] or [>0, >0]")); + OP_REQUIRES(ctx, dy.dims() == 4 && dy.dim_size(0) > 0 && dy.dim_size(1) == p.height && dy.dim_size(2) == p.width && dy.dim_size(3) > 0, errors::InvalidArgument("dy must have shape [>0, height, width, >0]")); + OP_REQUIRES(ctx, dy.dim_size(3) == p.numAttr, errors::InvalidArgument("argument count mismatch between inputs dy, attr")); + OP_REQUIRES(ctx, (attr_depth == p.depth || attr_depth == 1) && dy.dim_size(0) == p.depth, errors::InvalidArgument("minibatch size mismatch between inputs rast, dy, attr")); + if (ENABLE_DA) + { + OP_REQUIRES(ctx, dda.dims() == 4 && dda.dim_size(0) > 0 && dda.dim_size(1) == p.height && dda.dim_size(2) == p.width, errors::InvalidArgument("dda must have shape [>0, height, width, ?]")); + OP_REQUIRES(ctx, dda.dim_size(0) == p.depth, errors::InvalidArgument("minibatch size mismatch between rast, dda")); + } + + // All diff attrs mode. + if (p.diff_attrs_all) + p.numDiffAttr = p.numAttr; + + // Get input pointers. + p.attr = attr.flat().data(); + p.rast = rast.flat().data(); + p.tri = tri.flat().data(); + p.dy = dy.flat().data(); + p.rastDB = ENABLE_DA ? rast_db.flat().data() : 0; + p.dda = ENABLE_DA ? dda.flat().data() : 0; + p.attrBC = (p.instance_mode && attr_depth < p.depth) ? 1 : 0; + + // Allocate attribute gradient output tensor. + Tensor* grad_attr_tensor = NULL; + TensorShape grad_attr_shape; + if (p.instance_mode) + grad_attr_shape.AddDim(attr_depth); + grad_attr_shape.AddDim(p.numVertices); + grad_attr_shape.AddDim(p.numAttr); + OP_REQUIRES_OK(ctx, ctx->allocate_output(0, grad_attr_shape, &grad_attr_tensor)); + p.gradAttr = grad_attr_tensor->flat().data(); + + // Allocate bary gradient output tensor. + Tensor* grad_rast_tensor = NULL; + TensorShape grad_rast_shape; + grad_rast_shape.AddDim(p.depth); + grad_rast_shape.AddDim(p.height); + grad_rast_shape.AddDim(p.width); + grad_rast_shape.AddDim(4); + OP_REQUIRES_OK(ctx, ctx->allocate_output(1, grad_rast_shape, &grad_rast_tensor)); + p.gradRaster = grad_rast_tensor->flat().data(); + + // Allocate bary pixel diff gradient output tensor. + if (ENABLE_DA) + { + Tensor* grad_rast_db_tensor = NULL; + OP_REQUIRES_OK(ctx, ctx->allocate_output(2, grad_rast_shape, &grad_rast_db_tensor)); + p.gradRasterDB = grad_rast_db_tensor->flat().data(); + } + + // Clear attribute gradients. + cudaMemsetAsync(p.gradAttr, 0, attr_depth * p.numVertices * p.numAttr * sizeof(float), stream); + + // Verify that buffers are aligned to allow float2/float4 operations. + OP_REQUIRES(ctx, !((uintptr_t)p.rast & 15), errors::Internal("rast input tensor not aligned to float4")); + OP_REQUIRES(ctx, !((uintptr_t)p.gradRaster & 15), errors::Internal("grad_rast output tensor not aligned to float4")); + if (ENABLE_DA) + { + OP_REQUIRES(ctx, !((uintptr_t)p.dda & 7), errors::Internal("dda input tensor not aligned to float2")); + OP_REQUIRES(ctx, !((uintptr_t)p.rastDB & 15), errors::Internal("rast_db input tensor not aligned to float4")); + OP_REQUIRES(ctx, !((uintptr_t)p.gradRasterDB & 15), errors::Internal("grad_rast_db output tensor not aligned to float4")); + } + + // Choose launch parameters. + dim3 blockSize = getLaunchBlockSize(IP_GRAD_MAX_KERNEL_BLOCK_WIDTH, IP_GRAD_MAX_KERNEL_BLOCK_HEIGHT, p.width, p.height); + dim3 gridSize = getLaunchGridSize(blockSize, p.width, p.height, p.depth); + + // Launch CUDA kernel. + void* args[] = {&p}; + void* func = ENABLE_DA ? (void*)InterpolateGradKernelDa : (void*)InterpolateGradKernel; + OP_CHECK_CUDA_ERROR(ctx, cudaLaunchKernel(func, gridSize, blockSize, args, 0, stream)); + } +}; + +REGISTER_OP("InterpolateGrad") + .Input ("attr: float") + .Input ("rast: float") + .Input ("tri: int32") + .Input ("dy: float") + .Output ("grad_attr: float") + .Output ("grad_rast: float") + ; + +REGISTER_OP("InterpolateGradDa") + .Input ("attr: float") + .Input ("rast: float") + .Input ("tri: int32") + .Input ("dy: float") + .Input ("rast_db: float") + .Input ("dda: float") + .Output ("grad_attr: float") + .Output ("grad_rast: float") + .Output ("grad_rast_db: float") + .Attr ("diff_attrs_all: int") + .Attr ("diff_attrs: list(int)"); + ; + +REGISTER_KERNEL_BUILDER(Name("InterpolateGrad") .Device(DEVICE_GPU), InterpolateGradOp); +REGISTER_KERNEL_BUILDER(Name("InterpolateGradDa").Device(DEVICE_GPU), InterpolateGradOp); + +//------------------------------------------------------------------------ diff --git a/extensions/nvdiffrast/nvdiffrast/tensorflow/tf_rasterize.cu b/extensions/nvdiffrast/nvdiffrast/tensorflow/tf_rasterize.cu new file mode 100644 index 0000000000000000000000000000000000000000..4d0a2616d3b74a4d0e76ccfefb6552d4a7f2a65f --- /dev/null +++ b/extensions/nvdiffrast/nvdiffrast/tensorflow/tf_rasterize.cu @@ -0,0 +1,242 @@ +// Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +// +// NVIDIA CORPORATION and its licensors retain all intellectual property +// and proprietary rights in and to this software, related documentation +// and any modifications thereto. Any use, reproduction, disclosure or +// distribution of this software and related documentation without an express +// license agreement from NVIDIA CORPORATION is strictly prohibited. + +//------------------------------------------------------------------------ +// Forward TensorFlow op. + +struct RasterizeFwdOp : public OpKernel +{ + RasterizeGLState m_glState; // OpenGL-related persistent state. + int m_tri_const; // 1 if triangle array is known to be constant. + + RasterizeFwdOp(OpKernelConstruction* ctx): + OpKernel(ctx) + { + memset(&m_glState, 0, sizeof(RasterizeGLState)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("enable_db", &m_glState.enableDB)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("tri_const", &m_tri_const)); + } + + void Compute(OpKernelContext* ctx) + { + cudaStream_t stream = ctx->eigen_device().stream(); + + // Check that input shapes are correct. + const Tensor& pos = ctx->input(0); + const Tensor& tri = ctx->input(1); + const Tensor& resolution = ctx->input(2); + const Tensor& ranges = ctx->input(3); + + // Determine number of outputs + int num_outputs = m_glState.enableDB ? 2 : 1; + + // Determine instance mode and check input dimensions. + bool instance_mode = pos.dims() > 2; + if (instance_mode) + { + OP_REQUIRES(ctx, pos.dims() == 3 && pos.dim_size(0) > 0 && pos.dim_size(1) > 0 && pos.dim_size(2) == 4, errors::InvalidArgument("instance mode - pos must have shape [>0, >0, 4]")); + OP_REQUIRES(ctx, tri.dims() == 2 && tri.dim_size(0) > 0 && tri.dim_size(1) == 3, errors::InvalidArgument("tri must have shape [>0, 3]")); + OP_REQUIRES(ctx, resolution.dims() == 1 && resolution.dim_size(0) == 2, errors::InvalidArgument("resolution must have shape [2]")); + } + else + { + OP_REQUIRES(ctx, pos.dims() == 2 && pos.dim_size(0) > 0 && pos.dim_size(1) == 4, errors::InvalidArgument("range mode - pos must have shape [>0, 4]")); + OP_REQUIRES(ctx, tri.dims() == 2 && tri.dim_size(0) > 0 && tri.dim_size(1) == 3, errors::InvalidArgument("tri must have shape [>0, 3]")); + OP_REQUIRES(ctx, resolution.dims() == 1 && resolution.dim_size(0) == 2, errors::InvalidArgument("resolution must have shape [2]")); + OP_REQUIRES(ctx, ranges.dims() == 2 && ranges.dim_size(0) > 0 && ranges.dim_size(1) == 2, errors::InvalidArgument("range mode - ranges must have shape [>0, 2]")); + } + + // Get output shape. + const int32_t* res_in = resolution.flat().data(); // This is in CPU memory. + int height = res_in[0]; + int width = res_in[1]; + int depth = instance_mode ? pos.dim_size(0) : ranges.dim_size(0); + OP_REQUIRES(ctx, height > 0 && width > 0, errors::InvalidArgument("resolution must be [>0, >0]")); + + // Get position and triangle buffer sizes in int32/float32. + int posCount = 4 * pos.dim_size(0) * (instance_mode ? pos.dim_size(1) : 1); + int triCount = 3 * tri.dim_size(0); + + // Init context and GL? + bool initCtx = !m_glState.glFBO; + if (initCtx) + { + const DeviceBase::GpuDeviceInfo* g = ctx->device()->tensorflow_gpu_device_info(); + int cudaDeviceIdx = g ? g->gpu_id : -1; + rasterizeInitGLContext(ctx, m_glState, cudaDeviceIdx); // In common/rasterize.cpp + } + else + setGLContext(m_glState.glctx); // (Re-)Activate GL context. + + // Resize all buffers. + bool changes = false; + rasterizeResizeBuffers(ctx, m_glState, changes, posCount, triCount, width, height, depth); // In common/rasterize_gl.cpp + if (changes) + { +#ifdef _WIN32 + // Workaround for occasional blank first frame on Windows. + releaseGLContext(); + setGLContext(m_glState.glctx); +#endif + } + + // Copy input data to GL and render. + const float* posPtr = pos.flat().data(); + const int32_t* rangesPtr = instance_mode ? 0 : ranges.flat().data(); // This is in CPU memory. + const int32_t* triPtr = (initCtx || !m_tri_const) ? tri.flat().data() : NULL; // Copy triangles only if needed. + int vtxPerInstance = instance_mode ? pos.dim_size(1) : 0; + rasterizeRender(ctx, m_glState, stream, posPtr, posCount, vtxPerInstance, triPtr, triCount, rangesPtr, width, height, depth, -1); + + // Allocate output tensors. + TensorShape output_shape; + output_shape.AddDim(depth); + output_shape.AddDim(height); + output_shape.AddDim(width); + output_shape.AddDim(4); + float* outputPtr[2]; + for (int i=0; i < 2; i++) + { + if (i >= num_outputs) + output_shape.set_dim(3, 0); // Zero channels for unwanted out_db tensor. + Tensor* output_tensor = NULL; + OP_REQUIRES_OK(ctx, ctx->allocate_output(i, output_shape, &output_tensor)); + if (i < num_outputs) + outputPtr[i] = output_tensor->flat().data(); + } + + // Copy rasterized results into CUDA buffers. + rasterizeCopyResults(ctx, m_glState, stream, outputPtr, width, height, depth); + + // Done. Release GL context. + releaseGLContext(); + } +}; + +REGISTER_OP("RasterizeFwd") + .Input ("pos: float") + .Input ("tri: int32") + .Input ("resolution: int32") + .Input ("ranges: int32") + .Output ("out: float") + .Output ("out_db: float") + .Attr ("enable_db: int") + .Attr ("tri_const: int"); + +REGISTER_KERNEL_BUILDER(Name("RasterizeFwd").Device(DEVICE_GPU).HostMemory("resolution").HostMemory("ranges"), RasterizeFwdOp); + +//------------------------------------------------------------------------ +// Gradient TensorFlow op. + +template +struct RasterizeGradOp : public OpKernel +{ + RasterizeGradParams m_attribs; + + RasterizeGradOp(OpKernelConstruction* ctx): OpKernel(ctx) + { + memset(&m_attribs, 0, sizeof(m_attribs)); + } + + void Compute(OpKernelContext* ctx) + { + RasterizeGradParams& p = m_attribs; + cudaStream_t stream = ctx->eigen_device().stream(); + + // Input tensors. + const Tensor& pos = ctx->input(0); + const Tensor& tri = ctx->input(1); + const Tensor& out = ctx->input(2); + const Tensor& dy = ctx->input(3); + const Tensor& ddb = ctx->input(ENABLE_DB ? 4 : 3); + + // Determine instance mode. + p.instance_mode = (pos.dims() > 2) ? 1 : 0; + + // Shape is taken from the rasterizer output tensor. + OP_REQUIRES(ctx, out.dims() == 4, errors::InvalidArgument("out must be rank-4")); + p.depth = out.dim_size(0); + p.height = out.dim_size(1); + p.width = out.dim_size(2); + OP_REQUIRES(ctx, p.depth > 0 && p.height > 0 && p.width > 0, errors::InvalidArgument("resolution must be [>0, >0, >0]")); + + // Check other shapes. + if (p.instance_mode) + OP_REQUIRES(ctx, pos.dims() == 3 && pos.dim_size(0) == p.depth && pos.dim_size(1) > 0 && pos.dim_size(2) == 4, errors::InvalidArgument("pos must have shape [depth, >0, 4]")); + else + OP_REQUIRES(ctx, pos.dims() == 2 && pos.dim_size(0) > 0 && pos.dim_size(1) == 4, errors::InvalidArgument("pos must have shape [>0, 4]")); + OP_REQUIRES(ctx, tri.dims() == 2 && tri.dim_size(0) > 0 && tri.dim_size(1) == 3, errors::InvalidArgument("tri must have shape [>0, 3]")); + OP_REQUIRES(ctx, out.dims() == 4 && out.dim_size(0) == p.depth && out.dim_size(1) == p.height && out.dim_size(2) == p.width && out.dim_size(3) == 4, errors::InvalidArgument("out must have shape [depth, height, width, 4]")); + OP_REQUIRES(ctx, dy.dims() == 4 && dy.dim_size(0) == p.depth && dy.dim_size(1) == p.height && dy.dim_size(2) == p.width && dy.dim_size(3) == 4, errors::InvalidArgument("dy must have shape [depth, height, width, 4]")); + if (ENABLE_DB) + OP_REQUIRES(ctx, ddb.dims() == 4 && ddb.dim_size(0) == p.depth && ddb.dim_size(1) == p.height && ddb.dim_size(2) == p.width && ddb.dim_size(3) == 4, errors::InvalidArgument("ddb must have shape [depth, height, width, 4]")); + + // Populate parameters. + p.numTriangles = tri.dim_size(0); + p.numVertices = p.instance_mode ? pos.dim_size(1) : pos.dim_size(0); + p.pos = pos.flat().data(); + p.tri = tri.flat().data(); + p.out = out.flat().data(); + p.dy = dy.flat().data(); + p.ddb = ENABLE_DB ? ddb.flat().data() : 0; + + // Set up pixel position to clip space x, y transform. + p.xs = 2.f / (float)p.width; + p.xo = 1.f / (float)p.width - 1.f; + p.ys = 2.f / (float)p.height; + p.yo = 1.f / (float)p.height - 1.f; + + // Allocate output tensor for position gradients. + Tensor* grad_tensor = NULL; + TensorShape grad_shape; + if (p.instance_mode) + grad_shape.AddDim(p.depth); + grad_shape.AddDim(p.numVertices); + grad_shape.AddDim(4); + OP_REQUIRES_OK(ctx, ctx->allocate_output(0, grad_shape, &grad_tensor)); + p.grad = grad_tensor->flat().data(); + + // Clear the output buffers. + size_t gradBytes = (p.instance_mode ? p.depth : 1) * p.numVertices * 4 * sizeof(float); + cudaMemsetAsync(p.grad, 0, gradBytes, stream); + + // Verify that buffers are aligned to allow float2/float4 operations. + OP_REQUIRES(ctx, !((uintptr_t)p.pos & 15), errors::Internal("pos input tensor not aligned to float4")); + OP_REQUIRES(ctx, !((uintptr_t)p.dy & 7), errors::Internal("dy input tensor not aligned to float2")); + if (ENABLE_DB) + OP_REQUIRES(ctx, !((uintptr_t)p.ddb & 15), errors::Internal("ddb input tensor not aligned to float4")); + + // Choose launch parameters. + dim3 blockSize = getLaunchBlockSize(RAST_GRAD_MAX_KERNEL_BLOCK_WIDTH, RAST_GRAD_MAX_KERNEL_BLOCK_HEIGHT, p.width, p.height); + dim3 gridSize = getLaunchGridSize(blockSize, p.width, p.height, p.depth); + + // Launch CUDA kernel. + void* args[] = {&p}; + void* func = ENABLE_DB ? (void*)RasterizeGradKernelDb : (void*)RasterizeGradKernel; + OP_CHECK_CUDA_ERROR(ctx, cudaLaunchKernel(func, gridSize, blockSize, args, 0, stream)); + } +}; + +REGISTER_OP("RasterizeGrad") + .Input ("pos: float") + .Input ("tri: int32") + .Input ("out: float") + .Input ("dy: float") + .Output ("grad: float"); + +REGISTER_OP("RasterizeGradDb") + .Input ("pos: float") + .Input ("tri: int32") + .Input ("out: float") + .Input ("dy: float") + .Input ("ddb: float") + .Output ("grad: float"); + +REGISTER_KERNEL_BUILDER(Name("RasterizeGrad") .Device(DEVICE_GPU), RasterizeGradOp); +REGISTER_KERNEL_BUILDER(Name("RasterizeGradDb").Device(DEVICE_GPU), RasterizeGradOp); + +//------------------------------------------------------------------------ diff --git a/extensions/nvdiffrast/nvdiffrast/tensorflow/tf_texture.cu b/extensions/nvdiffrast/nvdiffrast/tensorflow/tf_texture.cu new file mode 100644 index 0000000000000000000000000000000000000000..c5382fed28236da09d20a04c0524a937383daf5a --- /dev/null +++ b/extensions/nvdiffrast/nvdiffrast/tensorflow/tf_texture.cu @@ -0,0 +1,525 @@ +// Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +// +// NVIDIA CORPORATION and its licensors retain all intellectual property +// and proprietary rights in and to this software, related documentation +// and any modifications thereto. Any use, reproduction, disclosure or +// distribution of this software and related documentation without an express +// license agreement from NVIDIA CORPORATION is strictly prohibited. + +//------------------------------------------------------------------------ +// Common op attribute parser. + +static __host__ void parseOpAttributes(OpKernelConstruction* ctx, TextureKernelParams& p) +{ + // Mip and filter modes. + OP_REQUIRES_OK(ctx, ctx->GetAttr("filter_mode", &p.filterMode)); + OP_REQUIRES(ctx, p.filterMode >= 0 && p.filterMode < TEX_MODE_COUNT, errors::InvalidArgument("filter_mode unsupported")); + p.enableMip = (p.filterMode == TEX_MODE_LINEAR_MIPMAP_NEAREST || p.filterMode == TEX_MODE_LINEAR_MIPMAP_LINEAR); + + // Mip level clamp. + if (p.enableMip) + { + OP_REQUIRES_OK(ctx, ctx->GetAttr("max_mip_level", &p.mipLevelLimit)); + OP_REQUIRES(ctx, p.mipLevelLimit >= -1, errors::InvalidArgument("invalid max_mip_level")); + ctx->GetAttr("tex_const", &p.texConst); // Only available in forward op. + } + + // Boundary mode. + OP_REQUIRES_OK(ctx, ctx->GetAttr("boundary_mode", &p.boundaryMode)); + OP_REQUIRES(ctx, p.boundaryMode >= 0 && p.boundaryMode < TEX_BOUNDARY_MODE_COUNT, errors::InvalidArgument("boundary_mode unsupported")); +} + +//------------------------------------------------------------------------ +// Forward TensorFlow op. + +struct TextureFwdOp : public OpKernel +{ + TextureKernelParams m_attribs; + PersistentTensor m_persistentMipTensor; // Used if texture is constant and mips are enabled. + bool m_persistentMipTensorInitialized; + + TextureFwdOp(OpKernelConstruction* ctx): OpKernel(ctx) + { + memset(&m_attribs, 0, sizeof(m_attribs)); + m_persistentMipTensorInitialized = false; + parseOpAttributes(ctx, m_attribs); + } + + void Compute(OpKernelContext* ctx) + { + TextureKernelParams& p = m_attribs; + cudaStream_t stream = ctx->eigen_device().stream(); + bool cube_mode = (p.boundaryMode == TEX_BOUNDARY_MODE_CUBE); + + // Get input. + const Tensor& tex = ctx->input(0); + const Tensor& uv = ctx->input(1); + const Tensor& uv_da = ctx->input(p.enableMip ? 2 : 1); + + // Extract input dimensions. + p.n = (uv.dims() > 0) ? uv.dim_size(0) : 0; + p.imgHeight = (uv.dims() > 1) ? uv.dim_size(1) : 0; + p.imgWidth = (uv.dims() > 2) ? uv.dim_size(2) : 0; + p.texDepth = (tex.dims() > 0) ? tex.dim_size(0) : 0; + if (!cube_mode) + { + p.texHeight = (tex.dims() > 1) ? tex.dim_size(1) : 0; + p.texWidth = (tex.dims() > 2) ? tex.dim_size(2) : 0; + p.channels = (tex.dims() > 3) ? tex.dim_size(3) : 0; + } + else + { + p.texHeight = (tex.dims() > 2) ? tex.dim_size(2) : 0; + p.texWidth = (tex.dims() > 3) ? tex.dim_size(3) : 0; + p.channels = (tex.dims() > 4) ? tex.dim_size(4) : 0; + } + + // Sanity checks. + if (!cube_mode) + { + OP_REQUIRES(ctx, tex.dims() == 4 && tex.dim_size(0) > 0 && tex.dim_size(1) > 0 && tex.dim_size(2) > 0 && tex.dim_size(3) > 0, errors::InvalidArgument("tex must have shape[>0, >0, >0, >0]")); + OP_REQUIRES(ctx, uv.dims() == 4 && uv.dim_size(0) > 0 && uv.dim_size(1) > 0 && uv.dim_size(2) > 0 && uv.dim_size(3) == 2, errors::InvalidArgument("uv must have shape [>0, >0, >0, 2]")); + } + else + { + OP_REQUIRES(ctx, tex.dims() == 5 && tex.dim_size(0) > 0 && tex.dim_size(1) == 6 && tex.dim_size(2) > 0 && tex.dim_size(3) > 0 && tex.dim_size(4) > 0, errors::InvalidArgument("tex must have shape[>0, 6, >0, >0, >0] in cube map mode")); + OP_REQUIRES(ctx, uv.dims() == 4 && uv.dim_size(0) > 0 && uv.dim_size(1) > 0 && uv.dim_size(2) > 0 && uv.dim_size(3) == 3, errors::InvalidArgument("uv must have shape [>0, >0, >0, 3] in cube map mode")); + OP_REQUIRES(ctx, tex.dim_size(2) == tex.dim_size(3), errors::InvalidArgument("texture shape must be square in cube map mode")); + } + OP_REQUIRES(ctx, tex.dim_size(0) == 1 || tex.dim_size(0) == p.n, errors::InvalidArgument("minibatch size mismatch between inputs tex, uv")); + OP_REQUIRES(ctx, p.texWidth <= (1 << TEX_MAX_MIP_LEVEL) && p.texHeight <= (1 << TEX_MAX_MIP_LEVEL), errors::InvalidArgument("texture size too large")); + if (p.enableMip) + { + if (!cube_mode) + OP_REQUIRES(ctx, uv_da.dims() == 4 && uv_da.dim_size(0) == p.n && uv_da.dim_size(1) == p.imgHeight && uv_da.dim_size(2) == p.imgWidth && uv_da.dim_size(3) == 4, errors::InvalidArgument("uv_da must have shape [minibatch_size, height, width, 4]")); + else + OP_REQUIRES(ctx, uv_da.dims() == 4 && uv_da.dim_size(0) == p.n && uv_da.dim_size(1) == p.imgHeight && uv_da.dim_size(2) == p.imgWidth && uv_da.dim_size(3) == 6, errors::InvalidArgument("uv_da must have shape [minibatch_size, height, width, 6] in cube map mode")); + } + + // Get input pointers. + p.tex[0] = tex.flat().data(); + p.uv = uv.flat().data(); + p.uvDA = p.enableMip ? uv_da.flat().data() : 0; + + // Allocate output tensor. + Tensor* out_tensor = NULL; + TensorShape out_shape; + out_shape.AddDim(p.n); + out_shape.AddDim(p.imgHeight); + out_shape.AddDim(p.imgWidth); + out_shape.AddDim(p.channels); + OP_REQUIRES_OK(ctx, ctx->allocate_output(0, out_shape, &out_tensor)); + p.out = out_tensor->flat().data(); + + // Choose kernel variants based on channel count. + void* args[] = {&p}; + int channel_div_idx = 0; + if (!(p.channels & 3)) + channel_div_idx = 2; // Channel count divisible by 4. + else if (!(p.channels & 1)) + channel_div_idx = 1; // Channel count divisible by 2. + + // Mip-related setup. + float* pmip = 0; + if (p.enableMip) + { + // Generate mip offsets. + int mipOffsets[TEX_MAX_MIP_LEVEL]; + int mipTotal = calculateMipInfo(ctx, p, mipOffsets); + + // Mip output tensor. + Tensor* mip_tensor = NULL; + TensorShape mip_shape; + mip_shape.AddDim(mipTotal); + + // If texture is constant, calculate mip stack only once. + bool computeMip = true; + if (p.texConst) + { + // First execution? + if (!m_persistentMipTensorInitialized) + { + // Allocate a persistent mip tensor. + OP_REQUIRES_OK(ctx, ctx->allocate_persistent(DT_FLOAT, mip_shape, &m_persistentMipTensor, &mip_tensor)); + m_persistentMipTensorInitialized = true; + } + else + { + // Reuse the persistent tensor, do not recompute mip levels. + mip_tensor = m_persistentMipTensor.AccessTensor(ctx); + computeMip = false; + } + + // Set as output tensor as well. + ctx->set_output(1, *mip_tensor); + } + else + { + // Allocate an output tensor as usual. + OP_REQUIRES_OK(ctx, ctx->allocate_output(1, mip_shape, &mip_tensor)); + } + + pmip = mip_tensor->flat().data(); // Pointer to data. + for (int i=1; i <= p.mipLevelMax; i++) + p.tex[i] = pmip + mipOffsets[i]; // Pointers to mip levels. + + // Build mip levels if needed. + if (computeMip) + { + for (int i=1; i <= p.mipLevelMax; i++) + { + int2 ms = mipLevelSize(p, i); + int3 sz = make_int3(ms.x, ms.y, p.texDepth); + dim3 blockSize = getLaunchBlockSize(TEX_FWD_MAX_MIP_KERNEL_BLOCK_WIDTH, TEX_FWD_MAX_MIP_KERNEL_BLOCK_HEIGHT, sz.x, sz.y); + dim3 gridSize = getLaunchGridSize(blockSize, sz.x, sz.y, sz.z * (cube_mode ? 6 : 1)); + p.mipLevelOut = i; + + void* build_func_tbl[3] = { (void*)MipBuildKernel1, (void*)MipBuildKernel2, (void*)MipBuildKernel4 }; + OP_CHECK_CUDA_ERROR(ctx, cudaLaunchKernel(build_func_tbl[channel_div_idx], gridSize, blockSize, args, 0, stream)); + } + } + } + + // Verify that buffers are aligned to allow float2/float4 operations. Unused pointers are zero so always aligned. + if (!cube_mode) + OP_REQUIRES(ctx, !((uintptr_t)p.uv & 7), errors::Internal("uv input tensor not aligned to float2")); + if ((p.channels & 3) == 0) + { + OP_REQUIRES(ctx, !((uintptr_t)p.tex[0] & 15), errors::Internal("tex input tensor not aligned to float4")); + OP_REQUIRES(ctx, !((uintptr_t)p.out & 15), errors::Internal("out output tensor not aligned to float4")); + OP_REQUIRES(ctx, !((uintptr_t)pmip & 15), errors::Internal("mip output tensor not aligned to float4")); + } + if ((p.channels & 1) == 0) + { + OP_REQUIRES(ctx, !((uintptr_t)p.tex[0] & 7), errors::Internal("tex input tensor not aligned to float2")); + OP_REQUIRES(ctx, !((uintptr_t)p.out & 7), errors::Internal("out output tensor not aligned to float2")); + OP_REQUIRES(ctx, !((uintptr_t)pmip & 7), errors::Internal("mip output tensor not aligned to float2")); + } + if (!cube_mode) + OP_REQUIRES(ctx, !((uintptr_t)p.uvDA & 15), errors::Internal("uv_da input tensor not aligned to float4")); + else + OP_REQUIRES(ctx, !((uintptr_t)p.uvDA & 7), errors::Internal("uv_da input tensor not aligned to float2")); + + // Choose launch parameters for texture lookup kernel. + dim3 blockSize = getLaunchBlockSize(TEX_FWD_MAX_KERNEL_BLOCK_WIDTH, TEX_FWD_MAX_KERNEL_BLOCK_HEIGHT, p.imgWidth, p.imgHeight); + dim3 gridSize = getLaunchGridSize(blockSize, p.imgWidth, p.imgHeight, p.n); + + // Choose kernel based on filter mode, cube mode, and datatype. + void* func_tbl[TEX_MODE_COUNT * 3 * 2] = { + (void*)TextureFwdKernelNearest1, + (void*)TextureFwdKernelNearest2, + (void*)TextureFwdKernelNearest4, + (void*)TextureFwdKernelLinear1, + (void*)TextureFwdKernelLinear2, + (void*)TextureFwdKernelLinear4, + (void*)TextureFwdKernelLinearMipmapNearest1, + (void*)TextureFwdKernelLinearMipmapNearest2, + (void*)TextureFwdKernelLinearMipmapNearest4, + (void*)TextureFwdKernelLinearMipmapLinear1, + (void*)TextureFwdKernelLinearMipmapLinear2, + (void*)TextureFwdKernelLinearMipmapLinear4, + (void*)TextureFwdKernelCubeNearest1, + (void*)TextureFwdKernelCubeNearest2, + (void*)TextureFwdKernelCubeNearest4, + (void*)TextureFwdKernelCubeLinear1, + (void*)TextureFwdKernelCubeLinear2, + (void*)TextureFwdKernelCubeLinear4, + (void*)TextureFwdKernelCubeLinearMipmapNearest1, + (void*)TextureFwdKernelCubeLinearMipmapNearest2, + (void*)TextureFwdKernelCubeLinearMipmapNearest4, + (void*)TextureFwdKernelCubeLinearMipmapLinear1, + (void*)TextureFwdKernelCubeLinearMipmapLinear2, + (void*)TextureFwdKernelCubeLinearMipmapLinear4, + }; + + // Function index. + int func_idx = p.filterMode; + if (cube_mode) + func_idx += TEX_MODE_COUNT; + func_idx = func_idx * 3 + channel_div_idx; + + // Launch kernel. + OP_CHECK_CUDA_ERROR(ctx, cudaLaunchKernel(func_tbl[func_idx], gridSize, blockSize, args, 0, stream)); + } +}; + +REGISTER_OP("TextureFwd") + .Input ("tex: float") + .Input ("uv: float") + .Output ("out: float") + .Attr ("filter_mode: int") + .Attr ("boundary_mode: int"); + +REGISTER_OP("TextureFwdMip") + .Input ("tex: float") + .Input ("uv: float") + .Input ("uv_da: float") + .Output ("out: float") + .Output ("mip: float") + .Attr ("filter_mode: int") + .Attr ("boundary_mode: int") + .Attr ("tex_const: int") + .Attr ("max_mip_level: int"); + +REGISTER_KERNEL_BUILDER(Name("TextureFwd") .Device(DEVICE_GPU), TextureFwdOp); +REGISTER_KERNEL_BUILDER(Name("TextureFwdMip").Device(DEVICE_GPU), TextureFwdOp); + +//------------------------------------------------------------------------ +// Gradient TensorFlow op. + +struct TextureGradOp : public OpKernel +{ + TextureKernelParams m_attribs; + + TextureGradOp(OpKernelConstruction* ctx): OpKernel(ctx) + { + memset(&m_attribs, 0, sizeof(m_attribs)); + parseOpAttributes(ctx, m_attribs); + } + + void Compute(OpKernelContext* ctx) + { + TextureKernelParams& p = m_attribs; + cudaStream_t stream = ctx->eigen_device().stream(); + bool cube_mode = (p.boundaryMode == TEX_BOUNDARY_MODE_CUBE); + + // Get input. + const Tensor& tex = ctx->input(0); + const Tensor& uv = ctx->input(1); + const Tensor& dy = ctx->input(2); + const Tensor& uv_da = ctx->input(p.enableMip ? 3 : 2); + const Tensor& mip = ctx->input(p.enableMip ? 4 : 2); + + // Extract input dimensions. + p.n = (uv.dims() > 0) ? uv.dim_size(0) : 0; + p.imgHeight = (uv.dims() > 1) ? uv.dim_size(1) : 0; + p.imgWidth = (uv.dims() > 2) ? uv.dim_size(2) : 0; + p.texDepth = (tex.dims() > 0) ? tex.dim_size(0) : 0; + if (!cube_mode) + { + p.texHeight = (tex.dims() > 1) ? tex.dim_size(1) : 0; + p.texWidth = (tex.dims() > 2) ? tex.dim_size(2) : 0; + p.channels = (tex.dims() > 3) ? tex.dim_size(3) : 0; + } + else + { + p.texHeight = (tex.dims() > 2) ? tex.dim_size(2) : 0; + p.texWidth = (tex.dims() > 3) ? tex.dim_size(3) : 0; + p.channels = (tex.dims() > 4) ? tex.dim_size(4) : 0; + } + + // Sanity checks. + if (!cube_mode) + { + OP_REQUIRES(ctx, tex.dims() == 4 && tex.dim_size(0) > 0 && tex.dim_size(1) > 0 && tex.dim_size(2) > 0 && tex.dim_size(3) > 0, errors::InvalidArgument("tex must have shape[>0, >0, >0, >0]")); + OP_REQUIRES(ctx, uv.dims() == 4 && uv.dim_size(0) > 0 && uv.dim_size(1) > 0 && uv.dim_size(2) > 0 && uv.dim_size(3) == 2, errors::InvalidArgument("uv must have shape [>0, >0, >0, 2]")); + } + else + { + OP_REQUIRES(ctx, tex.dims() == 5 && tex.dim_size(0) > 0 && tex.dim_size(1) == 6 && tex.dim_size(2) > 0 && tex.dim_size(3) > 0 && tex.dim_size(4) > 0, errors::InvalidArgument("tex must have shape[>0, 6, >0, >0, >0] in cube map mode")); + OP_REQUIRES(ctx, uv.dims() == 4 && uv.dim_size(0) > 0 && uv.dim_size(1) > 0 && uv.dim_size(2) > 0 && uv.dim_size(3) == 3, errors::InvalidArgument("uv must have shape [>0, >0, >0, 3] in cube map mode")); + OP_REQUIRES(ctx, tex.dim_size(2) == tex.dim_size(3), errors::InvalidArgument("texture shape must be square in cube map mode")); + } + OP_REQUIRES(ctx, tex.dim_size(0) == 1 || tex.dim_size(0) == p.n, errors::InvalidArgument("minibatch size mismatch between inputs tex, uv")); + OP_REQUIRES(ctx, dy.dims() == 4 && dy.dim_size(0) == p.n && dy.dim_size(1) == p.imgHeight && dy.dim_size(2) == p.imgWidth && dy.dim_size(3) == p.channels, errors::InvalidArgument("dy must have shape [minibatch_size, height, width, channels]")); + if (p.enableMip) + { + if (!cube_mode) + OP_REQUIRES(ctx, uv_da.dims() == 4 && uv_da.dim_size(0) == p.n && uv_da.dim_size(1) == p.imgHeight && uv_da.dim_size(2) == p.imgWidth && uv_da.dim_size(3) == 4, errors::InvalidArgument("uv_da must have shape [minibatch_size, height, width, 4]")); + else + OP_REQUIRES(ctx, uv_da.dims() == 4 && uv_da.dim_size(0) == p.n && uv_da.dim_size(1) == p.imgHeight && uv_da.dim_size(2) == p.imgWidth && uv_da.dim_size(3) == 6, errors::InvalidArgument("uv_da must have shape [minibatch_size, height, width, 6] in cube map mode")); + } + + // Get input pointers. + p.tex[0] = tex.flat().data(); + p.uv = uv.flat().data(); + p.dy = dy.flat().data(); + p.uvDA = p.enableMip ? uv_da.flat().data() : 0; + float* pmip = p.enableMip ? (float*)mip.flat().data() : 0; + + // Allocate output tensor for tex gradient. + Tensor* grad_tex_tensor = NULL; + TensorShape grad_tex_shape; + grad_tex_shape.AddDim(p.texDepth); + if (cube_mode) + grad_tex_shape.AddDim(6); + grad_tex_shape.AddDim(p.texHeight); + grad_tex_shape.AddDim(p.texWidth); + grad_tex_shape.AddDim(p.channels); + OP_REQUIRES_OK(ctx, ctx->allocate_output(0, grad_tex_shape, &grad_tex_tensor)); + p.gradTex[0] = grad_tex_tensor->flat().data(); + + // Allocate output tensor for uv gradient. + if (p.filterMode != TEX_MODE_NEAREST) + { + TensorShape grad_uv_shape; + Tensor* grad_uv_tensor = NULL; + grad_uv_shape.AddDim(p.n); + grad_uv_shape.AddDim(p.imgHeight); + grad_uv_shape.AddDim(p.imgWidth); + grad_uv_shape.AddDim(uv.dim_size(3)); + OP_REQUIRES_OK(ctx, ctx->allocate_output(1, grad_uv_shape, &grad_uv_tensor)); + p.gradUV = grad_uv_tensor->flat().data(); + + // Allocate output tensor for uv_da gradient. + if (p.filterMode == TEX_MODE_LINEAR_MIPMAP_LINEAR) + { + Tensor* grad_uv_da_tensor = NULL; + grad_uv_shape.set_dim(3, uv_da.dim_size(3)); + OP_REQUIRES_OK(ctx, ctx->allocate_output(2, grad_uv_shape, &grad_uv_da_tensor)); + p.gradUVDA = grad_uv_da_tensor->flat().data(); + } + } + + // Choose kernel variants based on channel count. + int channel_div_idx = 0; + if (!(p.channels & 3)) + channel_div_idx = 2; // Channel count divisible by 4. + else if (!(p.channels & 1)) + channel_div_idx = 1; // Channel count divisible by 2. + + // Mip-related setup. + Tensor grad_mip_tensor; + float* pgradMip = 0; + if (p.enableMip) + { + // Generate mip offsets. + int mipOffsets[TEX_MAX_MIP_LEVEL]; + int mipTotal = calculateMipInfo(ctx, p, mipOffsets); + + // Get space for temporary mip gradients. + TensorShape grad_mip_shape; + grad_mip_shape.AddDim(mipTotal); + ctx->allocate_temp(DT_FLOAT, grad_mip_shape, &grad_mip_tensor); + pgradMip = grad_mip_tensor.flat().data(); + for (int i=1; i <= p.mipLevelMax; i++) + { + p.tex[i] = pmip + mipOffsets[i]; // Pointers to mip levels. + p.gradTex[i] = pgradMip + mipOffsets[i]; // Pointers to mip gradients. + } + + // Clear mip gradients. + OP_CHECK_CUDA_ERROR(ctx, cudaMemsetAsync(pgradMip, 0, mipTotal * sizeof(float), stream)); + } + + // Initialize texture gradients to zero. + int texBytes = p.texHeight * p.texWidth * p.texDepth * p.channels * sizeof(float); + if (cube_mode) + texBytes *= 6; + OP_CHECK_CUDA_ERROR(ctx, cudaMemsetAsync(p.gradTex[0], 0, texBytes, stream)); + + // Verify that buffers are aligned to allow float2/float4 operations. Unused pointers are zero so always aligned. + if (!cube_mode) + { + OP_REQUIRES(ctx, !((uintptr_t)p.uv & 7), errors::Internal("uv input tensor not aligned to float2")); + OP_REQUIRES(ctx, !((uintptr_t)p.gradUV & 7), errors::Internal("grad_uv output tensor not aligned to float2")); + OP_REQUIRES(ctx, !((uintptr_t)p.uvDA & 15), errors::Internal("uv_da input tensor not aligned to float4")); + OP_REQUIRES(ctx, !((uintptr_t)p.gradUVDA & 15), errors::Internal("grad_uv_da output tensor not aligned to float4")); + } + else + { + OP_REQUIRES(ctx, !((uintptr_t)p.uvDA & 7), errors::Internal("uv_da input tensor not aligned to float2")); + OP_REQUIRES(ctx, !((uintptr_t)p.gradUVDA & 7), errors::Internal("grad_uv_da output tensor not aligned to float2")); + } + if ((p.channels & 3) == 0) + { + OP_REQUIRES(ctx, !((uintptr_t)p.tex[0] & 15), errors::Internal("tex input tensor not aligned to float4")); + OP_REQUIRES(ctx, !((uintptr_t)p.gradTex[0] & 15), errors::Internal("grad_tex output tensor not aligned to float4")); + OP_REQUIRES(ctx, !((uintptr_t)p.dy & 15), errors::Internal("dy input tensor not aligned to float4")); + OP_REQUIRES(ctx, !((uintptr_t)pmip & 15), errors::Internal("mip input tensor not aligned to float4")); + OP_REQUIRES(ctx, !((uintptr_t)pgradMip & 15), errors::Internal("internal mip gradient tensor not aligned to float4")); + } + if ((p.channels & 1) == 0) + { + OP_REQUIRES(ctx, !((uintptr_t)p.tex[0] & 7), errors::Internal("tex input tensor not aligned to float2")); + OP_REQUIRES(ctx, !((uintptr_t)p.gradTex[0] & 7), errors::Internal("grad_tex output tensor not aligned to float2")); + OP_REQUIRES(ctx, !((uintptr_t)p.dy & 7), errors::Internal("dy output tensor not aligned to float2")); + OP_REQUIRES(ctx, !((uintptr_t)pmip & 7), errors::Internal("mip input tensor not aligned to float2")); + OP_REQUIRES(ctx, !((uintptr_t)pgradMip & 7), errors::Internal("internal mip gradient tensor not aligned to float2")); + } + + // Choose launch parameters for main gradient kernel. + void* args[] = {&p}; + dim3 blockSize = getLaunchBlockSize(TEX_GRAD_MAX_KERNEL_BLOCK_WIDTH, TEX_GRAD_MAX_KERNEL_BLOCK_HEIGHT, p.imgWidth, p.imgHeight); + dim3 gridSize = getLaunchGridSize(blockSize, p.imgWidth, p.imgHeight, p.n); + + void* func_tbl[TEX_MODE_COUNT * 2] = { + (void*)TextureGradKernelNearest, + (void*)TextureGradKernelLinear, + (void*)TextureGradKernelLinearMipmapNearest, + (void*)TextureGradKernelLinearMipmapLinear, + (void*)TextureGradKernelCubeNearest, + (void*)TextureGradKernelCubeLinear, + (void*)TextureGradKernelCubeLinearMipmapNearest, + (void*)TextureGradKernelCubeLinearMipmapLinear, + }; + + // Function index. + int func_idx = p.filterMode; + if (cube_mode) + func_idx += TEX_MODE_COUNT; + + // Launch main gradient kernel. + OP_CHECK_CUDA_ERROR(ctx, cudaLaunchKernel(func_tbl[func_idx], gridSize, blockSize, args, 0, stream)); + + // Launch kernel to pull gradients from mip levels. + if (p.enableMip) + { + dim3 blockSize = getLaunchBlockSize(TEX_GRAD_MAX_MIP_KERNEL_BLOCK_WIDTH, TEX_GRAD_MAX_MIP_KERNEL_BLOCK_HEIGHT, p.texWidth, p.texHeight); + dim3 gridSize = getLaunchGridSize(blockSize, p.texWidth, p.texHeight, p.texDepth * (cube_mode ? 6 : 1)); + int sharedBytes = blockSize.x * blockSize.y * p.channels * sizeof(float); + + void* mip_grad_func_tbl[3] = { (void*)MipGradKernel1, (void*)MipGradKernel2, (void*)MipGradKernel4 }; + OP_CHECK_CUDA_ERROR(ctx, cudaLaunchKernel(mip_grad_func_tbl[channel_div_idx], gridSize, blockSize, args, sharedBytes, stream)); + } + } +}; + +REGISTER_OP("TextureGradNearest") + .Input ("tex: float") + .Input ("uv: float") + .Input ("dy: float") + .Output ("grad_tex: float") + .Attr ("filter_mode: int") + .Attr ("boundary_mode: int"); + +REGISTER_OP("TextureGradLinear") + .Input ("tex: float") + .Input ("uv: float") + .Input ("dy: float") + .Output ("grad_tex: float") + .Output ("grad_uv: float") + .Attr ("filter_mode: int") + .Attr ("boundary_mode: int"); + +REGISTER_OP("TextureGradLinearMipmapNearest") + .Input ("tex: float") + .Input ("uv: float") + .Input ("dy: float") + .Input ("uv_da: float") + .Input ("mip: float") + .Output ("grad_tex: float") + .Output ("grad_uv: float") + .Attr ("filter_mode: int") + .Attr ("boundary_mode: int") + .Attr ("max_mip_level: int"); + +REGISTER_OP("TextureGradLinearMipmapLinear") + .Input ("tex: float") + .Input ("uv: float") + .Input ("dy: float") + .Input ("uv_da: float") + .Input ("mip: float") + .Output ("grad_tex: float") + .Output ("grad_uv: float") + .Output ("grad_uv_da: float") + .Attr ("filter_mode: int") + .Attr ("boundary_mode: int") + .Attr ("max_mip_level: int"); + +REGISTER_KERNEL_BUILDER(Name("TextureGradNearest") .Device(DEVICE_GPU), TextureGradOp); +REGISTER_KERNEL_BUILDER(Name("TextureGradLinear") .Device(DEVICE_GPU), TextureGradOp); +REGISTER_KERNEL_BUILDER(Name("TextureGradLinearMipmapNearest").Device(DEVICE_GPU), TextureGradOp); +REGISTER_KERNEL_BUILDER(Name("TextureGradLinearMipmapLinear") .Device(DEVICE_GPU), TextureGradOp); + +//------------------------------------------------------------------------ diff --git a/extensions/nvdiffrast/nvdiffrast/torch/__init__.py b/extensions/nvdiffrast/nvdiffrast/torch/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d28f95e7a9e423b5efb322c39e343a069caf0fe8 --- /dev/null +++ b/extensions/nvdiffrast/nvdiffrast/torch/__init__.py @@ -0,0 +1,10 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +from .ops import RasterizeCudaContext, RasterizeGLContext, get_log_level, set_log_level, rasterize, DepthPeeler, interpolate, texture, texture_construct_mip, antialias, antialias_construct_topology_hash +__all__ = ["RasterizeCudaContext", "RasterizeGLContext", "get_log_level", "set_log_level", "rasterize", "DepthPeeler", "interpolate", "texture", "texture_construct_mip", "antialias", "antialias_construct_topology_hash"] diff --git a/extensions/nvdiffrast/nvdiffrast/torch/ops.py b/extensions/nvdiffrast/nvdiffrast/torch/ops.py new file mode 100644 index 0000000000000000000000000000000000000000..edf8540fda5aed6736a72b44b993031157a9cf4b --- /dev/null +++ b/extensions/nvdiffrast/nvdiffrast/torch/ops.py @@ -0,0 +1,734 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +import importlib +import logging +import numpy as np +import os +import torch +import torch.utils.cpp_extension +from . import _C + +#---------------------------------------------------------------------------- +# C++/Cuda plugin compiler/loader. + +_cached_plugin = {} +def _get_plugin(gl=False): + assert isinstance(gl, bool) + + # Modified with precompiled torch CUDA extension + if not gl: + return _C + + # Return cached plugin if already loaded. + if _cached_plugin.get(gl, None) is not None: + return _cached_plugin[gl] + + # Make sure we can find the necessary compiler and libary binaries. + if os.name == 'nt': + lib_dir = os.path.dirname(__file__) + r"\..\lib" + def find_cl_path(): + import glob + def get_sort_key(x): + # Primary criterion is VS version, secondary is edition, third is internal MSVC version. + x = x.split('\\')[3:] + x[1] = {'BuildTools': '~0', 'Community': '~1', 'Pro': '~2', 'Professional': '~3', 'Enterprise': '~4'}.get(x[1], x[1]) + return x + vs_relative_path = r"\Microsoft Visual Studio\*\*\VC\Tools\MSVC\*\bin\Hostx64\x64" + paths = glob.glob(r"C:\Program Files" + vs_relative_path) + paths += glob.glob(r"C:\Program Files (x86)" + vs_relative_path) + if paths: + return sorted(paths, key=get_sort_key)[-1] + + # If cl.exe is not on path, try to find it. + if os.system("where cl.exe >nul 2>nul") != 0: + cl_path = find_cl_path() + if cl_path is None: + raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation") + os.environ['PATH'] += ';' + cl_path + + # Compiler options. + common_opts = ['-DNVDR_TORCH'] + cc_opts = [] + if os.name == 'nt': + cc_opts += ['/wd4067', '/wd4624'] # Disable warnings in torch headers. + + # Linker options for the GL-interfacing plugin. + ldflags = [] + if gl: + if os.name == 'posix': + ldflags = ['-lGL', '-lEGL'] + elif os.name == 'nt': + libs = ['gdi32', 'opengl32', 'user32', 'setgpu'] + ldflags = ['/LIBPATH:' + lib_dir] + ['/DEFAULTLIB:' + x for x in libs] + + # List of source files. + if gl: + source_files = [ + '../common/common.cpp', + '../common/glutil.cpp', + '../common/rasterize_gl.cpp', + 'torch_bindings_gl.cpp', + 'torch_rasterize_gl.cpp', + ] + else: + source_files = [ + '../common/cudaraster/impl/Buffer.cpp', + '../common/cudaraster/impl/CudaRaster.cpp', + '../common/cudaraster/impl/RasterImpl.cu', + '../common/cudaraster/impl/RasterImpl.cpp', + '../common/common.cpp', + '../common/rasterize.cu', + '../common/interpolate.cu', + '../common/texture.cu', + '../common/texture.cpp', + '../common/antialias.cu', + 'torch_bindings.cpp', + 'torch_rasterize.cpp', + 'torch_interpolate.cpp', + 'torch_texture.cpp', + 'torch_antialias.cpp', + ] + + # Some containers set this to contain old architectures that won't compile. We only need the one installed in the machine. + os.environ['TORCH_CUDA_ARCH_LIST'] = '' + + # On Linux, show a warning if GLEW is being forcibly loaded when compiling the GL plugin. + if gl and (os.name == 'posix') and ('libGLEW' in os.environ.get('LD_PRELOAD', '')): + logging.getLogger('nvdiffrast').warning("Warning: libGLEW is being loaded via LD_PRELOAD, and will probably conflict with the OpenGL plugin") + + # Try to detect if a stray lock file is left in cache directory and show a warning. This sometimes happens on Windows if the build is interrupted at just the right moment. + plugin_name = 'nvdiffrast_plugin' + ('_gl' if gl else '') + try: + lock_fn = os.path.join(torch.utils.cpp_extension._get_build_directory(plugin_name, False), 'lock') + if os.path.exists(lock_fn): + logging.getLogger('nvdiffrast').warning("Lock file exists in build directory: '%s'" % lock_fn) + except: + pass + + # Speed up compilation on Windows. + if os.name == 'nt': + # Skip telemetry sending step in vcvarsall.bat + os.environ['VSCMD_SKIP_SENDTELEMETRY'] = '1' + + # Opportunistically patch distutils to cache MSVC environments. + try: + import distutils._msvccompiler + import functools + if not hasattr(distutils._msvccompiler._get_vc_env, '__wrapped__'): + distutils._msvccompiler._get_vc_env = functools.lru_cache()(distutils._msvccompiler._get_vc_env) + except: + pass + + # Compile and load. + source_paths = [os.path.join(os.path.dirname(__file__), fn) for fn in source_files] + torch.utils.cpp_extension.load(name=plugin_name, sources=source_paths, extra_cflags=common_opts+cc_opts, extra_cuda_cflags=common_opts+['-lineinfo'], extra_ldflags=ldflags, with_cuda=True, verbose=False) + + # Import, cache, and return the compiled module. + _cached_plugin[gl] = importlib.import_module(plugin_name) + return _cached_plugin[gl] + +#---------------------------------------------------------------------------- +# Log level. +#---------------------------------------------------------------------------- + +def get_log_level(): + '''Get current log level. + + Returns: + Current log level in nvdiffrast. See `set_log_level()` for possible values. + ''' + return _get_plugin().get_log_level() + +def set_log_level(level): + '''Set log level. + + Log levels follow the convention on the C++ side of Torch: + 0 = Info, + 1 = Warning, + 2 = Error, + 3 = Fatal. + The default log level is 1. + + Args: + level: New log level as integer. Internal nvdiffrast messages of this + severity or higher will be printed, while messages of lower + severity will be silent. + ''' + _get_plugin().set_log_level(level) + +#---------------------------------------------------------------------------- +# CudaRaster state wrapper. +#---------------------------------------------------------------------------- + +class RasterizeCudaContext: + def __init__(self, device=None): + '''Create a new Cuda rasterizer context. + + The context is deleted and internal storage is released when the object is + destroyed. + + Args: + device (Optional): Cuda device on which the context is created. Type can be + `torch.device`, string (e.g., `'cuda:1'`), or int. If not + specified, context will be created on currently active Cuda + device. + Returns: + The newly created Cuda rasterizer context. + ''' + if device is None: + cuda_device_idx = torch.cuda.current_device() + else: + with torch.cuda.device(device): + cuda_device_idx = torch.cuda.current_device() + self.cpp_wrapper = _get_plugin().RasterizeCRStateWrapper(cuda_device_idx) + self.output_db = True + self.active_depth_peeler = None + +#---------------------------------------------------------------------------- +# GL state wrapper. +#---------------------------------------------------------------------------- + +class RasterizeGLContext: + def __init__(self, output_db=True, mode='automatic', device=None): + '''Create a new OpenGL rasterizer context. + + Creating an OpenGL context is a slow operation so you should usually reuse the same + context in all calls to `rasterize()` on the same CPU thread. The OpenGL context + is deleted when the object is destroyed. + + Side note: When using the OpenGL context in a rasterization operation, the + context's internal framebuffer object is automatically enlarged to accommodate the + rasterization operation's output shape, but it is never shrunk in size until the + context is destroyed. Thus, if you need to rasterize, say, deep low-resolution + tensors and also shallow high-resolution tensors, you can conserve GPU memory by + creating two separate OpenGL contexts for these tasks. In this scenario, using the + same OpenGL context for both tasks would end up reserving GPU memory for a deep, + high-resolution output tensor. + + Args: + output_db (bool): Compute and output image-space derivates of barycentrics. + mode: OpenGL context handling mode. Valid values are 'manual' and 'automatic'. + device (Optional): Cuda device on which the context is created. Type can be + `torch.device`, string (e.g., `'cuda:1'`), or int. If not + specified, context will be created on currently active Cuda + device. + Returns: + The newly created OpenGL rasterizer context. + ''' + assert output_db is True or output_db is False + assert mode in ['automatic', 'manual'] + self.output_db = output_db + self.mode = mode + if device is None: + cuda_device_idx = torch.cuda.current_device() + else: + with torch.cuda.device(device): + cuda_device_idx = torch.cuda.current_device() + self.cpp_wrapper = _get_plugin(gl=True).RasterizeGLStateWrapper(output_db, mode == 'automatic', cuda_device_idx) + self.active_depth_peeler = None # For error checking only. + + def set_context(self): + '''Set (activate) OpenGL context in the current CPU thread. + Only available if context was created in manual mode. + ''' + assert self.mode == 'manual' + self.cpp_wrapper.set_context() + + def release_context(self): + '''Release (deactivate) currently active OpenGL context. + Only available if context was created in manual mode. + ''' + assert self.mode == 'manual' + self.cpp_wrapper.release_context() + +#---------------------------------------------------------------------------- +# Rasterize. +#---------------------------------------------------------------------------- + +class _rasterize_func(torch.autograd.Function): + @staticmethod + def forward(ctx, raster_ctx, pos, tri, resolution, ranges, grad_db, peeling_idx): + if isinstance(raster_ctx, RasterizeGLContext): + out, out_db = _get_plugin(gl=True).rasterize_fwd_gl(raster_ctx.cpp_wrapper, pos, tri, resolution, ranges, peeling_idx) + else: + out, out_db = _get_plugin().rasterize_fwd_cuda(raster_ctx.cpp_wrapper, pos, tri, resolution, ranges, peeling_idx) + ctx.save_for_backward(pos, tri, out) + ctx.saved_grad_db = grad_db + return out, out_db + + @staticmethod + def backward(ctx, dy, ddb): + pos, tri, out = ctx.saved_tensors + if ctx.saved_grad_db: + g_pos = _get_plugin().rasterize_grad_db(pos, tri, out, dy, ddb) + else: + g_pos = _get_plugin().rasterize_grad(pos, tri, out, dy) + return None, g_pos, None, None, None, None, None + +# Op wrapper. +def rasterize(glctx, pos, tri, resolution, ranges=None, grad_db=True): + '''Rasterize triangles. + + All input tensors must be contiguous and reside in GPU memory except for + the `ranges` tensor that, if specified, has to reside in CPU memory. The + output tensors will be contiguous and reside in GPU memory. + + Args: + glctx: Rasterizer context of type `RasterizeGLContext` or `RasterizeCudaContext`. + pos: Vertex position tensor with dtype `torch.float32`. To enable range + mode, this tensor should have a 2D shape [num_vertices, 4]. To enable + instanced mode, use a 3D shape [minibatch_size, num_vertices, 4]. + tri: Triangle tensor with shape [num_triangles, 3] and dtype `torch.int32`. + resolution: Output resolution as integer tuple (height, width). + ranges: In range mode, tensor with shape [minibatch_size, 2] and dtype + `torch.int32`, specifying start indices and counts into `tri`. + Ignored in instanced mode. + grad_db: Propagate gradients of image-space derivatives of barycentrics + into `pos` in backward pass. Ignored if using an OpenGL context that + was not configured to output image-space derivatives. + + Returns: + A tuple of two tensors. The first output tensor has shape [minibatch_size, + height, width, 4] and contains the main rasterizer output in order (u, v, z/w, + triangle_id). If the OpenGL context was configured to output image-space + derivatives of barycentrics, the second output tensor will also have shape + [minibatch_size, height, width, 4] and contain said derivatives in order + (du/dX, du/dY, dv/dX, dv/dY). Otherwise it will be an empty tensor with shape + [minibatch_size, height, width, 0]. + ''' + assert isinstance(glctx, (RasterizeGLContext, RasterizeCudaContext)) + assert grad_db is True or grad_db is False + grad_db = grad_db and glctx.output_db + + # Sanitize inputs. + assert isinstance(pos, torch.Tensor) and isinstance(tri, torch.Tensor) + resolution = tuple(resolution) + if ranges is None: + ranges = torch.empty(size=(0, 2), dtype=torch.int32, device='cpu') + else: + assert isinstance(ranges, torch.Tensor) + + # Check that context is not currently reserved for depth peeling. + if glctx.active_depth_peeler is not None: + return RuntimeError("Cannot call rasterize() during depth peeling operation, use rasterize_next_layer() instead") + + # Instantiate the function. + return _rasterize_func.apply(glctx, pos, tri, resolution, ranges, grad_db, -1) + +#---------------------------------------------------------------------------- +# Depth peeler context manager for rasterizing multiple depth layers. +#---------------------------------------------------------------------------- + +class DepthPeeler: + def __init__(self, glctx, pos, tri, resolution, ranges=None, grad_db=True): + '''Create a depth peeler object for rasterizing multiple depth layers. + + Arguments are the same as in `rasterize()`. + + Returns: + The newly created depth peeler. + ''' + assert isinstance(glctx, (RasterizeGLContext, RasterizeCudaContext)) + assert grad_db is True or grad_db is False + grad_db = grad_db and glctx.output_db + + # Sanitize inputs as usual. + assert isinstance(pos, torch.Tensor) and isinstance(tri, torch.Tensor) + resolution = tuple(resolution) + if ranges is None: + ranges = torch.empty(size=(0, 2), dtype=torch.int32, device='cpu') + else: + assert isinstance(ranges, torch.Tensor) + + # Store all the parameters. + self.raster_ctx = glctx + self.pos = pos + self.tri = tri + self.resolution = resolution + self.ranges = ranges + self.grad_db = grad_db + self.peeling_idx = None + + def __enter__(self): + if self.raster_ctx is None: + raise RuntimeError("Cannot re-enter a terminated depth peeling operation") + if self.raster_ctx.active_depth_peeler is not None: + raise RuntimeError("Cannot have multiple depth peelers active simultaneously in a rasterization context") + self.raster_ctx.active_depth_peeler = self + self.peeling_idx = 0 + return self + + def __exit__(self, *args): + assert self.raster_ctx.active_depth_peeler is self + self.raster_ctx.active_depth_peeler = None + self.raster_ctx = None # Remove all references to input tensor so they're not left dangling. + self.pos = None + self.tri = None + self.resolution = None + self.ranges = None + self.grad_db = None + self.peeling_idx = None + return None + + def rasterize_next_layer(self): + '''Rasterize next depth layer. + + Operation is equivalent to `rasterize()` except that previously reported + surface points are culled away. + + Returns: + A tuple of two tensors as in `rasterize()`. + ''' + assert self.raster_ctx.active_depth_peeler is self + assert self.peeling_idx >= 0 + result = _rasterize_func.apply(self.raster_ctx, self.pos, self.tri, self.resolution, self.ranges, self.grad_db, self.peeling_idx) + self.peeling_idx += 1 + return result + +#---------------------------------------------------------------------------- +# Interpolate. +#---------------------------------------------------------------------------- + +# Output pixel differentials for at least some attributes. +class _interpolate_func_da(torch.autograd.Function): + @staticmethod + def forward(ctx, attr, rast, tri, rast_db, diff_attrs_all, diff_attrs_list): + out, out_da = _get_plugin().interpolate_fwd_da(attr, rast, tri, rast_db, diff_attrs_all, diff_attrs_list) + ctx.save_for_backward(attr, rast, tri, rast_db) + ctx.saved_misc = diff_attrs_all, diff_attrs_list + return out, out_da + + @staticmethod + def backward(ctx, dy, dda): + attr, rast, tri, rast_db = ctx.saved_tensors + diff_attrs_all, diff_attrs_list = ctx.saved_misc + g_attr, g_rast, g_rast_db = _get_plugin().interpolate_grad_da(attr, rast, tri, dy, rast_db, dda, diff_attrs_all, diff_attrs_list) + return g_attr, g_rast, None, g_rast_db, None, None + +# No pixel differential for any attribute. +class _interpolate_func(torch.autograd.Function): + @staticmethod + def forward(ctx, attr, rast, tri): + out, out_da = _get_plugin().interpolate_fwd(attr, rast, tri) + ctx.save_for_backward(attr, rast, tri) + return out, out_da + + @staticmethod + def backward(ctx, dy, _): + attr, rast, tri = ctx.saved_tensors + g_attr, g_rast = _get_plugin().interpolate_grad(attr, rast, tri, dy) + return g_attr, g_rast, None + +# Op wrapper. +def interpolate(attr, rast, tri, rast_db=None, diff_attrs=None): + """Interpolate vertex attributes. + + All input tensors must be contiguous and reside in GPU memory. The output tensors + will be contiguous and reside in GPU memory. + + Args: + attr: Attribute tensor with dtype `torch.float32`. + Shape is [num_vertices, num_attributes] in range mode, or + [minibatch_size, num_vertices, num_attributes] in instanced mode. + Broadcasting is supported along the minibatch axis. + rast: Main output tensor from `rasterize()`. + tri: Triangle tensor with shape [num_triangles, 3] and dtype `torch.int32`. + rast_db: (Optional) Tensor containing image-space derivatives of barycentrics, + i.e., the second output tensor from `rasterize()`. Enables computing + image-space derivatives of attributes. + diff_attrs: (Optional) List of attribute indices for which image-space + derivatives are to be computed. Special value 'all' is equivalent + to list [0, 1, ..., num_attributes - 1]. + + Returns: + A tuple of two tensors. The first output tensor contains interpolated + attributes and has shape [minibatch_size, height, width, num_attributes]. + If `rast_db` and `diff_attrs` were specified, the second output tensor contains + the image-space derivatives of the selected attributes and has shape + [minibatch_size, height, width, 2 * len(diff_attrs)]. The derivatives of the + first selected attribute A will be on channels 0 and 1 as (dA/dX, dA/dY), etc. + Otherwise, the second output tensor will be an empty tensor with shape + [minibatch_size, height, width, 0]. + """ + # Sanitize the list of pixel differential attributes. + if diff_attrs is None: + diff_attrs = [] + elif diff_attrs != 'all': + diff_attrs = np.asarray(diff_attrs, np.int32) + assert len(diff_attrs.shape) == 1 + diff_attrs = diff_attrs.tolist() + + diff_attrs_all = int(diff_attrs == 'all') + diff_attrs_list = [] if diff_attrs_all else diff_attrs + + # Check inputs. + assert all(isinstance(x, torch.Tensor) for x in (attr, rast, tri)) + if diff_attrs: + assert isinstance(rast_db, torch.Tensor) + + # Choose stub. + if diff_attrs: + return _interpolate_func_da.apply(attr, rast, tri, rast_db, diff_attrs_all, diff_attrs_list) + else: + return _interpolate_func.apply(attr, rast, tri) + +#---------------------------------------------------------------------------- +# Texture +#---------------------------------------------------------------------------- + +# Linear-mipmap-linear and linear-mipmap-nearest: Mipmaps enabled. +class _texture_func_mip(torch.autograd.Function): + @staticmethod + def forward(ctx, filter_mode, tex, uv, uv_da, mip_level_bias, mip_wrapper, filter_mode_enum, boundary_mode_enum, *mip_stack): + empty = torch.tensor([]) + if uv_da is None: + uv_da = empty + if mip_level_bias is None: + mip_level_bias = empty + if mip_wrapper is None: + mip_wrapper = _get_plugin().TextureMipWrapper() + out = _get_plugin().texture_fwd_mip(tex, uv, uv_da, mip_level_bias, mip_wrapper, mip_stack, filter_mode_enum, boundary_mode_enum) + ctx.save_for_backward(tex, uv, uv_da, mip_level_bias, *mip_stack) + ctx.saved_misc = filter_mode, mip_wrapper, filter_mode_enum, boundary_mode_enum + return out + + @staticmethod + def backward(ctx, dy): + tex, uv, uv_da, mip_level_bias, *mip_stack = ctx.saved_tensors + filter_mode, mip_wrapper, filter_mode_enum, boundary_mode_enum = ctx.saved_misc + if filter_mode == 'linear-mipmap-linear': + g_tex, g_uv, g_uv_da, g_mip_level_bias, g_mip_stack = _get_plugin().texture_grad_linear_mipmap_linear(tex, uv, dy, uv_da, mip_level_bias, mip_wrapper, mip_stack, filter_mode_enum, boundary_mode_enum) + return (None, g_tex, g_uv, g_uv_da, g_mip_level_bias, None, None, None) + tuple(g_mip_stack) + else: # linear-mipmap-nearest + g_tex, g_uv, g_mip_stack = _get_plugin().texture_grad_linear_mipmap_nearest(tex, uv, dy, uv_da, mip_level_bias, mip_wrapper, mip_stack, filter_mode_enum, boundary_mode_enum) + return (None, g_tex, g_uv, None, None, None, None, None) + tuple(g_mip_stack) + +# Linear and nearest: Mipmaps disabled. +class _texture_func(torch.autograd.Function): + @staticmethod + def forward(ctx, filter_mode, tex, uv, filter_mode_enum, boundary_mode_enum): + out = _get_plugin().texture_fwd(tex, uv, filter_mode_enum, boundary_mode_enum) + ctx.save_for_backward(tex, uv) + ctx.saved_misc = filter_mode, filter_mode_enum, boundary_mode_enum + return out + + @staticmethod + def backward(ctx, dy): + tex, uv = ctx.saved_tensors + filter_mode, filter_mode_enum, boundary_mode_enum = ctx.saved_misc + if filter_mode == 'linear': + g_tex, g_uv = _get_plugin().texture_grad_linear(tex, uv, dy, filter_mode_enum, boundary_mode_enum) + return None, g_tex, g_uv, None, None + else: # nearest + g_tex = _get_plugin().texture_grad_nearest(tex, uv, dy, filter_mode_enum, boundary_mode_enum) + return None, g_tex, None, None, None + +# Op wrapper. +def texture(tex, uv, uv_da=None, mip_level_bias=None, mip=None, filter_mode='auto', boundary_mode='wrap', max_mip_level=None): + """Perform texture sampling. + + All input tensors must be contiguous and reside in GPU memory. The output tensor + will be contiguous and reside in GPU memory. + + Args: + tex: Texture tensor with dtype `torch.float32`. For 2D textures, must have shape + [minibatch_size, tex_height, tex_width, tex_channels]. For cube map textures, + must have shape [minibatch_size, 6, tex_height, tex_width, tex_channels] where + tex_width and tex_height are equal. Note that `boundary_mode` must also be set + to 'cube' to enable cube map mode. Broadcasting is supported along the minibatch axis. + uv: Tensor containing per-pixel texture coordinates. When sampling a 2D texture, + must have shape [minibatch_size, height, width, 2]. When sampling a cube map + texture, must have shape [minibatch_size, height, width, 3]. + uv_da: (Optional) Tensor containing image-space derivatives of texture coordinates. + Must have same shape as `uv` except for the last dimension that is to be twice + as long. + mip_level_bias: (Optional) Per-pixel bias for mip level selection. If `uv_da` is omitted, + determines mip level directly. Must have shape [minibatch_size, height, width]. + mip: (Optional) Preconstructed mipmap stack from a `texture_construct_mip()` call, or a list + of tensors specifying a custom mipmap stack. When specifying a custom mipmap stack, + the tensors in the list must follow the same format as `tex` except for width and + height that must follow the usual rules for mipmap sizes. The base level texture + is still supplied in `tex` and must not be included in the list. Gradients of a + custom mipmap stack are not automatically propagated to base texture but the mipmap + tensors will receive gradients of their own. If a mipmap stack is not specified + but the chosen filter mode requires it, the mipmap stack is constructed internally + and discarded afterwards. + filter_mode: Texture filtering mode to be used. Valid values are 'auto', 'nearest', + 'linear', 'linear-mipmap-nearest', and 'linear-mipmap-linear'. Mode 'auto' + selects 'linear' if neither `uv_da` or `mip_level_bias` is specified, and + 'linear-mipmap-linear' when at least one of them is specified, these being + the highest-quality modes possible depending on the availability of the + image-space derivatives of the texture coordinates or direct mip level information. + boundary_mode: Valid values are 'wrap', 'clamp', 'zero', and 'cube'. If `tex` defines a + cube map, this must be set to 'cube'. The default mode 'wrap' takes fractional + part of texture coordinates. Mode 'clamp' clamps texture coordinates to the + centers of the boundary texels. Mode 'zero' virtually extends the texture with + all-zero values in all directions. + max_mip_level: If specified, limits the number of mipmaps constructed and used in mipmap-based + filter modes. + + Returns: + A tensor containing the results of the texture sampling with shape + [minibatch_size, height, width, tex_channels]. Cube map fetches with invalid uv coordinates + (e.g., zero vectors) output all zeros and do not propagate gradients. + """ + + # Default filter mode. + if filter_mode == 'auto': + filter_mode = 'linear-mipmap-linear' if (uv_da is not None or mip_level_bias is not None) else 'linear' + + # Sanitize inputs. + if max_mip_level is None: + max_mip_level = -1 + else: + max_mip_level = int(max_mip_level) + assert max_mip_level >= 0 + + # Check inputs. + assert isinstance(tex, torch.Tensor) and isinstance(uv, torch.Tensor) + if 'mipmap' in filter_mode: + assert isinstance(uv_da, torch.Tensor) or isinstance(mip_level_bias, torch.Tensor) + + # If mipping disabled via max level=0, we may as well use simpler filtering internally. + if max_mip_level == 0 and filter_mode in ['linear-mipmap-nearest', 'linear-mipmap-linear']: + filter_mode = 'linear' + + # Convert filter mode to internal enumeration. + filter_mode_dict = {'nearest': 0, 'linear': 1, 'linear-mipmap-nearest': 2, 'linear-mipmap-linear': 3} + filter_mode_enum = filter_mode_dict[filter_mode] + + # Convert boundary mode to internal enumeration. + boundary_mode_dict = {'cube': 0, 'wrap': 1, 'clamp': 2, 'zero': 3} + boundary_mode_enum = boundary_mode_dict[boundary_mode] + + # Construct a mipmap if necessary. + if 'mipmap' in filter_mode: + mip_wrapper, mip_stack = None, [] + if mip is not None: + assert isinstance(mip, (_get_plugin().TextureMipWrapper, list)) + if isinstance(mip, list): + assert all(isinstance(x, torch.Tensor) for x in mip) + mip_stack = mip + else: + mip_wrapper = mip + else: + mip_wrapper = _get_plugin().texture_construct_mip(tex, max_mip_level, boundary_mode == 'cube') + + # Choose stub. + if filter_mode == 'linear-mipmap-linear' or filter_mode == 'linear-mipmap-nearest': + return _texture_func_mip.apply(filter_mode, tex, uv, uv_da, mip_level_bias, mip_wrapper, filter_mode_enum, boundary_mode_enum, *mip_stack) + else: + return _texture_func.apply(filter_mode, tex, uv, filter_mode_enum, boundary_mode_enum) + +# Mipmap precalculation for cases where the texture stays constant. +def texture_construct_mip(tex, max_mip_level=None, cube_mode=False): + """Construct a mipmap stack for a texture. + + This function can be used for constructing a mipmap stack for a texture that is known to remain + constant. This avoids reconstructing it every time `texture()` is called. + + Args: + tex: Texture tensor with the same constraints as in `texture()`. + max_mip_level: If specified, limits the number of mipmaps constructed. + cube_mode: Must be set to True if `tex` specifies a cube map texture. + + Returns: + An opaque object containing the mipmap stack. This can be supplied in a call to `texture()` + in the `mip` argument. + """ + + assert isinstance(tex, torch.Tensor) + assert cube_mode is True or cube_mode is False + if max_mip_level is None: + max_mip_level = -1 + else: + max_mip_level = int(max_mip_level) + assert max_mip_level >= 0 + return _get_plugin().texture_construct_mip(tex, max_mip_level, cube_mode) + +#---------------------------------------------------------------------------- +# Antialias. +#---------------------------------------------------------------------------- + +class _antialias_func(torch.autograd.Function): + @staticmethod + def forward(ctx, color, rast, pos, tri, topology_hash, pos_gradient_boost): + out, work_buffer = _get_plugin().antialias_fwd(color, rast, pos, tri, topology_hash) + ctx.save_for_backward(color, rast, pos, tri) + ctx.saved_misc = pos_gradient_boost, work_buffer + return out + + @staticmethod + def backward(ctx, dy): + color, rast, pos, tri = ctx.saved_tensors + pos_gradient_boost, work_buffer = ctx.saved_misc + g_color, g_pos = _get_plugin().antialias_grad(color, rast, pos, tri, dy, work_buffer) + if pos_gradient_boost != 1.0: + g_pos = g_pos * pos_gradient_boost + return g_color, None, g_pos, None, None, None + +# Op wrapper. +def antialias(color, rast, pos, tri, topology_hash=None, pos_gradient_boost=1.0): + """Perform antialiasing. + + All input tensors must be contiguous and reside in GPU memory. The output tensor + will be contiguous and reside in GPU memory. + + Note that silhouette edge determination is based on vertex indices in the triangle + tensor. For it to work properly, a vertex belonging to multiple triangles must be + referred to using the same vertex index in each triangle. Otherwise, nvdiffrast will always + classify the adjacent edges as silhouette edges, which leads to bad performance and + potentially incorrect gradients. If you are unsure whether your data is good, check + which pixels are modified by the antialias operation and compare to the example in the + documentation. + + Args: + color: Input image to antialias with shape [minibatch_size, height, width, num_channels]. + rast: Main output tensor from `rasterize()`. + pos: Vertex position tensor used in the rasterization operation. + tri: Triangle tensor used in the rasterization operation. + topology_hash: (Optional) Preconstructed topology hash for the triangle tensor. If not + specified, the topology hash is constructed internally and discarded afterwards. + pos_gradient_boost: (Optional) Multiplier for gradients propagated to `pos`. + + Returns: + A tensor containing the antialiased image with the same shape as `color` input tensor. + """ + + # Check inputs. + assert all(isinstance(x, torch.Tensor) for x in (color, rast, pos, tri)) + + # Construct topology hash unless provided by user. + if topology_hash is not None: + assert isinstance(topology_hash, _get_plugin().TopologyHashWrapper) + else: + topology_hash = _get_plugin().antialias_construct_topology_hash(tri) + + # Instantiate the function. + return _antialias_func.apply(color, rast, pos, tri, topology_hash, pos_gradient_boost) + +# Topology hash precalculation for cases where the triangle array stays constant. +def antialias_construct_topology_hash(tri): + """Construct a topology hash for a triangle tensor. + + This function can be used for constructing a topology hash for a triangle tensor that is + known to remain constant. This avoids reconstructing it every time `antialias()` is called. + + Args: + tri: Triangle tensor with shape [num_triangles, 3]. Must be contiguous and reside in + GPU memory. + + Returns: + An opaque object containing the topology hash. This can be supplied in a call to + `antialias()` in the `topology_hash` argument. + """ + assert isinstance(tri, torch.Tensor) + return _get_plugin().antialias_construct_topology_hash(tri) + +#---------------------------------------------------------------------------- diff --git a/extensions/nvdiffrast/nvdiffrast/torch/torch_antialias.cpp b/extensions/nvdiffrast/nvdiffrast/torch/torch_antialias.cpp new file mode 100644 index 0000000000000000000000000000000000000000..730a200e4b8ab29ffe73c7cca493d4b2f0c80f92 --- /dev/null +++ b/extensions/nvdiffrast/nvdiffrast/torch/torch_antialias.cpp @@ -0,0 +1,243 @@ +// Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +// +// NVIDIA CORPORATION and its licensors retain all intellectual property +// and proprietary rights in and to this software, related documentation +// and any modifications thereto. Any use, reproduction, disclosure or +// distribution of this software and related documentation without an express +// license agreement from NVIDIA CORPORATION is strictly prohibited. + +#include "torch_common.inl" +#include "torch_types.h" +#include "../common/common.h" +#include "../common/antialias.h" + +//------------------------------------------------------------------------ +// Kernel prototypes. + +void AntialiasFwdMeshKernel (const AntialiasKernelParams p); +void AntialiasFwdDiscontinuityKernel(const AntialiasKernelParams p); +void AntialiasFwdAnalysisKernel (const AntialiasKernelParams p); +void AntialiasGradKernel (const AntialiasKernelParams p); + +//------------------------------------------------------------------------ +// Topology hash construction. + +TopologyHashWrapper antialias_construct_topology_hash(torch::Tensor tri) +{ + const at::cuda::OptionalCUDAGuard device_guard(device_of(tri)); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + AntialiasKernelParams p = {}; // Initialize all fields to zero. + + // Check inputs. + NVDR_CHECK_DEVICE(tri); + NVDR_CHECK_CONTIGUOUS(tri); + NVDR_CHECK_I32(tri); + NVDR_CHECK(tri.sizes().size() == 2 && tri.size(0) > 0 && tri.size(1) == 3, "tri must have shape [>0, 3]"); + + // Fill in kernel parameters. + p.numTriangles = tri.size(0); + p.numVertices = 0x7fffffff; // Let's not require vertex positions just to enable an error check. + p.tri = tri.data_ptr(); + + // Kernel parameters. + p.allocTriangles = 64; + while (p.allocTriangles < p.numTriangles) + p.allocTriangles <<= 1; // Must be power of two. + + // Construct the hash tensor and get pointer. + torch::TensorOptions opts = torch::TensorOptions().dtype(torch::kInt32).device(torch::kCUDA); + torch::Tensor ev_hash = torch::zeros({(uint64_t)p.allocTriangles * AA_HASH_ELEMENTS_PER_TRIANGLE(p.allocTriangles) * 4}, opts); + p.evHash = (uint4*)(ev_hash.data_ptr()); + + // Check alignment. + NVDR_CHECK(!((uintptr_t)p.evHash & 15), "ev_hash internal tensor not aligned to int4"); + + // Populate the hash. + void* args[] = {&p}; + NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((void*)AntialiasFwdMeshKernel, (p.numTriangles - 1) / AA_MESH_KERNEL_THREADS_PER_BLOCK + 1, AA_MESH_KERNEL_THREADS_PER_BLOCK, args, 0, stream)); + + // Return. + TopologyHashWrapper hash_wrap; + hash_wrap.ev_hash = ev_hash; + return hash_wrap; +} + +//------------------------------------------------------------------------ +// Forward op. + +std::tuple antialias_fwd(torch::Tensor color, torch::Tensor rast, torch::Tensor pos, torch::Tensor tri, TopologyHashWrapper topology_hash_wrap) +{ + const at::cuda::OptionalCUDAGuard device_guard(device_of(color)); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + AntialiasKernelParams p = {}; // Initialize all fields to zero. + p.instance_mode = (pos.sizes().size() > 2) ? 1 : 0; + torch::Tensor& topology_hash = topology_hash_wrap.ev_hash; // Unwrap. + + // Check inputs. + NVDR_CHECK_DEVICE(color, rast, pos, tri, topology_hash); + NVDR_CHECK_CONTIGUOUS(color, rast, pos, tri, topology_hash); + NVDR_CHECK_F32(color, rast, pos); + NVDR_CHECK_I32(tri, topology_hash); + + // Sanity checks. + NVDR_CHECK(color.sizes().size() == 4 && color.size(0) > 0 && color.size(1) > 0 && color.size(2) > 0 && color.size(3) > 0, "color must have shape[>0, >0, >0, >0]"); + NVDR_CHECK(rast.sizes().size() == 4 && rast.size(0) > 0 && rast.size(1) > 0 && rast.size(2) > 0 && rast.size(3) == 4, "rast must have shape[>0, >0, >0, 4]"); + NVDR_CHECK(tri.sizes().size() == 2 && tri.size(0) > 0 && tri.size(1) == 3, "tri must have shape [>0, 3]"); + NVDR_CHECK(color.size(1) == rast.size(1) && color.size(2) == rast.size(2), "color and rast inputs must have same spatial dimensions"); + if (p.instance_mode) + { + NVDR_CHECK(pos.sizes().size() == 3 && pos.size(0) > 0 && pos.size(1) > 0 && pos.size(2) == 4, "pos must have shape [>0, >0, 4] or [>0, 4]"); + NVDR_CHECK(rast.size(0) == color.size(0) && pos.size(0) == color.size(0), "minibatch size mismatch between inputs color, rast, pos"); + } + else + { + NVDR_CHECK(pos.sizes().size() == 2 && pos.size(0) > 0 && pos.size(1) == 4, "pos must have shape [>0, >0, 4] or [>0, 4]"); + NVDR_CHECK(rast.size(0) == color.size(0), "minibatch size mismatch between inputs color, rast"); + } + + // Extract input dimensions. + p.numVertices = pos.size(p.instance_mode ? 1 : 0); + p.numTriangles = tri.size(0); + p.n = color.size(0); + p.height = color.size(1); + p.width = color.size(2); + p.channels = color.size(3); + + // Get input pointers. + p.color = color.data_ptr(); + p.rasterOut = rast.data_ptr(); + p.tri = tri.data_ptr(); + p.pos = pos.data_ptr(); + p.evHash = (uint4*)(topology_hash.data_ptr()); + + // Misc parameters. + p.xh = .5f * (float)p.width; + p.yh = .5f * (float)p.height; + + // Determine hash allocation size. + p.allocTriangles = 64; + while (p.allocTriangles < p.numTriangles) + p.allocTriangles <<= 1; // Must be power of two. + + // Allocate output tensors. + torch::Tensor out = color.detach().clone(); // Use color as base. + torch::TensorOptions opts = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA); + torch::Tensor work_buffer = torch::empty({p.n * p.width * p.height * 8 + 4}, opts); // 8 int for a maximum of two work items per pixel. + p.output = out.data_ptr(); + p.workBuffer = (int4*)(work_buffer.data_ptr()); + + // Clear the work counters. + NVDR_CHECK_CUDA_ERROR(cudaMemsetAsync(p.workBuffer, 0, sizeof(int4), stream)); + + // Verify that buffers are aligned to allow float2/float4 operations. + NVDR_CHECK(!((uintptr_t)p.pos & 15), "pos input tensor not aligned to float4"); + NVDR_CHECK(!((uintptr_t)p.rasterOut & 7), "raster_out input tensor not aligned to float2"); + NVDR_CHECK(!((uintptr_t)p.workBuffer & 15), "work_buffer internal tensor not aligned to int4"); + NVDR_CHECK(!((uintptr_t)p.evHash & 15), "topology_hash internal tensor not aligned to int4"); + + // Choose launch parameters for the discontinuity finder kernel and launch. + void* args[] = {&p}; + dim3 blockSize(AA_DISCONTINUITY_KERNEL_BLOCK_WIDTH, AA_DISCONTINUITY_KERNEL_BLOCK_HEIGHT, 1); + dim3 gridSize = getLaunchGridSize(blockSize, p.width, p.height, p.n); + NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((void*)AntialiasFwdDiscontinuityKernel, gridSize, blockSize, args, 0, stream)); + + // Determine optimum block size for the persistent analysis kernel and launch. + int device = 0; + int numCTA = 0; + int numSM = 0; + NVDR_CHECK_CUDA_ERROR(cudaGetDevice(&device)); + NVDR_CHECK_CUDA_ERROR(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&numCTA, (void*)AntialiasFwdAnalysisKernel, AA_ANALYSIS_KERNEL_THREADS_PER_BLOCK, 0)); + NVDR_CHECK_CUDA_ERROR(cudaDeviceGetAttribute(&numSM, cudaDevAttrMultiProcessorCount, device)); + NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((void*)AntialiasFwdAnalysisKernel, numCTA * numSM, AA_ANALYSIS_KERNEL_THREADS_PER_BLOCK, args, 0, stream)); + + // Return results. + return std::tuple(out, work_buffer); +} + +//------------------------------------------------------------------------ +// Gradient op. + +std::tuple antialias_grad(torch::Tensor color, torch::Tensor rast, torch::Tensor pos, torch::Tensor tri, torch::Tensor dy, torch::Tensor work_buffer) +{ + const at::cuda::OptionalCUDAGuard device_guard(device_of(color)); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + AntialiasKernelParams p = {}; // Initialize all fields to zero. + p.instance_mode = (pos.sizes().size() > 2) ? 1 : 0; + + // Check inputs. + NVDR_CHECK_DEVICE(color, rast, pos, tri, dy, work_buffer); + NVDR_CHECK_CONTIGUOUS(color, rast, pos, tri, work_buffer); + NVDR_CHECK_F32(color, rast, pos, dy, work_buffer); + NVDR_CHECK_I32(tri); + + // Sanity checks. + NVDR_CHECK(dy.sizes().size() == 4 && dy.size(0) > 0 && dy.size(1) > 0 && dy.size(2) > 0 && dy.size(3) > 0, "dy must have shape[>0, >0, >0, >0]"); + NVDR_CHECK(color.sizes().size() == 4 && color.size(0) > 0 && color.size(1) > 0 && color.size(2) > 0 && color.size(3) > 0, "color must have shape[>0, >0, >0, >0]"); + NVDR_CHECK(rast.sizes().size() == 4 && rast.size(0) > 0 && rast.size(1) > 0 && rast.size(2) > 0 && rast.size(3) == 4, "raster_out must have shape[>0, >0, >0, 4]"); + NVDR_CHECK(tri.sizes().size() == 2 && tri.size(0) > 0 && tri.size(1) == 3, "tri must have shape [>0, 3]"); + NVDR_CHECK(color.size(1) == rast.size(1) && color.size(2) == rast.size(2), "color and raster_out inputs must have same spatial dimensions"); + NVDR_CHECK(color.size(1) == dy.size(1) && color.size(2) == dy.size(2) && color.size(3) == dy.size(3), "color and dy inputs must have same dimensions"); + if (p.instance_mode) + { + NVDR_CHECK(pos.sizes().size() == 3 && pos.size(0) > 0 && pos.size(1) > 0 && pos.size(2) == 4, "pos must have shape [>0, >0, 4] or [>0, 4]"); + NVDR_CHECK(rast.size(0) == color.size(0) && pos.size(0) == color.size(0), "minibatch size mismatch between inputs color, raster_out, pos"); + NVDR_CHECK(dy.size(0) == color.size(0) && rast.size(0) == color.size(0) && pos.size(0) ==color.size(0), "minibatch size mismatch between inputs dy, color, raster_out, pos"); + } + else + { + NVDR_CHECK(pos.sizes().size() == 2 && pos.size(0) > 0 && pos.size(1) == 4, "pos must have shape [>0, >0, 4] or [>0, 4]"); + NVDR_CHECK(rast.size(0) == color.size(0), "minibatch size mismatch between inputs color, raster_out"); + NVDR_CHECK(dy.size(0) == color.size(0) && rast.size(0) == color.size(0), "minibatch size mismatch between inputs dy, color, raster_out"); + } + + // Extract input dimensions. + p.numVertices = pos.size(p.instance_mode ? 1 : 0); + p.numTriangles = tri.size(0); + p.n = color.size(0); + p.height = color.size(1); + p.width = color.size(2); + p.channels = color.size(3); + + // Ensure dy is contiguous. + torch::Tensor dy_ = dy.contiguous(); + + // Get input pointers. + p.color = color.data_ptr(); + p.rasterOut = rast.data_ptr(); + p.tri = tri.data_ptr(); + p.pos = pos.data_ptr(); + p.dy = dy_.data_ptr(); + p.workBuffer = (int4*)(work_buffer.data_ptr()); + + // Misc parameters. + p.xh = .5f * (float)p.width; + p.yh = .5f * (float)p.height; + + // Allocate output tensors. + torch::Tensor grad_color = dy_.detach().clone(); // Use dy as base. + torch::Tensor grad_pos = torch::zeros_like(pos); + p.gradColor = grad_color.data_ptr(); + p.gradPos = grad_pos.data_ptr(); + + // Clear gradient kernel work counter. + NVDR_CHECK_CUDA_ERROR(cudaMemsetAsync(&p.workBuffer[0].y, 0, sizeof(int), stream)); + + // Verify that buffers are aligned to allow float2/float4 operations. + NVDR_CHECK(!((uintptr_t)p.pos & 15), "pos input tensor not aligned to float4"); + NVDR_CHECK(!((uintptr_t)p.workBuffer & 15), "work_buffer internal tensor not aligned to int4"); + + // Determine optimum block size for the gradient kernel and launch. + void* args[] = {&p}; + int device = 0; + int numCTA = 0; + int numSM = 0; + NVDR_CHECK_CUDA_ERROR(cudaGetDevice(&device)); + NVDR_CHECK_CUDA_ERROR(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&numCTA, (void*)AntialiasGradKernel, AA_GRAD_KERNEL_THREADS_PER_BLOCK, 0)); + NVDR_CHECK_CUDA_ERROR(cudaDeviceGetAttribute(&numSM, cudaDevAttrMultiProcessorCount, device)); + NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((void*)AntialiasGradKernel, numCTA * numSM, AA_GRAD_KERNEL_THREADS_PER_BLOCK, args, 0, stream)); + + // Return results. + return std::tuple(grad_color, grad_pos); +} + +//------------------------------------------------------------------------ diff --git a/extensions/nvdiffrast/nvdiffrast/torch/torch_bindings.cpp b/extensions/nvdiffrast/nvdiffrast/torch/torch_bindings.cpp new file mode 100644 index 0000000000000000000000000000000000000000..898e17e37b5ac559362732b1eaa118a64240dadb --- /dev/null +++ b/extensions/nvdiffrast/nvdiffrast/torch/torch_bindings.cpp @@ -0,0 +1,73 @@ +// Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +// +// NVIDIA CORPORATION and its licensors retain all intellectual property +// and proprietary rights in and to this software, related documentation +// and any modifications thereto. Any use, reproduction, disclosure or +// distribution of this software and related documentation without an express +// license agreement from NVIDIA CORPORATION is strictly prohibited. + +#include "torch_common.inl" +#include "torch_types.h" +#include + +//------------------------------------------------------------------------ +// Op prototypes. Return type macros for readability. + +#define OP_RETURN_T torch::Tensor +#define OP_RETURN_TT std::tuple +#define OP_RETURN_TTT std::tuple +#define OP_RETURN_TTTT std::tuple +#define OP_RETURN_TTV std::tuple > +#define OP_RETURN_TTTTV std::tuple > + +OP_RETURN_TT rasterize_fwd_cuda (RasterizeCRStateWrapper& stateWrapper, torch::Tensor pos, torch::Tensor tri, std::tuple resolution, torch::Tensor ranges, int peeling_idx); +OP_RETURN_T rasterize_grad (torch::Tensor pos, torch::Tensor tri, torch::Tensor out, torch::Tensor dy); +OP_RETURN_T rasterize_grad_db (torch::Tensor pos, torch::Tensor tri, torch::Tensor out, torch::Tensor dy, torch::Tensor ddb); +OP_RETURN_TT interpolate_fwd (torch::Tensor attr, torch::Tensor rast, torch::Tensor tri); +OP_RETURN_TT interpolate_fwd_da (torch::Tensor attr, torch::Tensor rast, torch::Tensor tri, torch::Tensor rast_db, bool diff_attrs_all, std::vector& diff_attrs_vec); +OP_RETURN_TT interpolate_grad (torch::Tensor attr, torch::Tensor rast, torch::Tensor tri, torch::Tensor dy); +OP_RETURN_TTT interpolate_grad_da (torch::Tensor attr, torch::Tensor rast, torch::Tensor tri, torch::Tensor dy, torch::Tensor rast_db, torch::Tensor dda, bool diff_attrs_all, std::vector& diff_attrs_vec); +TextureMipWrapper texture_construct_mip (torch::Tensor tex, int max_mip_level, bool cube_mode); +OP_RETURN_T texture_fwd (torch::Tensor tex, torch::Tensor uv, int filter_mode, int boundary_mode); +OP_RETURN_T texture_fwd_mip (torch::Tensor tex, torch::Tensor uv, torch::Tensor uv_da, torch::Tensor mip_level_bias, TextureMipWrapper mip_wrapper, std::vector mip_stack, int filter_mode, int boundary_mode); +OP_RETURN_T texture_grad_nearest (torch::Tensor tex, torch::Tensor uv, torch::Tensor dy, int filter_mode, int boundary_mode); +OP_RETURN_TT texture_grad_linear (torch::Tensor tex, torch::Tensor uv, torch::Tensor dy, int filter_mode, int boundary_mode); +OP_RETURN_TTV texture_grad_linear_mipmap_nearest (torch::Tensor tex, torch::Tensor uv, torch::Tensor dy, torch::Tensor uv_da, torch::Tensor mip_level_bias, TextureMipWrapper mip_wrapper, std::vector mip_stack, int filter_mode, int boundary_mode); +OP_RETURN_TTTTV texture_grad_linear_mipmap_linear (torch::Tensor tex, torch::Tensor uv, torch::Tensor dy, torch::Tensor uv_da, torch::Tensor mip_level_bias, TextureMipWrapper mip_wrapper, std::vector mip_stack, int filter_mode, int boundary_mode); +TopologyHashWrapper antialias_construct_topology_hash (torch::Tensor tri); +OP_RETURN_TT antialias_fwd (torch::Tensor color, torch::Tensor rast, torch::Tensor pos, torch::Tensor tri, TopologyHashWrapper topology_hash); +OP_RETURN_TT antialias_grad (torch::Tensor color, torch::Tensor rast, torch::Tensor pos, torch::Tensor tri, torch::Tensor dy, torch::Tensor work_buffer); + +//------------------------------------------------------------------------ + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + // State classes. + pybind11::class_(m, "RasterizeCRStateWrapper").def(pybind11::init()); + pybind11::class_(m, "TextureMipWrapper").def(pybind11::init<>()); + pybind11::class_(m, "TopologyHashWrapper"); + + // Plumbing to torch/c10 logging system. + m.def("get_log_level", [](void) { return FLAGS_caffe2_log_level; }, "get log level"); + m.def("set_log_level", [](int level){ FLAGS_caffe2_log_level = level; }, "set log level"); + + // Ops. + m.def("rasterize_fwd_cuda", &rasterize_fwd_cuda, "rasterize forward op (cuda)"); + m.def("rasterize_grad", &rasterize_grad, "rasterize gradient op ignoring db gradients"); + m.def("rasterize_grad_db", &rasterize_grad_db, "rasterize gradient op with db gradients"); + m.def("interpolate_fwd", &interpolate_fwd, "interpolate forward op with attribute derivatives"); + m.def("interpolate_fwd_da", &interpolate_fwd_da, "interpolate forward op without attribute derivatives"); + m.def("interpolate_grad", &interpolate_grad, "interpolate gradient op with attribute derivatives"); + m.def("interpolate_grad_da", &interpolate_grad_da, "interpolate gradient op without attribute derivatives"); + m.def("texture_construct_mip", &texture_construct_mip, "texture mipmap construction"); + m.def("texture_fwd", &texture_fwd, "texture forward op without mipmapping"); + m.def("texture_fwd_mip", &texture_fwd_mip, "texture forward op with mipmapping"); + m.def("texture_grad_nearest", &texture_grad_nearest, "texture gradient op in nearest mode"); + m.def("texture_grad_linear", &texture_grad_linear, "texture gradient op in linear mode"); + m.def("texture_grad_linear_mipmap_nearest", &texture_grad_linear_mipmap_nearest, "texture gradient op in linear-mipmap-nearest mode"); + m.def("texture_grad_linear_mipmap_linear", &texture_grad_linear_mipmap_linear, "texture gradient op in linear-mipmap-linear mode"); + m.def("antialias_construct_topology_hash", &antialias_construct_topology_hash, "antialias topology hash construction"); + m.def("antialias_fwd", &antialias_fwd, "antialias forward op"); + m.def("antialias_grad", &antialias_grad, "antialias gradient op"); +} + +//------------------------------------------------------------------------ diff --git a/extensions/nvdiffrast/nvdiffrast/torch/torch_bindings_gl.cpp b/extensions/nvdiffrast/nvdiffrast/torch/torch_bindings_gl.cpp new file mode 100644 index 0000000000000000000000000000000000000000..5363e80297b9f9d5d212c890c8a455e60122366f --- /dev/null +++ b/extensions/nvdiffrast/nvdiffrast/torch/torch_bindings_gl.cpp @@ -0,0 +1,30 @@ +// Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +// +// NVIDIA CORPORATION and its licensors retain all intellectual property +// and proprietary rights in and to this software, related documentation +// and any modifications thereto. Any use, reproduction, disclosure or +// distribution of this software and related documentation without an express +// license agreement from NVIDIA CORPORATION is strictly prohibited. + +#include "torch_common.inl" +#include "torch_types.h" +#include + +//------------------------------------------------------------------------ +// Op prototypes. + +std::tuple rasterize_fwd_gl(RasterizeGLStateWrapper& stateWrapper, torch::Tensor pos, torch::Tensor tri, std::tuple resolution, torch::Tensor ranges, int peeling_idx); + +//------------------------------------------------------------------------ + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + // State classes. + pybind11::class_(m, "RasterizeGLStateWrapper").def(pybind11::init()) + .def("set_context", &RasterizeGLStateWrapper::setContext) + .def("release_context", &RasterizeGLStateWrapper::releaseContext); + + // Ops. + m.def("rasterize_fwd_gl", &rasterize_fwd_gl, "rasterize forward op (opengl)"); +} + +//------------------------------------------------------------------------ diff --git a/extensions/nvdiffrast/nvdiffrast/torch/torch_common.inl b/extensions/nvdiffrast/nvdiffrast/torch/torch_common.inl new file mode 100644 index 0000000000000000000000000000000000000000..74dea41528822294878d9ee5d36d1230d1df7ae6 --- /dev/null +++ b/extensions/nvdiffrast/nvdiffrast/torch/torch_common.inl @@ -0,0 +1,29 @@ +// Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +// +// NVIDIA CORPORATION and its licensors retain all intellectual property +// and proprietary rights in and to this software, related documentation +// and any modifications thereto. Any use, reproduction, disclosure or +// distribution of this software and related documentation without an express +// license agreement from NVIDIA CORPORATION is strictly prohibited. + +#pragma once +#include "../common/framework.h" + +//------------------------------------------------------------------------ +// Input check helpers. +//------------------------------------------------------------------------ + +#ifdef _MSC_VER +#define __func__ __FUNCTION__ +#endif + +#define NVDR_CHECK_DEVICE(...) do { TORCH_CHECK(at::cuda::check_device({__VA_ARGS__}), __func__, "(): Inputs " #__VA_ARGS__ " must reside on the same GPU device") } while(0) +#define NVDR_CHECK_CPU(...) do { nvdr_check_cpu({__VA_ARGS__}, __func__, "(): Inputs " #__VA_ARGS__ " must reside on CPU"); } while(0) +#define NVDR_CHECK_CONTIGUOUS(...) do { nvdr_check_contiguous({__VA_ARGS__}, __func__, "(): Inputs " #__VA_ARGS__ " must be contiguous tensors"); } while(0) +#define NVDR_CHECK_F32(...) do { nvdr_check_f32({__VA_ARGS__}, __func__, "(): Inputs " #__VA_ARGS__ " must be float32 tensors"); } while(0) +#define NVDR_CHECK_I32(...) do { nvdr_check_i32({__VA_ARGS__}, __func__, "(): Inputs " #__VA_ARGS__ " must be int32 tensors"); } while(0) +inline void nvdr_check_cpu(at::ArrayRef ts, const char* func, const char* err_msg) { for (const at::Tensor& t : ts) TORCH_CHECK(t.device().type() == c10::DeviceType::CPU, func, err_msg); } +inline void nvdr_check_contiguous(at::ArrayRef ts, const char* func, const char* err_msg) { for (const at::Tensor& t : ts) TORCH_CHECK(t.is_contiguous(), func, err_msg); } +inline void nvdr_check_f32(at::ArrayRef ts, const char* func, const char* err_msg) { for (const at::Tensor& t : ts) TORCH_CHECK(t.dtype() == torch::kFloat32, func, err_msg); } +inline void nvdr_check_i32(at::ArrayRef ts, const char* func, const char* err_msg) { for (const at::Tensor& t : ts) TORCH_CHECK(t.dtype() == torch::kInt32, func, err_msg); } +//------------------------------------------------------------------------ diff --git a/extensions/nvdiffrast/nvdiffrast/torch/torch_interpolate.cpp b/extensions/nvdiffrast/nvdiffrast/torch/torch_interpolate.cpp new file mode 100644 index 0000000000000000000000000000000000000000..b2c99fccfe0b11b71018e2c0ddcf637a337522b8 --- /dev/null +++ b/extensions/nvdiffrast/nvdiffrast/torch/torch_interpolate.cpp @@ -0,0 +1,250 @@ +// Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +// +// NVIDIA CORPORATION and its licensors retain all intellectual property +// and proprietary rights in and to this software, related documentation +// and any modifications thereto. Any use, reproduction, disclosure or +// distribution of this software and related documentation without an express +// license agreement from NVIDIA CORPORATION is strictly prohibited. + +#include "torch_common.inl" +#include "../common/common.h" +#include "../common/interpolate.h" + +//------------------------------------------------------------------------ +// Kernel prototypes. + +void InterpolateFwdKernel (const InterpolateKernelParams p); +void InterpolateFwdKernelDa (const InterpolateKernelParams p); +void InterpolateGradKernel (const InterpolateKernelParams p); +void InterpolateGradKernelDa(const InterpolateKernelParams p); + +//------------------------------------------------------------------------ +// Helper + +static void set_diff_attrs(InterpolateKernelParams& p, bool diff_attrs_all, std::vector& diff_attrs_vec) +{ + if (diff_attrs_all) + { + p.numDiffAttr = p.numAttr; + p.diff_attrs_all = 1; + } + else + { + NVDR_CHECK(diff_attrs_vec.size() <= IP_MAX_DIFF_ATTRS, "too many entries in diff_attrs list (increase IP_MAX_DIFF_ATTRS)"); + p.numDiffAttr = diff_attrs_vec.size(); + memcpy(p.diffAttrs, &diff_attrs_vec[0], diff_attrs_vec.size()*sizeof(int)); + } +} + +//------------------------------------------------------------------------ +// Forward op. + +std::tuple interpolate_fwd_da(torch::Tensor attr, torch::Tensor rast, torch::Tensor tri, torch::Tensor rast_db, bool diff_attrs_all, std::vector& diff_attrs_vec) +{ + const at::cuda::OptionalCUDAGuard device_guard(device_of(attr)); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + InterpolateKernelParams p = {}; // Initialize all fields to zero. + bool enable_da = (rast_db.defined()) && (diff_attrs_all || !diff_attrs_vec.empty()); + p.instance_mode = (attr.sizes().size() > 2) ? 1 : 0; + + // Check inputs. + if (enable_da) + { + NVDR_CHECK_DEVICE(attr, rast, tri, rast_db); + NVDR_CHECK_CONTIGUOUS(attr, rast, tri, rast_db); + NVDR_CHECK_F32(attr, rast, rast_db); + NVDR_CHECK_I32(tri); + } + else + { + NVDR_CHECK_DEVICE(attr, rast, tri); + NVDR_CHECK_CONTIGUOUS(attr, rast, tri); + NVDR_CHECK_F32(attr, rast); + NVDR_CHECK_I32(tri); + } + + // Sanity checks. + NVDR_CHECK(rast.sizes().size() == 4 && rast.size(0) > 0 && rast.size(1) > 0 && rast.size(2) > 0 && rast.size(3) == 4, "rast must have shape[>0, >0, >0, 4]"); + NVDR_CHECK( tri.sizes().size() == 2 && tri.size(0) > 0 && tri.size(1) == 3, "tri must have shape [>0, 3]"); + NVDR_CHECK((attr.sizes().size() == 2 || attr.sizes().size() == 3) && attr.size(0) > 0 && attr.size(1) > 0 && (attr.sizes().size() == 2 || attr.size(2) > 0), "attr must have shape [>0, >0, >0] or [>0, >0]"); + if (p.instance_mode) + NVDR_CHECK(attr.size(0) == rast.size(0) || attr.size(0) == 1, "minibatch size mismatch between inputs rast, attr"); + if (enable_da) + { + NVDR_CHECK(rast_db.sizes().size() == 4 && rast_db.size(0) > 0 && rast_db.size(1) > 0 && rast_db.size(2) > 0 && rast_db.size(3) == 4, "rast_db must have shape[>0, >0, >0, 4]"); + NVDR_CHECK(rast_db.size(1) == rast.size(1) && rast_db.size(2) == rast.size(2), "spatial size mismatch between inputs rast and rast_db"); + NVDR_CHECK(rast_db.size(0) == rast.size(0), "minibatch size mismatch between inputs rast, rast_db"); + } + + // Extract input dimensions. + p.numVertices = attr.size(p.instance_mode ? 1 : 0); + p.numAttr = attr.size(p.instance_mode ? 2 : 1); + p.numTriangles = tri.size(0); + p.height = rast.size(1); + p.width = rast.size(2); + p.depth = rast.size(0); + + // Set attribute pixel differential info if enabled, otherwise leave as zero. + if (enable_da) + set_diff_attrs(p, diff_attrs_all, diff_attrs_vec); + else + p.numDiffAttr = 0; + + // Get input pointers. + p.attr = attr.data_ptr(); + p.rast = rast.data_ptr(); + p.tri = tri.data_ptr(); + p.rastDB = enable_da ? rast_db.data_ptr() : NULL; + p.attrBC = (p.instance_mode && attr.size(0) == 1) ? 1 : 0; + + // Allocate output tensors. + torch::TensorOptions opts = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA); + torch::Tensor out = torch::empty({p.depth, p.height, p.width, p.numAttr}, opts); + torch::Tensor out_da = torch::empty({p.depth, p.height, p.width, p.numDiffAttr * 2}, opts); + + p.out = out.data_ptr(); + p.outDA = enable_da ? out_da.data_ptr() : NULL; + + // Verify that buffers are aligned to allow float2/float4 operations. + NVDR_CHECK(!((uintptr_t)p.rast & 15), "rast input tensor not aligned to float4"); + NVDR_CHECK(!((uintptr_t)p.rastDB & 15), "rast_db input tensor not aligned to float4"); + NVDR_CHECK(!((uintptr_t)p.outDA & 7), "out_da output tensor not aligned to float2"); + + // Choose launch parameters. + dim3 blockSize = getLaunchBlockSize(IP_FWD_MAX_KERNEL_BLOCK_WIDTH, IP_FWD_MAX_KERNEL_BLOCK_HEIGHT, p.width, p.height); + dim3 gridSize = getLaunchGridSize(blockSize, p.width, p.height, p.depth); + + // Launch CUDA kernel. + void* args[] = {&p}; + void* func = enable_da ? (void*)InterpolateFwdKernelDa : (void*)InterpolateFwdKernel; + NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel(func, gridSize, blockSize, args, 0, stream)); + + // Return results. + return std::tuple(out, out_da); +} + +// Version without derivatives. +std::tuple interpolate_fwd(torch::Tensor attr, torch::Tensor rast, torch::Tensor tri) +{ + std::vector empty_vec; + torch::Tensor empty_tensor; + return interpolate_fwd_da(attr, rast, tri, empty_tensor, false, empty_vec); +} + +//------------------------------------------------------------------------ +// Gradient op. + +std::tuple interpolate_grad_da(torch::Tensor attr, torch::Tensor rast, torch::Tensor tri, torch::Tensor dy, torch::Tensor rast_db, torch::Tensor dda, bool diff_attrs_all, std::vector& diff_attrs_vec) +{ + const at::cuda::OptionalCUDAGuard device_guard(device_of(attr)); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + InterpolateKernelParams p = {}; // Initialize all fields to zero. + bool enable_da = (rast_db.defined()) && (diff_attrs_all || !diff_attrs_vec.empty()); + p.instance_mode = (attr.sizes().size() > 2) ? 1 : 0; + + // Check inputs. + if (enable_da) + { + NVDR_CHECK_DEVICE(attr, rast, tri, dy, rast_db, dda); + NVDR_CHECK_CONTIGUOUS(attr, rast, tri, rast_db); + NVDR_CHECK_F32(attr, rast, dy, rast_db, dda); + NVDR_CHECK_I32(tri); + } + else + { + NVDR_CHECK_DEVICE(attr, rast, tri, dy); + NVDR_CHECK_CONTIGUOUS(attr, rast, tri); + NVDR_CHECK_F32(attr, rast, dy); + NVDR_CHECK_I32(tri); + } + + // Depth of attributes. + int attr_depth = p.instance_mode ? (attr.sizes().size() > 1 ? attr.size(0) : 0) : 1; + + // Sanity checks. + NVDR_CHECK(rast.sizes().size() == 4 && rast.size(0) > 0 && rast.size(1) > 0 && rast.size(2) > 0 && rast.size(3) == 4, "rast must have shape[>0, >0, >0, 4]"); + NVDR_CHECK(tri.sizes().size() == 2 && tri.size(0) > 0 && tri.size(1) == 3, "tri must have shape [>0, 3]"); + NVDR_CHECK((attr.sizes().size() == 2 || attr.sizes().size() == 3) && attr.size(0) > 0 && attr.size(1) > 0 && (attr.sizes().size() == 2 || attr.size(2) > 0), "attr must have shape [>0, >0, >0] or [>0, >0]"); + NVDR_CHECK(dy.sizes().size() == 4 && dy.size(0) > 0 && dy.size(1) == rast.size(1) && dy.size(2) == rast.size(2) && dy.size(3) > 0, "dy must have shape [>0, height, width, >0]"); + NVDR_CHECK(dy.size(3) == attr.size(attr.sizes().size() - 1), "argument count mismatch between inputs dy, attr"); + NVDR_CHECK((attr_depth == rast.size(0) || attr_depth == 1) && dy.size(0) == rast.size(0), "minibatch size mismatch between inputs rast, dy, attr"); + if (enable_da) + { + NVDR_CHECK(dda.sizes().size() == 4 && dda.size(0) > 0 && dda.size(1) == rast.size(1) && dda.size(2) == rast.size(2), "dda must have shape [>0, height, width, ?]"); + NVDR_CHECK(dda.size(0) == rast.size(0), "minibatch size mismatch between rast, dda"); + NVDR_CHECK(rast_db.sizes().size() == 4 && rast_db.size(0) > 0 && rast_db.size(1) > 0 && rast_db.size(2) > 0 && rast_db.size(3) == 4, "rast_db must have shape[>0, >0, >0, 4]"); + NVDR_CHECK(rast_db.size(1) == rast.size(1) && rast_db.size(2) == rast.size(2), "spatial size mismatch between inputs rast and rast_db"); + NVDR_CHECK(rast_db.size(0) == rast.size(0), "minibatch size mismatch between inputs rast, rast_db"); + } + + // Extract input dimensions. + p.numVertices = attr.size(p.instance_mode ? 1 : 0); + p.numAttr = attr.size(p.instance_mode ? 2 : 1); + p.numTriangles = tri.size(0); + p.height = rast.size(1); + p.width = rast.size(2); + p.depth = rast.size(0); + + // Ensure gradients are contiguous. + torch::Tensor dy_ = dy.contiguous(); + torch::Tensor dda_; + if (enable_da) + dda_ = dda.contiguous(); + + // Set attribute pixel differential info if enabled, otherwise leave as zero. + if (enable_da) + set_diff_attrs(p, diff_attrs_all, diff_attrs_vec); + else + p.numDiffAttr = 0; + + // Get input pointers. + p.attr = attr.data_ptr(); + p.rast = rast.data_ptr(); + p.tri = tri.data_ptr(); + p.dy = dy_.data_ptr(); + p.rastDB = enable_da ? rast_db.data_ptr() : NULL; + p.dda = enable_da ? dda_.data_ptr() : NULL; + p.attrBC = (p.instance_mode && attr_depth < p.depth) ? 1 : 0; + + // Allocate output tensors. + torch::TensorOptions opts = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA); + torch::Tensor gradAttr = torch::zeros_like(attr); + torch::Tensor gradRaster = torch::empty_like(rast); + torch::Tensor gradRasterDB; + if (enable_da) + gradRasterDB = torch::empty_like(rast_db); + + p.gradAttr = gradAttr.data_ptr(); + p.gradRaster = gradRaster.data_ptr(); + p.gradRasterDB = enable_da ? gradRasterDB.data_ptr() : NULL; + + // Verify that buffers are aligned to allow float2/float4 operations. + NVDR_CHECK(!((uintptr_t)p.rast & 15), "rast input tensor not aligned to float4"); + NVDR_CHECK(!((uintptr_t)p.rastDB & 15), "rast_db input tensor not aligned to float4"); + NVDR_CHECK(!((uintptr_t)p.dda & 7), "dda input tensor not aligned to float2"); + NVDR_CHECK(!((uintptr_t)p.gradRaster & 15), "grad_rast output tensor not aligned to float4"); + NVDR_CHECK(!((uintptr_t)p.gradRasterDB & 15), "grad_rast_db output tensor not aligned to float4"); + + // Choose launch parameters. + dim3 blockSize = getLaunchBlockSize(IP_GRAD_MAX_KERNEL_BLOCK_WIDTH, IP_GRAD_MAX_KERNEL_BLOCK_HEIGHT, p.width, p.height); + dim3 gridSize = getLaunchGridSize(blockSize, p.width, p.height, p.depth); + + // Launch CUDA kernel. + void* args[] = {&p}; + void* func = enable_da ? (void*)InterpolateGradKernelDa : (void*)InterpolateGradKernel; + NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel(func, gridSize, blockSize, args, 0, stream)); + + // Return results. + return std::tuple(gradAttr, gradRaster, gradRasterDB); +} + +// Version without derivatives. +std::tuple interpolate_grad(torch::Tensor attr, torch::Tensor rast, torch::Tensor tri, torch::Tensor dy) +{ + std::vector empty_vec; + torch::Tensor empty_tensor; + std::tuple result = interpolate_grad_da(attr, rast, tri, dy, empty_tensor, empty_tensor, false, empty_vec); + return std::tuple(std::get<0>(result), std::get<1>(result)); +} + +//------------------------------------------------------------------------ diff --git a/extensions/nvdiffrast/nvdiffrast/torch/torch_rasterize.cpp b/extensions/nvdiffrast/nvdiffrast/torch/torch_rasterize.cpp new file mode 100644 index 0000000000000000000000000000000000000000..589e227ac0a8dc9735e32a3b77e38a5d1e11c882 --- /dev/null +++ b/extensions/nvdiffrast/nvdiffrast/torch/torch_rasterize.cpp @@ -0,0 +1,265 @@ +// Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +// +// NVIDIA CORPORATION and its licensors retain all intellectual property +// and proprietary rights in and to this software, related documentation +// and any modifications thereto. Any use, reproduction, disclosure or +// distribution of this software and related documentation without an express +// license agreement from NVIDIA CORPORATION is strictly prohibited. + +#include "torch_common.inl" +#include "torch_types.h" +#include "../common/common.h" +#include "../common/rasterize.h" +#include "../common/cudaraster/CudaRaster.hpp" +#include "../common/cudaraster/impl/Constants.hpp" +#include + +//------------------------------------------------------------------------ +// Kernel prototypes. + +void RasterizeCudaFwdShaderKernel(const RasterizeCudaFwdShaderParams p); +void RasterizeGradKernel(const RasterizeGradParams p); +void RasterizeGradKernelDb(const RasterizeGradParams p); + +//------------------------------------------------------------------------ +// Python CudaRaster state wrapper methods. + +RasterizeCRStateWrapper::RasterizeCRStateWrapper(int cudaDeviceIdx_) +{ + const at::cuda::OptionalCUDAGuard device_guard(cudaDeviceIdx_); + cudaDeviceIdx = cudaDeviceIdx_; + cr = new CR::CudaRaster(); +} + +RasterizeCRStateWrapper::~RasterizeCRStateWrapper(void) +{ + const at::cuda::OptionalCUDAGuard device_guard(cudaDeviceIdx); + delete cr; +} + +//------------------------------------------------------------------------ +// Forward op (Cuda). + +std::tuple rasterize_fwd_cuda(RasterizeCRStateWrapper& stateWrapper, torch::Tensor pos, torch::Tensor tri, std::tuple resolution, torch::Tensor ranges, int peeling_idx) +{ + const at::cuda::OptionalCUDAGuard device_guard(device_of(pos)); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + CR::CudaRaster* cr = stateWrapper.cr; + + // Check inputs. + NVDR_CHECK_DEVICE(pos, tri); + NVDR_CHECK_CPU(ranges); + NVDR_CHECK_CONTIGUOUS(pos, tri, ranges); + NVDR_CHECK_F32(pos); + NVDR_CHECK_I32(tri, ranges); + + // Check that CudaRaster context was created for the correct GPU. + NVDR_CHECK(pos.get_device() == stateWrapper.cudaDeviceIdx, "CudaRaster context must must reside on the same device as input tensors"); + + // Determine instance mode and check input dimensions. + bool instance_mode = pos.sizes().size() > 2; + if (instance_mode) + NVDR_CHECK(pos.sizes().size() == 3 && pos.size(0) > 0 && pos.size(1) > 0 && pos.size(2) == 4, "instance mode - pos must have shape [>0, >0, 4]"); + else + { + NVDR_CHECK(pos.sizes().size() == 2 && pos.size(0) > 0 && pos.size(1) == 4, "range mode - pos must have shape [>0, 4]"); + NVDR_CHECK(ranges.sizes().size() == 2 && ranges.size(0) > 0 && ranges.size(1) == 2, "range mode - ranges must have shape [>0, 2]"); + } + NVDR_CHECK(tri.sizes().size() == 2 && tri.size(0) > 0 && tri.size(1) == 3, "tri must have shape [>0, 3]"); + + // Get output shape. + int height_out = std::get<0>(resolution); + int width_out = std::get<1>(resolution); + int depth = instance_mode ? pos.size(0) : ranges.size(0); // Depth of tensor, not related to depth buffering. + NVDR_CHECK(height_out > 0 && width_out > 0, "resolution must be [>0, >0]"); + + // Round internal resolution up to tile size. + int height = (height_out + CR_TILE_SIZE - 1) & (-CR_TILE_SIZE); + int width = (width_out + CR_TILE_SIZE - 1) & (-CR_TILE_SIZE); + + // Get position and triangle buffer sizes in vertices / triangles. + int posCount = instance_mode ? pos.size(1) : pos.size(0); + int triCount = tri.size(0); + + // Set up CudaRaster buffers. + const float* posPtr = pos.data_ptr(); + const int32_t* rangesPtr = instance_mode ? 0 : ranges.data_ptr(); // This is in CPU memory. + const int32_t* triPtr = tri.data_ptr(); + cr->setVertexBuffer((void*)posPtr, posCount); + cr->setIndexBuffer((void*)triPtr, triCount); + cr->setBufferSize(width_out, height_out, depth); + + // Enable depth peeling? + bool enablePeel = (peeling_idx > 0); + cr->setRenderModeFlags(enablePeel ? CR::CudaRaster::RenderModeFlag_EnableDepthPeeling : 0); // No backface culling. + if (enablePeel) + cr->swapDepthAndPeel(); // Use previous depth buffer as peeling depth input. + + // Determine viewport tiling. + int tileCountX = (width + CR_MAXVIEWPORT_SIZE - 1) / CR_MAXVIEWPORT_SIZE; + int tileCountY = (height + CR_MAXVIEWPORT_SIZE - 1) / CR_MAXVIEWPORT_SIZE; + int tileSizeX = ((width + tileCountX - 1) / tileCountX + CR_TILE_SIZE - 1) & (-CR_TILE_SIZE); + int tileSizeY = ((height + tileCountY - 1) / tileCountY + CR_TILE_SIZE - 1) & (-CR_TILE_SIZE); + TORCH_CHECK(tileCountX > 0 && tileCountY > 0 && tileSizeX > 0 && tileSizeY > 0, "internal error in tile size calculation: count or size is zero"); + TORCH_CHECK(tileSizeX <= CR_MAXVIEWPORT_SIZE && tileSizeY <= CR_MAXVIEWPORT_SIZE, "internal error in tile size calculation: tile larger than allowed"); + TORCH_CHECK((tileSizeX & (CR_TILE_SIZE - 1)) == 0 && (tileSizeY & (CR_TILE_SIZE - 1)) == 0, "internal error in tile size calculation: tile not divisible by ", CR_TILE_SIZE); + TORCH_CHECK(tileCountX * tileSizeX >= width && tileCountY * tileSizeY >= height, "internal error in tile size calculation: tiles do not cover viewport"); + + // Rasterize in tiles. + for (int tileY = 0; tileY < tileCountY; tileY++) + for (int tileX = 0; tileX < tileCountX; tileX++) + { + // Set CudaRaster viewport according to tile. + int offsetX = tileX * tileSizeX; + int offsetY = tileY * tileSizeY; + int sizeX = (width_out - offsetX) < tileSizeX ? (width_out - offsetX) : tileSizeX; + int sizeY = (height_out - offsetY) < tileSizeY ? (height_out - offsetY) : tileSizeY; + cr->setViewport(sizeX, sizeY, offsetX, offsetY); + + // Run all triangles in one batch. In case of error, the workload could be split into smaller batches - maybe do that in the future. + // Only enable peeling-specific optimizations to skip first stages when image fits in one tile. Those are not valid otherwise. + cr->deferredClear(0u); + bool success = cr->drawTriangles(rangesPtr, enablePeel && (tileCountX == 1 && tileCountY == 1), stream); + NVDR_CHECK(success, "subtriangle count overflow"); + } + + // Allocate output tensors. + torch::TensorOptions opts = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA); + torch::Tensor out = torch::empty({depth, height_out, width_out, 4}, opts); + torch::Tensor out_db = torch::empty({depth, height_out, width_out, 4}, opts); + + // Populate pixel shader kernel parameters. + RasterizeCudaFwdShaderParams p; + p.pos = posPtr; + p.tri = triPtr; + p.in_idx = (const int*)cr->getColorBuffer(); + p.out = out.data_ptr(); + p.out_db = out_db.data_ptr(); + p.numTriangles = triCount; + p.numVertices = posCount; + p.width_in = width; + p.height_in = height; + p.width_out = width_out; + p.height_out = height_out; + p.depth = depth; + p.instance_mode = (pos.sizes().size() > 2) ? 1 : 0; + p.xs = 2.f / (float)width_out; + p.xo = 1.f / (float)width_out - 1.f; + p.ys = 2.f / (float)height_out; + p.yo = 1.f / (float)height_out - 1.f; + + // Verify that buffers are aligned to allow float2/float4 operations. + NVDR_CHECK(!((uintptr_t)p.pos & 15), "pos input tensor not aligned to float4"); + NVDR_CHECK(!((uintptr_t)p.out & 15), "out output tensor not aligned to float4"); + NVDR_CHECK(!((uintptr_t)p.out_db & 15), "out_db output tensor not aligned to float4"); + + // Choose launch parameters. + dim3 blockSize = getLaunchBlockSize(RAST_CUDA_FWD_SHADER_KERNEL_BLOCK_WIDTH, RAST_CUDA_FWD_SHADER_KERNEL_BLOCK_HEIGHT, p.width_out, p.height_out); + dim3 gridSize = getLaunchGridSize(blockSize, p.width_out, p.height_out, p.depth); + + // Launch CUDA kernel. + void* args[] = {&p}; + NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((void*)RasterizeCudaFwdShaderKernel, gridSize, blockSize, args, 0, stream)); + + // Return. + return std::tuple(out, out_db); +} + +//------------------------------------------------------------------------ +// Gradient op. + +torch::Tensor rasterize_grad_db(torch::Tensor pos, torch::Tensor tri, torch::Tensor out, torch::Tensor dy, torch::Tensor ddb) +{ + const at::cuda::OptionalCUDAGuard device_guard(device_of(pos)); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + RasterizeGradParams p; + bool enable_db = ddb.defined(); + + // Check inputs. + if (enable_db) + { + NVDR_CHECK_DEVICE(pos, tri, out, dy, ddb); + NVDR_CHECK_CONTIGUOUS(pos, tri, out); + NVDR_CHECK_F32(pos, out, dy, ddb); + NVDR_CHECK_I32(tri); + } + else + { + NVDR_CHECK_DEVICE(pos, tri, out, dy); + NVDR_CHECK_CONTIGUOUS(pos, tri, out); + NVDR_CHECK_F32(pos, out, dy); + NVDR_CHECK_I32(tri); + } + + // Determine instance mode. + p.instance_mode = (pos.sizes().size() > 2) ? 1 : 0; + + // Shape is taken from the rasterizer output tensor. + NVDR_CHECK(out.sizes().size() == 4, "tensor out must be rank-4"); + p.depth = out.size(0); + p.height = out.size(1); + p.width = out.size(2); + NVDR_CHECK(p.depth > 0 && p.height > 0 && p.width > 0, "resolution must be [>0, >0, >0]"); + + // Check other shapes. + if (p.instance_mode) + NVDR_CHECK(pos.sizes().size() == 3 && pos.size(0) == p.depth && pos.size(1) > 0 && pos.size(2) == 4, "pos must have shape [depth, >0, 4]"); + else + NVDR_CHECK(pos.sizes().size() == 2 && pos.size(0) > 0 && pos.size(1) == 4, "pos must have shape [>0, 4]"); + NVDR_CHECK(tri.sizes().size() == 2 && tri.size(0) > 0 && tri.size(1) == 3, "tri must have shape [>0, 3]"); + NVDR_CHECK(out.sizes().size() == 4 && out.size(0) == p.depth && out.size(1) == p.height && out.size(2) == p.width && out.size(3) == 4, "out must have shape [depth, height, width, 4]"); + NVDR_CHECK( dy.sizes().size() == 4 && dy.size(0) == p.depth && dy.size(1) == p.height && dy.size(2) == p.width && dy.size(3) == 4, "dy must have shape [depth, height, width, 4]"); + if (enable_db) + NVDR_CHECK(ddb.sizes().size() == 4 && ddb.size(0) == p.depth && ddb.size(1) == p.height && ddb.size(2) == p.width && ddb.size(3) == 4, "ddb must have shape [depth, height, width, 4]"); + + // Ensure gradients are contiguous. + torch::Tensor dy_ = dy.contiguous(); + torch::Tensor ddb_; + if (enable_db) + ddb_ = ddb.contiguous(); + + // Populate parameters. + p.numTriangles = tri.size(0); + p.numVertices = p.instance_mode ? pos.size(1) : pos.size(0); + p.pos = pos.data_ptr(); + p.tri = tri.data_ptr(); + p.out = out.data_ptr(); + p.dy = dy_.data_ptr(); + p.ddb = enable_db ? ddb_.data_ptr() : NULL; + + // Set up pixel position to clip space x, y transform. + p.xs = 2.f / (float)p.width; + p.xo = 1.f / (float)p.width - 1.f; + p.ys = 2.f / (float)p.height; + p.yo = 1.f / (float)p.height - 1.f; + + // Allocate output tensor for position gradients. + torch::Tensor grad = torch::zeros_like(pos); + p.grad = grad.data_ptr(); + + // Verify that buffers are aligned to allow float2/float4 operations. + NVDR_CHECK(!((uintptr_t)p.pos & 15), "pos input tensor not aligned to float4"); + NVDR_CHECK(!((uintptr_t)p.dy & 7), "dy input tensor not aligned to float2"); + NVDR_CHECK(!((uintptr_t)p.ddb & 15), "ddb input tensor not aligned to float4"); + + // Choose launch parameters. + dim3 blockSize = getLaunchBlockSize(RAST_GRAD_MAX_KERNEL_BLOCK_WIDTH, RAST_GRAD_MAX_KERNEL_BLOCK_HEIGHT, p.width, p.height); + dim3 gridSize = getLaunchGridSize(blockSize, p.width, p.height, p.depth); + + // Launch CUDA kernel. + void* args[] = {&p}; + void* func = enable_db ? (void*)RasterizeGradKernelDb : (void*)RasterizeGradKernel; + NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel(func, gridSize, blockSize, args, 0, stream)); + + // Return the gradients. + return grad; +} + +// Version without derivatives. +torch::Tensor rasterize_grad(torch::Tensor pos, torch::Tensor tri, torch::Tensor out, torch::Tensor dy) +{ + torch::Tensor empty_tensor; + return rasterize_grad_db(pos, tri, out, dy, empty_tensor); +} + +//------------------------------------------------------------------------ diff --git a/extensions/nvdiffrast/nvdiffrast/torch/torch_rasterize_gl.cpp b/extensions/nvdiffrast/nvdiffrast/torch/torch_rasterize_gl.cpp new file mode 100644 index 0000000000000000000000000000000000000000..3776134adbd53f9138ef34fbbb2c00eb62883041 --- /dev/null +++ b/extensions/nvdiffrast/nvdiffrast/torch/torch_rasterize_gl.cpp @@ -0,0 +1,132 @@ +// Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +// +// NVIDIA CORPORATION and its licensors retain all intellectual property +// and proprietary rights in and to this software, related documentation +// and any modifications thereto. Any use, reproduction, disclosure or +// distribution of this software and related documentation without an express +// license agreement from NVIDIA CORPORATION is strictly prohibited. + +#include "torch_common.inl" +#include "torch_types.h" +#include "../common/common.h" +#include "../common/rasterize_gl.h" +#include + +//------------------------------------------------------------------------ +// Python GL state wrapper methods. + +RasterizeGLStateWrapper::RasterizeGLStateWrapper(bool enableDB, bool automatic_, int cudaDeviceIdx_) +{ + pState = new RasterizeGLState(); + automatic = automatic_; + cudaDeviceIdx = cudaDeviceIdx_; + memset(pState, 0, sizeof(RasterizeGLState)); + pState->enableDB = enableDB ? 1 : 0; + rasterizeInitGLContext(NVDR_CTX_PARAMS, *pState, cudaDeviceIdx_); + releaseGLContext(); +} + +RasterizeGLStateWrapper::~RasterizeGLStateWrapper(void) +{ + setGLContext(pState->glctx); + rasterizeReleaseBuffers(NVDR_CTX_PARAMS, *pState); + releaseGLContext(); + destroyGLContext(pState->glctx); + delete pState; +} + +void RasterizeGLStateWrapper::setContext(void) +{ + setGLContext(pState->glctx); +} + +void RasterizeGLStateWrapper::releaseContext(void) +{ + releaseGLContext(); +} + +//------------------------------------------------------------------------ +// Forward op (OpenGL). + +std::tuple rasterize_fwd_gl(RasterizeGLStateWrapper& stateWrapper, torch::Tensor pos, torch::Tensor tri, std::tuple resolution, torch::Tensor ranges, int peeling_idx) +{ + const at::cuda::OptionalCUDAGuard device_guard(device_of(pos)); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + RasterizeGLState& s = *stateWrapper.pState; + + // Check inputs. + NVDR_CHECK_DEVICE(pos, tri); + NVDR_CHECK_CPU(ranges); + NVDR_CHECK_CONTIGUOUS(pos, tri, ranges); + NVDR_CHECK_F32(pos); + NVDR_CHECK_I32(tri, ranges); + + // Check that GL context was created for the correct GPU. + NVDR_CHECK(pos.get_device() == stateWrapper.cudaDeviceIdx, "GL context must must reside on the same device as input tensors"); + + // Determine number of outputs + int num_outputs = s.enableDB ? 2 : 1; + + // Determine instance mode and check input dimensions. + bool instance_mode = pos.sizes().size() > 2; + if (instance_mode) + NVDR_CHECK(pos.sizes().size() == 3 && pos.size(0) > 0 && pos.size(1) > 0 && pos.size(2) == 4, "instance mode - pos must have shape [>0, >0, 4]"); + else + { + NVDR_CHECK(pos.sizes().size() == 2 && pos.size(0) > 0 && pos.size(1) == 4, "range mode - pos must have shape [>0, 4]"); + NVDR_CHECK(ranges.sizes().size() == 2 && ranges.size(0) > 0 && ranges.size(1) == 2, "range mode - ranges must have shape [>0, 2]"); + } + NVDR_CHECK(tri.sizes().size() == 2 && tri.size(0) > 0 && tri.size(1) == 3, "tri must have shape [>0, 3]"); + + // Get output shape. + int height = std::get<0>(resolution); + int width = std::get<1>(resolution); + int depth = instance_mode ? pos.size(0) : ranges.size(0); + NVDR_CHECK(height > 0 && width > 0, "resolution must be [>0, >0]"); + + // Get position and triangle buffer sizes in int32/float32. + int posCount = 4 * pos.size(0) * (instance_mode ? pos.size(1) : 1); + int triCount = 3 * tri.size(0); + + // Set the GL context unless manual context. + if (stateWrapper.automatic) + setGLContext(s.glctx); + + // Resize all buffers. + bool changes = false; + rasterizeResizeBuffers(NVDR_CTX_PARAMS, s, changes, posCount, triCount, width, height, depth); + if (changes) + { +#ifdef _WIN32 + // Workaround for occasional blank first frame on Windows. + releaseGLContext(); + setGLContext(s.glctx); +#endif + } + + // Copy input data to GL and render. + const float* posPtr = pos.data_ptr(); + const int32_t* rangesPtr = instance_mode ? 0 : ranges.data_ptr(); // This is in CPU memory. + const int32_t* triPtr = tri.data_ptr(); + int vtxPerInstance = instance_mode ? pos.size(1) : 0; + rasterizeRender(NVDR_CTX_PARAMS, s, stream, posPtr, posCount, vtxPerInstance, triPtr, triCount, rangesPtr, width, height, depth, peeling_idx); + + // Allocate output tensors. + torch::TensorOptions opts = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA); + torch::Tensor out = torch::empty({depth, height, width, 4}, opts); + torch::Tensor out_db = torch::empty({depth, height, width, s.enableDB ? 4 : 0}, opts); + float* outputPtr[2]; + outputPtr[0] = out.data_ptr(); + outputPtr[1] = s.enableDB ? out_db.data_ptr() : NULL; + + // Copy rasterized results into CUDA buffers. + rasterizeCopyResults(NVDR_CTX_PARAMS, s, stream, outputPtr, width, height, depth); + + // Done. Release GL context and return. + if (stateWrapper.automatic) + releaseGLContext(); + + return std::tuple(out, out_db); +} + +//------------------------------------------------------------------------ diff --git a/extensions/nvdiffrast/nvdiffrast/torch/torch_texture.cpp b/extensions/nvdiffrast/nvdiffrast/torch/torch_texture.cpp new file mode 100644 index 0000000000000000000000000000000000000000..2257f566623495c7044ea3f532ef00e327477dc7 --- /dev/null +++ b/extensions/nvdiffrast/nvdiffrast/torch/torch_texture.cpp @@ -0,0 +1,718 @@ +// Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +// +// NVIDIA CORPORATION and its licensors retain all intellectual property +// and proprietary rights in and to this software, related documentation +// and any modifications thereto. Any use, reproduction, disclosure or +// distribution of this software and related documentation without an express +// license agreement from NVIDIA CORPORATION is strictly prohibited. + +#include "torch_common.inl" +#include "torch_types.h" +#include "../common/common.h" +#include "../common/texture.h" +#include + +//------------------------------------------------------------------------ +// Kernel prototypes. + +void MipBuildKernel1 (const TextureKernelParams p); +void MipBuildKernel2 (const TextureKernelParams p); +void MipBuildKernel4 (const TextureKernelParams p); +void TextureFwdKernelNearest1 (const TextureKernelParams p); +void TextureFwdKernelNearest2 (const TextureKernelParams p); +void TextureFwdKernelNearest4 (const TextureKernelParams p); +void TextureFwdKernelLinear1 (const TextureKernelParams p); +void TextureFwdKernelLinear2 (const TextureKernelParams p); +void TextureFwdKernelLinear4 (const TextureKernelParams p); +void TextureFwdKernelLinearMipmapNearest1 (const TextureKernelParams p); +void TextureFwdKernelLinearMipmapNearest2 (const TextureKernelParams p); +void TextureFwdKernelLinearMipmapNearest4 (const TextureKernelParams p); +void TextureFwdKernelLinearMipmapLinear1 (const TextureKernelParams p); +void TextureFwdKernelLinearMipmapLinear2 (const TextureKernelParams p); +void TextureFwdKernelLinearMipmapLinear4 (const TextureKernelParams p); +void TextureFwdKernelCubeNearest1 (const TextureKernelParams p); +void TextureFwdKernelCubeNearest2 (const TextureKernelParams p); +void TextureFwdKernelCubeNearest4 (const TextureKernelParams p); +void TextureFwdKernelCubeLinear1 (const TextureKernelParams p); +void TextureFwdKernelCubeLinear2 (const TextureKernelParams p); +void TextureFwdKernelCubeLinear4 (const TextureKernelParams p); +void TextureFwdKernelCubeLinearMipmapNearest1 (const TextureKernelParams p); +void TextureFwdKernelCubeLinearMipmapNearest2 (const TextureKernelParams p); +void TextureFwdKernelCubeLinearMipmapNearest4 (const TextureKernelParams p); +void TextureFwdKernelCubeLinearMipmapLinear1 (const TextureKernelParams p); +void TextureFwdKernelCubeLinearMipmapLinear2 (const TextureKernelParams p); +void TextureFwdKernelCubeLinearMipmapLinear4 (const TextureKernelParams p); +void TextureFwdKernelLinearMipmapNearestBO1 (const TextureKernelParams p); +void TextureFwdKernelLinearMipmapNearestBO2 (const TextureKernelParams p); +void TextureFwdKernelLinearMipmapNearestBO4 (const TextureKernelParams p); +void TextureFwdKernelLinearMipmapLinearBO1 (const TextureKernelParams p); +void TextureFwdKernelLinearMipmapLinearBO2 (const TextureKernelParams p); +void TextureFwdKernelLinearMipmapLinearBO4 (const TextureKernelParams p); +void TextureFwdKernelCubeLinearMipmapNearestBO1 (const TextureKernelParams p); +void TextureFwdKernelCubeLinearMipmapNearestBO2 (const TextureKernelParams p); +void TextureFwdKernelCubeLinearMipmapNearestBO4 (const TextureKernelParams p); +void TextureFwdKernelCubeLinearMipmapLinearBO1 (const TextureKernelParams p); +void TextureFwdKernelCubeLinearMipmapLinearBO2 (const TextureKernelParams p); +void TextureFwdKernelCubeLinearMipmapLinearBO4 (const TextureKernelParams p); +void MipGradKernel1 (const TextureKernelParams p); +void MipGradKernel2 (const TextureKernelParams p); +void MipGradKernel4 (const TextureKernelParams p); +void TextureGradKernelNearest (const TextureKernelParams p); +void TextureGradKernelLinear (const TextureKernelParams p); +void TextureGradKernelLinearMipmapNearest (const TextureKernelParams p); +void TextureGradKernelLinearMipmapLinear (const TextureKernelParams p); +void TextureGradKernelCubeNearest (const TextureKernelParams p); +void TextureGradKernelCubeLinear (const TextureKernelParams p); +void TextureGradKernelCubeLinearMipmapNearest (const TextureKernelParams p); +void TextureGradKernelCubeLinearMipmapLinear (const TextureKernelParams p); +void TextureGradKernelLinearMipmapNearestBO (const TextureKernelParams p); +void TextureGradKernelLinearMipmapLinearBO (const TextureKernelParams p); +void TextureGradKernelCubeLinearMipmapNearestBO (const TextureKernelParams p); +void TextureGradKernelCubeLinearMipmapLinearBO (const TextureKernelParams p); + +//------------------------------------------------------------------------ +// Modeselektor. + +static void set_modes(TextureKernelParams& p, int filter_mode, int boundary_mode, int max_mip_level) +{ + // Mip and filter modes. + p.filterMode = filter_mode; + NVDR_CHECK(p.filterMode >= 0 && p.filterMode < TEX_MODE_COUNT, "filter_mode unsupported"); + p.enableMip = (p.filterMode == TEX_MODE_LINEAR_MIPMAP_NEAREST || p.filterMode == TEX_MODE_LINEAR_MIPMAP_LINEAR); + + // Mip level clamp. + if (p.enableMip) + { + p.mipLevelLimit = max_mip_level; + NVDR_CHECK(p.mipLevelLimit >= -1, "invalid max_mip_level"); + } + + // Boundary mode. + p.boundaryMode = boundary_mode; + NVDR_CHECK(p.boundaryMode >= 0 && p.boundaryMode < TEX_BOUNDARY_MODE_COUNT, "boundary_mode unsupported"); +} + +//------------------------------------------------------------------------ +// Mipmap construction. + +TextureMipWrapper texture_construct_mip(torch::Tensor tex, int max_mip_level, bool cube_mode) +{ + const at::cuda::OptionalCUDAGuard device_guard(device_of(tex)); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + TextureKernelParams p = {}; // Initialize all fields to zero. + p.mipLevelLimit = max_mip_level; + p.boundaryMode = cube_mode ? TEX_BOUNDARY_MODE_CUBE : TEX_BOUNDARY_MODE_WRAP; + NVDR_CHECK(p.mipLevelLimit >= -1, "invalid max_mip_level"); + + // Check inputs. + NVDR_CHECK_DEVICE(tex); + NVDR_CHECK_CONTIGUOUS(tex); + NVDR_CHECK_F32(tex); + + // Populate parameters and sanity check tex shape. + if (!cube_mode) + { + NVDR_CHECK(tex.sizes().size() == 4 && tex.size(0) > 0 && tex.size(1) > 0 && tex.size(2) > 0 && tex.size(3) > 0, "tex must have shape[>0, >0, >0, >0]"); + } + else + { + NVDR_CHECK(tex.sizes().size() == 5 && tex.size(0) > 0 && tex.size(1) == 6 && tex.size(2) > 0 && tex.size(3) > 0 && tex.size(4) > 0, "tex must have shape[>0, 6, >0, >0, >0] in cube map mode"); + NVDR_CHECK(tex.size(2) == tex.size(3), "texture shape must be square in cube map mode"); + } + p.texDepth = tex.size(0); + p.texHeight = tex.size(cube_mode ? 2 : 1); + p.texWidth = tex.size(cube_mode ? 3 : 2); + p.channels = tex.size(cube_mode ? 4 : 3); + + // Set texture pointer. + p.tex[0] = tex.data_ptr(); + + // Generate mip offsets and calculate total size. + int mipOffsets[TEX_MAX_MIP_LEVEL]; + int mipTotal = calculateMipInfo(NVDR_CTX_PARAMS, p, mipOffsets); + + // Allocate and set mip tensor. + torch::TensorOptions opts = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA); + torch::Tensor mip = torch::empty({mipTotal}, opts); + float* pmip = mip.data_ptr(); + for (int i=1; i <= p.mipLevelMax; i++) + p.tex[i] = pmip + mipOffsets[i]; // Pointers to mip levels. + + // Choose kernel variants based on channel count. + void* args[] = {&p}; + int channel_div_idx = 0; + if (!(p.channels & 3)) + channel_div_idx = 2; // Channel count divisible by 4. + else if (!(p.channels & 1)) + channel_div_idx = 1; // Channel count divisible by 2. + + // Build mip levels. + for (int i=1; i <= p.mipLevelMax; i++) + { + int2 ms = mipLevelSize(p, i); + int3 sz = make_int3(ms.x, ms.y, p.texDepth); + dim3 blockSize = getLaunchBlockSize(TEX_FWD_MAX_MIP_KERNEL_BLOCK_WIDTH, TEX_FWD_MAX_MIP_KERNEL_BLOCK_HEIGHT, sz.x, sz.y); + dim3 gridSize = getLaunchGridSize(blockSize, sz.x, sz.y, sz.z * (cube_mode ? 6 : 1)); + p.mipLevelOut = i; + + void* build_func_tbl[3] = { (void*)MipBuildKernel1, (void*)MipBuildKernel2, (void*)MipBuildKernel4 }; + NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel(build_func_tbl[channel_div_idx], gridSize, blockSize, args, 0, stream)); + } + + // Return the mip tensor in a wrapper. + TextureMipWrapper mip_wrapper; + mip_wrapper.mip = mip; + mip_wrapper.max_mip_level = max_mip_level; + mip_wrapper.texture_size = tex.sizes().vec(); + mip_wrapper.cube_mode = cube_mode; + return mip_wrapper; +} + +//------------------------------------------------------------------------ +// Forward op. + +torch::Tensor texture_fwd_mip(torch::Tensor tex, torch::Tensor uv, torch::Tensor uv_da, torch::Tensor mip_level_bias, TextureMipWrapper mip_wrapper, std::vector mip_stack, int filter_mode, int boundary_mode) +{ + const at::cuda::OptionalCUDAGuard device_guard(device_of(tex)); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + TextureKernelParams p = {}; // Initialize all fields to zero. + bool has_mip_stack = (mip_stack.size() > 0); + torch::Tensor& mip_w = mip_wrapper.mip; // Unwrap. + int max_mip_level = has_mip_stack ? mip_stack.size() : mip_wrapper.max_mip_level; + set_modes(p, filter_mode, boundary_mode, max_mip_level); + + // See if we have these tensors or not. + bool has_uv_da = uv_da.defined() && uv_da.nbytes(); + bool has_mip_level_bias = mip_level_bias.defined() && mip_level_bias.nbytes(); + + if (p.enableMip) + { + NVDR_CHECK(has_uv_da || has_mip_level_bias, "mipmapping filter mode requires uv_da and/or mip_level_bias input"); + NVDR_CHECK(has_mip_stack || mip_w.defined(), "mipmapping filter mode requires mip wrapper or mip stack input"); + } + + // Check inputs. + NVDR_CHECK_DEVICE(tex, uv); + NVDR_CHECK_CONTIGUOUS(tex, uv); + NVDR_CHECK_F32(tex, uv); + if (p.enableMip) + { + if (has_mip_stack) + { + TORCH_CHECK(at::cuda::check_device(mip_stack), __func__, "(): Mip stack inputs must reside on the correct GPU device"); + nvdr_check_contiguous(mip_stack, __func__, "(): Mip stack inputs must be contiguous tensors"); + nvdr_check_f32(mip_stack, __func__, "(): Mip stack inputs must be float32 tensors"); + } + else + { + NVDR_CHECK_DEVICE(mip_w); + NVDR_CHECK_CONTIGUOUS(mip_w); + NVDR_CHECK_F32(mip_w); + } + if (has_uv_da) + { + NVDR_CHECK_DEVICE(uv_da); + NVDR_CHECK_CONTIGUOUS(uv_da); + NVDR_CHECK_F32(uv_da); + } + if (has_mip_level_bias) + { + NVDR_CHECK_DEVICE(mip_level_bias); + NVDR_CHECK_CONTIGUOUS(mip_level_bias); + NVDR_CHECK_F32(mip_level_bias); + } + } + + // Sanity checks and state setters. + bool cube_mode = (boundary_mode == TEX_BOUNDARY_MODE_CUBE); + if (!cube_mode) + { + NVDR_CHECK(tex.sizes().size() == 4 && tex.size(0) > 0 && tex.size(1) > 0 && tex.size(2) > 0 && tex.size(3) > 0, "tex must have shape[>0, >0, >0, >0]"); + NVDR_CHECK(uv.sizes().size() == 4 && uv.size(0) > 0 && uv.size(1) > 0 && uv.size(2) > 0 && uv.size(3) == 2, "uv must have shape [>0, >0, >0, 2]"); + p.texHeight = tex.size(1); + p.texWidth = tex.size(2); + p.channels = tex.size(3); + } + else + { + NVDR_CHECK(tex.sizes().size() == 5 && tex.size(0) > 0 && tex.size(1) == 6 && tex.size(2) > 0 && tex.size(3) > 0 && tex.size(4) > 0, "tex must have shape[>0, 6, >0, >0, >0] in cube map mode"); + NVDR_CHECK(uv.sizes().size() == 4 && uv.size(0) > 0 && uv.size(1) > 0 && uv.size(2) > 0 && uv.size(3) == 3, "uv must have shape [>0, >0, >0, 3] in cube map mode"); + NVDR_CHECK(tex.size(2) == tex.size(3), "texture shape must be square in cube map mode"); + p.texHeight = tex.size(2); + p.texWidth = tex.size(3); + p.channels = tex.size(4); + } + NVDR_CHECK(tex.size(0) == 1 || tex.size(0) == uv.size(0), "minibatch size mismatch between inputs tex, uv"); + NVDR_CHECK(p.texWidth <= (1 << TEX_MAX_MIP_LEVEL) && p.texHeight <= (1 << TEX_MAX_MIP_LEVEL), "texture size too large"); + p.n = uv.size(0); + p.imgHeight = uv.size(1); + p.imgWidth = uv.size(2); + p.texDepth = tex.size(0); + if (p.enableMip) + { + if (has_uv_da) + { + if (!cube_mode) + NVDR_CHECK(uv_da.sizes().size() == 4 && uv_da.size(0) == p.n && uv_da.size(1) == p.imgHeight && uv_da.size(2) == p.imgWidth && uv_da.size(3) == 4, "uv_da must have shape [minibatch_size, height, width, 4]"); + else + NVDR_CHECK(uv_da.sizes().size() == 4 && uv_da.size(0) == p.n && uv_da.size(1) == p.imgHeight && uv_da.size(2) == p.imgWidth && uv_da.size(3) == 6, "uv_da must have shape [minibatch_size, height, width, 6] in cube map mode"); + } + if (has_mip_level_bias) + NVDR_CHECK(mip_level_bias.sizes().size() == 3 && mip_level_bias.size(0) == p.n && mip_level_bias.size(1) == p.imgHeight && mip_level_bias.size(2) == p.imgWidth, "mip_level_bias must have shape [minibatch_size, height, width]"); + } + + // Get input pointers. + p.tex[0] = tex.data_ptr(); + p.uv = uv.data_ptr(); + p.uvDA = (p.enableMip && has_uv_da) ? uv_da.data_ptr() : NULL; + p.mipLevelBias = (p.enableMip && has_mip_level_bias) ? mip_level_bias.data_ptr() : NULL; + + // Allocate output tensor. + torch::TensorOptions opts = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA); + torch::Tensor out = torch::empty({p.n, p.imgHeight, p.imgWidth, p.channels}, opts); + p.out = out.data_ptr(); + + // Choose kernel variants based on channel count. + void* args[] = {&p}; + int channel_div_idx = 0; + if (!(p.channels & 3)) + channel_div_idx = 2; // Channel count divisible by 4. + else if (!(p.channels & 1)) + channel_div_idx = 1; // Channel count divisible by 2. + + // Mip-related setup. + float* pmip = 0; + if (p.enableMip) + { + if (has_mip_stack) + { + // Custom mip stack supplied. Check that sizes match and assign. + p.mipLevelMax = max_mip_level; + for (int i=1; i <= p.mipLevelMax; i++) + { + torch::Tensor& t = mip_stack[i-1]; + int2 sz = mipLevelSize(p, i); + if (!cube_mode) + NVDR_CHECK(t.sizes().size() == 4 && t.size(0) == tex.size(0) && t.size(1) == sz.y && t.size(2) == sz.x && t.size(3) == p.channels, "mip level size mismatch in custom mip stack"); + else + NVDR_CHECK(t.sizes().size() == 5 && t.size(0) == tex.size(0) && t.size(1) == 6 && t.size(2) == sz.y && t.size(3) == sz.x && t.size(4) == p.channels, "mip level size mismatch in mip stack"); + if (sz.x == 1 && sz.y == 1) + NVDR_CHECK(i == p.mipLevelMax, "mip level size mismatch in mip stack"); + p.tex[i] = t.data_ptr(); + } + } + else + { + // Generate mip offsets, check mipmap size, and set mip data pointer. + int mipOffsets[TEX_MAX_MIP_LEVEL]; + int mipTotal = calculateMipInfo(NVDR_CTX_PARAMS, p, mipOffsets); + NVDR_CHECK(tex.sizes() == mip_wrapper.texture_size && cube_mode == mip_wrapper.cube_mode, "mip does not match texture size"); + NVDR_CHECK(mip_w.sizes().size() == 1 && mip_w.size(0) == mipTotal, "wrapped mip tensor size mismatch"); + pmip = mip_w.data_ptr(); + for (int i=1; i <= p.mipLevelMax; i++) + p.tex[i] = pmip + mipOffsets[i]; // Pointers to mip levels. + } + } + + // Verify that buffers are aligned to allow float2/float4 operations. Unused pointers are zero so always aligned. + if (!cube_mode) + NVDR_CHECK(!((uintptr_t)p.uv & 7), "uv input tensor not aligned to float2"); + if ((p.channels & 3) == 0) + { + for (int i=0; i <= p.mipLevelMax; i++) + NVDR_CHECK(!((uintptr_t)p.tex[i] & 15), "tex or mip input tensor not aligned to float4"); + NVDR_CHECK(!((uintptr_t)p.out & 15), "out output tensor not aligned to float4"); + NVDR_CHECK(!((uintptr_t)pmip & 15), "mip input tensor not aligned to float4"); + } + if ((p.channels & 1) == 0) + { + for (int i=0; i <= p.mipLevelMax; i++) + NVDR_CHECK(!((uintptr_t)p.tex[i] & 7), "tex or mip input tensor not aligned to float2"); + NVDR_CHECK(!((uintptr_t)p.out & 7), "out output tensor not aligned to float2"); + NVDR_CHECK(!((uintptr_t)pmip & 7), "mip input tensor not aligned to float2"); + } + if (!cube_mode) + NVDR_CHECK(!((uintptr_t)p.uvDA & 15), "uv_da input tensor not aligned to float4"); + else + NVDR_CHECK(!((uintptr_t)p.uvDA & 7), "uv_da input tensor not aligned to float2"); + + // Choose launch parameters for texture lookup kernel. + dim3 blockSize = getLaunchBlockSize(TEX_FWD_MAX_KERNEL_BLOCK_WIDTH, TEX_FWD_MAX_KERNEL_BLOCK_HEIGHT, p.imgWidth, p.imgHeight); + dim3 gridSize = getLaunchGridSize(blockSize, p.imgWidth, p.imgHeight, p.n); + + // Choose kernel based on filter mode, cube mode, bias-only mode, and datatype. + void* func_tbl[TEX_MODE_COUNT * 2 * 2 * 3] = { + (void*)TextureFwdKernelNearest1, + (void*)TextureFwdKernelNearest2, + (void*)TextureFwdKernelNearest4, + (void*)TextureFwdKernelLinear1, + (void*)TextureFwdKernelLinear2, + (void*)TextureFwdKernelLinear4, + (void*)TextureFwdKernelLinearMipmapNearest1, + (void*)TextureFwdKernelLinearMipmapNearest2, + (void*)TextureFwdKernelLinearMipmapNearest4, + (void*)TextureFwdKernelLinearMipmapLinear1, + (void*)TextureFwdKernelLinearMipmapLinear2, + (void*)TextureFwdKernelLinearMipmapLinear4, + (void*)TextureFwdKernelCubeNearest1, + (void*)TextureFwdKernelCubeNearest2, + (void*)TextureFwdKernelCubeNearest4, + (void*)TextureFwdKernelCubeLinear1, + (void*)TextureFwdKernelCubeLinear2, + (void*)TextureFwdKernelCubeLinear4, + (void*)TextureFwdKernelCubeLinearMipmapNearest1, + (void*)TextureFwdKernelCubeLinearMipmapNearest2, + (void*)TextureFwdKernelCubeLinearMipmapNearest4, + (void*)TextureFwdKernelCubeLinearMipmapLinear1, + (void*)TextureFwdKernelCubeLinearMipmapLinear2, + (void*)TextureFwdKernelCubeLinearMipmapLinear4, + NULL, + NULL, + NULL, + NULL, + NULL, + NULL, + (void*)TextureFwdKernelLinearMipmapNearestBO1, + (void*)TextureFwdKernelLinearMipmapNearestBO2, + (void*)TextureFwdKernelLinearMipmapNearestBO4, + (void*)TextureFwdKernelLinearMipmapLinearBO1, + (void*)TextureFwdKernelLinearMipmapLinearBO2, + (void*)TextureFwdKernelLinearMipmapLinearBO4, + NULL, + NULL, + NULL, + NULL, + NULL, + NULL, + (void*)TextureFwdKernelCubeLinearMipmapNearestBO1, + (void*)TextureFwdKernelCubeLinearMipmapNearestBO2, + (void*)TextureFwdKernelCubeLinearMipmapNearestBO4, + (void*)TextureFwdKernelCubeLinearMipmapLinearBO1, + (void*)TextureFwdKernelCubeLinearMipmapLinearBO2, + (void*)TextureFwdKernelCubeLinearMipmapLinearBO4, + }; + + // Function index. + int func_idx = p.filterMode; + if (cube_mode) + func_idx += TEX_MODE_COUNT; // Cube variant. + if (p.enableMip && !has_uv_da) + func_idx += TEX_MODE_COUNT * 2; // Bias-only variant. + func_idx = func_idx * 3 + channel_div_idx; // Choose vector size. + + // Launch kernel. + NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel(func_tbl[func_idx], gridSize, blockSize, args, 0, stream)); + + // Return output tensor. + return out; +} + +// Version without mipmaps. +torch::Tensor texture_fwd(torch::Tensor tex, torch::Tensor uv, int filter_mode, int boundary_mode) +{ + torch::Tensor empty_tensor; + std::vector empty_vector; + return texture_fwd_mip(tex, uv, empty_tensor, empty_tensor, TextureMipWrapper(), empty_vector, filter_mode, boundary_mode); +} + +//------------------------------------------------------------------------ +// Gradient op. + +std::tuple > texture_grad_linear_mipmap_linear(torch::Tensor tex, torch::Tensor uv, torch::Tensor dy, torch::Tensor uv_da, torch::Tensor mip_level_bias, TextureMipWrapper mip_wrapper, std::vector mip_stack, int filter_mode, int boundary_mode) +{ + const at::cuda::OptionalCUDAGuard device_guard(device_of(tex)); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + TextureKernelParams p = {}; // Initialize all fields to zero. + bool has_mip_stack = (mip_stack.size() > 0); + torch::Tensor& mip_w = mip_wrapper.mip; // Unwrap. + int max_mip_level = has_mip_stack ? mip_stack.size() : mip_wrapper.max_mip_level; + set_modes(p, filter_mode, boundary_mode, max_mip_level); + + // See if we have these tensors or not. + bool has_uv_da = uv_da.defined() && uv_da.nbytes(); + bool has_mip_level_bias = mip_level_bias.defined() && mip_level_bias.nbytes(); + + if (p.enableMip) + { + NVDR_CHECK(has_uv_da || has_mip_level_bias, "mipmapping filter mode requires uv_da and/or mip_level_bias input"); + NVDR_CHECK(has_mip_stack || mip_w.defined(), "mipmapping filter mode requires mip wrapper or mip stack input"); + } + + // Check inputs. + NVDR_CHECK_DEVICE(tex, uv); + NVDR_CHECK_CONTIGUOUS(tex, uv); + NVDR_CHECK_F32(tex, uv); + if (p.enableMip) + { + if (has_mip_stack) + { + TORCH_CHECK(at::cuda::check_device(mip_stack), __func__, "(): Mip stack inputs must reside on the correct GPU device"); + nvdr_check_contiguous(mip_stack, __func__, "(): Mip stack inputs must be contiguous tensors"); + nvdr_check_f32(mip_stack, __func__, "(): Mip stack inputs must be float32 tensors"); + } + else + { + NVDR_CHECK_DEVICE(mip_w); + NVDR_CHECK_CONTIGUOUS(mip_w); + NVDR_CHECK_F32(mip_w); + } + if (has_uv_da) + { + NVDR_CHECK_DEVICE(uv_da); + NVDR_CHECK_CONTIGUOUS(uv_da); + NVDR_CHECK_F32(uv_da); + } + if (has_mip_level_bias) + { + NVDR_CHECK_DEVICE(mip_level_bias); + NVDR_CHECK_CONTIGUOUS(mip_level_bias); + NVDR_CHECK_F32(mip_level_bias); + } + } + + // Sanity checks and state setters. + bool cube_mode = (boundary_mode == TEX_BOUNDARY_MODE_CUBE); + if (!cube_mode) + { + NVDR_CHECK(tex.sizes().size() == 4 && tex.size(0) > 0 && tex.size(1) > 0 && tex.size(2) > 0 && tex.size(3) > 0, "tex must have shape[>0, >0, >0, >0]"); + NVDR_CHECK(uv.sizes().size() == 4 && uv.size(0) > 0 && uv.size(1) > 0 && uv.size(2) > 0 && uv.size(3) == 2, "uv must have shape [>0, >0, >0, 2]"); + p.texHeight = tex.size(1); + p.texWidth = tex.size(2); + p.channels = tex.size(3); + } + else + { + NVDR_CHECK(tex.sizes().size() == 5 && tex.size(0) > 0 && tex.size(1) == 6 && tex.size(2) > 0 && tex.size(3) > 0 && tex.size(4) > 0, "tex must have shape[>0, 6, >0, >0, >0] in cube map mode"); + NVDR_CHECK(uv.sizes().size() == 4 && uv.size(0) > 0 && uv.size(1) > 0 && uv.size(2) > 0 && uv.size(3) == 3, "uv must have shape [>0, >0, >0, 3] in cube map mode"); + NVDR_CHECK(tex.size(2) == tex.size(3), "texture shape must be square in cube map mode"); + p.texHeight = tex.size(2); + p.texWidth = tex.size(3); + p.channels = tex.size(4); + } + NVDR_CHECK(tex.size(0) == 1 || tex.size(0) == uv.size(0), "minibatch size mismatch between inputs tex, uv"); + NVDR_CHECK(p.texWidth <= (1 << TEX_MAX_MIP_LEVEL) && p.texHeight <= (1 << TEX_MAX_MIP_LEVEL), "texture size too large"); + p.n = uv.size(0); + p.imgHeight = uv.size(1); + p.imgWidth = uv.size(2); + p.texDepth = tex.size(0); + if (p.enableMip) + { + if (has_uv_da) + { + if (!cube_mode) + NVDR_CHECK(uv_da.sizes().size() == 4 && uv_da.size(0) == p.n && uv_da.size(1) == p.imgHeight && uv_da.size(2) == p.imgWidth && uv_da.size(3) == 4, "uv_da must have shape [minibatch_size, height, width, 4]"); + else + NVDR_CHECK(uv_da.sizes().size() == 4 && uv_da.size(0) == p.n && uv_da.size(1) == p.imgHeight && uv_da.size(2) == p.imgWidth && uv_da.size(3) == 6, "uv_da must have shape [minibatch_size, height, width, 6] in cube map mode"); + } + if (has_mip_level_bias) + NVDR_CHECK(mip_level_bias.sizes().size() == 3 && mip_level_bias.size(0) == p.n && mip_level_bias.size(1) == p.imgHeight && mip_level_bias.size(2) == p.imgWidth, "mip_level_bias must have shape [minibatch_size, height, width]"); + } + NVDR_CHECK(dy.sizes().size() == 4 && dy.size(0) == p.n && dy.size(1) == p.imgHeight && dy.size(2) == p.imgWidth && dy.size(3) == p.channels, "dy must have shape [minibatch_size, height, width, channels]"); + + // Get contiguous version of dy. + torch::Tensor dy_ = dy.contiguous(); + + // Get input pointers. + p.tex[0] = tex.data_ptr(); + p.uv = uv.data_ptr(); + p.dy = dy_.data_ptr(); + p.uvDA = (p.enableMip && has_uv_da) ? uv_da.data_ptr() : NULL; + p.mipLevelBias = (p.enableMip && has_mip_level_bias) ? mip_level_bias.data_ptr() : NULL; + + // Allocate output tensor for tex gradient. + torch::Tensor grad_tex = torch::zeros_like(tex); + p.gradTex[0] = grad_tex.data_ptr(); + + // Allocate output tensor for uv gradient. + torch::Tensor grad_uv; + torch::Tensor grad_uv_da; + torch::Tensor grad_mip_level_bias; + if (p.filterMode != TEX_MODE_NEAREST) + { + grad_uv = torch::empty_like(uv); + p.gradUV = grad_uv.data_ptr(); + + // Gradients for things affecting mip level. + if (p.filterMode == TEX_MODE_LINEAR_MIPMAP_LINEAR) + { + // Allocate output tensor for uv_da gradient. + if (has_uv_da) + { + grad_uv_da = torch::empty_like(uv_da); + p.gradUVDA = grad_uv_da.data_ptr(); + } + + // Allocate output tensor for mip_level_bias gradient. + if (has_mip_level_bias) + { + grad_mip_level_bias = torch::empty_like(mip_level_bias); + p.gradMipLevelBias = grad_mip_level_bias.data_ptr(); + } + } + } + + // Choose kernel variants based on channel count. + int channel_div_idx = 0; + if (!(p.channels & 3)) + channel_div_idx = 2; // Channel count divisible by 4. + else if (!(p.channels & 1)) + channel_div_idx = 1; // Channel count divisible by 2. + + // Mip-related setup. + torch::Tensor grad_mip; + std::vector grad_mip_stack; + float* pmip = 0; + float* pgradMip = 0; + if (p.enableMip) + { + if (has_mip_stack) + { + // Custom mip stack supplied. Check that sizes match, assign, construct gradient tensors. + p.mipLevelMax = max_mip_level; + for (int i=1; i <= p.mipLevelMax; i++) + { + torch::Tensor& t = mip_stack[i-1]; + int2 sz = mipLevelSize(p, i); + if (!cube_mode) + NVDR_CHECK(t.sizes().size() == 4 && t.size(0) == tex.size(0) && t.size(1) == sz.y && t.size(2) == sz.x && t.size(3) == p.channels, "mip level size mismatch in mip stack"); + else + NVDR_CHECK(t.sizes().size() == 5 && t.size(0) == tex.size(0) && t.size(1) == 6 && t.size(2) == sz.y && t.size(3) == sz.x && t.size(4) == p.channels, "mip level size mismatch in mip stack"); + if (sz.x == 1 && sz.y == 1) + NVDR_CHECK(i == p.mipLevelMax, "mip level size mismatch in mip stack"); + + torch::Tensor g = torch::zeros_like(t); + grad_mip_stack.push_back(g); + + p.tex[i] = t.data_ptr(); + p.gradTex[i] = g.data_ptr(); + } + } + else + { + // Generate mip offsets and get space for temporary mip gradients. + int mipOffsets[TEX_MAX_MIP_LEVEL]; + int mipTotal = calculateMipInfo(NVDR_CTX_PARAMS, p, mipOffsets); + NVDR_CHECK(tex.sizes() == mip_wrapper.texture_size && cube_mode == mip_wrapper.cube_mode, "mip does not match texture size"); + NVDR_CHECK(mip_w.sizes().size() == 1 && mip_w.size(0) == mipTotal, "mip tensor size mismatch"); + grad_mip = torch::zeros_like(mip_w); + pmip = (float*)mip_w.data_ptr(); + pgradMip = grad_mip.data_ptr(); + for (int i=1; i <= p.mipLevelMax; i++) + { + p.tex[i] = pmip + mipOffsets[i]; // Pointers to mip levels. + p.gradTex[i] = pgradMip + mipOffsets[i]; // Pointers to mip gradients. + } + } + } + + // Verify that buffers are aligned to allow float2/float4 operations. Unused pointers are zero so always aligned. + if (!cube_mode) + { + NVDR_CHECK(!((uintptr_t)p.uv & 7), "uv input tensor not aligned to float2"); + NVDR_CHECK(!((uintptr_t)p.gradUV & 7), "grad_uv output tensor not aligned to float2"); + NVDR_CHECK(!((uintptr_t)p.uvDA & 15), "uv_da input tensor not aligned to float4"); + NVDR_CHECK(!((uintptr_t)p.gradUVDA & 15), "grad_uv_da output tensor not aligned to float4"); + } + else + { + NVDR_CHECK(!((uintptr_t)p.uvDA & 7), "uv_da input tensor not aligned to float2"); + NVDR_CHECK(!((uintptr_t)p.gradUVDA & 7), "grad_uv_da output tensor not aligned to float2"); + } + if ((p.channels & 3) == 0) + { + for (int i=0; i <= p.mipLevelMax; i++) + { + NVDR_CHECK(!((uintptr_t)p.tex[i] & 15), "tex or mip input tensor not aligned to float4"); + NVDR_CHECK(!((uintptr_t)p.gradTex[i] & 15), "grad_tex output tensor not aligned to float4"); + } + NVDR_CHECK(!((uintptr_t)p.dy & 15), "dy input tensor not aligned to float4"); + NVDR_CHECK(!((uintptr_t)pmip & 15), "mip input tensor not aligned to float4"); + NVDR_CHECK(!((uintptr_t)pgradMip & 15), "internal mip gradient tensor not aligned to float4"); + } + if ((p.channels & 1) == 0) + { + for (int i=0; i <= p.mipLevelMax; i++) + { + NVDR_CHECK(!((uintptr_t)p.tex[i] & 7), "tex or mip input tensor not aligned to float2"); + NVDR_CHECK(!((uintptr_t)p.gradTex[i] & 7), "grad_tex output tensor not aligned to float2"); + } + NVDR_CHECK(!((uintptr_t)p.dy & 7), "dy output tensor not aligned to float2"); + NVDR_CHECK(!((uintptr_t)pmip & 7), "mip input tensor not aligned to float2"); + NVDR_CHECK(!((uintptr_t)pgradMip & 7), "internal mip gradient tensor not aligned to float2"); + } + + // Choose launch parameters for main gradient kernel. + void* args[] = {&p}; + dim3 blockSize = getLaunchBlockSize(TEX_GRAD_MAX_KERNEL_BLOCK_WIDTH, TEX_GRAD_MAX_KERNEL_BLOCK_HEIGHT, p.imgWidth, p.imgHeight); + dim3 gridSize = getLaunchGridSize(blockSize, p.imgWidth, p.imgHeight, p.n); + + void* func_tbl[TEX_MODE_COUNT * 2 * 2] = { + (void*)TextureGradKernelNearest, + (void*)TextureGradKernelLinear, + (void*)TextureGradKernelLinearMipmapNearest, + (void*)TextureGradKernelLinearMipmapLinear, + (void*)TextureGradKernelCubeNearest, + (void*)TextureGradKernelCubeLinear, + (void*)TextureGradKernelCubeLinearMipmapNearest, + (void*)TextureGradKernelCubeLinearMipmapLinear, + NULL, + NULL, + (void*)TextureGradKernelLinearMipmapNearestBO, + (void*)TextureGradKernelLinearMipmapLinearBO, + NULL, + NULL, + (void*)TextureGradKernelCubeLinearMipmapNearestBO, + (void*)TextureGradKernelCubeLinearMipmapLinearBO, + }; + + // Function index. + int func_idx = p.filterMode; + if (cube_mode) + func_idx += TEX_MODE_COUNT; // Cube variant. + if (p.enableMip && !has_uv_da) + func_idx += TEX_MODE_COUNT * 2; // Bias-only variant. + + // Launch main gradient kernel. + NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel(func_tbl[func_idx], gridSize, blockSize, args, 0, stream)); + + // Launch kernel to pull gradients from mip levels. Don't do this if mip stack was supplied - individual level gradients are already there. + if (p.enableMip && !has_mip_stack) + { + dim3 blockSize = getLaunchBlockSize(TEX_GRAD_MAX_MIP_KERNEL_BLOCK_WIDTH, TEX_GRAD_MAX_MIP_KERNEL_BLOCK_HEIGHT, p.texWidth, p.texHeight); + dim3 gridSize = getLaunchGridSize(blockSize, p.texWidth, p.texHeight, p.texDepth * (cube_mode ? 6 : 1)); + int sharedBytes = blockSize.x * blockSize.y * p.channels * sizeof(float); + + void* mip_grad_func_tbl[3] = { (void*)MipGradKernel1, (void*)MipGradKernel2, (void*)MipGradKernel4 }; + NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel(mip_grad_func_tbl[channel_div_idx], gridSize, blockSize, args, sharedBytes, stream)); + } + + // Return output tensors. + return std::tuple >(grad_tex, grad_uv, grad_uv_da, grad_mip_level_bias, grad_mip_stack); +} + +// Version for nearest filter mode. +torch::Tensor texture_grad_nearest(torch::Tensor tex, torch::Tensor uv, torch::Tensor dy, int filter_mode, int boundary_mode) +{ + torch::Tensor empty_tensor; + std::vector empty_vector; + std::tuple > result = texture_grad_linear_mipmap_linear(tex, uv, dy, empty_tensor, empty_tensor, TextureMipWrapper(), empty_vector, filter_mode, boundary_mode); + return std::get<0>(result); +} + +// Version for linear filter mode. +std::tuple texture_grad_linear(torch::Tensor tex, torch::Tensor uv, torch::Tensor dy, int filter_mode, int boundary_mode) +{ + torch::Tensor empty_tensor; + std::vector empty_vector; + std::tuple > result = texture_grad_linear_mipmap_linear(tex, uv, dy, empty_tensor, empty_tensor, TextureMipWrapper(), empty_vector, filter_mode, boundary_mode); + return std::tuple(std::get<0>(result), std::get<1>(result)); +} + +// Version for linear-mipmap-nearest mode. +std::tuple > texture_grad_linear_mipmap_nearest(torch::Tensor tex, torch::Tensor uv, torch::Tensor dy, torch::Tensor uv_da, torch::Tensor mip_level_bias, TextureMipWrapper mip_wrapper, std::vector mip_stack, int filter_mode, int boundary_mode) +{ + std::tuple > result = texture_grad_linear_mipmap_linear(tex, uv, dy, uv_da, mip_level_bias, mip_wrapper, mip_stack, filter_mode, boundary_mode); + return std::tuple >(std::get<0>(result), std::get<1>(result), std::get<4>(result)); +} + +//------------------------------------------------------------------------ diff --git a/extensions/nvdiffrast/nvdiffrast/torch/torch_types.h b/extensions/nvdiffrast/nvdiffrast/torch/torch_types.h new file mode 100644 index 0000000000000000000000000000000000000000..8e389582e65d5df91f4273b8959969fa6dbe1b37 --- /dev/null +++ b/extensions/nvdiffrast/nvdiffrast/torch/torch_types.h @@ -0,0 +1,65 @@ +// Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +// +// NVIDIA CORPORATION and its licensors retain all intellectual property +// and proprietary rights in and to this software, related documentation +// and any modifications thereto. Any use, reproduction, disclosure or +// distribution of this software and related documentation without an express +// license agreement from NVIDIA CORPORATION is strictly prohibited. + +#include "torch_common.inl" + +//------------------------------------------------------------------------ +// Python GL state wrapper. + +class RasterizeGLState; +class RasterizeGLStateWrapper +{ +public: + RasterizeGLStateWrapper (bool enableDB, bool automatic, int cudaDeviceIdx); + ~RasterizeGLStateWrapper (void); + + void setContext (void); + void releaseContext (void); + + RasterizeGLState* pState; + bool automatic; + int cudaDeviceIdx; +}; + +//------------------------------------------------------------------------ +// Python CudaRaster state wrapper. + +namespace CR { class CudaRaster; } +class RasterizeCRStateWrapper +{ +public: + RasterizeCRStateWrapper (int cudaDeviceIdx); + ~RasterizeCRStateWrapper (void); + + CR::CudaRaster* cr; + int cudaDeviceIdx; +}; + +//------------------------------------------------------------------------ +// Mipmap wrapper to prevent intrusion from Python side. + +class TextureMipWrapper +{ +public: + torch::Tensor mip; + int max_mip_level; + std::vector texture_size; // For error checking. + bool cube_mode; // For error checking. +}; + + +//------------------------------------------------------------------------ +// Antialias topology hash wrapper to prevent intrusion from Python side. + +class TopologyHashWrapper +{ +public: + torch::Tensor ev_hash; +}; + +//------------------------------------------------------------------------ diff --git a/extensions/nvdiffrast/run_sample.sh b/extensions/nvdiffrast/run_sample.sh new file mode 100644 index 0000000000000000000000000000000000000000..3758865c3359c12da203fb34360f8caa2824e8ef --- /dev/null +++ b/extensions/nvdiffrast/run_sample.sh @@ -0,0 +1,52 @@ +#!/bin/bash + +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +function print_help { + echo "Usage: `basename $0` [--build-container] " + echo "" + echo "Option --build-container will build the Docker container based on" + echo "docker/Dockerfile and tag the image with gltorch:latest." + echo "" + echo "Example: `basename $0` samples/torch/envphong.py" +} + +build_container=0 +sample="" +while [[ "$#" -gt 0 ]]; do + case $1 in + --build-container) build_container=1;; + -h|--help) print_help; exit 0 ;; + --*) echo "Unknown parameter passed: $1"; exit 1 ;; + *) sample="$1"; shift; break; + esac + shift +done + +rest=$@ + +# Build the docker container +if [ "$build_container" = "1" ]; then + docker build --tag gltorch:latest -f docker/Dockerfile . +fi + +if [ ! -f "$sample" ]; then + echo + echo "No python sample given or file '$sample' not found. Exiting." + exit 1 +fi + +image="gltorch:latest" + +echo "Using container image: $image" +echo "Running command: $sample $rest" + +# Run a sample with docker +docker run --rm -it --gpus all --user $(id -u):$(id -g) \ + -v `pwd`:/app --workdir /app -e TORCH_EXTENSIONS_DIR=/app/tmp $image python3 $sample $rest diff --git a/extensions/nvdiffrast/setup copy.py b/extensions/nvdiffrast/setup copy.py new file mode 100644 index 0000000000000000000000000000000000000000..f7f9dede9649583be8fdd2ba6aa6c3aab184ed54 --- /dev/null +++ b/extensions/nvdiffrast/setup copy.py @@ -0,0 +1,51 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +import nvdiffrast +import setuptools +import os + +with open("README.md", "r") as fh: + long_description = fh.read() + +setuptools.setup( + name="nvdiffrast", + version=nvdiffrast.__version__, + author="Samuli Laine", + author_email="slaine@nvidia.com", + description="nvdiffrast - modular primitives for high-performance differentiable rendering", + long_description=long_description, + long_description_content_type="text/markdown", + url="https://github.com/NVlabs/nvdiffrast", + packages=setuptools.find_packages(), + package_data={ + 'nvdiffrast': [ + 'common/*.h', + 'common/*.inl', + 'common/*.cu', + 'common/*.cpp', + 'common/cudaraster/*.hpp', + 'common/cudaraster/impl/*.cpp', + 'common/cudaraster/impl/*.hpp', + 'common/cudaraster/impl/*.inl', + 'common/cudaraster/impl/*.cu', + 'lib/*.h', + 'torch/*.h', + 'torch/*.inl', + 'torch/*.cpp', + 'tensorflow/*.cu', + ] + (['lib/*.lib'] if os.name == 'nt' else []) + }, + include_package_data=True, + install_requires=['numpy'], # note: can't require torch here as it will install torch even for a TensorFlow container + classifiers=[ + "Programming Language :: Python :: 3", + "Operating System :: OS Independent", + ], + python_requires='>=3.6', +) diff --git a/extensions/nvdiffrast/setup.py b/extensions/nvdiffrast/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..507cb06f18fbc948e81fd7791f87489d8c35347b --- /dev/null +++ b/extensions/nvdiffrast/setup.py @@ -0,0 +1,82 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +import nvdiffrast +import setuptools +import os +from torch.utils.cpp_extension import CUDAExtension, BuildExtension + + +with open("README.md", "r") as fh: + long_description = fh.read() + +setuptools.setup( + name="nvdiffrast", + version=nvdiffrast.__version__, + author="Samuli Laine", + author_email="slaine@nvidia.com", + description="nvdiffrast - modular primitives for high-performance differentiable rendering", + long_description=long_description, + long_description_content_type="text/markdown", + url="https://github.com/NVlabs/nvdiffrast", + packages=setuptools.find_packages(), + # package_data={ + # 'nvdiffrast': [ + # 'common/*.h', + # 'common/*.inl', + # 'common/*.cu', + # 'common/*.cpp', + # 'common/cudaraster/*.hpp', + # 'common/cudaraster/impl/*.cpp', + # 'common/cudaraster/impl/*.hpp', + # 'common/cudaraster/impl/*.inl', + # 'common/cudaraster/impl/*.cu', + # 'lib/*.h', + # 'torch/*.h', + # 'torch/*.inl', + # 'torch/*.cpp', + # 'tensorflow/*.cu', + # ] + (['lib/*.lib'] if os.name == 'nt' else []) + # }, + # include_package_data=True, + ext_modules=[ + CUDAExtension( + name="nvdiffrast.torch._C", + sources=[ + 'nvdiffrast/common/cudaraster/impl/Buffer.cpp', + 'nvdiffrast/common/cudaraster/impl/CudaRaster.cpp', + 'nvdiffrast/common/cudaraster/impl/RasterImpl_.cu', + 'nvdiffrast/common/cudaraster/impl/RasterImpl.cpp', + 'nvdiffrast/common/common.cpp', + 'nvdiffrast/common/rasterize.cu', + 'nvdiffrast/common/interpolate.cu', + 'nvdiffrast/common/texture_.cu', + 'nvdiffrast/common/texture.cpp', + 'nvdiffrast/common/antialias.cu', + 'nvdiffrast/torch/torch_bindings.cpp', + 'nvdiffrast/torch/torch_rasterize.cpp', + 'nvdiffrast/torch/torch_interpolate.cpp', + 'nvdiffrast/torch/torch_texture.cpp', + 'nvdiffrast/torch/torch_antialias.cpp', + ], + extra_compile_args={ + 'cxx': ['-DNVDR_TORCH'], + 'nvcc': ['-DNVDR_TORCH', '-lineinfo'], + }, + ) + ], + cmdclass={ + 'build_ext': BuildExtension + }, + install_requires=['numpy'], # note: can't require torch here as it will install torch even for a TensorFlow container + classifiers=[ + "Programming Language :: Python :: 3", + "Operating System :: OS Independent", + ], + python_requires='>=3.6', +) diff --git a/gitattributes b/gitattributes new file mode 100644 index 0000000000000000000000000000000000000000..f0248c86648a63bb0d6b267bbcf068b83b128902 --- /dev/null +++ b/gitattributes @@ -0,0 +1,61 @@ +*.7z filter=lfs diff=lfs merge=lfs -text +*.arrow filter=lfs diff=lfs merge=lfs -text +*.bin filter=lfs diff=lfs merge=lfs -text +*.bz2 filter=lfs diff=lfs merge=lfs -text +*.ckpt filter=lfs diff=lfs merge=lfs -text +*.ftz filter=lfs diff=lfs merge=lfs -text +*.gz filter=lfs diff=lfs merge=lfs -text +*.h5 filter=lfs diff=lfs merge=lfs -text +*.joblib filter=lfs diff=lfs merge=lfs -text +*.lfs.* filter=lfs diff=lfs merge=lfs -text +*.mlmodel filter=lfs diff=lfs merge=lfs -text +*.model filter=lfs diff=lfs merge=lfs -text +*.msgpack filter=lfs diff=lfs merge=lfs -text +*.npy filter=lfs diff=lfs merge=lfs -text +*.npz filter=lfs diff=lfs merge=lfs -text +*.onnx filter=lfs diff=lfs merge=lfs -text +*.ot filter=lfs diff=lfs merge=lfs -text +*.parquet filter=lfs diff=lfs merge=lfs -text +*.pb filter=lfs diff=lfs merge=lfs -text +*.pickle filter=lfs diff=lfs merge=lfs -text +*.pkl filter=lfs diff=lfs merge=lfs -text +*.pt filter=lfs diff=lfs merge=lfs -text +*.pth filter=lfs diff=lfs merge=lfs -text +*.rar filter=lfs diff=lfs merge=lfs -text +*.safetensors filter=lfs diff=lfs merge=lfs -text +saved_model/**/* filter=lfs diff=lfs merge=lfs -text +*.tar.* filter=lfs diff=lfs merge=lfs -text +*.tar filter=lfs diff=lfs merge=lfs -text +*.tflite filter=lfs diff=lfs merge=lfs -text +*.tgz filter=lfs diff=lfs merge=lfs -text +*.wasm filter=lfs diff=lfs merge=lfs -text +*.xz filter=lfs diff=lfs merge=lfs -text +*.zip filter=lfs diff=lfs merge=lfs -text +*.zst filter=lfs diff=lfs merge=lfs -text +*tfevents* filter=lfs diff=lfs merge=lfs -text +assets/example_multi_image/flower_1.png filter=lfs diff=lfs merge=lfs -text +assets/example_multi_image/flower_2.png filter=lfs diff=lfs merge=lfs -text +assets/example_multi_image/flower_3.png filter=lfs diff=lfs merge=lfs -text +assets/example_multi_image/flower_4.png filter=lfs diff=lfs merge=lfs -text +assets/example_multi_image/flower_5.png filter=lfs diff=lfs merge=lfs -text +assets/example_multi_image/flower_6.png filter=lfs diff=lfs merge=lfs -text +assets/example_multi_image/flower_7.png filter=lfs diff=lfs merge=lfs -text +assets/example_multi_image/flower_8.png filter=lfs diff=lfs merge=lfs -text +assets/example_multi_image/monkey_1.png filter=lfs diff=lfs merge=lfs -text +assets/example_multi_image/monkey_2.png filter=lfs diff=lfs merge=lfs -text +assets/example_multi_image/monkey_3.png filter=lfs diff=lfs merge=lfs -text +assets/example_multi_image/monkey_4.png filter=lfs diff=lfs merge=lfs -text +assets/example_multi_image/paopao_1.png filter=lfs diff=lfs merge=lfs -text +assets/example_multi_image/paopao_2.png filter=lfs diff=lfs merge=lfs -text +assets/example_multi_image/paopao_3.png filter=lfs diff=lfs merge=lfs -text +assets/example_multi_image/paopao_4.png filter=lfs diff=lfs merge=lfs -text +assets/example_multi_image/paopao_5.png filter=lfs diff=lfs merge=lfs -text +assets/example_multi_image/paopao_6.png filter=lfs diff=lfs merge=lfs -text +assets/example_multi_image/paopao_7.png filter=lfs diff=lfs merge=lfs -text +assets/example_multi_image/paopao_8.png filter=lfs diff=lfs merge=lfs -text +assets/example_multi_image/SpongeBob_1.png filter=lfs diff=lfs merge=lfs -text +assets/example_multi_image/SpongeBob_2.png filter=lfs diff=lfs merge=lfs -text +assets/example_multi_image/SpongeBob_3.png filter=lfs diff=lfs merge=lfs -text +assets/example_multi_image/SpongeBob_4.png filter=lfs diff=lfs merge=lfs -text +wheels/diff_gaussian_rasterization-0.0.0-cp310-cp310-linux_x86_64.whl filter=lfs diff=lfs merge=lfs -text +wheels/nvdiffrast-0.3.3-cp310-cp310-linux_x86_64.whl filter=lfs diff=lfs merge=lfs -text diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..8fd76ec57f63c18e00fb3ccf51bfef441063f999 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,31 @@ +--extra-index-url https://download.pytorch.org/whl/cu121 + +torch==2.4.0 +torchvision==0.19.0 +pillow==10.4.0 +imageio==2.36.1 +imageio-ffmpeg==0.5.1 +tqdm==4.67.1 +easydict==1.13 +opencv-python-headless==4.10.0.84 +scipy==1.14.1 +rembg==2.0.60 +onnxruntime==1.20.1 +trimesh==4.5.3 +xatlas==0.0.9 +pyvista==0.44.2 +pymeshfix==0.17.0 +igraph==0.11.8 +git+https://github.com/EasternJournalist/utils3d.git@9a4eb15e4021b67b12c460c7057d642626897ec8 +xformers==0.0.27.post2 +spconv-cu120==2.3.6 +transformers==4.46.3 +gradio_litmodel3d==0.0.1 +pydantic==2.10.6 +einops==0.8.1 +huggingface_hub==0.25.0 +lpips==0.1.4 +open3d==0.19.0 +https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.0.post2/flash_attn-2.7.0.post2+cu12torch2.4cxx11abiFALSE-cp310-cp310-linux_x86_64.whl +https://huggingface.co/spaces/JeffreyXiang/TRELLIS/resolve/main/wheels/diff_gaussian_rasterization-0.0.0-cp310-cp310-linux_x86_64.whl?download=true +https://huggingface.co/spaces/JeffreyXiang/TRELLIS/resolve/main/wheels/nvdiffrast-0.3.3-cp310-cp310-linux_x86_64.whl?download=true \ No newline at end of file diff --git a/trellis/__init__.py b/trellis/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..20d240afc9c26a21aee76954628b3d4ef9a1ccbd --- /dev/null +++ b/trellis/__init__.py @@ -0,0 +1,6 @@ +from . import models +from . import modules +from . import pipelines +from . import renderers +from . import representations +from . import utils diff --git a/trellis/__pycache__/__init__.cpython-310.pyc b/trellis/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..23e2fd5458d5ea6efe46fe6ef3d794c9f51d5559 Binary files /dev/null and b/trellis/__pycache__/__init__.cpython-310.pyc differ diff --git a/trellis/models/__init__.py b/trellis/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8c302e12b2c530c3cc0e844be83551bcc5814e4a --- /dev/null +++ b/trellis/models/__init__.py @@ -0,0 +1,88 @@ +import importlib + +__attributes = { + 'SparseStructureEncoder': 'sparse_structure_vae', + 'SparseStructureDecoder': 'sparse_structure_vae', + 'SparseStructureFlowModel': 'sparse_structure_flow', + 'SLatEncoder': 'structured_latent_vae', + 'SLatGaussianDecoder': 'structured_latent_vae', + 'SLatMeshDecoder': 'structured_latent_vae', + 'SLatFlowModel': 'structured_latent_flow', + 'ModulatedMultiViewCond': 'sparse_structure_flow', +} + +__submodules = [] + +__all__ = list(__attributes.keys()) + __submodules + +def __getattr__(name): + if name not in globals(): + if name in __attributes: + module_name = __attributes[name] + module = importlib.import_module(f".{module_name}", __name__) + globals()[name] = getattr(module, name) + elif name in __submodules: + module = importlib.import_module(f".{name}", __name__) + globals()[name] = module + else: + raise AttributeError(f"module {__name__} has no attribute {name}") + return globals()[name] + + +def from_pretrained(path: str, **kwargs): + """ + Load a model from a pretrained checkpoint. + + Args: + path: The path to the checkpoint. Can be either local path or a Hugging Face model name. + NOTE: config file and model file should take the name f'{path}.json' and f'{path}.safetensors' respectively. + **kwargs: Additional arguments for the model constructor. + """ + import os + import json + from safetensors.torch import load_file + is_local = os.path.exists(f"{path}.json") and os.path.exists(f"{path}.safetensors") + + if is_local: + config_file = f"{path}.json" + model_file = f"{path}.safetensors" + else: + from huggingface_hub import hf_hub_download + path_parts = path.split('/') + repo_id = f'{path_parts[0]}/{path_parts[1]}' + model_name = '/'.join(path_parts[2:]) + config_file = hf_hub_download(repo_id, f"{model_name}.json") + model_file = hf_hub_download(repo_id, f"{model_name}.safetensors") + + with open(config_file, 'r') as f: + config = json.load(f) + model = __getattr__(config['name'])(**config['args'], **kwargs) + model.load_state_dict(load_file(model_file), strict=False) + + return model + +def save_finetuned_model(model, output_dir: str): + """ + Save a fine-tuned model's state_dict as safetensors with a timestamp. + + Args: + model: The model to be saved. + output_dir: The directory where the model's state_dict will be saved. + The file will be saved as f'{output_dir}/{timestamp}.safetensors'. + """ + from safetensors.torch import save_file + import os + from datetime import datetime + + if not os.path.exists(output_dir): + os.makedirs(output_dir) + + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + save_file(model.state_dict(), f"{output_dir}/{timestamp}.safetensors") + +# For Pylance +if __name__ == '__main__': + from .sparse_structure_vae import SparseStructureEncoder, SparseStructureDecoder + from .sparse_structure_flow import SparseStructureFlowModel, ModulatedMultiViewCond + from .structured_latent_vae import SLatEncoder, SLatGaussianDecoder, SLatMeshDecoder + from .structured_latent_flow import SLatFlowModel diff --git a/trellis/models/__pycache__/__init__.cpython-310.pyc b/trellis/models/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..49d5d11874115c1ee31886d315263c1495d44f5b Binary files /dev/null and b/trellis/models/__pycache__/__init__.cpython-310.pyc differ diff --git a/trellis/models/__pycache__/sparse_structure_flow.cpython-310.pyc b/trellis/models/__pycache__/sparse_structure_flow.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..202e2449ca45640ed8caf16011e835fe16555f6f Binary files /dev/null and b/trellis/models/__pycache__/sparse_structure_flow.cpython-310.pyc differ diff --git a/trellis/models/__pycache__/sparse_structure_vae.cpython-310.pyc b/trellis/models/__pycache__/sparse_structure_vae.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5bfef55b5aa9a1c02f3e2f354d5f4e2149a49d76 Binary files /dev/null and b/trellis/models/__pycache__/sparse_structure_vae.cpython-310.pyc differ diff --git a/trellis/models/__pycache__/structured_latent_flow.cpython-310.pyc b/trellis/models/__pycache__/structured_latent_flow.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8d0e095cd6f99f10575fe0c83645a74c6aa58e3c Binary files /dev/null and b/trellis/models/__pycache__/structured_latent_flow.cpython-310.pyc differ diff --git a/trellis/models/heads/camera_head.py b/trellis/models/heads/camera_head.py new file mode 100644 index 0000000000000000000000000000000000000000..176d76fb5baeb3a42fa3675a1d1fb14010f2904d --- /dev/null +++ b/trellis/models/heads/camera_head.py @@ -0,0 +1,162 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import math +import numpy as np + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from vggt.layers import Mlp +from vggt.layers.block import Block +from vggt.heads.head_act import activate_pose + + +class CameraHead(nn.Module): + """ + CameraHead predicts camera parameters from token representations using iterative refinement. + + It applies a series of transformer blocks (the "trunk") to dedicated camera tokens. + """ + + def __init__( + self, + dim_in: int = 2048, + trunk_depth: int = 4, + pose_encoding_type: str = "absT_quaR_FoV", + num_heads: int = 16, + mlp_ratio: int = 4, + init_values: float = 0.01, + trans_act: str = "linear", + quat_act: str = "linear", + fl_act: str = "relu", # Field of view activations: ensures FOV values are positive. + ): + super().__init__() + + if pose_encoding_type == "absT_quaR_FoV": + self.target_dim = 9 + else: + raise ValueError(f"Unsupported camera encoding type: {pose_encoding_type}") + + self.trans_act = trans_act + self.quat_act = quat_act + self.fl_act = fl_act + self.trunk_depth = trunk_depth + + # Build the trunk using a sequence of transformer blocks. + self.trunk = nn.Sequential( + *[ + Block( + dim=dim_in, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + init_values=init_values, + ) + for _ in range(trunk_depth) + ] + ) + + # Normalizations for camera token and trunk output. + self.token_norm = nn.LayerNorm(dim_in) + self.trunk_norm = nn.LayerNorm(dim_in) + + # Learnable empty camera pose token. + self.empty_pose_tokens = nn.Parameter(torch.zeros(1, 1, self.target_dim)) + self.embed_pose = nn.Linear(self.target_dim, dim_in) + + # Module for producing modulation parameters: shift, scale, and a gate. + self.poseLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(dim_in, 3 * dim_in, bias=True)) + + # Adaptive layer normalization without affine parameters. + self.adaln_norm = nn.LayerNorm(dim_in, elementwise_affine=False, eps=1e-6) + self.pose_branch = Mlp( + in_features=dim_in, + hidden_features=dim_in // 2, + out_features=self.target_dim, + drop=0, + ) + + def forward(self, aggregated_tokens_list: list, num_iterations: int = 4) -> list: + """ + Forward pass to predict camera parameters. + + Args: + aggregated_tokens_list (list): List of token tensors from the network; + the last tensor is used for prediction. + num_iterations (int, optional): Number of iterative refinement steps. Defaults to 4. + + Returns: + list: A list of predicted camera encodings (post-activation) from each iteration. + """ + # Use tokens from the last block for camera prediction. + tokens = aggregated_tokens_list[-1] + + # Extract the camera tokens + pose_tokens = tokens[:, :, 0] + pose_tokens = self.token_norm(pose_tokens) + + pred_pose_enc_list = self.trunk_fn(pose_tokens, num_iterations) + return pred_pose_enc_list + + def trunk_fn(self, pose_tokens: torch.Tensor, num_iterations: int) -> list: + """ + Iteratively refine camera pose predictions. + + Args: + pose_tokens (torch.Tensor): Normalized camera tokens with shape [B, 1, C]. + num_iterations (int): Number of refinement iterations. + + Returns: + list: List of activated camera encodings from each iteration. + """ + B, S, C = pose_tokens.shape # S is expected to be 1. + pred_pose_enc = None + pred_pose_enc_list = [] + + for _ in range(num_iterations): + # Use a learned empty pose for the first iteration. + if pred_pose_enc is None: + module_input = self.embed_pose(self.empty_pose_tokens.expand(B, S, -1)) + else: + # Detach the previous prediction to avoid backprop through time. + pred_pose_enc = pred_pose_enc.detach() + module_input = self.embed_pose(pred_pose_enc) + + # Generate modulation parameters and split them into shift, scale, and gate components. + shift_msa, scale_msa, gate_msa = self.poseLN_modulation(module_input).chunk(3, dim=-1) + + # Adaptive layer normalization and modulation. + pose_tokens_modulated = gate_msa * modulate(self.adaln_norm(pose_tokens), shift_msa, scale_msa) + pose_tokens_modulated = pose_tokens_modulated + pose_tokens + + pose_tokens_modulated = self.trunk(pose_tokens_modulated) + # Compute the delta update for the pose encoding. + pred_pose_enc_delta = self.pose_branch(self.trunk_norm(pose_tokens_modulated)) + + if pred_pose_enc is None: + pred_pose_enc = pred_pose_enc_delta + else: + pred_pose_enc = pred_pose_enc + pred_pose_enc_delta + + # Apply final activation functions for translation, quaternion, and field-of-view. + activated_pose = activate_pose( + pred_pose_enc, + trans_act=self.trans_act, + quat_act=self.quat_act, + fl_act=self.fl_act, + ) + pred_pose_enc_list.append(activated_pose) + + return pred_pose_enc_list + + +def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor) -> torch.Tensor: + """ + Modulate the input tensor using scaling and shifting parameters. + """ + # modified from https://github.com/facebookresearch/DiT/blob/796c29e532f47bba17c5b9c5eb39b9354b8b7c64/models.py#L19 + return x * (1 + scale) + shift diff --git a/trellis/models/heads/dpt_head.py b/trellis/models/heads/dpt_head.py new file mode 100644 index 0000000000000000000000000000000000000000..390b6dab2ee2e9e5ed274b92a0d418a91261a21f --- /dev/null +++ b/trellis/models/heads/dpt_head.py @@ -0,0 +1,497 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + + +# Inspired by https://github.com/DepthAnything/Depth-Anything-V2 + + +import os +from typing import List, Dict, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from .head_act import activate_head +from .utils import create_uv_grid, position_grid_to_embed + + +class DPTHead(nn.Module): + """ + DPT Head for dense prediction tasks. + + This implementation follows the architecture described in "Vision Transformers for Dense Prediction" + (https://arxiv.org/abs/2103.13413). The DPT head processes features from a vision transformer + backbone and produces dense predictions by fusing multi-scale features. + + Args: + dim_in (int): Input dimension (channels). + patch_size (int, optional): Patch size. Default is 14. + output_dim (int, optional): Number of output channels. Default is 4. + activation (str, optional): Activation type. Default is "inv_log". + conf_activation (str, optional): Confidence activation type. Default is "expp1". + features (int, optional): Feature channels for intermediate representations. Default is 256. + out_channels (List[int], optional): Output channels for each intermediate layer. + intermediate_layer_idx (List[int], optional): Indices of layers from aggregated tokens used for DPT. + pos_embed (bool, optional): Whether to use positional embedding. Default is True. + feature_only (bool, optional): If True, return features only without the last several layers and activation head. Default is False. + down_ratio (int, optional): Downscaling factor for the output resolution. Default is 1. + """ + + def __init__( + self, + dim_in: int, + patch_size: int = 14, + output_dim: int = 4, + activation: str = "inv_log", + conf_activation: str = "expp1", + features: int = 256, + out_channels: List[int] = [256, 512, 1024, 1024], + intermediate_layer_idx: List[int] = [4, 11, 17, 23], + pos_embed: bool = True, + feature_only: bool = False, + down_ratio: int = 1, + ) -> None: + super(DPTHead, self).__init__() + self.patch_size = patch_size + self.activation = activation + self.conf_activation = conf_activation + self.pos_embed = pos_embed + self.feature_only = feature_only + self.down_ratio = down_ratio + self.intermediate_layer_idx = intermediate_layer_idx + + self.norm = nn.LayerNorm(dim_in) + + # Projection layers for each output channel from tokens. + self.projects = nn.ModuleList( + [ + nn.Conv2d( + in_channels=dim_in, + out_channels=oc, + kernel_size=1, + stride=1, + padding=0, + ) + for oc in out_channels + ] + ) + + # Resize layers for upsampling feature maps. + self.resize_layers = nn.ModuleList( + [ + nn.ConvTranspose2d( + in_channels=out_channels[0], out_channels=out_channels[0], kernel_size=4, stride=4, padding=0 + ), + nn.ConvTranspose2d( + in_channels=out_channels[1], out_channels=out_channels[1], kernel_size=2, stride=2, padding=0 + ), + nn.Identity(), + nn.Conv2d( + in_channels=out_channels[3], out_channels=out_channels[3], kernel_size=3, stride=2, padding=1 + ), + ] + ) + + self.scratch = _make_scratch( + out_channels, + features, + expand=False, + ) + + # Attach additional modules to scratch. + self.scratch.stem_transpose = None + self.scratch.refinenet1 = _make_fusion_block(features) + self.scratch.refinenet2 = _make_fusion_block(features) + self.scratch.refinenet3 = _make_fusion_block(features) + self.scratch.refinenet4 = _make_fusion_block(features, has_residual=False) + + head_features_1 = features + head_features_2 = 32 + + if feature_only: + self.scratch.output_conv1 = nn.Conv2d(head_features_1, head_features_1, kernel_size=3, stride=1, padding=1) + else: + self.scratch.output_conv1 = nn.Conv2d( + head_features_1, head_features_1 // 2, kernel_size=3, stride=1, padding=1 + ) + conv2_in_channels = head_features_1 // 2 + + self.scratch.output_conv2 = nn.Sequential( + nn.Conv2d(conv2_in_channels, head_features_2, kernel_size=3, stride=1, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(head_features_2, output_dim, kernel_size=1, stride=1, padding=0), + ) + + def forward( + self, + aggregated_tokens_list: List[torch.Tensor], + images: torch.Tensor, + patch_start_idx: int, + frames_chunk_size: int = 8, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + """ + Forward pass through the DPT head, supports processing by chunking frames. + Args: + aggregated_tokens_list (List[Tensor]): List of token tensors from different transformer layers. + images (Tensor): Input images with shape [B, S, 3, H, W], in range [0, 1]. + patch_start_idx (int): Starting index for patch tokens in the token sequence. + Used to separate patch tokens from other tokens (e.g., camera or register tokens). + frames_chunk_size (int, optional): Number of frames to process in each chunk. + If None or larger than S, all frames are processed at once. Default: 8. + + Returns: + Tensor or Tuple[Tensor, Tensor]: + - If feature_only=True: Feature maps with shape [B, S, C, H, W] + - Otherwise: Tuple of (predictions, confidence) both with shape [B, S, 1, H, W] + """ + B, S, _, H, W = images.shape + + # If frames_chunk_size is not specified or greater than S, process all frames at once + if frames_chunk_size is None or frames_chunk_size >= S: + return self._forward_impl(aggregated_tokens_list, images, patch_start_idx) + + # Otherwise, process frames in chunks to manage memory usage + assert frames_chunk_size > 0 + + # Process frames in batches + all_preds = [] + all_conf = [] + + for frames_start_idx in range(0, S, frames_chunk_size): + frames_end_idx = min(frames_start_idx + frames_chunk_size, S) + + # Process batch of frames + if self.feature_only: + chunk_output = self._forward_impl( + aggregated_tokens_list, images, patch_start_idx, frames_start_idx, frames_end_idx + ) + all_preds.append(chunk_output) + else: + chunk_preds, chunk_conf = self._forward_impl( + aggregated_tokens_list, images, patch_start_idx, frames_start_idx, frames_end_idx + ) + all_preds.append(chunk_preds) + all_conf.append(chunk_conf) + + # Concatenate results along the sequence dimension + if self.feature_only: + return torch.cat(all_preds, dim=1) + else: + return torch.cat(all_preds, dim=1), torch.cat(all_conf, dim=1) + + def _forward_impl( + self, + aggregated_tokens_list: List[torch.Tensor], + images: torch.Tensor, + patch_start_idx: int, + frames_start_idx: int = None, + frames_end_idx: int = None, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + """ + Implementation of the forward pass through the DPT head. + + This method processes a specific chunk of frames from the sequence. + + Args: + aggregated_tokens_list (List[Tensor]): List of token tensors from different transformer layers. + images (Tensor): Input images with shape [B, S, 3, H, W]. + patch_start_idx (int): Starting index for patch tokens. + frames_start_idx (int, optional): Starting index for frames to process. + frames_end_idx (int, optional): Ending index for frames to process. + + Returns: + Tensor or Tuple[Tensor, Tensor]: Feature maps or (predictions, confidence). + """ + if frames_start_idx is not None and frames_end_idx is not None: + images = images[:, frames_start_idx:frames_end_idx].contiguous() + + B, S, _, H, W = images.shape + + patch_h, patch_w = H // self.patch_size, W // self.patch_size + + out = [] + dpt_idx = 0 + + for layer_idx in self.intermediate_layer_idx: + x = aggregated_tokens_list[layer_idx][:, :, patch_start_idx:] + + # Select frames if processing a chunk + if frames_start_idx is not None and frames_end_idx is not None: + x = x[:, frames_start_idx:frames_end_idx] + + x = x.view(B * S, -1, x.shape[-1]) + + x = self.norm(x) + + x = x.permute(0, 2, 1).reshape((x.shape[0], x.shape[-1], patch_h, patch_w)) + + x = self.projects[dpt_idx](x) + if self.pos_embed: + x = self._apply_pos_embed(x, W, H) + x = self.resize_layers[dpt_idx](x) + + out.append(x) + dpt_idx += 1 + + # Fuse features from multiple layers. + out = self.scratch_forward(out) + # Interpolate fused output to match target image resolution. + out = custom_interpolate( + out, + (int(patch_h * self.patch_size / self.down_ratio), int(patch_w * self.patch_size / self.down_ratio)), + mode="bilinear", + align_corners=True, + ) + + if self.pos_embed: + out = self._apply_pos_embed(out, W, H) + + if self.feature_only: + return out.view(B, S, *out.shape[1:]) + + out = self.scratch.output_conv2(out) + preds, conf = activate_head(out, activation=self.activation, conf_activation=self.conf_activation) + + preds = preds.view(B, S, *preds.shape[1:]) + conf = conf.view(B, S, *conf.shape[1:]) + return preds, conf + + def _apply_pos_embed(self, x: torch.Tensor, W: int, H: int, ratio: float = 0.1) -> torch.Tensor: + """ + Apply positional embedding to tensor x. + """ + patch_w = x.shape[-1] + patch_h = x.shape[-2] + pos_embed = create_uv_grid(patch_w, patch_h, aspect_ratio=W / H, dtype=x.dtype, device=x.device) + pos_embed = position_grid_to_embed(pos_embed, x.shape[1]) + pos_embed = pos_embed * ratio + pos_embed = pos_embed.permute(2, 0, 1)[None].expand(x.shape[0], -1, -1, -1) + return x + pos_embed + + def scratch_forward(self, features: List[torch.Tensor]) -> torch.Tensor: + """ + Forward pass through the fusion blocks. + + Args: + features (List[Tensor]): List of feature maps from different layers. + + Returns: + Tensor: Fused feature map. + """ + layer_1, layer_2, layer_3, layer_4 = features + + layer_1_rn = self.scratch.layer1_rn(layer_1) + layer_2_rn = self.scratch.layer2_rn(layer_2) + layer_3_rn = self.scratch.layer3_rn(layer_3) + layer_4_rn = self.scratch.layer4_rn(layer_4) + + out = self.scratch.refinenet4(layer_4_rn, size=layer_3_rn.shape[2:]) + del layer_4_rn, layer_4 + + out = self.scratch.refinenet3(out, layer_3_rn, size=layer_2_rn.shape[2:]) + del layer_3_rn, layer_3 + + out = self.scratch.refinenet2(out, layer_2_rn, size=layer_1_rn.shape[2:]) + del layer_2_rn, layer_2 + + out = self.scratch.refinenet1(out, layer_1_rn) + del layer_1_rn, layer_1 + + out = self.scratch.output_conv1(out) + return out + + +################################################################################ +# Modules +################################################################################ + + +def _make_fusion_block(features: int, size: int = None, has_residual: bool = True, groups: int = 1) -> nn.Module: + return FeatureFusionBlock( + features, + nn.ReLU(inplace=True), + deconv=False, + bn=False, + expand=False, + align_corners=True, + size=size, + has_residual=has_residual, + groups=groups, + ) + + +def _make_scratch(in_shape: List[int], out_shape: int, groups: int = 1, expand: bool = False) -> nn.Module: + scratch = nn.Module() + out_shape1 = out_shape + out_shape2 = out_shape + out_shape3 = out_shape + if len(in_shape) >= 4: + out_shape4 = out_shape + + if expand: + out_shape1 = out_shape + out_shape2 = out_shape * 2 + out_shape3 = out_shape * 4 + if len(in_shape) >= 4: + out_shape4 = out_shape * 8 + + scratch.layer1_rn = nn.Conv2d( + in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups + ) + scratch.layer2_rn = nn.Conv2d( + in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups + ) + scratch.layer3_rn = nn.Conv2d( + in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups + ) + if len(in_shape) >= 4: + scratch.layer4_rn = nn.Conv2d( + in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups + ) + return scratch + + +class ResidualConvUnit(nn.Module): + """Residual convolution module.""" + + def __init__(self, features, activation, bn, groups=1): + """Init. + + Args: + features (int): number of features + """ + super().__init__() + + self.bn = bn + self.groups = groups + self.conv1 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups) + self.conv2 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups) + + self.norm1 = None + self.norm2 = None + + self.activation = activation + self.skip_add = nn.quantized.FloatFunctional() + + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input + + Returns: + tensor: output + """ + + out = self.activation(x) + out = self.conv1(out) + if self.norm1 is not None: + out = self.norm1(out) + + out = self.activation(out) + out = self.conv2(out) + if self.norm2 is not None: + out = self.norm2(out) + + return self.skip_add.add(out, x) + + +class FeatureFusionBlock(nn.Module): + """Feature fusion block.""" + + def __init__( + self, + features, + activation, + deconv=False, + bn=False, + expand=False, + align_corners=True, + size=None, + has_residual=True, + groups=1, + ): + """Init. + + Args: + features (int): number of features + """ + super(FeatureFusionBlock, self).__init__() + + self.deconv = deconv + self.align_corners = align_corners + self.groups = groups + self.expand = expand + out_features = features + if self.expand == True: + out_features = features // 2 + + self.out_conv = nn.Conv2d( + features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=self.groups + ) + + if has_residual: + self.resConfUnit1 = ResidualConvUnit(features, activation, bn, groups=self.groups) + + self.has_residual = has_residual + self.resConfUnit2 = ResidualConvUnit(features, activation, bn, groups=self.groups) + + self.skip_add = nn.quantized.FloatFunctional() + self.size = size + + def forward(self, *xs, size=None): + """Forward pass. + + Returns: + tensor: output + """ + output = xs[0] + + if self.has_residual: + res = self.resConfUnit1(xs[1]) + output = self.skip_add.add(output, res) + + output = self.resConfUnit2(output) + + if (size is None) and (self.size is None): + modifier = {"scale_factor": 2} + elif size is None: + modifier = {"size": self.size} + else: + modifier = {"size": size} + + output = custom_interpolate(output, **modifier, mode="bilinear", align_corners=self.align_corners) + output = self.out_conv(output) + + return output + + +def custom_interpolate( + x: torch.Tensor, + size: Tuple[int, int] = None, + scale_factor: float = None, + mode: str = "bilinear", + align_corners: bool = True, +) -> torch.Tensor: + """ + Custom interpolate to avoid INT_MAX issues in nn.functional.interpolate. + """ + if size is None: + size = (int(x.shape[-2] * scale_factor), int(x.shape[-1] * scale_factor)) + + INT_MAX = 1610612736 + + input_elements = size[0] * size[1] * x.shape[0] * x.shape[1] + + if input_elements > INT_MAX: + chunks = torch.chunk(x, chunks=(input_elements // INT_MAX) + 1, dim=0) + interpolated_chunks = [ + nn.functional.interpolate(chunk, size=size, mode=mode, align_corners=align_corners) for chunk in chunks + ] + x = torch.cat(interpolated_chunks, dim=0) + return x.contiguous() + else: + return nn.functional.interpolate(x, size=size, mode=mode, align_corners=align_corners) \ No newline at end of file diff --git a/trellis/models/heads/head_act.py b/trellis/models/heads/head_act.py new file mode 100644 index 0000000000000000000000000000000000000000..5b33511408cac0dbbe75825b382b47e58ed42588 --- /dev/null +++ b/trellis/models/heads/head_act.py @@ -0,0 +1,126 @@ + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + + +import torch +import torch.nn.functional as F + + +def activate_pose(pred_pose_enc, trans_act="linear", quat_act="linear", fl_act="linear"): + """ + Activate pose parameters with specified activation functions. + + Args: + pred_pose_enc: Tensor containing encoded pose parameters [translation, quaternion, focal length] + trans_act: Activation type for translation component + quat_act: Activation type for quaternion component + fl_act: Activation type for focal length component + + Returns: + Activated pose parameters tensor + """ + T = pred_pose_enc[..., :3] + quat = pred_pose_enc[..., 3:7] + fl = pred_pose_enc[..., 7:] # or fov + + T = base_pose_act(T, trans_act) + quat = base_pose_act(quat, quat_act) + fl = base_pose_act(fl, fl_act) # or fov + + pred_pose_enc = torch.cat([T, quat, fl], dim=-1) + + return pred_pose_enc + + +def base_pose_act(pose_enc, act_type="linear"): + """ + Apply basic activation function to pose parameters. + + Args: + pose_enc: Tensor containing encoded pose parameters + act_type: Activation type ("linear", "inv_log", "exp", "relu") + + Returns: + Activated pose parameters + """ + if act_type == "linear": + return pose_enc + elif act_type == "inv_log": + return inverse_log_transform(pose_enc) + elif act_type == "exp": + return torch.exp(pose_enc) + elif act_type == "relu": + return F.relu(pose_enc) + else: + raise ValueError(f"Unknown act_type: {act_type}") + + +def activate_head(out, activation="norm_exp", conf_activation="expp1"): + """ + Process network output to extract 3D points and confidence values. + + Args: + out: Network output tensor (B, C, H, W) + activation: Activation type for 3D points + conf_activation: Activation type for confidence values + + Returns: + Tuple of (3D points tensor, confidence tensor) + """ + # Move channels from last dim to the 4th dimension => (B, H, W, C) + fmap = out.permute(0, 2, 3, 1) # B,H,W,C expected + + # Split into xyz (first C-1 channels) and confidence (last channel) + xyz = fmap[:, :, :, :-1] + conf = fmap[:, :, :, -1] + + if activation == "norm_exp": + d = xyz.norm(dim=-1, keepdim=True).clamp(min=1e-8) + xyz_normed = xyz / d + pts3d = xyz_normed * torch.expm1(d) + elif activation == "norm": + pts3d = xyz / xyz.norm(dim=-1, keepdim=True) + elif activation == "exp": + pts3d = torch.exp(xyz) + elif activation == "relu": + pts3d = F.relu(xyz) + elif activation == "inv_log": + pts3d = inverse_log_transform(xyz) + elif activation == "xy_inv_log": + xy, z = xyz.split([2, 1], dim=-1) + z = inverse_log_transform(z) + pts3d = torch.cat([xy * z, z], dim=-1) + elif activation == "sigmoid": + pts3d = torch.sigmoid(xyz) + elif activation == "linear": + pts3d = xyz + else: + raise ValueError(f"Unknown activation: {activation}") + + if conf_activation == "expp1": + conf_out = 1 + conf.exp() + elif conf_activation == "expp0": + conf_out = conf.exp() + elif conf_activation == "sigmoid": + conf_out = torch.sigmoid(conf) + else: + raise ValueError(f"Unknown conf_activation: {conf_activation}") + + return pts3d, conf_out + + +def inverse_log_transform(y): + """ + Apply inverse log transform: sign(y) * (exp(|y|) - 1) + + Args: + y: Input tensor + + Returns: + Transformed tensor + """ + return torch.sign(y) * (torch.expm1(torch.abs(y))) diff --git a/trellis/models/heads/utils.py b/trellis/models/heads/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..26a22ef30e26bb340f34dcd083333aaaf205c2d9 --- /dev/null +++ b/trellis/models/heads/utils.py @@ -0,0 +1,108 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn + + +def position_grid_to_embed(pos_grid: torch.Tensor, embed_dim: int, omega_0: float = 100) -> torch.Tensor: + """ + Convert 2D position grid (HxWx2) to sinusoidal embeddings (HxWxC) + + Args: + pos_grid: Tensor of shape (H, W, 2) containing 2D coordinates + embed_dim: Output channel dimension for embeddings + + Returns: + Tensor of shape (H, W, embed_dim) with positional embeddings + """ + H, W, grid_dim = pos_grid.shape + assert grid_dim == 2 + pos_flat = pos_grid.reshape(-1, grid_dim) # Flatten to (H*W, 2) + + # Process x and y coordinates separately + emb_x = make_sincos_pos_embed(embed_dim // 2, pos_flat[:, 0], omega_0=omega_0) # [1, H*W, D/2] + emb_y = make_sincos_pos_embed(embed_dim // 2, pos_flat[:, 1], omega_0=omega_0) # [1, H*W, D/2] + + # Combine and reshape + emb = torch.cat([emb_x, emb_y], dim=-1) # [1, H*W, D] + + return emb.view(H, W, embed_dim) # [H, W, D] + + +def make_sincos_pos_embed(embed_dim: int, pos: torch.Tensor, omega_0: float = 100) -> torch.Tensor: + """ + This function generates a 1D positional embedding from a given grid using sine and cosine functions. + + Args: + - embed_dim: The embedding dimension. + - pos: The position to generate the embedding from. + + Returns: + - emb: The generated 1D positional embedding. + """ + assert embed_dim % 2 == 0 + omega = torch.arange(embed_dim // 2, dtype=torch.double, device=pos.device) + omega /= embed_dim / 2.0 + omega = 1.0 / omega_0**omega # (D/2,) + + pos = pos.reshape(-1) # (M,) + out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product + + emb_sin = torch.sin(out) # (M, D/2) + emb_cos = torch.cos(out) # (M, D/2) + + emb = torch.cat([emb_sin, emb_cos], dim=1) # (M, D) + return emb.float() + + +# Inspired by https://github.com/microsoft/moge + + +def create_uv_grid( + width: int, height: int, aspect_ratio: float = None, dtype: torch.dtype = None, device: torch.device = None +) -> torch.Tensor: + """ + Create a normalized UV grid of shape (width, height, 2). + + The grid spans horizontally and vertically according to an aspect ratio, + ensuring the top-left corner is at (-x_span, -y_span) and the bottom-right + corner is at (x_span, y_span), normalized by the diagonal of the plane. + + Args: + width (int): Number of points horizontally. + height (int): Number of points vertically. + aspect_ratio (float, optional): Width-to-height ratio. Defaults to width/height. + dtype (torch.dtype, optional): Data type of the resulting tensor. + device (torch.device, optional): Device on which the tensor is created. + + Returns: + torch.Tensor: A (width, height, 2) tensor of UV coordinates. + """ + # Derive aspect ratio if not explicitly provided + if aspect_ratio is None: + aspect_ratio = float(width) / float(height) + + # Compute normalized spans for X and Y + diag_factor = (aspect_ratio**2 + 1.0) ** 0.5 + span_x = aspect_ratio / diag_factor + span_y = 1.0 / diag_factor + + # Establish the linspace boundaries + left_x = -span_x * (width - 1) / width + right_x = span_x * (width - 1) / width + top_y = -span_y * (height - 1) / height + bottom_y = span_y * (height - 1) / height + + # Generate 1D coordinates + x_coords = torch.linspace(left_x, right_x, steps=width, dtype=dtype, device=device) + y_coords = torch.linspace(top_y, bottom_y, steps=height, dtype=dtype, device=device) + + # Create 2D meshgrid (width x height) and stack into UV + uu, vv = torch.meshgrid(x_coords, y_coords, indexing="xy") + uv_grid = torch.stack((uu, vv), dim=-1) + + return uv_grid \ No newline at end of file diff --git a/trellis/models/sparse_structure_flow.py b/trellis/models/sparse_structure_flow.py new file mode 100644 index 0000000000000000000000000000000000000000..80982f1571d804ac3587b4fea8c6e4858325129f --- /dev/null +++ b/trellis/models/sparse_structure_flow.py @@ -0,0 +1,303 @@ +from typing import * +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +from ..modules.utils import convert_module_to_f16, convert_module_to_f32, convert_module_to_bf16 +from ..modules.transformer import AbsolutePositionEmbedder, ModulatedTransformerCrossBlock, ModulatedTransformerCrossBlock_woT +from ..modules.spatial import patchify, unpatchify + + +class TimestepEmbedder(nn.Module): + """ + Embeds scalar timesteps into vector representations. + """ + def __init__(self, hidden_size, frequency_embedding_size=256): + super().__init__() + self.mlp = nn.Sequential( + nn.Linear(frequency_embedding_size, hidden_size, bias=True), + nn.SiLU(), + nn.Linear(hidden_size, hidden_size, bias=True), + ) + self.frequency_embedding_size = frequency_embedding_size + + @staticmethod + def timestep_embedding(t, dim, max_period=10000): + """ + Create sinusoidal timestep embeddings. + + Args: + t: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + dim: the dimension of the output. + max_period: controls the minimum frequency of the embeddings. + + Returns: + an (N, D) Tensor of positional embeddings. + """ + # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py + half = dim // 2 + freqs = torch.exp( + -np.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half + ).to(device=t.device) + args = t[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + def forward(self, t): + t_freq = self.timestep_embedding(t, self.frequency_embedding_size) + t_emb = self.mlp(t_freq) + return t_emb + + +class SparseStructureFlowModel(nn.Module): + def __init__( + self, + resolution: int, + in_channels: int, + model_channels: int, + cond_channels: int, + out_channels: int, + num_blocks: int, + num_heads: Optional[int] = None, + num_head_channels: Optional[int] = 64, + mlp_ratio: float = 4, + patch_size: int = 2, + pe_mode: Literal["ape", "rope"] = "ape", + use_fp16: bool = False, + use_bf16: bool = False, + use_checkpoint: bool = False, + share_mod: bool = False, + qk_rms_norm: bool = False, + qk_rms_norm_cross: bool = False, + ): + super().__init__() + self.resolution = resolution + self.in_channels = in_channels + self.model_channels = model_channels + self.cond_channels = cond_channels + self.out_channels = out_channels + self.num_blocks = num_blocks + self.num_heads = num_heads or model_channels // num_head_channels + self.mlp_ratio = mlp_ratio + self.patch_size = patch_size + self.pe_mode = pe_mode + self.use_fp16 = use_fp16 + self.use_bf16 = use_bf16 + self.use_checkpoint = use_checkpoint + self.share_mod = share_mod + self.qk_rms_norm = qk_rms_norm + self.qk_rms_norm_cross = qk_rms_norm_cross + if use_fp16: + self.dtype = torch.float16 + elif use_bf16: + self.dtype = torch.bfloat16 + else: + self.dtype = torch.float32 + + self.t_embedder = TimestepEmbedder(model_channels) + if share_mod: + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + nn.Linear(model_channels, 6 * model_channels, bias=True) + ) + + if pe_mode == "ape": + self.pos_embedder = AbsolutePositionEmbedder(model_channels, 3) + coords = torch.meshgrid(*[torch.arange(res, device=self.device) for res in [resolution // patch_size] * 3], indexing='ij') + coords = torch.stack(coords, dim=-1).reshape(-1, 3) + pos_emb = self.pos_embedder(coords) + self.register_buffer("pos_emb", pos_emb) + + self.input_layer = nn.Linear(in_channels * patch_size**3, model_channels) + + self.blocks = nn.ModuleList([ + ModulatedTransformerCrossBlock( + model_channels, + cond_channels, + num_heads=self.num_heads, + mlp_ratio=self.mlp_ratio, + attn_mode='full', + use_checkpoint=self.use_checkpoint, + use_rope=(pe_mode == "rope"), + share_mod=share_mod, + qk_rms_norm=self.qk_rms_norm, + qk_rms_norm_cross=self.qk_rms_norm_cross, + ) + for _ in range(num_blocks) + ]) + + self.out_layer = nn.Linear(model_channels, out_channels * patch_size**3) + + self.initialize_weights() + if use_fp16: + self.convert_to_fp16() + elif use_bf16: + self.convert_to_bf16() + + @property + def device(self) -> torch.device: + """ + Return the device of the model. + """ + return next(self.parameters()).device + + def convert_to_fp16(self) -> None: + """ + Convert the torso of the model to float16. + """ + self.use_fp16 = True + self.use_bf16 = False + self.dtype = torch.float16 + self.blocks.apply(convert_module_to_f16) + + def convert_to_bf16(self) -> None: + """ + Convert the torso of the model to bfloat16. + """ + self.use_fp16 = False + self.use_bf16 = True + self.dtype = torch.bfloat16 + self.blocks.apply(convert_module_to_bf16) + + def convert_to_fp32(self) -> None: + """ + Convert the torso of the model to float32. + """ + self.use_fp16 = False + self.use_bf16 = False + self.dtype = torch.float32 + self.blocks.apply(convert_module_to_f32) + + def initialize_weights(self) -> None: + # Initialize transformer layers: + def _basic_init(module): + if isinstance(module, nn.Linear): + torch.nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + self.apply(_basic_init) + + # Initialize timestep embedding MLP: + nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) + nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) + + # Zero-out adaLN modulation layers in DiT blocks: + if self.share_mod: + nn.init.constant_(self.adaLN_modulation[-1].weight, 0) + nn.init.constant_(self.adaLN_modulation[-1].bias, 0) + else: + for block in self.blocks: + nn.init.constant_(block.adaLN_modulation[-1].weight, 0) + nn.init.constant_(block.adaLN_modulation[-1].bias, 0) + + # Zero-out output layers: + nn.init.constant_(self.out_layer.weight, 0) + nn.init.constant_(self.out_layer.bias, 0) + + def forward(self, x: torch.Tensor, t: torch.Tensor, cond: torch.Tensor) -> torch.Tensor: + assert [*x.shape] == [x.shape[0], self.in_channels, *[self.resolution] * 3], \ + f"Input shape mismatch, got {x.shape}, expected {[x.shape[0], self.in_channels, *[self.resolution] * 3]}" + + h = patchify(x, self.patch_size) + h = h.view(*h.shape[:2], -1).permute(0, 2, 1).contiguous() + + h = self.input_layer(h) + h = h + self.pos_emb[None] + t_emb = self.t_embedder(t) + if self.share_mod: + t_emb = self.adaLN_modulation(t_emb) + t_emb = t_emb.type(self.dtype) + h = h.type(self.dtype) + if isinstance(cond, list): + for i in range(len(cond)): + cond_tmp = cond[i].type(self.dtype) + for block in self.blocks: + h = block(h, t_emb, cond_tmp) + else: + cond = cond.type(self.dtype) + for block in self.blocks: + h = block(h, t_emb, cond) + h = h.type(x.dtype) + h = F.layer_norm(h, h.shape[-1:]) + h = self.out_layer(h) + + h = h.permute(0, 2, 1).view(h.shape[0], h.shape[2], *[self.resolution // self.patch_size] * 3) + h = unpatchify(h, self.patch_size).contiguous() + + return h + +class ModulatedMultiViewCond(nn.Module): + """ + Transformer cross-attention block (MSA + MCA + FFN) with adaptive layer norm conditioning. + """ + def __init__( + self, + channels: int, + ctx_channels: int, + num_heads: int, + mlp_ratio: float = 4.0, + attn_mode: Literal["full", "windowed"] = "full", + window_size: Optional[int] = None, + shift_window: Optional[Tuple[int, int, int]] = None, + use_checkpoint: bool = False, + use_rope: bool = False, + qk_rms_norm: bool = False, + qk_rms_norm_cross: bool = False, + qkv_bias: bool = True, + share_mod: bool = False, + num_init_tokens: int = 4096, + dtype: Optional[torch.dtype] = torch.float32, + use_fp16: bool = False, + ): + super().__init__() + self.cond_blocks = nn.ModuleList([ + ModulatedTransformerCrossBlock_woT( + channels, + ctx_channels, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + attn_mode=attn_mode, + use_checkpoint=use_checkpoint, + use_rope=use_rope, + share_mod=share_mod, + qk_rms_norm=qk_rms_norm, + qk_rms_norm_cross=qk_rms_norm_cross, + ) + for _ in range(4) + ]) + self.use_fp16 = use_fp16 + if use_fp16: + self.dtype = torch.float16 + else: + self.dtype = dtype + self.multiview_cond_tokens = nn.Parameter(torch.randn(1, num_init_tokens, channels).to(dtype)) + nn.init.normal_(self.multiview_cond_tokens, std=1e-6) + self.intermediate_layer_idx = [4, 11, 17, 23] + if use_fp16: + self.convert_to_fp16() + + + def convert_to_fp16(self) -> None: + """ + Convert the torso of the model to float16. + """ + self.use_fp16 = True + self.dtype = torch.float16 + self.cond_blocks.apply(convert_module_to_f16) + self.multiview_cond_tokens = nn.Parameter(self.multiview_cond_tokens.data.to(self.dtype)) + def forward(self, aggregated_tokens_list: List, image_cond: torch.Tensor): + + b = aggregated_tokens_list[0].shape[0] + patch_start_idx = 5 + idx = 0 + cond = self.multiview_cond_tokens.repeat(b, 1, 1) + for layer_idx in self.intermediate_layer_idx: + x = aggregated_tokens_list[layer_idx][:, :, patch_start_idx:] + # x = x.reshape(b, -1, 2048) + torch.cat([image_cond.reshape(b, -1, 1024), image_cond.reshape(b, -1, 1024)],dim=-1) + x = torch.cat([x.reshape(b, -1, 2048), image_cond.reshape(b, -1, 1024)],dim=-1).to(self.dtype) + cond = self.cond_blocks[idx](cond, x) + idx = idx + 1 + return cond \ No newline at end of file diff --git a/trellis/models/sparse_structure_vae.py b/trellis/models/sparse_structure_vae.py new file mode 100644 index 0000000000000000000000000000000000000000..0e1f0da68a699e4d44fa337a40b77e1ea4021eba --- /dev/null +++ b/trellis/models/sparse_structure_vae.py @@ -0,0 +1,394 @@ +from typing import * +import torch +import torch.nn as nn +import torch.nn.functional as F +from ..modules.norm import GroupNorm32, ChannelLayerNorm32 +from ..modules.spatial import pixel_shuffle_3d +from ..modules.utils import zero_module, convert_module_to_f16, convert_module_to_f32, convert_module_to_bf16 + + +def norm_layer(norm_type: str, *args, **kwargs) -> nn.Module: + """ + Return a normalization layer. + """ + if norm_type == "group": + return GroupNorm32(32, *args, **kwargs) + elif norm_type == "layer": + return ChannelLayerNorm32(*args, **kwargs) + else: + raise ValueError(f"Invalid norm type {norm_type}") + + +class ResBlock3d(nn.Module): + def __init__( + self, + channels: int, + out_channels: Optional[int] = None, + norm_type: Literal["group", "layer"] = "layer", + ): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + + self.norm1 = norm_layer(norm_type, channels) + self.norm2 = norm_layer(norm_type, self.out_channels) + self.conv1 = nn.Conv3d(channels, self.out_channels, 3, padding=1) + self.conv2 = zero_module(nn.Conv3d(self.out_channels, self.out_channels, 3, padding=1)) + self.skip_connection = nn.Conv3d(channels, self.out_channels, 1) if channels != self.out_channels else nn.Identity() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + h = self.norm1(x) + h = F.silu(h) + h = self.conv1(h) + h = self.norm2(h) + h = F.silu(h) + h = self.conv2(h) + h = h + self.skip_connection(x) + return h + + +class DownsampleBlock3d(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + mode: Literal["conv", "avgpool"] = "conv", + ): + assert mode in ["conv", "avgpool"], f"Invalid mode {mode}" + + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + + if mode == "conv": + self.conv = nn.Conv3d(in_channels, out_channels, 2, stride=2) + elif mode == "avgpool": + assert in_channels == out_channels, "Pooling mode requires in_channels to be equal to out_channels" + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if hasattr(self, "conv"): + return self.conv(x) + else: + return F.avg_pool3d(x, 2) + +class ResBlock2d(nn.Module): + def __init__( + self, + channels: int, + out_channels: Optional[int] = None, + norm_type: Literal["group", "layer"] = "layer", + ): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + + self.norm1 = norm_layer(norm_type, channels) + self.norm2 = norm_layer(norm_type, self.out_channels) + self.conv1 = nn.Conv2d(channels, self.out_channels, 3, padding=1) + self.conv2 = zero_module(nn.Conv2d(self.out_channels, self.out_channels, 3, padding=1)) + self.skip_connection = nn.Conv2d(channels, self.out_channels, 1) if channels != self.out_channels else nn.Identity() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + h = self.norm1(x) + h = F.silu(h) + h = self.conv1(h) + h = self.norm2(h) + h = F.silu(h) + h = self.conv2(h) + h = h + self.skip_connection(x) + return h + + +class DownsampleBlock2d(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + mode: Literal["conv", "avgpool"] = "conv", + ): + assert mode in ["conv", "avgpool"], f"Invalid mode {mode}" + + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + + if mode == "conv": + self.conv = nn.Conv2d(in_channels, out_channels, 2, stride=2) + elif mode == "avgpool": + assert in_channels == out_channels, "Pooling mode requires in_channels to be equal to out_channels" + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if hasattr(self, "conv"): + return self.conv(x) + else: + return F.avg_pool2d(x, 2) + +class UpsampleBlock3d(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + mode: Literal["conv", "nearest"] = "conv", + ): + assert mode in ["conv", "nearest"], f"Invalid mode {mode}" + + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + + if mode == "conv": + self.conv = nn.Conv3d(in_channels, out_channels*8, 3, padding=1) + elif mode == "nearest": + assert in_channels == out_channels, "Nearest mode requires in_channels to be equal to out_channels" + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if hasattr(self, "conv"): + x = self.conv(x) + return pixel_shuffle_3d(x, 2) + else: + return F.interpolate(x, scale_factor=2, mode="nearest") + + +class SparseStructureEncoder(nn.Module): + """ + Encoder for Sparse Structure (\mathcal{E}_S in the paper Sec. 3.3). + + Args: + in_channels (int): Channels of the input. + latent_channels (int): Channels of the latent representation. + num_res_blocks (int): Number of residual blocks at each resolution. + channels (List[int]): Channels of the encoder blocks. + num_res_blocks_middle (int): Number of residual blocks in the middle. + norm_type (Literal["group", "layer"]): Type of normalization layer. + use_fp16 (bool): Whether to use FP16. + """ + def __init__( + self, + in_channels: int, + latent_channels: int, + num_res_blocks: int, + channels: List[int], + num_res_blocks_middle: int = 2, + norm_type: Literal["group", "layer"] = "layer", + use_fp16: bool = False, + use_bf16: bool = False, + ): + super().__init__() + self.in_channels = in_channels + self.latent_channels = latent_channels + self.num_res_blocks = num_res_blocks + self.channels = channels + self.num_res_blocks_middle = num_res_blocks_middle + self.norm_type = norm_type + self.use_fp16 = use_fp16 + self.use_bf16 = use_bf16 + if use_fp16: + self.dtype = torch.float16 + elif use_bf16: + self.dtype = torch.bfloat16 + else: + self.dtype = torch.float32 + + self.input_layer = nn.Conv3d(in_channels, channels[0], 3, padding=1) + + self.blocks = nn.ModuleList([]) + for i, ch in enumerate(channels): + self.blocks.extend([ + ResBlock3d(ch, ch) + for _ in range(num_res_blocks) + ]) + if i < len(channels) - 1: + self.blocks.append( + DownsampleBlock3d(ch, channels[i+1]) + ) + + self.middle_block = nn.Sequential(*[ + ResBlock3d(channels[-1], channels[-1]) + for _ in range(num_res_blocks_middle) + ]) + + self.out_layer = nn.Sequential( + norm_layer(norm_type, channels[-1]), + nn.SiLU(), + nn.Conv3d(channels[-1], latent_channels*2, 3, padding=1) + ) + + if use_fp16: + self.convert_to_fp16() + elif use_bf16: + self.convert_to_bf16() + + @property + def device(self) -> torch.device: + """ + Return the device of the model. + """ + return next(self.parameters()).device + + def convert_to_fp16(self) -> None: + """ + Convert the torso of the model to float16. + """ + self.use_fp16 = True + self.dtype = torch.float16 + self.blocks.apply(convert_module_to_f16) + self.middle_block.apply(convert_module_to_f16) + + def convert_to_bf16(self) -> None: + """ + Convert the torso of the model to float16. + """ + self.use_bf16 = True + self.dtype = torch.bfloat16 + self.blocks.apply(convert_module_to_bf16) + self.middle_block.apply(convert_module_to_bf16) + + def convert_to_fp32(self) -> None: + """ + Convert the torso of the model to float32. + """ + self.use_fp16 = False + self.dtype = torch.float32 + self.blocks.apply(convert_module_to_f32) + self.middle_block.apply(convert_module_to_f32) + + def forward(self, x: torch.Tensor, sample_posterior: bool = False, return_raw: bool = False) -> torch.Tensor: + h = self.input_layer(x) + h = h.type(self.dtype) + + for block in self.blocks: + h = block(h) + h = self.middle_block(h) + + h = h.type(x.dtype) + h = self.out_layer(h) + + mean, logvar = h.chunk(2, dim=1) + + if sample_posterior: + std = torch.exp(0.5 * logvar) + z = mean + std * torch.randn_like(std) + else: + z = mean + + if return_raw: + return z, mean, logvar + return z + + +class SparseStructureDecoder(nn.Module): + """ + Decoder for Sparse Structure (\mathcal{D}_S in the paper Sec. 3.3). + + Args: + out_channels (int): Channels of the output. + latent_channels (int): Channels of the latent representation. + num_res_blocks (int): Number of residual blocks at each resolution. + channels (List[int]): Channels of the decoder blocks. + num_res_blocks_middle (int): Number of residual blocks in the middle. + norm_type (Literal["group", "layer"]): Type of normalization layer. + use_fp16 (bool): Whether to use FP16. + """ + def __init__( + self, + out_channels: int, + latent_channels: int, + num_res_blocks: int, + channels: List[int], + num_res_blocks_middle: int = 2, + norm_type: Literal["group", "layer"] = "layer", + use_fp16: bool = False, + use_bf16: bool = False, + ): + super().__init__() + self.out_channels = out_channels + self.latent_channels = latent_channels + self.num_res_blocks = num_res_blocks + self.channels = channels + self.num_res_blocks_middle = num_res_blocks_middle + self.norm_type = norm_type + self.use_fp16 = use_fp16 + self.use_bf16 = use_bf16 + if use_fp16: + self.dtype = torch.float16 + elif use_bf16: + self.dtype = torch.bfloat16 + else: + self.dtype = torch.float32 + + self.input_layer = nn.Conv3d(latent_channels, channels[0], 3, padding=1) + + self.middle_block = nn.Sequential(*[ + ResBlock3d(channels[0], channels[0]) + for _ in range(num_res_blocks_middle) + ]) + + self.blocks = nn.ModuleList([]) + for i, ch in enumerate(channels): + self.blocks.extend([ + ResBlock3d(ch, ch) + for _ in range(num_res_blocks) + ]) + if i < len(channels) - 1: + self.blocks.append( + UpsampleBlock3d(ch, channels[i+1]) + ) + + self.out_layer = nn.Sequential( + norm_layer(norm_type, channels[-1]), + nn.SiLU(), + nn.Conv3d(channels[-1], out_channels, 3, padding=1) + ) + + if use_fp16: + self.convert_to_fp16() + elif use_bf16: + self.convert_to_bf16() + + @property + def device(self) -> torch.device: + """ + Return the device of the model. + """ + return next(self.parameters()).device + + def convert_to_fp16(self) -> None: + """ + Convert the torso of the model to float16. + """ + self.use_fp16 = True + self.dtype = torch.float16 + self.blocks.apply(convert_module_to_f16) + self.middle_block.apply(convert_module_to_f16) + + def convert_to_bf16(self) -> None: + """ + Convert the torso of the model to bfloat16. + """ + self.use_bf16 = True + self.dtype = torch.bfloat16 + self.blocks.apply(convert_module_to_bf16) + self.middle_block.apply(convert_module_to_bf16) + + def convert_to_fp32(self) -> None: + """ + Convert the torso of the model to float32. + """ + self.use_fp16 = False + self.use_bf16 = False + self.dtype = torch.float32 + self.blocks.apply(convert_module_to_f32) + self.middle_block.apply(convert_module_to_f32) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + h = self.input_layer(x) + + h = h.type(self.dtype) + + h = self.middle_block(h) + for block in self.blocks: + h = block(h) + + h = h.type(x.dtype) + h = self.out_layer(h) + return h diff --git a/trellis/models/structured_latent_flow.py b/trellis/models/structured_latent_flow.py new file mode 100644 index 0000000000000000000000000000000000000000..e6468aa4f65415b301eb5d44b8babb7c5e78d3ce --- /dev/null +++ b/trellis/models/structured_latent_flow.py @@ -0,0 +1,342 @@ +from typing import * +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +from ..modules.utils import zero_module, convert_module_to_f16, convert_module_to_f32, convert_module_to_bf16 +from ..modules.transformer import AbsolutePositionEmbedder +from ..modules.norm import LayerNorm32 +from ..modules import sparse as sp +from ..modules.sparse.transformer import ModulatedSparseTransformerCrossBlock +from .sparse_structure_flow import TimestepEmbedder + + +class SparseResBlock3d(nn.Module): + def __init__( + self, + channels: int, + emb_channels: int, + out_channels: Optional[int] = None, + downsample: bool = False, + upsample: bool = False, + ): + super().__init__() + self.channels = channels + self.emb_channels = emb_channels + self.out_channels = out_channels or channels + self.downsample = downsample + self.upsample = upsample + + assert not (downsample and upsample), "Cannot downsample and upsample at the same time" + + self.norm1 = LayerNorm32(channels, elementwise_affine=True, eps=1e-6) + self.norm2 = LayerNorm32(self.out_channels, elementwise_affine=False, eps=1e-6) + self.conv1 = sp.SparseConv3d(channels, self.out_channels, 3) + self.conv2 = zero_module(sp.SparseConv3d(self.out_channels, self.out_channels, 3)) + self.emb_layers = nn.Sequential( + nn.SiLU(), + nn.Linear(emb_channels, 2 * self.out_channels, bias=True), + ) + self.skip_connection = sp.SparseLinear(channels, self.out_channels) if channels != self.out_channels else nn.Identity() + self.updown = None + if self.downsample: + self.updown = sp.SparseDownsample(2) + elif self.upsample: + self.updown = sp.SparseUpsample(2) + + def _updown(self, x: sp.SparseTensor) -> sp.SparseTensor: + if self.updown is not None: + x = self.updown(x) + return x + + def forward(self, x: sp.SparseTensor, emb: torch.Tensor) -> sp.SparseTensor: + emb_out = self.emb_layers(emb).type(x.dtype) + scale, shift = torch.chunk(emb_out, 2, dim=1) + + x = self._updown(x) + h = x.replace(self.norm1(x.feats)) + h = h.replace(F.silu(h.feats)) + h = self.conv1(h) + h = h.replace(self.norm2(h.feats)) * (1 + scale) + shift + h = h.replace(F.silu(h.feats)) + h = self.conv2(h) + h = h + self.skip_connection(x) + + return h + +class SparseResBlock3dwoT(nn.Module): + def __init__( + self, + channels: int, + out_channels: Optional[int] = None, + downsample: bool = False, + upsample: bool = False, + ): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.downsample = downsample + self.upsample = upsample + + assert not (downsample and upsample), "Cannot downsample and upsample at the same time" + + self.norm1 = LayerNorm32(channels, elementwise_affine=True, eps=1e-6) + self.norm2 = LayerNorm32(self.out_channels, elementwise_affine=False, eps=1e-6) + self.conv1 = sp.SparseConv3d(channels, self.out_channels, 3) + self.conv2 = zero_module(sp.SparseConv3d(self.out_channels, self.out_channels, 3)) + self.skip_connection = sp.SparseLinear(channels, self.out_channels) if channels != self.out_channels else nn.Identity() + self.updown = None + if self.downsample: + self.updown = sp.SparseDownsample(2) + elif self.upsample: + self.updown = sp.SparseUpsample(2) + + def _updown(self, x: sp.SparseTensor) -> sp.SparseTensor: + if self.updown is not None: + x = self.updown(x) + return x + + def forward(self, x: sp.SparseTensor) -> sp.SparseTensor: + + x = self._updown(x) + h = x.replace(self.norm1(x.feats)) + h = h.replace(F.silu(h.feats)) + h = self.conv1(h) + h = h.replace(F.silu(h.feats)) + h = self.conv2(h) + h = h + self.skip_connection(x) + + return h + +class SLatFlowModel(nn.Module): + def __init__( + self, + resolution: int, + in_channels: int, + model_channels: int, + cond_channels: int, + out_channels: int, + num_blocks: int, + num_heads: Optional[int] = None, + num_head_channels: Optional[int] = 64, + mlp_ratio: float = 4, + patch_size: int = 2, + num_io_res_blocks: int = 2, + io_block_channels: List[int] = None, + pe_mode: Literal["ape", "rope"] = "ape", + use_fp16: bool = False, + use_bf16: bool = False, + use_checkpoint: bool = False, + use_skip_connection: bool = True, + share_mod: bool = False, + qk_rms_norm: bool = False, + qk_rms_norm_cross: bool = False, + ): + super().__init__() + self.resolution = resolution + self.in_channels = in_channels + self.model_channels = model_channels + self.cond_channels = cond_channels + self.out_channels = out_channels + self.num_blocks = num_blocks + self.num_heads = num_heads or model_channels // num_head_channels + self.mlp_ratio = mlp_ratio + self.patch_size = patch_size + self.num_io_res_blocks = num_io_res_blocks + self.io_block_channels = io_block_channels + self.pe_mode = pe_mode + self.use_fp16 = use_fp16 + self.use_bf16 = use_bf16 + self.use_checkpoint = use_checkpoint + self.use_skip_connection = use_skip_connection + self.share_mod = share_mod + self.qk_rms_norm = qk_rms_norm + self.qk_rms_norm_cross = qk_rms_norm_cross + if use_fp16: + self.dtype = torch.float16 + elif use_bf16: + self.dtype = torch.bfloat16 + else: + self.dtype = torch.float32 + + assert int(np.log2(patch_size)) == np.log2(patch_size), "Patch size must be a power of 2" + assert np.log2(patch_size) == len(io_block_channels), "Number of IO ResBlocks must match the number of stages" + + self.t_embedder = TimestepEmbedder(model_channels) + if share_mod: + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + nn.Linear(model_channels, 6 * model_channels, bias=True) + ) + + if pe_mode == "ape": + self.pos_embedder = AbsolutePositionEmbedder(model_channels) + + self.input_layer = sp.SparseLinear(in_channels, io_block_channels[0]) + self.input_blocks = nn.ModuleList([]) + for chs, next_chs in zip(io_block_channels, io_block_channels[1:] + [model_channels]): + self.input_blocks.extend([ + SparseResBlock3d( + chs, + model_channels, + out_channels=chs, + ) + for _ in range(num_io_res_blocks-1) + ]) + self.input_blocks.append( + SparseResBlock3d( + chs, + model_channels, + out_channels=next_chs, + downsample=True, + ) + ) + + self.blocks = nn.ModuleList([ + ModulatedSparseTransformerCrossBlock( + model_channels, + cond_channels, + num_heads=self.num_heads, + mlp_ratio=self.mlp_ratio, + attn_mode='full', + use_checkpoint=self.use_checkpoint, + use_rope=(pe_mode == "rope"), + share_mod=self.share_mod, + qk_rms_norm=self.qk_rms_norm, + qk_rms_norm_cross=self.qk_rms_norm_cross, + ) + for _ in range(num_blocks) + ]) + + self.out_blocks = nn.ModuleList([]) + for chs, prev_chs in zip(reversed(io_block_channels), [model_channels] + list(reversed(io_block_channels[1:]))): + self.out_blocks.append( + SparseResBlock3d( + prev_chs * 2 if self.use_skip_connection else prev_chs, + model_channels, + out_channels=chs, + upsample=True, + ) + ) + self.out_blocks.extend([ + SparseResBlock3d( + chs * 2 if self.use_skip_connection else chs, + model_channels, + out_channels=chs, + ) + for _ in range(num_io_res_blocks-1) + ]) + self.out_layer = sp.SparseLinear(io_block_channels[0], out_channels) + + self.initialize_weights() + if use_fp16: + self.convert_to_fp16() + elif use_bf16: + self.convert_to_bf16() + else: + self.convert_to_fp32() + + + + @property + def device(self) -> torch.device: + """ + Return the device of the model. + """ + return next(self.parameters()).device + + def convert_to_fp16(self) -> None: + """ + Convert the torso of the model to float16. + """ + self.use_fp16 = True + self.use_bf16 = False + self.dtype = torch.float16 + self.input_blocks.apply(convert_module_to_f16) + self.blocks.apply(convert_module_to_f16) + self.out_blocks.apply(convert_module_to_f16) + + def convert_to_bf16(self) -> None: + """ + Convert the torso of the model to bfloat16. + """ + self.use_fp16 = False + self.use_bf16 = True + self.dtype = torch.bfloat16 + self.input_blocks.apply(convert_module_to_bf16) + self.blocks.apply(convert_module_to_bf16) + self.out_blocks.apply(convert_module_to_bf16) + + def convert_to_fp32(self) -> None: + """ + Convert the torso of the model to float32. + """ + self.use_fp16 = False + self.use_bf16 = False + self.dtype = torch.float32 + self.input_blocks.apply(convert_module_to_f32) + self.blocks.apply(convert_module_to_f32) + self.out_blocks.apply(convert_module_to_f32) + + def initialize_weights(self) -> None: + # Initialize transformer layers: + def _basic_init(module): + if isinstance(module, nn.Linear): + torch.nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + self.apply(_basic_init) + + # Initialize timestep embedding MLP: + nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) + nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) + + # Zero-out adaLN modulation layers in DiT blocks: + if self.share_mod: + nn.init.constant_(self.adaLN_modulation[-1].weight, 0) + nn.init.constant_(self.adaLN_modulation[-1].bias, 0) + else: + for block in self.blocks: + nn.init.constant_(block.adaLN_modulation[-1].weight, 0) + nn.init.constant_(block.adaLN_modulation[-1].bias, 0) + + # Zero-out output layers: + nn.init.constant_(self.out_layer.weight, 0) + nn.init.constant_(self.out_layer.bias, 0) + + def forward(self, x: sp.SparseTensor, t: torch.Tensor, cond: torch.Tensor) -> sp.SparseTensor: + h = self.input_layer(x).type(self.dtype) + t_emb = self.t_embedder(t) + if self.share_mod: + t_emb = self.adaLN_modulation(t_emb) + t_emb = t_emb.type(self.dtype) + + skips = [] + # pack with input blocks + for block in self.input_blocks: + h = block(h, t_emb) + skips.append(h.feats) + + if self.pe_mode == "ape": + h = h + self.pos_embedder(h.coords[:, 1:]).type(self.dtype) + + if isinstance(cond, list): + for i in range(len(cond)): + cond_tmp = cond[i].type(self.dtype) + for block in self.blocks: + h = block(h, t_emb, cond_tmp) + else: + cond = cond.type(self.dtype) + for block in self.blocks: + h = block(h, t_emb, cond) + + # unpack with output blocks + for block, skip in zip(self.out_blocks, reversed(skips)): + if self.use_skip_connection: + h = block(h.replace(torch.cat([h.feats, skip], dim=1)), t_emb) + else: + h = block(h, t_emb) + + h = h.replace(F.layer_norm(h.feats, h.feats.shape[-1:])) + h = self.out_layer(h.type(x.dtype)) + return h diff --git a/trellis/models/structured_latent_vae/__init__.py b/trellis/models/structured_latent_vae/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f5778511aeb082e11bf86543e6dbbe9f0dacb9a8 --- /dev/null +++ b/trellis/models/structured_latent_vae/__init__.py @@ -0,0 +1,3 @@ +from .encoder import SLatEncoder +from .decoder_gs import SLatGaussianDecoder +from .decoder_mesh import SLatMeshDecoder diff --git a/trellis/models/structured_latent_vae/__pycache__/__init__.cpython-310.pyc b/trellis/models/structured_latent_vae/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..522a8c1d39a85fe580c4e8cff169c44fe71a7a6f Binary files /dev/null and b/trellis/models/structured_latent_vae/__pycache__/__init__.cpython-310.pyc differ diff --git a/trellis/models/structured_latent_vae/__pycache__/base.cpython-310.pyc b/trellis/models/structured_latent_vae/__pycache__/base.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..94c7ce79d4149bec801650c5acda79572c96ba50 Binary files /dev/null and b/trellis/models/structured_latent_vae/__pycache__/base.cpython-310.pyc differ diff --git a/trellis/models/structured_latent_vae/__pycache__/decoder_gs.cpython-310.pyc b/trellis/models/structured_latent_vae/__pycache__/decoder_gs.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ff121ce35266506a24172b4d8eb2925484732f1c Binary files /dev/null and b/trellis/models/structured_latent_vae/__pycache__/decoder_gs.cpython-310.pyc differ diff --git a/trellis/models/structured_latent_vae/__pycache__/decoder_mesh.cpython-310.pyc b/trellis/models/structured_latent_vae/__pycache__/decoder_mesh.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3015c09fa91b24e0dbc90b056834349d4a84aa42 Binary files /dev/null and b/trellis/models/structured_latent_vae/__pycache__/decoder_mesh.cpython-310.pyc differ diff --git a/trellis/models/structured_latent_vae/__pycache__/encoder.cpython-310.pyc b/trellis/models/structured_latent_vae/__pycache__/encoder.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1cca13828f447ca049d502fad46f48b314ec5501 Binary files /dev/null and b/trellis/models/structured_latent_vae/__pycache__/encoder.cpython-310.pyc differ diff --git a/trellis/models/structured_latent_vae/base.py b/trellis/models/structured_latent_vae/base.py new file mode 100644 index 0000000000000000000000000000000000000000..d53b7945f35804d2669d02d1d0e799a3b840424d --- /dev/null +++ b/trellis/models/structured_latent_vae/base.py @@ -0,0 +1,139 @@ +from typing import * +import torch +import torch.nn as nn +from ...modules.utils import convert_module_to_f16, convert_module_to_f32, convert_module_to_bf16 +from ...modules import sparse as sp +from ...modules.transformer import AbsolutePositionEmbedder +from ...modules.sparse.transformer import SparseTransformerBlock + + +def block_attn_config(self): + """ + Return the attention configuration of the model. + """ + for i in range(self.num_blocks): + if self.attn_mode == "shift_window": + yield "serialized", self.window_size, 0, (16 * (i % 2),) * 3, sp.SerializeMode.Z_ORDER + elif self.attn_mode == "shift_sequence": + yield "serialized", self.window_size, self.window_size // 2 * (i % 2), (0, 0, 0), sp.SerializeMode.Z_ORDER + elif self.attn_mode == "shift_order": + yield "serialized", self.window_size, 0, (0, 0, 0), sp.SerializeModes[i % 4] + elif self.attn_mode == "full": + yield "full", None, None, None, None + elif self.attn_mode == "swin": + yield "windowed", self.window_size, None, self.window_size // 2 * (i % 2), None + + +class SparseTransformerBase(nn.Module): + """ + Sparse Transformer without output layers. + Serve as the base class for encoder and decoder. + """ + def __init__( + self, + in_channels: int, + model_channels: int, + num_blocks: int, + num_heads: Optional[int] = None, + num_head_channels: Optional[int] = 64, + mlp_ratio: float = 4.0, + attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "full", + window_size: Optional[int] = None, + pe_mode: Literal["ape", "rope"] = "ape", + use_fp16: bool = False, + use_bf16: bool = False, + use_checkpoint: bool = False, + qk_rms_norm: bool = False, + ): + super().__init__() + self.in_channels = in_channels + self.model_channels = model_channels + self.num_blocks = num_blocks + self.window_size = window_size + self.num_heads = num_heads or model_channels // num_head_channels + self.mlp_ratio = mlp_ratio + self.attn_mode = attn_mode + self.pe_mode = pe_mode + self.use_fp16 = use_fp16 + self.use_bf16 = use_bf16 + self.use_checkpoint = use_checkpoint + self.qk_rms_norm = qk_rms_norm + if use_fp16: + self.dtype = torch.float16 + elif use_bf16: + self.dtype = torch.bfloat16 + else: + self.dtype = torch.float32 + + if pe_mode == "ape": + self.pos_embedder = AbsolutePositionEmbedder(model_channels) + + self.input_layer = sp.SparseLinear(in_channels, model_channels) + self.blocks = nn.ModuleList([ + SparseTransformerBlock( + model_channels, + num_heads=self.num_heads, + mlp_ratio=self.mlp_ratio, + attn_mode=attn_mode, + window_size=window_size, + shift_sequence=shift_sequence, + shift_window=shift_window, + serialize_mode=serialize_mode, + use_checkpoint=self.use_checkpoint, + use_rope=(pe_mode == "rope"), + qk_rms_norm=self.qk_rms_norm, + ) + for attn_mode, window_size, shift_sequence, shift_window, serialize_mode in block_attn_config(self) + ]) + + @property + def device(self) -> torch.device: + """ + Return the device of the model. + """ + return next(self.parameters()).device + + def convert_to_fp16(self) -> None: + """ + Convert the torso of the model to float16. + """ + self.use_fp16 = True + self.use_bf16 = False + self.dtype = torch.float16 + self.blocks.apply(convert_module_to_f16) + + def convert_to_bf16(self) -> None: + """ + Convert the torso of the model to bfloat16. + """ + self.use_fp16 = False + self.use_bf16 = True + self.dtype = torch.bfloat16 + self.blocks.apply(convert_module_to_bf16) + + def convert_to_fp32(self) -> None: + """ + Convert the torso of the model to float32. + """ + self.use_fp16 = False + self.use_bf16 = False + self.dtype = torch.float32 + self.blocks.apply(convert_module_to_f32) + + def initialize_weights(self) -> None: + # Initialize transformer layers: + def _basic_init(module): + if isinstance(module, nn.Linear): + torch.nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + self.apply(_basic_init) + + def forward(self, x: sp.SparseTensor) -> sp.SparseTensor: + h = self.input_layer(x) + if self.pe_mode == "ape": + h = h + self.pos_embedder(x.coords[:, 1:]) + h = h.type(self.dtype) + for block in self.blocks: + h = block(h) + return h diff --git a/trellis/models/structured_latent_vae/decoder_gs.py b/trellis/models/structured_latent_vae/decoder_gs.py new file mode 100644 index 0000000000000000000000000000000000000000..64af7ee785dc0877624a20ccea9bd160d313dccf --- /dev/null +++ b/trellis/models/structured_latent_vae/decoder_gs.py @@ -0,0 +1,126 @@ +from typing import * +import torch +import torch.nn as nn +import torch.nn.functional as F +from ...modules import sparse as sp +from ...utils.random_utils import hammersley_sequence +from .base import SparseTransformerBase +from ...representations import Gaussian + + +class SLatGaussianDecoder(SparseTransformerBase): + def __init__( + self, + resolution: int, + model_channels: int, + latent_channels: int, + num_blocks: int, + num_heads: Optional[int] = None, + num_head_channels: Optional[int] = 64, + mlp_ratio: float = 4, + attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "swin", + window_size: int = 8, + pe_mode: Literal["ape", "rope"] = "ape", + use_fp16: bool = False, + use_bf16: bool = False, + use_checkpoint: bool = False, + qk_rms_norm: bool = False, + representation_config: dict = None, + ): + super().__init__( + in_channels=latent_channels, + model_channels=model_channels, + num_blocks=num_blocks, + num_heads=num_heads, + num_head_channels=num_head_channels, + mlp_ratio=mlp_ratio, + attn_mode=attn_mode, + window_size=window_size, + pe_mode=pe_mode, + use_fp16=use_fp16, + use_bf16=use_bf16, + use_checkpoint=use_checkpoint, + qk_rms_norm=qk_rms_norm, + ) + self.resolution = resolution + self.rep_config = representation_config + self._calc_layout() + self.out_layer = sp.SparseLinear(model_channels, self.out_channels) + self._build_perturbation() + + self.initialize_weights() + if use_fp16: + self.convert_to_fp16() + elif use_bf16: + self.convert_to_bf16() + + def initialize_weights(self) -> None: + super().initialize_weights() + # Zero-out output layers: + nn.init.constant_(self.out_layer.weight, 0) + nn.init.constant_(self.out_layer.bias, 0) + + def _build_perturbation(self) -> None: + perturbation = [hammersley_sequence(3, i, self.rep_config['num_gaussians']) for i in range(self.rep_config['num_gaussians'])] + perturbation = torch.tensor(perturbation).float() * 2 - 1 + perturbation = perturbation / self.rep_config['voxel_size'] + perturbation = torch.atanh(perturbation).to(self.device) + self.register_buffer('offset_perturbation', perturbation) + + def _calc_layout(self) -> None: + self.layout = { + '_xyz' : {'shape': (self.rep_config['num_gaussians'], 3), 'size': self.rep_config['num_gaussians'] * 3}, + '_features_dc' : {'shape': (self.rep_config['num_gaussians'], 1, 3), 'size': self.rep_config['num_gaussians'] * 3}, + '_scaling' : {'shape': (self.rep_config['num_gaussians'], 3), 'size': self.rep_config['num_gaussians'] * 3}, + '_rotation' : {'shape': (self.rep_config['num_gaussians'], 4), 'size': self.rep_config['num_gaussians'] * 4}, + '_opacity' : {'shape': (self.rep_config['num_gaussians'], 1), 'size': self.rep_config['num_gaussians']}, + } + start = 0 + for k, v in self.layout.items(): + v['range'] = (start, start + v['size']) + start += v['size'] + self.out_channels = start + + def to_representation(self, x: sp.SparseTensor) -> List[Gaussian]: + """ + Convert a batch of network outputs to 3D representations. + + Args: + x: The [N x * x C] sparse tensor output by the network. + + Returns: + list of representations + """ + ret = [] + for i in range(x.shape[0]): + representation = Gaussian( + sh_degree=0, + aabb=[-0.5, -0.5, -0.5, 1.0, 1.0, 1.0], + mininum_kernel_size = self.rep_config['3d_filter_kernel_size'], + scaling_bias = self.rep_config['scaling_bias'], + opacity_bias = self.rep_config['opacity_bias'], + scaling_activation = self.rep_config['scaling_activation'] + ) + xyz = (x.coords[x.layout[i]][:, 1:].float() + 0.5) / self.resolution + for k, v in self.layout.items(): + if k == '_xyz': + offset = x.feats[x.layout[i]][:, v['range'][0]:v['range'][1]].reshape(-1, *v['shape']) + offset = offset * self.rep_config['lr'][k] + if self.rep_config['perturb_offset']: + offset = offset + self.offset_perturbation + offset = torch.tanh(offset) / self.resolution * 0.5 * self.rep_config['voxel_size'] + _xyz = xyz.unsqueeze(1) + offset + setattr(representation, k, _xyz.flatten(0, 1)) + else: + feats = x.feats[x.layout[i]][:, v['range'][0]:v['range'][1]].reshape(-1, *v['shape']).flatten(0, 1) + feats = feats * self.rep_config['lr'][k] + setattr(representation, k, feats) + ret.append(representation) + return ret + + def forward(self, x: sp.SparseTensor) -> List[Gaussian]: + h = super().forward(x) + h = h.type(x.dtype) + h = h.replace(F.layer_norm(h.feats, h.feats.shape[-1:])) + h = self.out_layer(h) + return self.to_representation(h) diff --git a/trellis/models/structured_latent_vae/decoder_mesh.py b/trellis/models/structured_latent_vae/decoder_mesh.py new file mode 100644 index 0000000000000000000000000000000000000000..958dc3a73b4aa367162c850707bb393163b79e5a --- /dev/null +++ b/trellis/models/structured_latent_vae/decoder_mesh.py @@ -0,0 +1,197 @@ +from typing import * +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +from ...modules.utils import zero_module, convert_module_to_f16, convert_module_to_f32, convert_module_to_bf16 +from ...modules import sparse as sp +from .base import SparseTransformerBase +from ...representations import MeshExtractResult +from ...representations.mesh import SparseFeatures2Mesh + + +class SparseSubdivideBlock3d(nn.Module): + """ + A 3D subdivide block that can subdivide the sparse tensor. + + Args: + channels: channels in the inputs and outputs. + out_channels: if specified, the number of output channels. + num_groups: the number of groups for the group norm. + """ + def __init__( + self, + channels: int, + resolution: int, + out_channels: Optional[int] = None, + num_groups: int = 32 + ): + super().__init__() + self.channels = channels + self.resolution = resolution + self.out_resolution = resolution * 2 + self.out_channels = out_channels or channels + + self.act_layers = nn.Sequential( + sp.SparseGroupNorm32(num_groups, channels), + sp.SparseSiLU() + ) + + self.sub = sp.SparseSubdivide() + + self.out_layers = nn.Sequential( + sp.SparseConv3d(channels, self.out_channels, 3, indice_key=f"res_{self.out_resolution}"), + sp.SparseGroupNorm32(num_groups, self.out_channels), + sp.SparseSiLU(), + zero_module(sp.SparseConv3d(self.out_channels, self.out_channels, 3, indice_key=f"res_{self.out_resolution}")), + ) + + if self.out_channels == channels: + self.skip_connection = nn.Identity() + else: + self.skip_connection = sp.SparseConv3d(channels, self.out_channels, 1, indice_key=f"res_{self.out_resolution}") + + def forward(self, x: sp.SparseTensor) -> sp.SparseTensor: + """ + Apply the block to a Tensor, conditioned on a timestep embedding. + + Args: + x: an [N x C x ...] Tensor of features. + Returns: + an [N x C x ...] Tensor of outputs. + """ + h = self.act_layers(x) + h = self.sub(h) + x = self.sub(x) + h = self.out_layers(h) + h = h + self.skip_connection(x) + return h + + +class SLatMeshDecoder(SparseTransformerBase): + def __init__( + self, + resolution: int, + model_channels: int, + latent_channels: int, + num_blocks: int, + num_heads: Optional[int] = None, + num_head_channels: Optional[int] = 64, + mlp_ratio: float = 4, + attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "swin", + window_size: int = 8, + pe_mode: Literal["ape", "rope"] = "ape", + use_fp16: bool = False, + use_bf16: bool = False, + use_checkpoint: bool = False, + qk_rms_norm: bool = False, + representation_config: dict = None, + mesh_extractor: str = "fc", + ): + super().__init__( + in_channels=latent_channels, + model_channels=model_channels, + num_blocks=num_blocks, + num_heads=num_heads, + num_head_channels=num_head_channels, + mlp_ratio=mlp_ratio, + attn_mode=attn_mode, + window_size=window_size, + pe_mode=pe_mode, + use_fp16=use_fp16, + use_bf16=use_bf16, + use_checkpoint=use_checkpoint, + qk_rms_norm=qk_rms_norm, + ) + self.resolution = resolution + self.rep_config = representation_config + if mesh_extractor == "mc": + try: + from ...representations.mesh import SparseFeatures2MCMesh + self.mesh_extractor = SparseFeatures2MCMesh(res=self.resolution*4, use_color=self.rep_config.get('use_color', False)) + except ImportError: + raise ValueError("SparseFeatures2MCMesh is not available. Please install the 'mc2mesh' extra.") + elif mesh_extractor == "fc": + self.mesh_extractor = SparseFeatures2Mesh(res=self.resolution*4, use_color=self.rep_config.get('use_color', False)) + else: + raise ValueError(f"Invalid mesh extractor {mesh_extractor}") + self.out_channels = self.mesh_extractor.feats_channels + self.upsample = nn.ModuleList([ + SparseSubdivideBlock3d( + channels=model_channels, + resolution=resolution, + out_channels=model_channels // 4 + ), + SparseSubdivideBlock3d( + channels=model_channels // 4, + resolution=resolution * 2, + out_channels=model_channels // 8 + ) + ]) + self.out_layer = sp.SparseLinear(model_channels // 8, self.out_channels) + + self.initialize_weights() + if use_fp16: + self.convert_to_fp16() + elif use_bf16: + self.convert_to_bf16() + + def initialize_weights(self) -> None: + super().initialize_weights() + # Zero-out output layers: + nn.init.constant_(self.out_layer.weight, 0) + nn.init.constant_(self.out_layer.bias, 0) + + def convert_to_fp16(self) -> None: + """ + Convert the torso of the model to float16. + """ + self.use_fp16 = True + self.use_bf16 = False + self.dtype = torch.float16 + super().convert_to_fp16() + self.upsample.apply(convert_module_to_f16) + + def convert_to_bf16(self) -> None: + """ + Convert the torso of the model to bfloat16. + """ + self.use_fp16 = False + self.use_bf16 = True + self.dtype = torch.bfloat16 + super().convert_to_bf16() + self.upsample.apply(convert_module_to_bf16) + + def convert_to_fp32(self) -> None: + """ + Convert the torso of the model to float32. + """ + self.use_fp16 = False + self.use_bf16 = False + self.dtype = torch.float32 + super().convert_to_fp32() + self.upsample.apply(convert_module_to_f32) + + def to_representation(self, x: sp.SparseTensor) -> List[MeshExtractResult]: + """ + Convert a batch of network outputs to 3D representations. + + Args: + x: The [N x * x C] sparse tensor output by the network. + + Returns: + list of representations + """ + ret = [] + for i in range(x.shape[0]): + mesh = self.mesh_extractor(x[i], training=self.training) + ret.append(mesh) + return ret + + def forward(self, x: sp.SparseTensor) -> List[MeshExtractResult]: + h = super().forward(x) + for block in self.upsample: + h = block(h) + h = h.type(x.dtype) + h = self.out_layer(h) + return self.to_representation(h) diff --git a/trellis/models/structured_latent_vae/encoder.py b/trellis/models/structured_latent_vae/encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..fd6eb07f66cec833ba0ea228b77f3645431ed0d1 --- /dev/null +++ b/trellis/models/structured_latent_vae/encoder.py @@ -0,0 +1,76 @@ +from typing import * +import torch +import torch.nn as nn +import torch.nn.functional as F +from ...modules import sparse as sp +from .base import SparseTransformerBase + + +class SLatEncoder(SparseTransformerBase): + def __init__( + self, + resolution: int, + in_channels: int, + model_channels: int, + latent_channels: int, + num_blocks: int, + num_heads: Optional[int] = None, + num_head_channels: Optional[int] = 64, + mlp_ratio: float = 4, + attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "swin", + window_size: int = 8, + pe_mode: Literal["ape", "rope"] = "ape", + use_fp16: bool = False, + use_bf16: bool = False, + use_checkpoint: bool = False, + qk_rms_norm: bool = False, + ): + super().__init__( + in_channels=in_channels, + model_channels=model_channels, + num_blocks=num_blocks, + num_heads=num_heads, + num_head_channels=num_head_channels, + mlp_ratio=mlp_ratio, + attn_mode=attn_mode, + window_size=window_size, + pe_mode=pe_mode, + use_fp16=use_fp16, + use_bf16=use_bf16, + use_checkpoint=use_checkpoint, + qk_rms_norm=qk_rms_norm, + ) + self.resolution = resolution + self.out_layer = sp.SparseLinear(model_channels, 2 * latent_channels) + + self.initialize_weights() + if use_fp16: + self.convert_to_fp16() + elif use_bf16: + self.convert_to_bf16() + + def initialize_weights(self) -> None: + super().initialize_weights() + # Zero-out output layers: + nn.init.constant_(self.out_layer.weight, 0) + nn.init.constant_(self.out_layer.bias, 0) + + def forward(self, x: sp.SparseTensor, sample_posterior=True, return_raw=False): + h = super().forward(x) + h = h.type(x.dtype) + h = h.replace(F.layer_norm(h.feats, h.feats.shape[-1:])) + h = self.out_layer(h) + + # Sample from the posterior distribution + mean, logvar = h.feats.chunk(2, dim=-1) + if sample_posterior: + std = torch.exp(0.5 * logvar) + z = mean + std * torch.randn_like(std) + else: + z = mean + z = h.replace(z) + + if return_raw: + return z, mean, logvar + else: + return z diff --git a/trellis/modules/__pycache__/norm.cpython-310.pyc b/trellis/modules/__pycache__/norm.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0f7287dfe3cf23e0e81b792837b4313725c31b72 Binary files /dev/null and b/trellis/modules/__pycache__/norm.cpython-310.pyc differ diff --git a/trellis/modules/__pycache__/spatial.cpython-310.pyc b/trellis/modules/__pycache__/spatial.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b154fb6e5f4de8c8d060947a4b9b0d2d801f20ca Binary files /dev/null and b/trellis/modules/__pycache__/spatial.cpython-310.pyc differ diff --git a/trellis/modules/__pycache__/utils.cpython-310.pyc b/trellis/modules/__pycache__/utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dd9a6303f0af67ed8395d0da57694dc282a4d8a6 Binary files /dev/null and b/trellis/modules/__pycache__/utils.cpython-310.pyc differ diff --git a/trellis/modules/attention/__init__.py b/trellis/modules/attention/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5e594957f8e1d0f4f5dda00528b54333c5fa671c --- /dev/null +++ b/trellis/modules/attention/__init__.py @@ -0,0 +1,38 @@ +from typing import * + +# BACKEND = 'xformers' +BACKEND = 'flash_attn' + +DEBUG = False + +def __from_env(): + import os + + global BACKEND + global DEBUG + + env_attn_backend = os.environ.get('ATTN_BACKEND') + env_sttn_debug = os.environ.get('ATTN_DEBUG') + + if env_attn_backend is not None and env_attn_backend in ['xformers', 'flash_attn', 'sdpa', 'naive']: + BACKEND = env_attn_backend + if env_sttn_debug is not None: + DEBUG = env_sttn_debug == '1' + + print(f"[ATTENTION] Using backend: {BACKEND}") + + +__from_env() + + +def set_backend(backend: Literal['xformers', 'flash_attn']): + global BACKEND + BACKEND = backend + +def set_debug(debug: bool): + global DEBUG + DEBUG = debug + + +from .full_attn import * +from .modules import * diff --git a/trellis/modules/attention/__pycache__/__init__.cpython-310.pyc b/trellis/modules/attention/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..22a07011334ff04709b1a0f6760439376bd74535 Binary files /dev/null and b/trellis/modules/attention/__pycache__/__init__.cpython-310.pyc differ diff --git a/trellis/modules/attention/__pycache__/full_attn.cpython-310.pyc b/trellis/modules/attention/__pycache__/full_attn.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..278ec68f3b43303ee95413eb56ab5886795a30f7 Binary files /dev/null and b/trellis/modules/attention/__pycache__/full_attn.cpython-310.pyc differ diff --git a/trellis/modules/attention/__pycache__/modules.cpython-310.pyc b/trellis/modules/attention/__pycache__/modules.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..423bcbe7586ebee9be62c3234e757d746d007bf2 Binary files /dev/null and b/trellis/modules/attention/__pycache__/modules.cpython-310.pyc differ diff --git a/trellis/modules/attention/full_attn.py b/trellis/modules/attention/full_attn.py new file mode 100644 index 0000000000000000000000000000000000000000..d9ebf6380a78906d4c6e969c63223fb7b398e5a7 --- /dev/null +++ b/trellis/modules/attention/full_attn.py @@ -0,0 +1,140 @@ +from typing import * +import torch +import math +from . import DEBUG, BACKEND + +if BACKEND == 'xformers': + import xformers.ops as xops +elif BACKEND == 'flash_attn': + import flash_attn +elif BACKEND == 'sdpa': + from torch.nn.functional import scaled_dot_product_attention as sdpa +elif BACKEND == 'naive': + pass +else: + raise ValueError(f"Unknown attention backend: {BACKEND}") + + +__all__ = [ + 'scaled_dot_product_attention', +] + + +def _naive_sdpa(q, k, v): + """ + Naive implementation of scaled dot product attention. + """ + q = q.permute(0, 2, 1, 3) # [N, H, L, C] + k = k.permute(0, 2, 1, 3) # [N, H, L, C] + v = v.permute(0, 2, 1, 3) # [N, H, L, C] + scale_factor = 1 / math.sqrt(q.size(-1)) + attn_weight = q @ k.transpose(-2, -1) * scale_factor + attn_weight = torch.softmax(attn_weight, dim=-1) + out = attn_weight @ v + out = out.permute(0, 2, 1, 3) # [N, L, H, C] + return out + + +@overload +def scaled_dot_product_attention(qkv: torch.Tensor) -> torch.Tensor: + """ + Apply scaled dot product attention. + + Args: + qkv (torch.Tensor): A [N, L, 3, H, C] tensor containing Qs, Ks, and Vs. + """ + ... + +@overload +def scaled_dot_product_attention(q: torch.Tensor, kv: torch.Tensor) -> torch.Tensor: + """ + Apply scaled dot product attention. + + Args: + q (torch.Tensor): A [N, L, H, C] tensor containing Qs. + kv (torch.Tensor): A [N, L, 2, H, C] tensor containing Ks and Vs. + """ + ... + +@overload +def scaled_dot_product_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor: + """ + Apply scaled dot product attention. + + Args: + q (torch.Tensor): A [N, L, H, Ci] tensor containing Qs. + k (torch.Tensor): A [N, L, H, Ci] tensor containing Ks. + v (torch.Tensor): A [N, L, H, Co] tensor containing Vs. + + Note: + k and v are assumed to have the same coordinate map. + """ + ... + +def scaled_dot_product_attention(*args, **kwargs): + arg_names_dict = { + 1: ['qkv'], + 2: ['q', 'kv'], + 3: ['q', 'k', 'v'] + } + num_all_args = len(args) + len(kwargs) + assert num_all_args in arg_names_dict, f"Invalid number of arguments, got {num_all_args}, expected 1, 2, or 3" + for key in arg_names_dict[num_all_args][len(args):]: + assert key in kwargs, f"Missing argument {key}" + + if num_all_args == 1: + qkv = args[0] if len(args) > 0 else kwargs['qkv'] + assert len(qkv.shape) == 5 and qkv.shape[2] == 3, f"Invalid shape for qkv, got {qkv.shape}, expected [N, L, 3, H, C]" + device = qkv.device + + elif num_all_args == 2: + q = args[0] if len(args) > 0 else kwargs['q'] + kv = args[1] if len(args) > 1 else kwargs['kv'] + assert q.shape[0] == kv.shape[0], f"Batch size mismatch, got {q.shape[0]} and {kv.shape[0]}" + assert len(q.shape) == 4, f"Invalid shape for q, got {q.shape}, expected [N, L, H, C]" + assert len(kv.shape) == 5, f"Invalid shape for kv, got {kv.shape}, expected [N, L, 2, H, C]" + device = q.device + + elif num_all_args == 3: + q = args[0] if len(args) > 0 else kwargs['q'] + k = args[1] if len(args) > 1 else kwargs['k'] + v = args[2] if len(args) > 2 else kwargs['v'] + assert q.shape[0] == k.shape[0] == v.shape[0], f"Batch size mismatch, got {q.shape[0]}, {k.shape[0]}, and {v.shape[0]}" + assert len(q.shape) == 4, f"Invalid shape for q, got {q.shape}, expected [N, L, H, Ci]" + assert len(k.shape) == 4, f"Invalid shape for k, got {k.shape}, expected [N, L, H, Ci]" + assert len(v.shape) == 4, f"Invalid shape for v, got {v.shape}, expected [N, L, H, Co]" + device = q.device + + if BACKEND == 'xformers': + if num_all_args == 1: + q, k, v = qkv.unbind(dim=2) + elif num_all_args == 2: + k, v = kv.unbind(dim=2) + out = xops.memory_efficient_attention(q, k, v) + elif BACKEND == 'flash_attn': + if num_all_args == 1: + out = flash_attn.flash_attn_qkvpacked_func(qkv) + elif num_all_args == 2: + out = flash_attn.flash_attn_kvpacked_func(q, kv) + elif num_all_args == 3: + out = flash_attn.flash_attn_func(q, k, v) + elif BACKEND == 'sdpa': + if num_all_args == 1: + q, k, v = qkv.unbind(dim=2) + elif num_all_args == 2: + k, v = kv.unbind(dim=2) + q = q.permute(0, 2, 1, 3) # [N, H, L, C] + k = k.permute(0, 2, 1, 3) # [N, H, L, C] + v = v.permute(0, 2, 1, 3) # [N, H, L, C] + out = sdpa(q, k, v) # [N, H, L, C] + out = out.permute(0, 2, 1, 3) # [N, L, H, C] + elif BACKEND == 'naive': + if num_all_args == 1: + q, k, v = qkv.unbind(dim=2) + elif num_all_args == 2: + k, v = kv.unbind(dim=2) + out = _naive_sdpa(q, k, v) + else: + raise ValueError(f"Unknown attention module: {BACKEND}") + + return out diff --git a/trellis/modules/attention/modules.py b/trellis/modules/attention/modules.py new file mode 100644 index 0000000000000000000000000000000000000000..dbe6235c27134f0477e48d3e12de3068c6a500ef --- /dev/null +++ b/trellis/modules/attention/modules.py @@ -0,0 +1,146 @@ +from typing import * +import torch +import torch.nn as nn +import torch.nn.functional as F +from .full_attn import scaled_dot_product_attention + + +class MultiHeadRMSNorm(nn.Module): + def __init__(self, dim: int, heads: int): + super().__init__() + self.scale = dim ** 0.5 + self.gamma = nn.Parameter(torch.ones(heads, dim)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return (F.normalize(x.float(), dim = -1) * self.gamma * self.scale).to(x.dtype) + + +class RotaryPositionEmbedder(nn.Module): + def __init__(self, hidden_size: int, in_channels: int = 3): + super().__init__() + assert hidden_size % 2 == 0, "Hidden size must be divisible by 2" + self.hidden_size = hidden_size + self.in_channels = in_channels + self.freq_dim = hidden_size // in_channels // 2 + self.freqs = torch.arange(self.freq_dim, dtype=torch.float32) / self.freq_dim + self.freqs = 1.0 / (10000 ** self.freqs) + + def _get_phases(self, indices: torch.Tensor) -> torch.Tensor: + self.freqs = self.freqs.to(indices.device) + phases = torch.outer(indices, self.freqs) + phases = torch.polar(torch.ones_like(phases), phases) + return phases + + def _rotary_embedding(self, x: torch.Tensor, phases: torch.Tensor) -> torch.Tensor: + x_complex = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2)) + x_rotated = x_complex * phases + x_embed = torch.view_as_real(x_rotated).reshape(*x_rotated.shape[:-1], -1).to(x.dtype) + return x_embed + + def forward(self, q: torch.Tensor, k: torch.Tensor, indices: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + q (sp.SparseTensor): [..., N, D] tensor of queries + k (sp.SparseTensor): [..., N, D] tensor of keys + indices (torch.Tensor): [..., N, C] tensor of spatial positions + """ + if indices is None: + indices = torch.arange(q.shape[-2], device=q.device) + if len(q.shape) > 2: + indices = indices.unsqueeze(0).expand(q.shape[:-2] + (-1,)) + + phases = self._get_phases(indices.reshape(-1)).reshape(*indices.shape[:-1], -1) + if phases.shape[1] < self.hidden_size // 2: + phases = torch.cat([phases, torch.polar( + torch.ones(*phases.shape[:-1], self.hidden_size // 2 - phases.shape[1], device=phases.device), + torch.zeros(*phases.shape[:-1], self.hidden_size // 2 - phases.shape[1], device=phases.device) + )], dim=-1) + q_embed = self._rotary_embedding(q, phases) + k_embed = self._rotary_embedding(k, phases) + return q_embed, k_embed + + +class MultiHeadAttention(nn.Module): + def __init__( + self, + channels: int, + num_heads: int, + ctx_channels: Optional[int]=None, + type: Literal["self", "cross"] = "self", + attn_mode: Literal["full", "windowed"] = "full", + window_size: Optional[int] = None, + shift_window: Optional[Tuple[int, int, int]] = None, + qkv_bias: bool = True, + use_rope: bool = False, + qk_rms_norm: bool = False, + ): + super().__init__() + assert channels % num_heads == 0 + assert type in ["self", "cross"], f"Invalid attention type: {type}" + assert attn_mode in ["full", "windowed"], f"Invalid attention mode: {attn_mode}" + assert type == "self" or attn_mode == "full", "Cross-attention only supports full attention" + + if attn_mode == "windowed": + raise NotImplementedError("Windowed attention is not yet implemented") + + self.channels = channels + self.head_dim = channels // num_heads + self.ctx_channels = ctx_channels if ctx_channels is not None else channels + self.num_heads = num_heads + self._type = type + self.attn_mode = attn_mode + self.window_size = window_size + self.shift_window = shift_window + self.use_rope = use_rope + self.qk_rms_norm = qk_rms_norm + + if self._type == "self": + self.to_qkv = nn.Linear(channels, channels * 3, bias=qkv_bias) + else: + self.to_q = nn.Linear(channels, channels, bias=qkv_bias) + self.to_kv = nn.Linear(self.ctx_channels, channels * 2, bias=qkv_bias) + + if self.qk_rms_norm: + self.q_rms_norm = MultiHeadRMSNorm(self.head_dim, num_heads) + self.k_rms_norm = MultiHeadRMSNorm(self.head_dim, num_heads) + + self.to_out = nn.Linear(channels, channels) + + if use_rope: + self.rope = RotaryPositionEmbedder(channels) + + def forward(self, x: torch.Tensor, context: Optional[torch.Tensor] = None, indices: Optional[torch.Tensor] = None) -> torch.Tensor: + B, L, C = x.shape + if self._type == "self": + qkv = self.to_qkv(x) + qkv = qkv.reshape(B, L, 3, self.num_heads, -1) + if self.use_rope: + q, k, v = qkv.unbind(dim=2) + q, k = self.rope(q, k, indices) + qkv = torch.stack([q, k, v], dim=2) + if self.attn_mode == "full": + if self.qk_rms_norm: + q, k, v = qkv.unbind(dim=2) + q = self.q_rms_norm(q) + k = self.k_rms_norm(k) + h = scaled_dot_product_attention(q, k, v) + else: + h = scaled_dot_product_attention(qkv) + elif self.attn_mode == "windowed": + raise NotImplementedError("Windowed attention is not yet implemented") + else: + Lkv = context.shape[1] + q = self.to_q(x) + kv = self.to_kv(context) + q = q.reshape(B, L, self.num_heads, -1) + kv = kv.reshape(B, Lkv, 2, self.num_heads, -1) + if self.qk_rms_norm: + q = self.q_rms_norm(q) + k, v = kv.unbind(dim=2) + k = self.k_rms_norm(k) + h = scaled_dot_product_attention(q, k, v) + else: + h = scaled_dot_product_attention(q, kv) + h = h.reshape(B, L, -1) + h = self.to_out(h) + return h diff --git a/trellis/modules/norm.py b/trellis/modules/norm.py new file mode 100644 index 0000000000000000000000000000000000000000..09035726081fb7afda2c62504d5474cfa483c58f --- /dev/null +++ b/trellis/modules/norm.py @@ -0,0 +1,25 @@ +import torch +import torch.nn as nn + + +class LayerNorm32(nn.LayerNorm): + def forward(self, x: torch.Tensor) -> torch.Tensor: + return super().forward(x.float()).type(x.dtype) + + +class GroupNorm32(nn.GroupNorm): + """ + A GroupNorm layer that converts to float32 before the forward pass. + """ + def forward(self, x: torch.Tensor) -> torch.Tensor: + return super().forward(x.float()).type(x.dtype) + + +class ChannelLayerNorm32(LayerNorm32): + def forward(self, x: torch.Tensor) -> torch.Tensor: + DIM = x.dim() + x = x.permute(0, *range(2, DIM), 1).contiguous() + x = super().forward(x) + x = x.permute(0, DIM-1, *range(1, DIM-1)).contiguous() + return x + \ No newline at end of file diff --git a/trellis/modules/sparse/__init__.py b/trellis/modules/sparse/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..726756c16dcfe0f04de0d2ea5bdce499fa220160 --- /dev/null +++ b/trellis/modules/sparse/__init__.py @@ -0,0 +1,102 @@ +from typing import * + +BACKEND = 'spconv' +DEBUG = False +ATTN = 'flash_attn' + +def __from_env(): + import os + + global BACKEND + global DEBUG + global ATTN + + env_sparse_backend = os.environ.get('SPARSE_BACKEND') + env_sparse_debug = os.environ.get('SPARSE_DEBUG') + env_sparse_attn = os.environ.get('SPARSE_ATTN_BACKEND') + if env_sparse_attn is None: + env_sparse_attn = os.environ.get('ATTN_BACKEND') + + if env_sparse_backend is not None and env_sparse_backend in ['spconv', 'torchsparse']: + BACKEND = env_sparse_backend + if env_sparse_debug is not None: + DEBUG = env_sparse_debug == '1' + if env_sparse_attn is not None and env_sparse_attn in ['xformers', 'flash_attn']: + ATTN = env_sparse_attn + + print(f"[SPARSE] Backend: {BACKEND}, Attention: {ATTN}") + + +__from_env() + + +def set_backend(backend: Literal['spconv', 'torchsparse']): + global BACKEND + BACKEND = backend + +def set_debug(debug: bool): + global DEBUG + DEBUG = debug + +def set_attn(attn: Literal['xformers', 'flash_attn']): + global ATTN + ATTN = attn + + +import importlib + +__attributes = { + 'SparseTensor': 'basic', + 'sparse_batch_broadcast': 'basic', + 'sparse_batch_op': 'basic', + 'sparse_cat': 'basic', + 'sparse_unbind': 'basic', + 'SparseGroupNorm': 'norm', + 'SparseLayerNorm': 'norm', + 'SparseGroupNorm32': 'norm', + 'SparseLayerNorm32': 'norm', + 'SparseReLU': 'nonlinearity', + 'SparseSiLU': 'nonlinearity', + 'SparseGELU': 'nonlinearity', + 'SparseActivation': 'nonlinearity', + 'SparseLinear': 'linear', + 'sparse_scaled_dot_product_attention': 'attention', + 'SerializeMode': 'attention', + 'sparse_serialized_scaled_dot_product_self_attention': 'attention', + 'sparse_windowed_scaled_dot_product_self_attention': 'attention', + 'SparseMultiHeadAttention': 'attention', + 'SparseConv3d': 'conv', + 'SparseInverseConv3d': 'conv', + 'SparseDownsample': 'spatial', + 'SparseUpsample': 'spatial', + 'SparseSubdivide' : 'spatial' +} + +__submodules = ['transformer'] + +__all__ = list(__attributes.keys()) + __submodules + +def __getattr__(name): + if name not in globals(): + if name in __attributes: + module_name = __attributes[name] + module = importlib.import_module(f".{module_name}", __name__) + globals()[name] = getattr(module, name) + elif name in __submodules: + module = importlib.import_module(f".{name}", __name__) + globals()[name] = module + else: + raise AttributeError(f"module {__name__} has no attribute {name}") + return globals()[name] + + +# For Pylance +if __name__ == '__main__': + from .basic import * + from .norm import * + from .nonlinearity import * + from .linear import * + from .attention import * + from .conv import * + from .spatial import * + import transformer diff --git a/trellis/modules/sparse/__pycache__/__init__.cpython-310.pyc b/trellis/modules/sparse/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ed807f72966edfb0f5326cb2471f319509edc26e Binary files /dev/null and b/trellis/modules/sparse/__pycache__/__init__.cpython-310.pyc differ diff --git a/trellis/modules/sparse/__pycache__/basic.cpython-310.pyc b/trellis/modules/sparse/__pycache__/basic.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..43d040474e123ecb5dd8fddd2f7197eb0356e235 Binary files /dev/null and b/trellis/modules/sparse/__pycache__/basic.cpython-310.pyc differ diff --git a/trellis/modules/sparse/__pycache__/linear.cpython-310.pyc b/trellis/modules/sparse/__pycache__/linear.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3a96a59de3c903bc055104d73b2d38b01cef55ca Binary files /dev/null and b/trellis/modules/sparse/__pycache__/linear.cpython-310.pyc differ diff --git a/trellis/modules/sparse/__pycache__/nonlinearity.cpython-310.pyc b/trellis/modules/sparse/__pycache__/nonlinearity.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..731a3b564306a0cac9022a1c9ad7e94a9cbba817 Binary files /dev/null and b/trellis/modules/sparse/__pycache__/nonlinearity.cpython-310.pyc differ diff --git a/trellis/modules/sparse/__pycache__/norm.cpython-310.pyc b/trellis/modules/sparse/__pycache__/norm.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a4573decac983f6f5b0eea585e17c9a61980958d Binary files /dev/null and b/trellis/modules/sparse/__pycache__/norm.cpython-310.pyc differ diff --git a/trellis/modules/sparse/__pycache__/spatial.cpython-310.pyc b/trellis/modules/sparse/__pycache__/spatial.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e23c93b4939b088b91bd8c505addebc20b1f7409 Binary files /dev/null and b/trellis/modules/sparse/__pycache__/spatial.cpython-310.pyc differ diff --git a/trellis/modules/sparse/attention/__init__.py b/trellis/modules/sparse/attention/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..32b3c2c837c613e41755ac4c85f9ed057a6f5bfb --- /dev/null +++ b/trellis/modules/sparse/attention/__init__.py @@ -0,0 +1,4 @@ +from .full_attn import * +from .serialized_attn import * +from .windowed_attn import * +from .modules import * diff --git a/trellis/modules/sparse/attention/__pycache__/__init__.cpython-310.pyc b/trellis/modules/sparse/attention/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f593d4c1e17d1ad6c00e90b05dd019a90ed590da Binary files /dev/null and b/trellis/modules/sparse/attention/__pycache__/__init__.cpython-310.pyc differ diff --git a/trellis/modules/sparse/attention/__pycache__/full_attn.cpython-310.pyc b/trellis/modules/sparse/attention/__pycache__/full_attn.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..68e865e9c460bfabd40f858962360334dbf1bd25 Binary files /dev/null and b/trellis/modules/sparse/attention/__pycache__/full_attn.cpython-310.pyc differ diff --git a/trellis/modules/sparse/attention/__pycache__/modules.cpython-310.pyc b/trellis/modules/sparse/attention/__pycache__/modules.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..db6e497005c2cb7d7912fd7ba1325929e826c7ed Binary files /dev/null and b/trellis/modules/sparse/attention/__pycache__/modules.cpython-310.pyc differ diff --git a/trellis/modules/sparse/attention/__pycache__/serialized_attn.cpython-310.pyc b/trellis/modules/sparse/attention/__pycache__/serialized_attn.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..47fdd114317f1e73d6ec0be5bccd8d6bc9961a26 Binary files /dev/null and b/trellis/modules/sparse/attention/__pycache__/serialized_attn.cpython-310.pyc differ diff --git a/trellis/modules/sparse/attention/__pycache__/windowed_attn.cpython-310.pyc b/trellis/modules/sparse/attention/__pycache__/windowed_attn.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1bf32dc2364b42eec3090693994a0115b1f29d15 Binary files /dev/null and b/trellis/modules/sparse/attention/__pycache__/windowed_attn.cpython-310.pyc differ diff --git a/trellis/modules/sparse/attention/full_attn.py b/trellis/modules/sparse/attention/full_attn.py new file mode 100644 index 0000000000000000000000000000000000000000..e9e27aeb98419621f3f9999fd3b11eebf2b90a40 --- /dev/null +++ b/trellis/modules/sparse/attention/full_attn.py @@ -0,0 +1,215 @@ +from typing import * +import torch +from .. import SparseTensor +from .. import DEBUG, ATTN + +if ATTN == 'xformers': + import xformers.ops as xops +elif ATTN == 'flash_attn': + import flash_attn +else: + raise ValueError(f"Unknown attention module: {ATTN}") + + +__all__ = [ + 'sparse_scaled_dot_product_attention', +] + + +@overload +def sparse_scaled_dot_product_attention(qkv: SparseTensor) -> SparseTensor: + """ + Apply scaled dot product attention to a sparse tensor. + + Args: + qkv (SparseTensor): A [N, *, 3, H, C] sparse tensor containing Qs, Ks, and Vs. + """ + ... + +@overload +def sparse_scaled_dot_product_attention(q: SparseTensor, kv: Union[SparseTensor, torch.Tensor]) -> SparseTensor: + """ + Apply scaled dot product attention to a sparse tensor. + + Args: + q (SparseTensor): A [N, *, H, C] sparse tensor containing Qs. + kv (SparseTensor or torch.Tensor): A [N, *, 2, H, C] sparse tensor or a [N, L, 2, H, C] dense tensor containing Ks and Vs. + """ + ... + +@overload +def sparse_scaled_dot_product_attention(q: torch.Tensor, kv: SparseTensor) -> torch.Tensor: + """ + Apply scaled dot product attention to a sparse tensor. + + Args: + q (SparseTensor): A [N, L, H, C] dense tensor containing Qs. + kv (SparseTensor or torch.Tensor): A [N, *, 2, H, C] sparse tensor containing Ks and Vs. + """ + ... + +@overload +def sparse_scaled_dot_product_attention(q: SparseTensor, k: SparseTensor, v: SparseTensor) -> SparseTensor: + """ + Apply scaled dot product attention to a sparse tensor. + + Args: + q (SparseTensor): A [N, *, H, Ci] sparse tensor containing Qs. + k (SparseTensor): A [N, *, H, Ci] sparse tensor containing Ks. + v (SparseTensor): A [N, *, H, Co] sparse tensor containing Vs. + + Note: + k and v are assumed to have the same coordinate map. + """ + ... + +@overload +def sparse_scaled_dot_product_attention(q: SparseTensor, k: torch.Tensor, v: torch.Tensor) -> SparseTensor: + """ + Apply scaled dot product attention to a sparse tensor. + + Args: + q (SparseTensor): A [N, *, H, Ci] sparse tensor containing Qs. + k (torch.Tensor): A [N, L, H, Ci] dense tensor containing Ks. + v (torch.Tensor): A [N, L, H, Co] dense tensor containing Vs. + """ + ... + +@overload +def sparse_scaled_dot_product_attention(q: torch.Tensor, k: SparseTensor, v: SparseTensor) -> torch.Tensor: + """ + Apply scaled dot product attention to a sparse tensor. + + Args: + q (torch.Tensor): A [N, L, H, Ci] dense tensor containing Qs. + k (SparseTensor): A [N, *, H, Ci] sparse tensor containing Ks. + v (SparseTensor): A [N, *, H, Co] sparse tensor containing Vs. + """ + ... + +def sparse_scaled_dot_product_attention(*args, **kwargs): + arg_names_dict = { + 1: ['qkv'], + 2: ['q', 'kv'], + 3: ['q', 'k', 'v'] + } + num_all_args = len(args) + len(kwargs) + assert num_all_args in arg_names_dict, f"Invalid number of arguments, got {num_all_args}, expected 1, 2, or 3" + for key in arg_names_dict[num_all_args][len(args):]: + assert key in kwargs, f"Missing argument {key}" + + if num_all_args == 1: + qkv = args[0] if len(args) > 0 else kwargs['qkv'] + assert isinstance(qkv, SparseTensor), f"qkv must be a SparseTensor, got {type(qkv)}" + assert len(qkv.shape) == 4 and qkv.shape[1] == 3, f"Invalid shape for qkv, got {qkv.shape}, expected [N, *, 3, H, C]" + device = qkv.device + + s = qkv + q_seqlen = [qkv.layout[i].stop - qkv.layout[i].start for i in range(qkv.shape[0])] + kv_seqlen = q_seqlen + qkv = qkv.feats # [T, 3, H, C] + + elif num_all_args == 2: + q = args[0] if len(args) > 0 else kwargs['q'] + kv = args[1] if len(args) > 1 else kwargs['kv'] + assert isinstance(q, SparseTensor) and isinstance(kv, (SparseTensor, torch.Tensor)) or \ + isinstance(q, torch.Tensor) and isinstance(kv, SparseTensor), \ + f"Invalid types, got {type(q)} and {type(kv)}" + assert q.shape[0] == kv.shape[0], f"Batch size mismatch, got {q.shape[0]} and {kv.shape[0]}" + device = q.device + + if isinstance(q, SparseTensor): + assert len(q.shape) == 3, f"Invalid shape for q, got {q.shape}, expected [N, *, H, C]" + s = q + q_seqlen = [q.layout[i].stop - q.layout[i].start for i in range(q.shape[0])] + q = q.feats # [T_Q, H, C] + else: + assert len(q.shape) == 4, f"Invalid shape for q, got {q.shape}, expected [N, L, H, C]" + s = None + N, L, H, C = q.shape + q_seqlen = [L] * N + q = q.reshape(N * L, H, C) # [T_Q, H, C] + + if isinstance(kv, SparseTensor): + assert len(kv.shape) == 4 and kv.shape[1] == 2, f"Invalid shape for kv, got {kv.shape}, expected [N, *, 2, H, C]" + kv_seqlen = [kv.layout[i].stop - kv.layout[i].start for i in range(kv.shape[0])] + kv = kv.feats # [T_KV, 2, H, C] + else: + assert len(kv.shape) == 5, f"Invalid shape for kv, got {kv.shape}, expected [N, L, 2, H, C]" + N, L, _, H, C = kv.shape + kv_seqlen = [L] * N + kv = kv.reshape(N * L, 2, H, C) # [T_KV, 2, H, C] + + elif num_all_args == 3: + q = args[0] if len(args) > 0 else kwargs['q'] + k = args[1] if len(args) > 1 else kwargs['k'] + v = args[2] if len(args) > 2 else kwargs['v'] + assert isinstance(q, SparseTensor) and isinstance(k, (SparseTensor, torch.Tensor)) and type(k) == type(v) or \ + isinstance(q, torch.Tensor) and isinstance(k, SparseTensor) and isinstance(v, SparseTensor), \ + f"Invalid types, got {type(q)}, {type(k)}, and {type(v)}" + assert q.shape[0] == k.shape[0] == v.shape[0], f"Batch size mismatch, got {q.shape[0]}, {k.shape[0]}, and {v.shape[0]}" + device = q.device + + if isinstance(q, SparseTensor): + assert len(q.shape) == 3, f"Invalid shape for q, got {q.shape}, expected [N, *, H, Ci]" + s = q + q_seqlen = [q.layout[i].stop - q.layout[i].start for i in range(q.shape[0])] + q = q.feats # [T_Q, H, Ci] + else: + assert len(q.shape) == 4, f"Invalid shape for q, got {q.shape}, expected [N, L, H, Ci]" + s = None + N, L, H, CI = q.shape + q_seqlen = [L] * N + q = q.reshape(N * L, H, CI) # [T_Q, H, Ci] + + if isinstance(k, SparseTensor): + assert len(k.shape) == 3, f"Invalid shape for k, got {k.shape}, expected [N, *, H, Ci]" + assert len(v.shape) == 3, f"Invalid shape for v, got {v.shape}, expected [N, *, H, Co]" + kv_seqlen = [k.layout[i].stop - k.layout[i].start for i in range(k.shape[0])] + k = k.feats # [T_KV, H, Ci] + v = v.feats # [T_KV, H, Co] + else: + assert len(k.shape) == 4, f"Invalid shape for k, got {k.shape}, expected [N, L, H, Ci]" + assert len(v.shape) == 4, f"Invalid shape for v, got {v.shape}, expected [N, L, H, Co]" + N, L, H, CI, CO = *k.shape, v.shape[-1] + kv_seqlen = [L] * N + k = k.reshape(N * L, H, CI) # [T_KV, H, Ci] + v = v.reshape(N * L, H, CO) # [T_KV, H, Co] + + if DEBUG: + if s is not None: + for i in range(s.shape[0]): + assert (s.coords[s.layout[i]] == i).all(), f"SparseScaledDotProductSelfAttention: batch index mismatch" + if num_all_args in [2, 3]: + assert q.shape[:2] == [1, sum(q_seqlen)], f"SparseScaledDotProductSelfAttention: q shape mismatch" + if num_all_args == 3: + assert k.shape[:2] == [1, sum(kv_seqlen)], f"SparseScaledDotProductSelfAttention: k shape mismatch" + assert v.shape[:2] == [1, sum(kv_seqlen)], f"SparseScaledDotProductSelfAttention: v shape mismatch" + + if ATTN == 'xformers': + if num_all_args == 1: + q, k, v = qkv.unbind(dim=1) + elif num_all_args == 2: + k, v = kv.unbind(dim=1) + q = q.unsqueeze(0) + k = k.unsqueeze(0) + v = v.unsqueeze(0) + mask = xops.fmha.BlockDiagonalMask.from_seqlens(q_seqlen, kv_seqlen) + out = xops.memory_efficient_attention(q, k, v, mask)[0] + elif ATTN == 'flash_attn': + cu_seqlens_q = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(q_seqlen), dim=0)]).int().to(device) + if num_all_args in [2, 3]: + cu_seqlens_kv = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(kv_seqlen), dim=0)]).int().to(device) + if num_all_args == 1: + out = flash_attn.flash_attn_varlen_qkvpacked_func(qkv, cu_seqlens_q, max(q_seqlen)) + elif num_all_args == 2: + out = flash_attn.flash_attn_varlen_kvpacked_func(q, kv, cu_seqlens_q, cu_seqlens_kv, max(q_seqlen), max(kv_seqlen)) + elif num_all_args == 3: + out = flash_attn.flash_attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_kv, max(q_seqlen), max(kv_seqlen)) + else: + raise ValueError(f"Unknown attention module: {ATTN}") + + if s is not None: + return s.replace(out) + else: + return out.reshape(N, L, H, -1) diff --git a/trellis/modules/sparse/attention/modules.py b/trellis/modules/sparse/attention/modules.py new file mode 100644 index 0000000000000000000000000000000000000000..5d2fe782b0947700e308e9ec0325e7e91c84e3c2 --- /dev/null +++ b/trellis/modules/sparse/attention/modules.py @@ -0,0 +1,139 @@ +from typing import * +import torch +import torch.nn as nn +import torch.nn.functional as F +from .. import SparseTensor +from .full_attn import sparse_scaled_dot_product_attention +from .serialized_attn import SerializeMode, sparse_serialized_scaled_dot_product_self_attention +from .windowed_attn import sparse_windowed_scaled_dot_product_self_attention +from ...attention import RotaryPositionEmbedder + + +class SparseMultiHeadRMSNorm(nn.Module): + def __init__(self, dim: int, heads: int): + super().__init__() + self.scale = dim ** 0.5 + self.gamma = nn.Parameter(torch.ones(heads, dim)) + + def forward(self, x: Union[SparseTensor, torch.Tensor]) -> Union[SparseTensor, torch.Tensor]: + x_type = x.dtype + x = x.float() + if isinstance(x, SparseTensor): + x = x.replace(F.normalize(x.feats, dim=-1)) + else: + x = F.normalize(x, dim=-1) + return (x * self.gamma * self.scale).to(x_type) + + +class SparseMultiHeadAttention(nn.Module): + def __init__( + self, + channels: int, + num_heads: int, + ctx_channels: Optional[int] = None, + type: Literal["self", "cross"] = "self", + attn_mode: Literal["full", "serialized", "windowed"] = "full", + window_size: Optional[int] = None, + shift_sequence: Optional[int] = None, + shift_window: Optional[Tuple[int, int, int]] = None, + serialize_mode: Optional[SerializeMode] = None, + qkv_bias: bool = True, + use_rope: bool = False, + qk_rms_norm: bool = False, + ): + super().__init__() + assert channels % num_heads == 0 + assert type in ["self", "cross"], f"Invalid attention type: {type}" + assert attn_mode in ["full", "serialized", "windowed"], f"Invalid attention mode: {attn_mode}" + assert type == "self" or attn_mode == "full", "Cross-attention only supports full attention" + assert type == "self" or use_rope is False, "Rotary position embeddings only supported for self-attention" + self.channels = channels + self.ctx_channels = ctx_channels if ctx_channels is not None else channels + self.num_heads = num_heads + self._type = type + self.attn_mode = attn_mode + self.window_size = window_size + self.shift_sequence = shift_sequence + self.shift_window = shift_window + self.serialize_mode = serialize_mode + self.use_rope = use_rope + self.qk_rms_norm = qk_rms_norm + + if self._type == "self": + self.to_qkv = nn.Linear(channels, channels * 3, bias=qkv_bias) + else: + self.to_q = nn.Linear(channels, channels, bias=qkv_bias) + self.to_kv = nn.Linear(self.ctx_channels, channels * 2, bias=qkv_bias) + + if self.qk_rms_norm: + self.q_rms_norm = SparseMultiHeadRMSNorm(channels // num_heads, num_heads) + self.k_rms_norm = SparseMultiHeadRMSNorm(channels // num_heads, num_heads) + + self.to_out = nn.Linear(channels, channels) + + if use_rope: + self.rope = RotaryPositionEmbedder(channels) + + @staticmethod + def _linear(module: nn.Linear, x: Union[SparseTensor, torch.Tensor]) -> Union[SparseTensor, torch.Tensor]: + if isinstance(x, SparseTensor): + return x.replace(module(x.feats)) + else: + return module(x) + + @staticmethod + def _reshape_chs(x: Union[SparseTensor, torch.Tensor], shape: Tuple[int, ...]) -> Union[SparseTensor, torch.Tensor]: + if isinstance(x, SparseTensor): + return x.reshape(*shape) + else: + return x.reshape(*x.shape[:2], *shape) + + def _fused_pre(self, x: Union[SparseTensor, torch.Tensor], num_fused: int) -> Union[SparseTensor, torch.Tensor]: + if isinstance(x, SparseTensor): + x_feats = x.feats.unsqueeze(0) + else: + x_feats = x + x_feats = x_feats.reshape(*x_feats.shape[:2], num_fused, self.num_heads, -1) + return x.replace(x_feats.squeeze(0)) if isinstance(x, SparseTensor) else x_feats + + def _rope(self, qkv: SparseTensor) -> SparseTensor: + q, k, v = qkv.feats.unbind(dim=1) # [T, H, C] + q, k = self.rope(q, k, qkv.coords[:, 1:]) + qkv = qkv.replace(torch.stack([q, k, v], dim=1)) + return qkv + + def forward(self, x: Union[SparseTensor, torch.Tensor], context: Optional[Union[SparseTensor, torch.Tensor]] = None) -> Union[SparseTensor, torch.Tensor]: + if self._type == "self": + qkv = self._linear(self.to_qkv, x) + qkv = self._fused_pre(qkv, num_fused=3) + if self.use_rope: + qkv = self._rope(qkv) + if self.qk_rms_norm: + q, k, v = qkv.unbind(dim=1) + q = self.q_rms_norm(q) + k = self.k_rms_norm(k) + qkv = qkv.replace(torch.stack([q.feats, k.feats, v.feats], dim=1)) + if self.attn_mode == "full": + h = sparse_scaled_dot_product_attention(qkv) + elif self.attn_mode == "serialized": + h = sparse_serialized_scaled_dot_product_self_attention( + qkv, self.window_size, serialize_mode=self.serialize_mode, shift_sequence=self.shift_sequence, shift_window=self.shift_window + ) + elif self.attn_mode == "windowed": + h = sparse_windowed_scaled_dot_product_self_attention( + qkv, self.window_size, shift_window=self.shift_window + ) + else: + q = self._linear(self.to_q, x) + q = self._reshape_chs(q, (self.num_heads, -1)) + kv = self._linear(self.to_kv, context) + kv = self._fused_pre(kv, num_fused=2) + if self.qk_rms_norm: + q = self.q_rms_norm(q) + k, v = kv.unbind(dim=1) + k = self.k_rms_norm(k) + kv = kv.replace(torch.stack([k.feats, v.feats], dim=1)) + h = sparse_scaled_dot_product_attention(q, kv) + h = self._reshape_chs(h, (-1,)) + h = self._linear(self.to_out, h) + return h diff --git a/trellis/modules/sparse/attention/serialized_attn.py b/trellis/modules/sparse/attention/serialized_attn.py new file mode 100644 index 0000000000000000000000000000000000000000..5950b75b2f5a6d6e79ab6d472b8501aaa5ec4a26 --- /dev/null +++ b/trellis/modules/sparse/attention/serialized_attn.py @@ -0,0 +1,193 @@ +from typing import * +from enum import Enum +import torch +import math +from .. import SparseTensor +from .. import DEBUG, ATTN + +if ATTN == 'xformers': + import xformers.ops as xops +elif ATTN == 'flash_attn': + import flash_attn +else: + raise ValueError(f"Unknown attention module: {ATTN}") + + +__all__ = [ + 'sparse_serialized_scaled_dot_product_self_attention', +] + + +class SerializeMode(Enum): + Z_ORDER = 0 + Z_ORDER_TRANSPOSED = 1 + HILBERT = 2 + HILBERT_TRANSPOSED = 3 + + +SerializeModes = [ + SerializeMode.Z_ORDER, + SerializeMode.Z_ORDER_TRANSPOSED, + SerializeMode.HILBERT, + SerializeMode.HILBERT_TRANSPOSED +] + + +def calc_serialization( + tensor: SparseTensor, + window_size: int, + serialize_mode: SerializeMode = SerializeMode.Z_ORDER, + shift_sequence: int = 0, + shift_window: Tuple[int, int, int] = (0, 0, 0) +) -> Tuple[torch.Tensor, torch.Tensor, List[int]]: + """ + Calculate serialization and partitioning for a set of coordinates. + + Args: + tensor (SparseTensor): The input tensor. + window_size (int): The window size to use. + serialize_mode (SerializeMode): The serialization mode to use. + shift_sequence (int): The shift of serialized sequence. + shift_window (Tuple[int, int, int]): The shift of serialized coordinates. + + Returns: + (torch.Tensor, torch.Tensor): Forwards and backwards indices. + """ + fwd_indices = [] + bwd_indices = [] + seq_lens = [] + seq_batch_indices = [] + offsets = [0] + + if 'vox2seq' not in globals(): + import vox2seq + + # Serialize the input + serialize_coords = tensor.coords[:, 1:].clone() + serialize_coords += torch.tensor(shift_window, dtype=torch.int32, device=tensor.device).reshape(1, 3) + if serialize_mode == SerializeMode.Z_ORDER: + code = vox2seq.encode(serialize_coords, mode='z_order', permute=[0, 1, 2]) + elif serialize_mode == SerializeMode.Z_ORDER_TRANSPOSED: + code = vox2seq.encode(serialize_coords, mode='z_order', permute=[1, 0, 2]) + elif serialize_mode == SerializeMode.HILBERT: + code = vox2seq.encode(serialize_coords, mode='hilbert', permute=[0, 1, 2]) + elif serialize_mode == SerializeMode.HILBERT_TRANSPOSED: + code = vox2seq.encode(serialize_coords, mode='hilbert', permute=[1, 0, 2]) + else: + raise ValueError(f"Unknown serialize mode: {serialize_mode}") + + for bi, s in enumerate(tensor.layout): + num_points = s.stop - s.start + num_windows = (num_points + window_size - 1) // window_size + valid_window_size = num_points / num_windows + to_ordered = torch.argsort(code[s.start:s.stop]) + if num_windows == 1: + fwd_indices.append(to_ordered) + bwd_indices.append(torch.zeros_like(to_ordered).scatter_(0, to_ordered, torch.arange(num_points, device=tensor.device))) + fwd_indices[-1] += s.start + bwd_indices[-1] += offsets[-1] + seq_lens.append(num_points) + seq_batch_indices.append(bi) + offsets.append(offsets[-1] + seq_lens[-1]) + else: + # Partition the input + offset = 0 + mids = [(i + 0.5) * valid_window_size + shift_sequence for i in range(num_windows)] + split = [math.floor(i * valid_window_size + shift_sequence) for i in range(num_windows + 1)] + bwd_index = torch.zeros((num_points,), dtype=torch.int64, device=tensor.device) + for i in range(num_windows): + mid = mids[i] + valid_start = split[i] + valid_end = split[i + 1] + padded_start = math.floor(mid - 0.5 * window_size) + padded_end = padded_start + window_size + fwd_indices.append(to_ordered[torch.arange(padded_start, padded_end, device=tensor.device) % num_points]) + offset += valid_start - padded_start + bwd_index.scatter_(0, fwd_indices[-1][valid_start-padded_start:valid_end-padded_start], torch.arange(offset, offset + valid_end - valid_start, device=tensor.device)) + offset += padded_end - valid_start + fwd_indices[-1] += s.start + seq_lens.extend([window_size] * num_windows) + seq_batch_indices.extend([bi] * num_windows) + bwd_indices.append(bwd_index + offsets[-1]) + offsets.append(offsets[-1] + num_windows * window_size) + + fwd_indices = torch.cat(fwd_indices) + bwd_indices = torch.cat(bwd_indices) + + return fwd_indices, bwd_indices, seq_lens, seq_batch_indices + + +def sparse_serialized_scaled_dot_product_self_attention( + qkv: SparseTensor, + window_size: int, + serialize_mode: SerializeMode = SerializeMode.Z_ORDER, + shift_sequence: int = 0, + shift_window: Tuple[int, int, int] = (0, 0, 0) +) -> SparseTensor: + """ + Apply serialized scaled dot product self attention to a sparse tensor. + + Args: + qkv (SparseTensor): [N, *, 3, H, C] sparse tensor containing Qs, Ks, and Vs. + window_size (int): The window size to use. + serialize_mode (SerializeMode): The serialization mode to use. + shift_sequence (int): The shift of serialized sequence. + shift_window (Tuple[int, int, int]): The shift of serialized coordinates. + shift (int): The shift to use. + """ + assert len(qkv.shape) == 4 and qkv.shape[1] == 3, f"Invalid shape for qkv, got {qkv.shape}, expected [N, *, 3, H, C]" + + serialization_spatial_cache_name = f'serialization_{serialize_mode}_{window_size}_{shift_sequence}_{shift_window}' + serialization_spatial_cache = qkv.get_spatial_cache(serialization_spatial_cache_name) + if serialization_spatial_cache is None: + fwd_indices, bwd_indices, seq_lens, seq_batch_indices = calc_serialization(qkv, window_size, serialize_mode, shift_sequence, shift_window) + qkv.register_spatial_cache(serialization_spatial_cache_name, (fwd_indices, bwd_indices, seq_lens, seq_batch_indices)) + else: + fwd_indices, bwd_indices, seq_lens, seq_batch_indices = serialization_spatial_cache + + M = fwd_indices.shape[0] + T = qkv.feats.shape[0] + H = qkv.feats.shape[2] + C = qkv.feats.shape[3] + + qkv_feats = qkv.feats[fwd_indices] # [M, 3, H, C] + + if DEBUG: + start = 0 + qkv_coords = qkv.coords[fwd_indices] + for i in range(len(seq_lens)): + assert (qkv_coords[start:start+seq_lens[i], 0] == seq_batch_indices[i]).all(), f"SparseWindowedScaledDotProductSelfAttention: batch index mismatch" + start += seq_lens[i] + + if all([seq_len == window_size for seq_len in seq_lens]): + B = len(seq_lens) + N = window_size + qkv_feats = qkv_feats.reshape(B, N, 3, H, C) + if ATTN == 'xformers': + q, k, v = qkv_feats.unbind(dim=2) # [B, N, H, C] + out = xops.memory_efficient_attention(q, k, v) # [B, N, H, C] + elif ATTN == 'flash_attn': + out = flash_attn.flash_attn_qkvpacked_func(qkv_feats) # [B, N, H, C] + else: + raise ValueError(f"Unknown attention module: {ATTN}") + out = out.reshape(B * N, H, C) # [M, H, C] + else: + if ATTN == 'xformers': + q, k, v = qkv_feats.unbind(dim=1) # [M, H, C] + q = q.unsqueeze(0) # [1, M, H, C] + k = k.unsqueeze(0) # [1, M, H, C] + v = v.unsqueeze(0) # [1, M, H, C] + mask = xops.fmha.BlockDiagonalMask.from_seqlens(seq_lens) + out = xops.memory_efficient_attention(q, k, v, mask)[0] # [M, H, C] + elif ATTN == 'flash_attn': + cu_seqlens = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(seq_lens), dim=0)], dim=0) \ + .to(qkv.device).int() + out = flash_attn.flash_attn_varlen_qkvpacked_func(qkv_feats, cu_seqlens, max(seq_lens)) # [M, H, C] + + out = out[bwd_indices] # [T, H, C] + + if DEBUG: + qkv_coords = qkv_coords[bwd_indices] + assert torch.equal(qkv_coords, qkv.coords), "SparseWindowedScaledDotProductSelfAttention: coordinate mismatch" + + return qkv.replace(out) diff --git a/trellis/modules/sparse/attention/windowed_attn.py b/trellis/modules/sparse/attention/windowed_attn.py new file mode 100644 index 0000000000000000000000000000000000000000..cd642c5252e29a3a5e59fad7ed3880b7b00bcf9a --- /dev/null +++ b/trellis/modules/sparse/attention/windowed_attn.py @@ -0,0 +1,135 @@ +from typing import * +import torch +import math +from .. import SparseTensor +from .. import DEBUG, ATTN + +if ATTN == 'xformers': + import xformers.ops as xops +elif ATTN == 'flash_attn': + import flash_attn +else: + raise ValueError(f"Unknown attention module: {ATTN}") + + +__all__ = [ + 'sparse_windowed_scaled_dot_product_self_attention', +] + + +def calc_window_partition( + tensor: SparseTensor, + window_size: Union[int, Tuple[int, ...]], + shift_window: Union[int, Tuple[int, ...]] = 0 +) -> Tuple[torch.Tensor, torch.Tensor, List[int], List[int]]: + """ + Calculate serialization and partitioning for a set of coordinates. + + Args: + tensor (SparseTensor): The input tensor. + window_size (int): The window size to use. + shift_window (Tuple[int, ...]): The shift of serialized coordinates. + + Returns: + (torch.Tensor): Forwards indices. + (torch.Tensor): Backwards indices. + (List[int]): Sequence lengths. + (List[int]): Sequence batch indices. + """ + DIM = tensor.coords.shape[1] - 1 + shift_window = (shift_window,) * DIM if isinstance(shift_window, int) else shift_window + window_size = (window_size,) * DIM if isinstance(window_size, int) else window_size + shifted_coords = tensor.coords.clone().detach() + shifted_coords[:, 1:] += torch.tensor(shift_window, device=tensor.device, dtype=torch.int32).unsqueeze(0) + + MAX_COORDS = shifted_coords[:, 1:].max(dim=0).values.tolist() + NUM_WINDOWS = [math.ceil((mc + 1) / ws) for mc, ws in zip(MAX_COORDS, window_size)] + OFFSET = torch.cumprod(torch.tensor([1] + NUM_WINDOWS[::-1]), dim=0).tolist()[::-1] + + shifted_coords[:, 1:] //= torch.tensor(window_size, device=tensor.device, dtype=torch.int32).unsqueeze(0) + shifted_indices = (shifted_coords * torch.tensor(OFFSET, device=tensor.device, dtype=torch.int32).unsqueeze(0)).sum(dim=1) + fwd_indices = torch.argsort(shifted_indices) + bwd_indices = torch.empty_like(fwd_indices) + bwd_indices[fwd_indices] = torch.arange(fwd_indices.shape[0], device=tensor.device) + seq_lens = torch.bincount(shifted_indices) + seq_batch_indices = torch.arange(seq_lens.shape[0], device=tensor.device, dtype=torch.int32) // OFFSET[0] + mask = seq_lens != 0 + seq_lens = seq_lens[mask].tolist() + seq_batch_indices = seq_batch_indices[mask].tolist() + + return fwd_indices, bwd_indices, seq_lens, seq_batch_indices + + +def sparse_windowed_scaled_dot_product_self_attention( + qkv: SparseTensor, + window_size: int, + shift_window: Tuple[int, int, int] = (0, 0, 0) +) -> SparseTensor: + """ + Apply windowed scaled dot product self attention to a sparse tensor. + + Args: + qkv (SparseTensor): [N, *, 3, H, C] sparse tensor containing Qs, Ks, and Vs. + window_size (int): The window size to use. + shift_window (Tuple[int, int, int]): The shift of serialized coordinates. + shift (int): The shift to use. + """ + assert len(qkv.shape) == 4 and qkv.shape[1] == 3, f"Invalid shape for qkv, got {qkv.shape}, expected [N, *, 3, H, C]" + + serialization_spatial_cache_name = f'window_partition_{window_size}_{shift_window}' + serialization_spatial_cache = qkv.get_spatial_cache(serialization_spatial_cache_name) + if serialization_spatial_cache is None: + fwd_indices, bwd_indices, seq_lens, seq_batch_indices = calc_window_partition(qkv, window_size, shift_window) + qkv.register_spatial_cache(serialization_spatial_cache_name, (fwd_indices, bwd_indices, seq_lens, seq_batch_indices)) + else: + fwd_indices, bwd_indices, seq_lens, seq_batch_indices = serialization_spatial_cache + + M = fwd_indices.shape[0] + T = qkv.feats.shape[0] + H = qkv.feats.shape[2] + C = qkv.feats.shape[3] + + qkv_feats = qkv.feats[fwd_indices] # [M, 3, H, C] + + if DEBUG: + start = 0 + qkv_coords = qkv.coords[fwd_indices] + for i in range(len(seq_lens)): + seq_coords = qkv_coords[start:start+seq_lens[i]] + assert (seq_coords[:, 0] == seq_batch_indices[i]).all(), f"SparseWindowedScaledDotProductSelfAttention: batch index mismatch" + assert (seq_coords[:, 1:].max(dim=0).values - seq_coords[:, 1:].min(dim=0).values < window_size).all(), \ + f"SparseWindowedScaledDotProductSelfAttention: window size exceeded" + start += seq_lens[i] + + if all([seq_len == window_size for seq_len in seq_lens]): + B = len(seq_lens) + N = window_size + qkv_feats = qkv_feats.reshape(B, N, 3, H, C) + if ATTN == 'xformers': + q, k, v = qkv_feats.unbind(dim=2) # [B, N, H, C] + out = xops.memory_efficient_attention(q, k, v) # [B, N, H, C] + elif ATTN == 'flash_attn': + out = flash_attn.flash_attn_qkvpacked_func(qkv_feats) # [B, N, H, C] + else: + raise ValueError(f"Unknown attention module: {ATTN}") + out = out.reshape(B * N, H, C) # [M, H, C] + else: + if ATTN == 'xformers': + q, k, v = qkv_feats.unbind(dim=1) # [M, H, C] + q = q.unsqueeze(0) # [1, M, H, C] + k = k.unsqueeze(0) # [1, M, H, C] + v = v.unsqueeze(0) # [1, M, H, C] + mask = xops.fmha.BlockDiagonalMask.from_seqlens(seq_lens) + out = xops.memory_efficient_attention(q, k, v, mask)[0] # [M, H, C] + elif ATTN == 'flash_attn': + cu_seqlens = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(seq_lens), dim=0)], dim=0) \ + .to(qkv.device).int() + out = flash_attn.flash_attn_varlen_qkvpacked_func(qkv_feats, cu_seqlens, max(seq_lens)) # [M, H, C] + + out = out[bwd_indices] # [T, H, C] + + if DEBUG: + qkv_coords = qkv_coords[bwd_indices] + assert torch.equal(qkv_coords, qkv.coords), "SparseWindowedScaledDotProductSelfAttention: coordinate mismatch" + + return qkv.replace(out) diff --git a/trellis/modules/sparse/basic.py b/trellis/modules/sparse/basic.py new file mode 100644 index 0000000000000000000000000000000000000000..8837f44052f6d573d09e3bfb897e659e10516bb5 --- /dev/null +++ b/trellis/modules/sparse/basic.py @@ -0,0 +1,459 @@ +from typing import * +import torch +import torch.nn as nn +from . import BACKEND, DEBUG +SparseTensorData = None # Lazy import + + +__all__ = [ + 'SparseTensor', + 'sparse_batch_broadcast', + 'sparse_batch_op', + 'sparse_cat', + 'sparse_unbind', +] + + +class SparseTensor: + """ + Sparse tensor with support for both torchsparse and spconv backends. + + Parameters: + - feats (torch.Tensor): Features of the sparse tensor. + - coords (torch.Tensor): Coordinates of the sparse tensor. + - shape (torch.Size): Shape of the sparse tensor. + - layout (List[slice]): Layout of the sparse tensor for each batch + - data (SparseTensorData): Sparse tensor data used for convolusion + + NOTE: + - Data corresponding to a same batch should be contiguous. + - Coords should be in [0, 1023] + """ + @overload + def __init__(self, feats: torch.Tensor, coords: torch.Tensor, shape: Optional[torch.Size] = None, layout: Optional[List[slice]] = None, **kwargs): ... + + @overload + def __init__(self, data, shape: Optional[torch.Size] = None, layout: Optional[List[slice]] = None, **kwargs): ... + + def __init__(self, *args, **kwargs): + # Lazy import of sparse tensor backend + global SparseTensorData + if SparseTensorData is None: + import importlib + if BACKEND == 'torchsparse': + SparseTensorData = importlib.import_module('torchsparse').SparseTensor + elif BACKEND == 'spconv': + SparseTensorData = importlib.import_module('spconv.pytorch').SparseConvTensor + + method_id = 0 + if len(args) != 0: + method_id = 0 if isinstance(args[0], torch.Tensor) else 1 + else: + method_id = 1 if 'data' in kwargs else 0 + + if method_id == 0: + feats, coords, shape, layout = args + (None,) * (4 - len(args)) + if 'feats' in kwargs: + feats = kwargs['feats'] + del kwargs['feats'] + if 'coords' in kwargs: + coords = kwargs['coords'] + del kwargs['coords'] + if 'shape' in kwargs: + shape = kwargs['shape'] + del kwargs['shape'] + if 'layout' in kwargs: + layout = kwargs['layout'] + del kwargs['layout'] + + if shape is None: + shape = self.__cal_shape(feats, coords) + if layout is None: + layout = self.__cal_layout(coords, shape[0]) + if BACKEND == 'torchsparse': + self.data = SparseTensorData(feats, coords, **kwargs) + elif BACKEND == 'spconv': + spatial_shape = list(coords.max(0)[0] + 1)[1:] + self.data = SparseTensorData(feats.reshape(feats.shape[0], -1), coords, spatial_shape, shape[0], **kwargs) + self.data._features = feats + elif method_id == 1: + data, shape, layout = args + (None,) * (3 - len(args)) + if 'data' in kwargs: + data = kwargs['data'] + del kwargs['data'] + if 'shape' in kwargs: + shape = kwargs['shape'] + del kwargs['shape'] + if 'layout' in kwargs: + layout = kwargs['layout'] + del kwargs['layout'] + + self.data = data + if shape is None: + shape = self.__cal_shape(self.feats, self.coords) + if layout is None: + layout = self.__cal_layout(self.coords, shape[0]) + + self._shape = shape + self._layout = layout + self._scale = kwargs.get('scale', (1, 1, 1)) + self._spatial_cache = kwargs.get('spatial_cache', {}) + + if DEBUG: + try: + assert self.feats.shape[0] == self.coords.shape[0], f"Invalid feats shape: {self.feats.shape}, coords shape: {self.coords.shape}" + assert self.shape == self.__cal_shape(self.feats, self.coords), f"Invalid shape: {self.shape}" + assert self.layout == self.__cal_layout(self.coords, self.shape[0]), f"Invalid layout: {self.layout}" + for i in range(self.shape[0]): + assert torch.all(self.coords[self.layout[i], 0] == i), f"The data of batch {i} is not contiguous" + except Exception as e: + print('Debugging information:') + print(f"- Shape: {self.shape}") + print(f"- Layout: {self.layout}") + print(f"- Scale: {self._scale}") + print(f"- Coords: {self.coords}") + raise e + + def __cal_shape(self, feats, coords): + shape = [] + shape.append(coords[:, 0].max().item() + 1) + shape.extend([*feats.shape[1:]]) + return torch.Size(shape) + + def __cal_layout(self, coords, batch_size): + seq_len = torch.bincount(coords[:, 0], minlength=batch_size) + offset = torch.cumsum(seq_len, dim=0) + layout = [slice((offset[i] - seq_len[i]).item(), offset[i].item()) for i in range(batch_size)] + return layout + + @property + def shape(self) -> torch.Size: + return self._shape + + def dim(self) -> int: + return len(self.shape) + + @property + def layout(self) -> List[slice]: + return self._layout + + @property + def feats(self) -> torch.Tensor: + if BACKEND == 'torchsparse': + return self.data.F + elif BACKEND == 'spconv': + return self.data.features + + @feats.setter + def feats(self, value: torch.Tensor): + if BACKEND == 'torchsparse': + self.data.F = value + elif BACKEND == 'spconv': + self.data.features = value + + @property + def coords(self) -> torch.Tensor: + if BACKEND == 'torchsparse': + return self.data.C + elif BACKEND == 'spconv': + return self.data.indices + + @coords.setter + def coords(self, value: torch.Tensor): + if BACKEND == 'torchsparse': + self.data.C = value + elif BACKEND == 'spconv': + self.data.indices = value + + @property + def dtype(self): + return self.feats.dtype + + @property + def device(self): + return self.feats.device + + @overload + def to(self, dtype: torch.dtype) -> 'SparseTensor': ... + + @overload + def to(self, device: Optional[Union[str, torch.device]] = None, dtype: Optional[torch.dtype] = None) -> 'SparseTensor': ... + + def to(self, *args, **kwargs) -> 'SparseTensor': + device = None + dtype = None + if len(args) == 2: + device, dtype = args + elif len(args) == 1: + if isinstance(args[0], torch.dtype): + dtype = args[0] + else: + device = args[0] + if 'dtype' in kwargs: + assert dtype is None, "to() received multiple values for argument 'dtype'" + dtype = kwargs['dtype'] + if 'device' in kwargs: + assert device is None, "to() received multiple values for argument 'device'" + device = kwargs['device'] + + new_feats = self.feats.to(device=device, dtype=dtype) + new_coords = self.coords.to(device=device) + return self.replace(new_feats, new_coords) + + def type(self, dtype): + new_feats = self.feats.type(dtype) + return self.replace(new_feats) + + def cpu(self) -> 'SparseTensor': + new_feats = self.feats.cpu() + new_coords = self.coords.cpu() + return self.replace(new_feats, new_coords) + + def cuda(self) -> 'SparseTensor': + new_feats = self.feats.cuda() + new_coords = self.coords.cuda() + return self.replace(new_feats, new_coords) + + def half(self) -> 'SparseTensor': + new_feats = self.feats.half() + return self.replace(new_feats) + + def float(self) -> 'SparseTensor': + new_feats = self.feats.float() + return self.replace(new_feats) + + def detach(self) -> 'SparseTensor': + new_coords = self.coords.detach() + new_feats = self.feats.detach() + return self.replace(new_feats, new_coords) + + def dense(self) -> torch.Tensor: + if BACKEND == 'torchsparse': + return self.data.dense() + elif BACKEND == 'spconv': + return self.data.dense() + + def reshape(self, *shape) -> 'SparseTensor': + new_feats = self.feats.reshape(self.feats.shape[0], *shape) + return self.replace(new_feats) + + def unbind(self, dim: int) -> List['SparseTensor']: + return sparse_unbind(self, dim) + + def replace(self, feats: torch.Tensor, coords: Optional[torch.Tensor] = None) -> 'SparseTensor': + new_shape = [self.shape[0]] + new_shape.extend(feats.shape[1:]) + if BACKEND == 'torchsparse': + new_data = SparseTensorData( + feats=feats, + coords=self.data.coords if coords is None else coords, + stride=self.data.stride, + spatial_range=self.data.spatial_range, + ) + new_data._caches = self.data._caches + elif BACKEND == 'spconv': + new_data = SparseTensorData( + self.data.features.reshape(self.data.features.shape[0], -1), + self.data.indices, + self.data.spatial_shape, + self.data.batch_size, + self.data.grid, + self.data.voxel_num, + self.data.indice_dict + ) + new_data._features = feats + new_data.benchmark = self.data.benchmark + new_data.benchmark_record = self.data.benchmark_record + new_data.thrust_allocator = self.data.thrust_allocator + new_data._timer = self.data._timer + new_data.force_algo = self.data.force_algo + new_data.int8_scale = self.data.int8_scale + if coords is not None: + new_data.indices = coords + new_tensor = SparseTensor(new_data, shape=torch.Size(new_shape), layout=self.layout, scale=self._scale, spatial_cache=self._spatial_cache) + return new_tensor + + @staticmethod + def full(aabb, dim, value, dtype=torch.float32, device=None) -> 'SparseTensor': + N, C = dim + x = torch.arange(aabb[0], aabb[3] + 1) + y = torch.arange(aabb[1], aabb[4] + 1) + z = torch.arange(aabb[2], aabb[5] + 1) + coords = torch.stack(torch.meshgrid(x, y, z, indexing='ij'), dim=-1).reshape(-1, 3) + coords = torch.cat([ + torch.arange(N).view(-1, 1).repeat(1, coords.shape[0]).view(-1, 1), + coords.repeat(N, 1), + ], dim=1).to(dtype=torch.int32, device=device) + feats = torch.full((coords.shape[0], C), value, dtype=dtype, device=device) + return SparseTensor(feats=feats, coords=coords) + + def __merge_sparse_cache(self, other: 'SparseTensor') -> dict: + new_cache = {} + for k in set(list(self._spatial_cache.keys()) + list(other._spatial_cache.keys())): + if k in self._spatial_cache: + new_cache[k] = self._spatial_cache[k] + if k in other._spatial_cache: + if k not in new_cache: + new_cache[k] = other._spatial_cache[k] + else: + new_cache[k].update(other._spatial_cache[k]) + return new_cache + + def __neg__(self) -> 'SparseTensor': + return self.replace(-self.feats) + + def __elemwise__(self, other: Union[torch.Tensor, 'SparseTensor'], op: callable) -> 'SparseTensor': + if isinstance(other, torch.Tensor): + try: + other = torch.broadcast_to(other, self.shape) + other = sparse_batch_broadcast(self, other) + except: + pass + if isinstance(other, SparseTensor): + other = other.feats + new_feats = op(self.feats, other) + new_tensor = self.replace(new_feats) + if isinstance(other, SparseTensor): + new_tensor._spatial_cache = self.__merge_sparse_cache(other) + return new_tensor + + def __add__(self, other: Union[torch.Tensor, 'SparseTensor', float]) -> 'SparseTensor': + return self.__elemwise__(other, torch.add) + + def __radd__(self, other: Union[torch.Tensor, 'SparseTensor', float]) -> 'SparseTensor': + return self.__elemwise__(other, torch.add) + + def __sub__(self, other: Union[torch.Tensor, 'SparseTensor', float]) -> 'SparseTensor': + return self.__elemwise__(other, torch.sub) + + def __rsub__(self, other: Union[torch.Tensor, 'SparseTensor', float]) -> 'SparseTensor': + return self.__elemwise__(other, lambda x, y: torch.sub(y, x)) + + def __mul__(self, other: Union[torch.Tensor, 'SparseTensor', float]) -> 'SparseTensor': + return self.__elemwise__(other, torch.mul) + + def __rmul__(self, other: Union[torch.Tensor, 'SparseTensor', float]) -> 'SparseTensor': + return self.__elemwise__(other, torch.mul) + + def __truediv__(self, other: Union[torch.Tensor, 'SparseTensor', float]) -> 'SparseTensor': + return self.__elemwise__(other, torch.div) + + def __rtruediv__(self, other: Union[torch.Tensor, 'SparseTensor', float]) -> 'SparseTensor': + return self.__elemwise__(other, lambda x, y: torch.div(y, x)) + + def __getitem__(self, idx): + if isinstance(idx, int): + idx = [idx] + elif isinstance(idx, slice): + idx = range(*idx.indices(self.shape[0])) + elif isinstance(idx, torch.Tensor): + if idx.dtype == torch.bool: + assert idx.shape == (self.shape[0],), f"Invalid index shape: {idx.shape}" + idx = idx.nonzero().squeeze(1) + elif idx.dtype in [torch.int32, torch.int64]: + assert len(idx.shape) == 1, f"Invalid index shape: {idx.shape}" + else: + raise ValueError(f"Unknown index type: {idx.dtype}") + else: + raise ValueError(f"Unknown index type: {type(idx)}") + + coords = [] + feats = [] + for new_idx, old_idx in enumerate(idx): + coords.append(self.coords[self.layout[old_idx]].clone()) + coords[-1][:, 0] = new_idx + feats.append(self.feats[self.layout[old_idx]]) + coords = torch.cat(coords, dim=0).contiguous() + feats = torch.cat(feats, dim=0).contiguous() + return SparseTensor(feats=feats, coords=coords) + + def register_spatial_cache(self, key, value) -> None: + """ + Register a spatial cache. + The spatial cache can be any thing you want to cache. + The registery and retrieval of the cache is based on current scale. + """ + scale_key = str(self._scale) + if scale_key not in self._spatial_cache: + self._spatial_cache[scale_key] = {} + self._spatial_cache[scale_key][key] = value + + def get_spatial_cache(self, key=None): + """ + Get a spatial cache. + """ + scale_key = str(self._scale) + cur_scale_cache = self._spatial_cache.get(scale_key, {}) + if key is None: + return cur_scale_cache + return cur_scale_cache.get(key, None) + + +def sparse_batch_broadcast(input: SparseTensor, other: torch.Tensor) -> torch.Tensor: + """ + Broadcast a 1D tensor to a sparse tensor along the batch dimension then perform an operation. + + Args: + input (torch.Tensor): 1D tensor to broadcast. + target (SparseTensor): Sparse tensor to broadcast to. + op (callable): Operation to perform after broadcasting. Defaults to torch.add. + """ + coords, feats = input.coords, input.feats + broadcasted = torch.zeros_like(feats) + for k in range(input.shape[0]): + broadcasted[input.layout[k]] = other[k] + return broadcasted + + +def sparse_batch_op(input: SparseTensor, other: torch.Tensor, op: callable = torch.add) -> SparseTensor: + """ + Broadcast a 1D tensor to a sparse tensor along the batch dimension then perform an operation. + + Args: + input (torch.Tensor): 1D tensor to broadcast. + target (SparseTensor): Sparse tensor to broadcast to. + op (callable): Operation to perform after broadcasting. Defaults to torch.add. + """ + return input.replace(op(input.feats, sparse_batch_broadcast(input, other))) + + +def sparse_cat(inputs: List[SparseTensor], dim: int = 0) -> SparseTensor: + """ + Concatenate a list of sparse tensors. + + Args: + inputs (List[SparseTensor]): List of sparse tensors to concatenate. + """ + if dim == 0: + start = 0 + coords = [] + for input in inputs: + coords.append(input.coords.clone()) + coords[-1][:, 0] += start + start += input.shape[0] + coords = torch.cat(coords, dim=0) + feats = torch.cat([input.feats for input in inputs], dim=0) + output = SparseTensor( + coords=coords, + feats=feats, + ) + else: + feats = torch.cat([input.feats for input in inputs], dim=dim) + output = inputs[0].replace(feats) + + return output + + +def sparse_unbind(input: SparseTensor, dim: int) -> List[SparseTensor]: + """ + Unbind a sparse tensor along a dimension. + + Args: + input (SparseTensor): Sparse tensor to unbind. + dim (int): Dimension to unbind. + """ + if dim == 0: + return [input[i] for i in range(input.shape[0])] + else: + feats = input.feats.unbind(dim) + return [input.replace(f) for f in feats] diff --git a/trellis/modules/sparse/conv/__init__.py b/trellis/modules/sparse/conv/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..de36b35899cbfcdbed8f40b69712e80e2c8e449e --- /dev/null +++ b/trellis/modules/sparse/conv/__init__.py @@ -0,0 +1,21 @@ +from .. import BACKEND + + +SPCONV_ALGO = 'auto' # 'auto', 'implicit_gemm', 'native' + +def __from_env(): + import os + + global SPCONV_ALGO + env_spconv_algo = os.environ.get('SPCONV_ALGO') + if env_spconv_algo is not None and env_spconv_algo in ['auto', 'implicit_gemm', 'native']: + SPCONV_ALGO = env_spconv_algo + print(f"[SPARSE][CONV] spconv algo: {SPCONV_ALGO}") + + +__from_env() + +if BACKEND == 'torchsparse': + from .conv_torchsparse import * +elif BACKEND == 'spconv': + from .conv_spconv import * diff --git a/trellis/modules/sparse/conv/__pycache__/__init__.cpython-310.pyc b/trellis/modules/sparse/conv/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4c28690836369c624e802ccd4261d1003088521c Binary files /dev/null and b/trellis/modules/sparse/conv/__pycache__/__init__.cpython-310.pyc differ diff --git a/trellis/modules/sparse/conv/__pycache__/conv_spconv.cpython-310.pyc b/trellis/modules/sparse/conv/__pycache__/conv_spconv.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..de10562e71f2d8868f0f6d6f46b5418700c25a89 Binary files /dev/null and b/trellis/modules/sparse/conv/__pycache__/conv_spconv.cpython-310.pyc differ diff --git a/trellis/modules/sparse/conv/conv_spconv.py b/trellis/modules/sparse/conv/conv_spconv.py new file mode 100644 index 0000000000000000000000000000000000000000..524bcd4a845b2d6bd090a5f74bc8859978727528 --- /dev/null +++ b/trellis/modules/sparse/conv/conv_spconv.py @@ -0,0 +1,80 @@ +import torch +import torch.nn as nn +from .. import SparseTensor +from .. import DEBUG +from . import SPCONV_ALGO + +class SparseConv3d(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, padding=None, bias=True, indice_key=None): + super(SparseConv3d, self).__init__() + if 'spconv' not in globals(): + import spconv.pytorch as spconv + algo = None + if SPCONV_ALGO == 'native': + algo = spconv.ConvAlgo.Native + elif SPCONV_ALGO == 'implicit_gemm': + algo = spconv.ConvAlgo.MaskImplicitGemm + if stride == 1 and (padding is None): + self.conv = spconv.SubMConv3d(in_channels, out_channels, kernel_size, dilation=dilation, bias=bias, indice_key=indice_key, algo=algo) + else: + self.conv = spconv.SparseConv3d(in_channels, out_channels, kernel_size, stride=stride, dilation=dilation, padding=padding, bias=bias, indice_key=indice_key, algo=algo) + self.stride = tuple(stride) if isinstance(stride, (list, tuple)) else (stride, stride, stride) + self.padding = padding + + def forward(self, x: SparseTensor) -> SparseTensor: + spatial_changed = any(s != 1 for s in self.stride) or (self.padding is not None) + new_data = self.conv(x.data) + new_shape = [x.shape[0], self.conv.out_channels] + new_layout = None if spatial_changed else x.layout + + if spatial_changed and (x.shape[0] != 1): + # spconv was non-1 stride will break the contiguous of the output tensor, sort by the coords + fwd = new_data.indices[:, 0].argsort() + bwd = torch.zeros_like(fwd).scatter_(0, fwd, torch.arange(fwd.shape[0], device=fwd.device)) + sorted_feats = new_data.features[fwd] + sorted_coords = new_data.indices[fwd] + unsorted_data = new_data + new_data = spconv.SparseConvTensor(sorted_feats, sorted_coords, unsorted_data.spatial_shape, unsorted_data.batch_size) # type: ignore + + out = SparseTensor( + new_data, shape=torch.Size(new_shape), layout=new_layout, + scale=tuple([s * stride for s, stride in zip(x._scale, self.stride)]), + spatial_cache=x._spatial_cache, + ) + + if spatial_changed and (x.shape[0] != 1): + out.register_spatial_cache(f'conv_{self.stride}_unsorted_data', unsorted_data) + out.register_spatial_cache(f'conv_{self.stride}_sort_bwd', bwd) + + return out + + +class SparseInverseConv3d(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, bias=True, indice_key=None): + super(SparseInverseConv3d, self).__init__() + if 'spconv' not in globals(): + import spconv.pytorch as spconv + self.conv = spconv.SparseInverseConv3d(in_channels, out_channels, kernel_size, bias=bias, indice_key=indice_key) + self.stride = tuple(stride) if isinstance(stride, (list, tuple)) else (stride, stride, stride) + + def forward(self, x: SparseTensor) -> SparseTensor: + spatial_changed = any(s != 1 for s in self.stride) + if spatial_changed: + # recover the original spconv order + data = x.get_spatial_cache(f'conv_{self.stride}_unsorted_data') + bwd = x.get_spatial_cache(f'conv_{self.stride}_sort_bwd') + data = data.replace_feature(x.feats[bwd]) + if DEBUG: + assert torch.equal(data.indices, x.coords[bwd]), 'Recover the original order failed' + else: + data = x.data + + new_data = self.conv(data) + new_shape = [x.shape[0], self.conv.out_channels] + new_layout = None if spatial_changed else x.layout + out = SparseTensor( + new_data, shape=torch.Size(new_shape), layout=new_layout, + scale=tuple([s // stride for s, stride in zip(x._scale, self.stride)]), + spatial_cache=x._spatial_cache, + ) + return out diff --git a/trellis/modules/sparse/conv/conv_torchsparse.py b/trellis/modules/sparse/conv/conv_torchsparse.py new file mode 100644 index 0000000000000000000000000000000000000000..1d612582d4b31f90aca3c00b693bbbc2550dc62c --- /dev/null +++ b/trellis/modules/sparse/conv/conv_torchsparse.py @@ -0,0 +1,38 @@ +import torch +import torch.nn as nn +from .. import SparseTensor + + +class SparseConv3d(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, bias=True, indice_key=None): + super(SparseConv3d, self).__init__() + if 'torchsparse' not in globals(): + import torchsparse + self.conv = torchsparse.nn.Conv3d(in_channels, out_channels, kernel_size, stride, 0, dilation, bias) + + def forward(self, x: SparseTensor) -> SparseTensor: + out = self.conv(x.data) + new_shape = [x.shape[0], self.conv.out_channels] + out = SparseTensor(out, shape=torch.Size(new_shape), layout=x.layout if all(s == 1 for s in self.conv.stride) else None) + out._spatial_cache = x._spatial_cache + out._scale = tuple([s * stride for s, stride in zip(x._scale, self.conv.stride)]) + return out + + +class SparseInverseConv3d(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, bias=True, indice_key=None): + super(SparseInverseConv3d, self).__init__() + if 'torchsparse' not in globals(): + import torchsparse + self.conv = torchsparse.nn.Conv3d(in_channels, out_channels, kernel_size, stride, 0, dilation, bias, transposed=True) + + def forward(self, x: SparseTensor) -> SparseTensor: + out = self.conv(x.data) + new_shape = [x.shape[0], self.conv.out_channels] + out = SparseTensor(out, shape=torch.Size(new_shape), layout=x.layout if all(s == 1 for s in self.conv.stride) else None) + out._spatial_cache = x._spatial_cache + out._scale = tuple([s // stride for s, stride in zip(x._scale, self.conv.stride)]) + return out + + + diff --git a/trellis/modules/sparse/linear.py b/trellis/modules/sparse/linear.py new file mode 100644 index 0000000000000000000000000000000000000000..a854e77ce87d1a190b9730d91f363a821ff250bd --- /dev/null +++ b/trellis/modules/sparse/linear.py @@ -0,0 +1,15 @@ +import torch +import torch.nn as nn +from . import SparseTensor + +__all__ = [ + 'SparseLinear' +] + + +class SparseLinear(nn.Linear): + def __init__(self, in_features, out_features, bias=True): + super(SparseLinear, self).__init__(in_features, out_features, bias) + + def forward(self, input: SparseTensor) -> SparseTensor: + return input.replace(super().forward(input.feats)) diff --git a/trellis/modules/sparse/nonlinearity.py b/trellis/modules/sparse/nonlinearity.py new file mode 100644 index 0000000000000000000000000000000000000000..f200098dd82011a3aeee1688b9eb17018fa78295 --- /dev/null +++ b/trellis/modules/sparse/nonlinearity.py @@ -0,0 +1,35 @@ +import torch +import torch.nn as nn +from . import SparseTensor + +__all__ = [ + 'SparseReLU', + 'SparseSiLU', + 'SparseGELU', + 'SparseActivation' +] + + +class SparseReLU(nn.ReLU): + def forward(self, input: SparseTensor) -> SparseTensor: + return input.replace(super().forward(input.feats)) + + +class SparseSiLU(nn.SiLU): + def forward(self, input: SparseTensor) -> SparseTensor: + return input.replace(super().forward(input.feats)) + + +class SparseGELU(nn.GELU): + def forward(self, input: SparseTensor) -> SparseTensor: + return input.replace(super().forward(input.feats)) + + +class SparseActivation(nn.Module): + def __init__(self, activation: nn.Module): + super().__init__() + self.activation = activation + + def forward(self, input: SparseTensor) -> SparseTensor: + return input.replace(self.activation(input.feats)) + diff --git a/trellis/modules/sparse/norm.py b/trellis/modules/sparse/norm.py new file mode 100644 index 0000000000000000000000000000000000000000..6b38a36682c098210000dc31d68ddc31ccd2929d --- /dev/null +++ b/trellis/modules/sparse/norm.py @@ -0,0 +1,58 @@ +import torch +import torch.nn as nn +from . import SparseTensor +from . import DEBUG + +__all__ = [ + 'SparseGroupNorm', + 'SparseLayerNorm', + 'SparseGroupNorm32', + 'SparseLayerNorm32', +] + + +class SparseGroupNorm(nn.GroupNorm): + def __init__(self, num_groups, num_channels, eps=1e-5, affine=True): + super(SparseGroupNorm, self).__init__(num_groups, num_channels, eps, affine) + + def forward(self, input: SparseTensor) -> SparseTensor: + nfeats = torch.zeros_like(input.feats) + for k in range(input.shape[0]): + if DEBUG: + assert (input.coords[input.layout[k], 0] == k).all(), f"SparseGroupNorm: batch index mismatch" + bfeats = input.feats[input.layout[k]] + bfeats = bfeats.permute(1, 0).reshape(1, input.shape[1], -1) + bfeats = super().forward(bfeats) + bfeats = bfeats.reshape(input.shape[1], -1).permute(1, 0) + nfeats[input.layout[k]] = bfeats + return input.replace(nfeats) + + +class SparseLayerNorm(nn.LayerNorm): + def __init__(self, normalized_shape, eps=1e-5, elementwise_affine=True): + super(SparseLayerNorm, self).__init__(normalized_shape, eps, elementwise_affine) + + def forward(self, input: SparseTensor) -> SparseTensor: + nfeats = torch.zeros_like(input.feats) + for k in range(input.shape[0]): + bfeats = input.feats[input.layout[k]] + bfeats = bfeats.permute(1, 0).reshape(1, input.shape[1], -1) + bfeats = super().forward(bfeats) + bfeats = bfeats.reshape(input.shape[1], -1).permute(1, 0) + nfeats[input.layout[k]] = bfeats + return input.replace(nfeats) + + +class SparseGroupNorm32(SparseGroupNorm): + """ + A GroupNorm layer that converts to float32 before the forward pass. + """ + def forward(self, x: SparseTensor) -> SparseTensor: + return super().forward(x.float()).type(x.dtype) + +class SparseLayerNorm32(SparseLayerNorm): + """ + A LayerNorm layer that converts to float32 before the forward pass. + """ + def forward(self, x: SparseTensor) -> SparseTensor: + return super().forward(x.float()).type(x.dtype) diff --git a/trellis/modules/sparse/spatial.py b/trellis/modules/sparse/spatial.py new file mode 100644 index 0000000000000000000000000000000000000000..ad7121473f335b307e2f7ea5f05c964d3aec0440 --- /dev/null +++ b/trellis/modules/sparse/spatial.py @@ -0,0 +1,110 @@ +from typing import * +import torch +import torch.nn as nn +from . import SparseTensor + +__all__ = [ + 'SparseDownsample', + 'SparseUpsample', + 'SparseSubdivide' +] + + +class SparseDownsample(nn.Module): + """ + Downsample a sparse tensor by a factor of `factor`. + Implemented as average pooling. + """ + def __init__(self, factor: Union[int, Tuple[int, ...], List[int]]): + super(SparseDownsample, self).__init__() + self.factor = tuple(factor) if isinstance(factor, (list, tuple)) else factor + + def forward(self, input: SparseTensor) -> SparseTensor: + DIM = input.coords.shape[-1] - 1 + factor = self.factor if isinstance(self.factor, tuple) else (self.factor,) * DIM + assert DIM == len(factor), 'Input coordinates must have the same dimension as the downsample factor.' + + coord = list(input.coords.unbind(dim=-1)) + for i, f in enumerate(factor): + coord[i+1] = coord[i+1] // f + + MAX = [coord[i+1].max().item() + 1 for i in range(DIM)] + OFFSET = torch.cumprod(torch.tensor(MAX[::-1]), 0).tolist()[::-1] + [1] + code = sum([c * o for c, o in zip(coord, OFFSET)]) + code, idx = code.unique(return_inverse=True) + + new_feats = torch.scatter_reduce( + torch.zeros(code.shape[0], input.feats.shape[1], device=input.feats.device, dtype=input.feats.dtype), + dim=0, + index=idx.unsqueeze(1).expand(-1, input.feats.shape[1]), + src=input.feats, + reduce='mean' + ) + new_coords = torch.stack( + [code // OFFSET[0]] + + [(code // OFFSET[i+1]) % MAX[i] for i in range(DIM)], + dim=-1 + ) + out = SparseTensor(new_feats, new_coords, input.shape,) + out._scale = tuple([s // f for s, f in zip(input._scale, factor)]) + out._spatial_cache = input._spatial_cache + + out.register_spatial_cache(f'upsample_{factor}_coords', input.coords) + out.register_spatial_cache(f'upsample_{factor}_layout', input.layout) + out.register_spatial_cache(f'upsample_{factor}_idx', idx) + + return out + + +class SparseUpsample(nn.Module): + """ + Upsample a sparse tensor by a factor of `factor`. + Implemented as nearest neighbor interpolation. + """ + def __init__(self, factor: Union[int, Tuple[int, int, int], List[int]]): + super(SparseUpsample, self).__init__() + self.factor = tuple(factor) if isinstance(factor, (list, tuple)) else factor + + def forward(self, input: SparseTensor) -> SparseTensor: + DIM = input.coords.shape[-1] - 1 + factor = self.factor if isinstance(self.factor, tuple) else (self.factor,) * DIM + assert DIM == len(factor), 'Input coordinates must have the same dimension as the upsample factor.' + + new_coords = input.get_spatial_cache(f'upsample_{factor}_coords') + new_layout = input.get_spatial_cache(f'upsample_{factor}_layout') + idx = input.get_spatial_cache(f'upsample_{factor}_idx') + if any([x is None for x in [new_coords, new_layout, idx]]): + raise ValueError('Upsample cache not found. SparseUpsample must be paired with SparseDownsample.') + new_feats = input.feats[idx] + out = SparseTensor(new_feats, new_coords, input.shape, new_layout) + out._scale = tuple([s * f for s, f in zip(input._scale, factor)]) + out._spatial_cache = input._spatial_cache + return out + +class SparseSubdivide(nn.Module): + """ + Upsample a sparse tensor by a factor of `factor`. + Implemented as nearest neighbor interpolation. + """ + def __init__(self): + super(SparseSubdivide, self).__init__() + + def forward(self, input: SparseTensor) -> SparseTensor: + DIM = input.coords.shape[-1] - 1 + # upsample scale=2^DIM + n_cube = torch.ones([2] * DIM, device=input.device, dtype=torch.int) + n_coords = torch.nonzero(n_cube) + n_coords = torch.cat([torch.zeros_like(n_coords[:, :1]), n_coords], dim=-1) + factor = n_coords.shape[0] + assert factor == 2 ** DIM + # print(n_coords.shape) + new_coords = input.coords.clone() + new_coords[:, 1:] *= 2 + new_coords = new_coords.unsqueeze(1) + n_coords.unsqueeze(0).to(new_coords.dtype) + + new_feats = input.feats.unsqueeze(1).expand(input.feats.shape[0], factor, *input.feats.shape[1:]) + out = SparseTensor(new_feats.flatten(0, 1), new_coords.flatten(0, 1), input.shape) + out._scale = input._scale * 2 + out._spatial_cache = input._spatial_cache + return out + diff --git a/trellis/modules/sparse/transformer/__init__.py b/trellis/modules/sparse/transformer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b08b0d4e5bc24060a2cdc8df75d06dce122972bd --- /dev/null +++ b/trellis/modules/sparse/transformer/__init__.py @@ -0,0 +1,2 @@ +from .blocks import * +from .modulated import * \ No newline at end of file diff --git a/trellis/modules/sparse/transformer/__pycache__/__init__.cpython-310.pyc b/trellis/modules/sparse/transformer/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b6160f08743d3a9c228bab36d007892740cafabb Binary files /dev/null and b/trellis/modules/sparse/transformer/__pycache__/__init__.cpython-310.pyc differ diff --git a/trellis/modules/sparse/transformer/__pycache__/blocks.cpython-310.pyc b/trellis/modules/sparse/transformer/__pycache__/blocks.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..62021de1246d666327d9d3f5dfc3bb4f68ddd168 Binary files /dev/null and b/trellis/modules/sparse/transformer/__pycache__/blocks.cpython-310.pyc differ diff --git a/trellis/modules/sparse/transformer/__pycache__/modulated.cpython-310.pyc b/trellis/modules/sparse/transformer/__pycache__/modulated.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..107fc8f1ed75174d48e8f9641a8691e980d621c7 Binary files /dev/null and b/trellis/modules/sparse/transformer/__pycache__/modulated.cpython-310.pyc differ diff --git a/trellis/modules/sparse/transformer/blocks.py b/trellis/modules/sparse/transformer/blocks.py new file mode 100644 index 0000000000000000000000000000000000000000..9d037a49bf83e1c2dfb2f8c4b23d2e9d6c51e9f0 --- /dev/null +++ b/trellis/modules/sparse/transformer/blocks.py @@ -0,0 +1,151 @@ +from typing import * +import torch +import torch.nn as nn +from ..basic import SparseTensor +from ..linear import SparseLinear +from ..nonlinearity import SparseGELU +from ..attention import SparseMultiHeadAttention, SerializeMode +from ...norm import LayerNorm32 + + +class SparseFeedForwardNet(nn.Module): + def __init__(self, channels: int, mlp_ratio: float = 4.0): + super().__init__() + self.mlp = nn.Sequential( + SparseLinear(channels, int(channels * mlp_ratio)), + SparseGELU(approximate="tanh"), + SparseLinear(int(channels * mlp_ratio), channels), + ) + + def forward(self, x: SparseTensor) -> SparseTensor: + return self.mlp(x) + + +class SparseTransformerBlock(nn.Module): + """ + Sparse Transformer block (MSA + FFN). + """ + def __init__( + self, + channels: int, + num_heads: int, + mlp_ratio: float = 4.0, + attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "full", + window_size: Optional[int] = None, + shift_sequence: Optional[int] = None, + shift_window: Optional[Tuple[int, int, int]] = None, + serialize_mode: Optional[SerializeMode] = None, + use_checkpoint: bool = False, + use_rope: bool = False, + qk_rms_norm: bool = False, + qkv_bias: bool = True, + ln_affine: bool = False, + ): + super().__init__() + self.use_checkpoint = use_checkpoint + self.norm1 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6) + self.norm2 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6) + self.attn = SparseMultiHeadAttention( + channels, + num_heads=num_heads, + attn_mode=attn_mode, + window_size=window_size, + shift_sequence=shift_sequence, + shift_window=shift_window, + serialize_mode=serialize_mode, + qkv_bias=qkv_bias, + use_rope=use_rope, + qk_rms_norm=qk_rms_norm, + ) + self.mlp = SparseFeedForwardNet( + channels, + mlp_ratio=mlp_ratio, + ) + + def _forward(self, x: SparseTensor) -> SparseTensor: + h = x.replace(self.norm1(x.feats)) + h = self.attn(h) + x = x + h + h = x.replace(self.norm2(x.feats)) + h = self.mlp(h) + x = x + h + return x + + def forward(self, x: SparseTensor) -> SparseTensor: + if self.use_checkpoint: + return torch.utils.checkpoint.checkpoint(self._forward, x, use_reentrant=False) + else: + return self._forward(x) + + +class SparseTransformerCrossBlock(nn.Module): + """ + Sparse Transformer cross-attention block (MSA + MCA + FFN). + """ + def __init__( + self, + channels: int, + ctx_channels: int, + num_heads: int, + mlp_ratio: float = 4.0, + attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "full", + window_size: Optional[int] = None, + shift_sequence: Optional[int] = None, + shift_window: Optional[Tuple[int, int, int]] = None, + serialize_mode: Optional[SerializeMode] = None, + use_checkpoint: bool = False, + use_rope: bool = False, + qk_rms_norm: bool = False, + qk_rms_norm_cross: bool = False, + qkv_bias: bool = True, + ln_affine: bool = False, + ): + super().__init__() + self.use_checkpoint = use_checkpoint + self.norm1 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6) + self.norm2 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6) + self.norm3 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6) + self.self_attn = SparseMultiHeadAttention( + channels, + num_heads=num_heads, + type="self", + attn_mode=attn_mode, + window_size=window_size, + shift_sequence=shift_sequence, + shift_window=shift_window, + serialize_mode=serialize_mode, + qkv_bias=qkv_bias, + use_rope=use_rope, + qk_rms_norm=qk_rms_norm, + ) + self.cross_attn = SparseMultiHeadAttention( + channels, + ctx_channels=ctx_channels, + num_heads=num_heads, + type="cross", + attn_mode="full", + qkv_bias=qkv_bias, + qk_rms_norm=qk_rms_norm_cross, + ) + self.mlp = SparseFeedForwardNet( + channels, + mlp_ratio=mlp_ratio, + ) + + def _forward(self, x: SparseTensor, mod: torch.Tensor, context: torch.Tensor): + h = x.replace(self.norm1(x.feats)) + h = self.self_attn(h) + x = x + h + h = x.replace(self.norm2(x.feats)) + h = self.cross_attn(h, context) + x = x + h + h = x.replace(self.norm3(x.feats)) + h = self.mlp(h) + x = x + h + return x + + def forward(self, x: SparseTensor, context: torch.Tensor): + if self.use_checkpoint: + return torch.utils.checkpoint.checkpoint(self._forward, x, context, use_reentrant=False) + else: + return self._forward(x, context) diff --git a/trellis/modules/sparse/transformer/modulated.py b/trellis/modules/sparse/transformer/modulated.py new file mode 100644 index 0000000000000000000000000000000000000000..4a8416559f39acbed9e5996e9891c97f95c80c8f --- /dev/null +++ b/trellis/modules/sparse/transformer/modulated.py @@ -0,0 +1,166 @@ +from typing import * +import torch +import torch.nn as nn +from ..basic import SparseTensor +from ..attention import SparseMultiHeadAttention, SerializeMode +from ...norm import LayerNorm32 +from .blocks import SparseFeedForwardNet + + +class ModulatedSparseTransformerBlock(nn.Module): + """ + Sparse Transformer block (MSA + FFN) with adaptive layer norm conditioning. + """ + def __init__( + self, + channels: int, + num_heads: int, + mlp_ratio: float = 4.0, + attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "full", + window_size: Optional[int] = None, + shift_sequence: Optional[int] = None, + shift_window: Optional[Tuple[int, int, int]] = None, + serialize_mode: Optional[SerializeMode] = None, + use_checkpoint: bool = False, + use_rope: bool = False, + qk_rms_norm: bool = False, + qkv_bias: bool = True, + share_mod: bool = False, + ): + super().__init__() + self.use_checkpoint = use_checkpoint + self.share_mod = share_mod + self.norm1 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6) + self.norm2 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6) + self.attn = SparseMultiHeadAttention( + channels, + num_heads=num_heads, + attn_mode=attn_mode, + window_size=window_size, + shift_sequence=shift_sequence, + shift_window=shift_window, + serialize_mode=serialize_mode, + qkv_bias=qkv_bias, + use_rope=use_rope, + qk_rms_norm=qk_rms_norm, + ) + self.mlp = SparseFeedForwardNet( + channels, + mlp_ratio=mlp_ratio, + ) + if not share_mod: + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + nn.Linear(channels, 6 * channels, bias=True) + ) + + def _forward(self, x: SparseTensor, mod: torch.Tensor) -> SparseTensor: + if self.share_mod: + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = mod.chunk(6, dim=1) + else: + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(mod).chunk(6, dim=1) + h = x.replace(self.norm1(x.feats)) + h = h * (1 + scale_msa) + shift_msa + h = self.attn(h) + h = h * gate_msa + x = x + h + h = x.replace(self.norm2(x.feats)) + h = h * (1 + scale_mlp) + shift_mlp + h = self.mlp(h) + h = h * gate_mlp + x = x + h + return x + + def forward(self, x: SparseTensor, mod: torch.Tensor) -> SparseTensor: + if self.use_checkpoint: + return torch.utils.checkpoint.checkpoint(self._forward, x, mod, use_reentrant=False) + else: + return self._forward(x, mod) + + +class ModulatedSparseTransformerCrossBlock(nn.Module): + """ + Sparse Transformer cross-attention block (MSA + MCA + FFN) with adaptive layer norm conditioning. + """ + def __init__( + self, + channels: int, + ctx_channels: int, + num_heads: int, + mlp_ratio: float = 4.0, + attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "full", + window_size: Optional[int] = None, + shift_sequence: Optional[int] = None, + shift_window: Optional[Tuple[int, int, int]] = None, + serialize_mode: Optional[SerializeMode] = None, + use_checkpoint: bool = False, + use_rope: bool = False, + qk_rms_norm: bool = False, + qk_rms_norm_cross: bool = False, + qkv_bias: bool = True, + share_mod: bool = False, + + ): + super().__init__() + self.use_checkpoint = use_checkpoint + self.share_mod = share_mod + self.norm1 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6) + self.norm2 = LayerNorm32(channels, elementwise_affine=True, eps=1e-6) + self.norm3 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6) + self.self_attn = SparseMultiHeadAttention( + channels, + num_heads=num_heads, + type="self", + attn_mode=attn_mode, + window_size=window_size, + shift_sequence=shift_sequence, + shift_window=shift_window, + serialize_mode=serialize_mode, + qkv_bias=qkv_bias, + use_rope=use_rope, + qk_rms_norm=qk_rms_norm, + ) + self.cross_attn = SparseMultiHeadAttention( + channels, + ctx_channels=ctx_channels, + num_heads=num_heads, + type="cross", + attn_mode="full", + qkv_bias=qkv_bias, + qk_rms_norm=qk_rms_norm_cross, + ) + self.mlp = SparseFeedForwardNet( + channels, + mlp_ratio=mlp_ratio, + ) + if not share_mod: + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + nn.Linear(channels, 6 * channels, bias=True) + ) + + def _forward(self, x: SparseTensor, mod: torch.Tensor, context: torch.Tensor) -> SparseTensor: + if self.share_mod: + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = mod.chunk(6, dim=1) + else: + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(mod).chunk(6, dim=1) + h = x.replace(self.norm1(x.feats)) + h = h * (1 + scale_msa) + shift_msa + h = self.self_attn(h) + h = h * gate_msa + x = x + h + h = x.replace(self.norm2(x.feats)) + h = self.cross_attn(h, context) + x = x + h + h = x.replace(self.norm3(x.feats)) + h = h * (1 + scale_mlp) + shift_mlp + h = self.mlp(h) + h = h * gate_mlp + x = x + h + return x + + def forward(self, x: SparseTensor, mod: torch.Tensor, context: torch.Tensor) -> SparseTensor: + if self.use_checkpoint: + return torch.utils.checkpoint.checkpoint(self._forward, x, mod, context, use_reentrant=False) + else: + return self._forward(x, mod, context) diff --git a/trellis/modules/spatial.py b/trellis/modules/spatial.py new file mode 100644 index 0000000000000000000000000000000000000000..79e268d36c2ba49b0275744022a1a1e19983dae3 --- /dev/null +++ b/trellis/modules/spatial.py @@ -0,0 +1,48 @@ +import torch + + +def pixel_shuffle_3d(x: torch.Tensor, scale_factor: int) -> torch.Tensor: + """ + 3D pixel shuffle. + """ + B, C, H, W, D = x.shape + C_ = C // scale_factor**3 + x = x.reshape(B, C_, scale_factor, scale_factor, scale_factor, H, W, D) + x = x.permute(0, 1, 5, 2, 6, 3, 7, 4) + x = x.reshape(B, C_, H*scale_factor, W*scale_factor, D*scale_factor) + return x + + +def patchify(x: torch.Tensor, patch_size: int): + """ + Patchify a tensor. + + Args: + x (torch.Tensor): (N, C, *spatial) tensor + patch_size (int): Patch size + """ + DIM = x.dim() - 2 + for d in range(2, DIM + 2): + assert x.shape[d] % patch_size == 0, f"Dimension {d} of input tensor must be divisible by patch size, got {x.shape[d]} and {patch_size}" + + x = x.reshape(*x.shape[:2], *sum([[x.shape[d] // patch_size, patch_size] for d in range(2, DIM + 2)], [])) + x = x.permute(0, 1, *([2 * i + 3 for i in range(DIM)] + [2 * i + 2 for i in range(DIM)])) + x = x.reshape(x.shape[0], x.shape[1] * (patch_size ** DIM), *(x.shape[-DIM:])) + return x + + +def unpatchify(x: torch.Tensor, patch_size: int): + """ + Unpatchify a tensor. + + Args: + x (torch.Tensor): (N, C, *spatial) tensor + patch_size (int): Patch size + """ + DIM = x.dim() - 2 + assert x.shape[1] % (patch_size ** DIM) == 0, f"Second dimension of input tensor must be divisible by patch size to unpatchify, got {x.shape[1]} and {patch_size ** DIM}" + + x = x.reshape(x.shape[0], x.shape[1] // (patch_size ** DIM), *([patch_size] * DIM), *(x.shape[-DIM:])) + x = x.permute(0, 1, *(sum([[2 + DIM + i, 2 + i] for i in range(DIM)], []))) + x = x.reshape(x.shape[0], x.shape[1], *[x.shape[2 + 2 * i] * patch_size for i in range(DIM)]) + return x diff --git a/trellis/modules/transformer/__init__.py b/trellis/modules/transformer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b08b0d4e5bc24060a2cdc8df75d06dce122972bd --- /dev/null +++ b/trellis/modules/transformer/__init__.py @@ -0,0 +1,2 @@ +from .blocks import * +from .modulated import * \ No newline at end of file diff --git a/trellis/modules/transformer/__pycache__/__init__.cpython-310.pyc b/trellis/modules/transformer/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6614cf67a229fc59aae86938bdbd6000f8aa0def Binary files /dev/null and b/trellis/modules/transformer/__pycache__/__init__.cpython-310.pyc differ diff --git a/trellis/modules/transformer/__pycache__/blocks.cpython-310.pyc b/trellis/modules/transformer/__pycache__/blocks.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bf7d94a799eaf030be44fcf8136207b6eda905db Binary files /dev/null and b/trellis/modules/transformer/__pycache__/blocks.cpython-310.pyc differ diff --git a/trellis/modules/transformer/__pycache__/modulated.cpython-310.pyc b/trellis/modules/transformer/__pycache__/modulated.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8f842476dad14630b99b4a11ee852805887ebba9 Binary files /dev/null and b/trellis/modules/transformer/__pycache__/modulated.cpython-310.pyc differ diff --git a/trellis/modules/transformer/blocks.py b/trellis/modules/transformer/blocks.py new file mode 100644 index 0000000000000000000000000000000000000000..c37eb7ed92f4aacfc9e974a63b247589d95977da --- /dev/null +++ b/trellis/modules/transformer/blocks.py @@ -0,0 +1,182 @@ +from typing import * +import torch +import torch.nn as nn +from ..attention import MultiHeadAttention +from ..norm import LayerNorm32 + + +class AbsolutePositionEmbedder(nn.Module): + """ + Embeds spatial positions into vector representations. + """ + def __init__(self, channels: int, in_channels: int = 3): + super().__init__() + self.channels = channels + self.in_channels = in_channels + self.freq_dim = channels // in_channels // 2 + self.freqs = torch.arange(self.freq_dim, dtype=torch.float32) / self.freq_dim + self.freqs = 1.0 / (10000 ** self.freqs) + + def _sin_cos_embedding(self, x: torch.Tensor) -> torch.Tensor: + """ + Create sinusoidal position embeddings. + + Args: + x: a 1-D Tensor of N indices + + Returns: + an (N, D) Tensor of positional embeddings. + """ + self.freqs = self.freqs.to(x.device) + out = torch.outer(x, self.freqs) + out = torch.cat([torch.sin(out), torch.cos(out)], dim=-1) + return out + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Args: + x (torch.Tensor): (N, D) tensor of spatial positions + """ + N, D = x.shape + assert D == self.in_channels, "Input dimension must match number of input channels" + embed = self._sin_cos_embedding(x.reshape(-1)) + embed = embed.reshape(N, -1) + if embed.shape[1] < self.channels: + embed = torch.cat([embed, torch.zeros(N, self.channels - embed.shape[1], device=embed.device)], dim=-1) + return embed + + +class FeedForwardNet(nn.Module): + def __init__(self, channels: int, mlp_ratio: float = 4.0): + super().__init__() + self.mlp = nn.Sequential( + nn.Linear(channels, int(channels * mlp_ratio)), + nn.GELU(approximate="tanh"), + nn.Linear(int(channels * mlp_ratio), channels), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.mlp(x) + + +class TransformerBlock(nn.Module): + """ + Transformer block (MSA + FFN). + """ + def __init__( + self, + channels: int, + num_heads: int, + mlp_ratio: float = 4.0, + attn_mode: Literal["full", "windowed"] = "full", + window_size: Optional[int] = None, + shift_window: Optional[int] = None, + use_checkpoint: bool = False, + use_rope: bool = False, + qk_rms_norm: bool = False, + qkv_bias: bool = True, + ln_affine: bool = False, + ): + super().__init__() + self.use_checkpoint = use_checkpoint + self.norm1 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6) + self.norm2 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6) + self.attn = MultiHeadAttention( + channels, + num_heads=num_heads, + attn_mode=attn_mode, + window_size=window_size, + shift_window=shift_window, + qkv_bias=qkv_bias, + use_rope=use_rope, + qk_rms_norm=qk_rms_norm, + ) + self.mlp = FeedForwardNet( + channels, + mlp_ratio=mlp_ratio, + ) + + def _forward(self, x: torch.Tensor) -> torch.Tensor: + h = self.norm1(x) + h = self.attn(h) + x = x + h + h = self.norm2(x) + h = self.mlp(h) + x = x + h + return x + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.use_checkpoint: + return torch.utils.checkpoint.checkpoint(self._forward, x, use_reentrant=False) + else: + return self._forward(x) + + +class TransformerCrossBlock(nn.Module): + """ + Transformer cross-attention block (MSA + MCA + FFN). + """ + def __init__( + self, + channels: int, + ctx_channels: int, + num_heads: int, + mlp_ratio: float = 4.0, + attn_mode: Literal["full", "windowed"] = "full", + window_size: Optional[int] = None, + shift_window: Optional[Tuple[int, int, int]] = None, + use_checkpoint: bool = False, + use_rope: bool = False, + qk_rms_norm: bool = False, + qk_rms_norm_cross: bool = False, + qkv_bias: bool = True, + ln_affine: bool = False, + ): + super().__init__() + self.use_checkpoint = use_checkpoint + self.norm1 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6) + self.norm2 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6) + self.norm3 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6) + self.self_attn = MultiHeadAttention( + channels, + num_heads=num_heads, + type="self", + attn_mode=attn_mode, + window_size=window_size, + shift_window=shift_window, + qkv_bias=qkv_bias, + use_rope=use_rope, + qk_rms_norm=qk_rms_norm, + ) + self.cross_attn = MultiHeadAttention( + channels, + ctx_channels=ctx_channels, + num_heads=num_heads, + type="cross", + attn_mode="full", + qkv_bias=qkv_bias, + qk_rms_norm=qk_rms_norm_cross, + ) + self.mlp = FeedForwardNet( + channels, + mlp_ratio=mlp_ratio, + ) + + def _forward(self, x: torch.Tensor, context: torch.Tensor): + h = self.norm1(x) + h = self.self_attn(h) + x = x + h + h = self.norm2(x) + h = self.cross_attn(h, context) + x = x + h + h = self.norm3(x) + h = self.mlp(h) + x = x + h + return x + + def forward(self, x: torch.Tensor, context: torch.Tensor): + if self.use_checkpoint: + return torch.utils.checkpoint.checkpoint(self._forward, x, context, use_reentrant=False) + else: + return self._forward(x, context) + \ No newline at end of file diff --git a/trellis/modules/transformer/modulated.py b/trellis/modules/transformer/modulated.py new file mode 100644 index 0000000000000000000000000000000000000000..c85d8d551b2bf9e45fb86bd98539967a4cec6665 --- /dev/null +++ b/trellis/modules/transformer/modulated.py @@ -0,0 +1,280 @@ +from typing import * +import torch +import torch.nn as nn +import torch.utils +import torch.utils.checkpoint +from ..attention import MultiHeadAttention +from ..norm import LayerNorm32 +from .blocks import FeedForwardNet + + +class ModulatedTransformerBlock(nn.Module): + """ + Transformer block (MSA + FFN) with adaptive layer norm conditioning. + """ + def __init__( + self, + channels: int, + num_heads: int, + mlp_ratio: float = 4.0, + attn_mode: Literal["full", "windowed"] = "full", + window_size: Optional[int] = None, + shift_window: Optional[Tuple[int, int, int]] = None, + use_checkpoint: bool = False, + use_rope: bool = False, + qk_rms_norm: bool = False, + qkv_bias: bool = True, + share_mod: bool = False, + ): + super().__init__() + self.use_checkpoint = use_checkpoint + self.share_mod = share_mod + self.norm1 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6) + self.norm2 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6) + self.attn = MultiHeadAttention( + channels, + num_heads=num_heads, + attn_mode=attn_mode, + window_size=window_size, + shift_window=shift_window, + qkv_bias=qkv_bias, + use_rope=use_rope, + qk_rms_norm=qk_rms_norm, + ) + self.mlp = FeedForwardNet( + channels, + mlp_ratio=mlp_ratio, + ) + if not share_mod: + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + nn.Linear(channels, 6 * channels, bias=True) + ) + + def _forward(self, x: torch.Tensor, mod: torch.Tensor) -> torch.Tensor: + if self.share_mod: + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = mod.chunk(6, dim=1) + else: + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(mod).chunk(6, dim=1) + h = self.norm1(x) + h = h * (1 + scale_msa.unsqueeze(1)) + shift_msa.unsqueeze(1) + h = self.attn(h) + h = h * gate_msa.unsqueeze(1) + x = x + h + h = self.norm2(x) + h = h * (1 + scale_mlp.unsqueeze(1)) + shift_mlp.unsqueeze(1) + h = self.mlp(h) + h = h * gate_mlp.unsqueeze(1) + x = x + h + return x + + def forward(self, x: torch.Tensor, mod: torch.Tensor) -> torch.Tensor: + if self.use_checkpoint: + return torch.utils.checkpoint.checkpoint(self._forward, x, mod, use_reentrant=False) + else: + return self._forward(x, mod) + + +class ModulatedTransformerCrossBlock(nn.Module): + """ + Transformer cross-attention block (MSA + MCA + FFN) with adaptive layer norm conditioning. + """ + def __init__( + self, + channels: int, + ctx_channels: int, + num_heads: int, + mlp_ratio: float = 4.0, + attn_mode: Literal["full", "windowed"] = "full", + window_size: Optional[int] = None, + shift_window: Optional[Tuple[int, int, int]] = None, + use_checkpoint: bool = False, + use_rope: bool = False, + qk_rms_norm: bool = False, + qk_rms_norm_cross: bool = False, + qkv_bias: bool = True, + share_mod: bool = False, + ): + super().__init__() + self.use_checkpoint = use_checkpoint + self.share_mod = share_mod + self.norm1 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6) + self.norm2 = LayerNorm32(channels, elementwise_affine=True, eps=1e-6) + self.norm3 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6) + self.self_attn = MultiHeadAttention( + channels, + num_heads=num_heads, + type="self", + attn_mode=attn_mode, + window_size=window_size, + shift_window=shift_window, + qkv_bias=qkv_bias, + use_rope=use_rope, + qk_rms_norm=qk_rms_norm, + ) + self.cross_attn = MultiHeadAttention( + channels, + ctx_channels=ctx_channels, + num_heads=num_heads, + type="cross", + attn_mode="full", + qkv_bias=qkv_bias, + qk_rms_norm=qk_rms_norm_cross, + ) + self.mlp = FeedForwardNet( + channels, + mlp_ratio=mlp_ratio, + ) + if not share_mod: + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + nn.Linear(channels, 6 * channels, bias=True) + ) + + def _forward(self, x: torch.Tensor, mod: torch.Tensor, context: torch.Tensor): + if self.share_mod: + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = mod.chunk(6, dim=1) + else: + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(mod).chunk(6, dim=1) + h = self.norm1(x) + h = h * (1 + scale_msa.unsqueeze(1)) + shift_msa.unsqueeze(1) + # h = torch.utils.checkpoint.checkpoint(self.self_attn, h) + h = self.self_attn(h) + h = h * gate_msa.unsqueeze(1) + x = x + h + h = self.norm2(x) + h = self.cross_attn(h, context) + x = x + h + h = self.norm3(x) + h = h * (1 + scale_mlp.unsqueeze(1)) + shift_mlp.unsqueeze(1) + h = self.mlp(h) + h = h * gate_mlp.unsqueeze(1) + x = x + h + return x + + def forward(self, x: torch.Tensor, mod: torch.Tensor, context: torch.Tensor): + if self.use_checkpoint: + return torch.utils.checkpoint.checkpoint(self._forward, x, mod, context, use_reentrant=False) + else: + return self._forward(x, mod, context) + +class ModulatedPosedTransformerBlock(nn.Module): + """ + Transformer block (MSA + FFN) with adaptive layer norm conditioning. + """ + def __init__( + self, + channels: int, + num_heads: int, + mlp_ratio: float = 4.0, + attn_mode: Literal["full", "windowed"] = "full", + window_size: Optional[int] = None, + shift_window: Optional[Tuple[int, int, int]] = None, + use_checkpoint: bool = False, + use_rope: bool = False, + qk_rms_norm: bool = False, + qkv_bias: bool = True, + ): + super().__init__() + self.use_checkpoint = use_checkpoint + self.norm1 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6) + self.norm2 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6) + self.attn = MultiHeadAttention( + channels, + num_heads=num_heads, + attn_mode=attn_mode, + window_size=window_size, + shift_window=shift_window, + qkv_bias=qkv_bias, + use_rope=use_rope, + qk_rms_norm=qk_rms_norm, + ) + self.mlp = FeedForwardNet( + channels, + mlp_ratio=mlp_ratio, + ) + + def _forward(self, x: torch.Tensor) -> torch.Tensor: + h = self.norm1(x) + h = self.attn(h) + x = x + h + h = self.norm2(x) + h = self.mlp(h) + x = x + h + return x + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.use_checkpoint: + return torch.utils.checkpoint.checkpoint(self._forward, x, use_reentrant=False) + else: + return self._forward(x) + +class ModulatedTransformerCrossBlock_woT(nn.Module): + """ + Transformer cross-attention block (MSA + MCA + FFN) with adaptive layer norm conditioning. + """ + def __init__( + self, + channels: int, + ctx_channels: int, + num_heads: int, + mlp_ratio: float = 4.0, + attn_mode: Literal["full", "windowed"] = "full", + window_size: Optional[int] = None, + shift_window: Optional[Tuple[int, int, int]] = None, + use_checkpoint: bool = False, + use_rope: bool = False, + qk_rms_norm: bool = False, + qk_rms_norm_cross: bool = False, + qkv_bias: bool = True, + share_mod: bool = False, + ): + super().__init__() + self.use_checkpoint = use_checkpoint + self.share_mod = share_mod + self.norm1 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6) + self.norm2 = LayerNorm32(channels, elementwise_affine=True, eps=1e-6) + self.norm3 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6) + self.self_attn = MultiHeadAttention( + channels, + num_heads=num_heads, + type="self", + attn_mode=attn_mode, + window_size=window_size, + shift_window=shift_window, + qkv_bias=qkv_bias, + use_rope=use_rope, + qk_rms_norm=qk_rms_norm, + ) + self.cross_attn = MultiHeadAttention( + channels, + ctx_channels=ctx_channels, + num_heads=num_heads, + type="cross", + attn_mode="full", + qkv_bias=qkv_bias, + qk_rms_norm=qk_rms_norm_cross, + ) + self.mlp = FeedForwardNet( + channels, + mlp_ratio=mlp_ratio, + ) + + def _forward(self, x: torch.Tensor, context: torch.Tensor): + + h = self.norm1(x) + h = self.self_attn(h) + x = x + h + h = self.norm2(x) + h = self.cross_attn(h, context) + x = x + h + h = self.norm3(x) + h = self.mlp(h) + x = x + h + return x + + def forward(self, x: torch.Tensor, context: torch.Tensor): + if self.use_checkpoint: + return torch.utils.checkpoint.checkpoint(self._forward, x, context, use_reentrant=False) + else: + return self._forward(x, context) \ No newline at end of file diff --git a/trellis/modules/utils.py b/trellis/modules/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..79aad7b08b1cb00db06d60d8215c7156926e80ac --- /dev/null +++ b/trellis/modules/utils.py @@ -0,0 +1,75 @@ +import torch.nn as nn +from ..modules import sparse as sp +import torch + +FP16_MODULES = ( + nn.Conv1d, + nn.Conv2d, + nn.Conv3d, + nn.ConvTranspose1d, + nn.ConvTranspose2d, + nn.ConvTranspose3d, + nn.Linear, + sp.SparseConv3d, + sp.SparseInverseConv3d, + sp.SparseLinear, +) + +BF16_MODULES = ( + nn.Conv1d, + nn.Conv2d, + nn.Conv3d, + nn.ConvTranspose1d, + nn.ConvTranspose2d, + nn.ConvTranspose3d, + nn.Linear, + sp.SparseConv3d, + sp.SparseInverseConv3d, + sp.SparseLinear, +) + +def convert_module_to_f16(l): + """ + Convert primitive modules to float16. + """ + if isinstance(l, FP16_MODULES): + for p in l.parameters(): + p.data = p.data.half() + +def convert_module_to_bf16(l): + """ + Convert primitive modules to bfloat16. + """ + if isinstance(l, BF16_MODULES): + for p in l.parameters(): + p.data = p.data.type(torch.bfloat16) + +def convert_module_to_f32(l): + """ + Convert primitive modules to float32, undoing convert_module_to_f16(). + """ + if isinstance(l, FP16_MODULES): + for p in l.parameters(): + p.data = p.data.float() + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +def scale_module(module, scale): + """ + Scale the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().mul_(scale) + return module + + +def modulate(x, shift, scale): + return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) diff --git a/trellis/pipelines/__init__.py b/trellis/pipelines/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..81cf31f44711ad94d5b0c6310d5752ef9547d1c2 --- /dev/null +++ b/trellis/pipelines/__init__.py @@ -0,0 +1,23 @@ +from . import samplers +from .trellis_image_to_3d import TrellisImageTo3DPipeline, TrellisVGGTTo3DPipeline + +def from_pretrained(path: str): + """ + Load a pipeline from a model folder or a Hugging Face model hub. + + Args: + path: The path to the model. Can be either local path or a Hugging Face model name. + """ + import os + import json + is_local = os.path.exists(f"{path}/pipeline.json") + + if is_local: + config_file = f"{path}/pipeline.json" + else: + from huggingface_hub import hf_hub_download + config_file = hf_hub_download(path, "pipeline.json") + + with open(config_file, 'r') as f: + config = json.load(f) + return globals()[config['name']].from_pretrained(path) diff --git a/trellis/pipelines/__pycache__/__init__.cpython-310.pyc b/trellis/pipelines/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b505f96842e09d26f38d9441e724e5f774165762 Binary files /dev/null and b/trellis/pipelines/__pycache__/__init__.cpython-310.pyc differ diff --git a/trellis/pipelines/__pycache__/base.cpython-310.pyc b/trellis/pipelines/__pycache__/base.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b36097d7794c20939ba974eae2d34b65a1a37893 Binary files /dev/null and b/trellis/pipelines/__pycache__/base.cpython-310.pyc differ diff --git a/trellis/pipelines/__pycache__/trellis_image_to_3d.cpython-310.pyc b/trellis/pipelines/__pycache__/trellis_image_to_3d.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d6993693d05d395fb94b38a1b1cee135661f8c54 Binary files /dev/null and b/trellis/pipelines/__pycache__/trellis_image_to_3d.cpython-310.pyc differ diff --git a/trellis/pipelines/__pycache__/trellis_image_to_ss.cpython-310.pyc b/trellis/pipelines/__pycache__/trellis_image_to_ss.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bc009fce20fde99a75dc3b633620c7a9bfffc63a Binary files /dev/null and b/trellis/pipelines/__pycache__/trellis_image_to_ss.cpython-310.pyc differ diff --git a/trellis/pipelines/base.py b/trellis/pipelines/base.py new file mode 100644 index 0000000000000000000000000000000000000000..30a2dabf298606c20830917c811f9da9dd7f4f51 --- /dev/null +++ b/trellis/pipelines/base.py @@ -0,0 +1,73 @@ +from typing import * +import torch +import torch.nn as nn +from .. import models + + +class Pipeline: + """ + A base class for pipelines. + """ + def __init__( + self, + models: dict[str, nn.Module] = None, + ): + self.t_scheduler = 'uniform' + if models is None: + return + self.models = models + for model in self.models.values(): + model.eval() + if 'slat_flow_model' in self.models: + self.slat_flow_model = self.models['slat_flow_model'] + if 'sparse_structure_flow_model' in self.models: + self.sparse_structure_flow_model = self.models['sparse_structure_flow_model'] + if 'sparse_structure_vggt_cond' in self.models: + self.sparse_structure_vggt_cond = self.models['sparse_structure_vggt_cond'] + @staticmethod + def from_pretrained(path: str) -> "Pipeline": + """ + Load a pretrained model. + """ + import os + import json + is_local = os.path.exists(f"{path}/pipeline.json") + + if is_local: + config_file = f"{path}/pipeline.json" + else: + from huggingface_hub import hf_hub_download + config_file = hf_hub_download(path, "pipeline.json") + + with open(config_file, 'r') as f: + args = json.load(f)['args'] + + _models = { + k: models.from_pretrained(f"{path}/{v}") + for k, v in args['models'].items() + } + + new_pipeline = Pipeline(_models) + new_pipeline._pretrained_args = args + return new_pipeline + + @property + def device(self) -> torch.device: + for model in self.models.values(): + if hasattr(model, 'device'): + return model.device + for model in self.models.values(): + if hasattr(model, 'parameters'): + return next(model.parameters()).device + raise RuntimeError("No device found.") + + def to(self, device: torch.device) -> None: + for model in self.models.values(): + model.to(device) + + def cuda(self) -> None: + self.to(torch.device("cuda")) + + def cpu(self) -> None: + self.to(torch.device("cpu")) + diff --git a/trellis/pipelines/samplers/__init__.py b/trellis/pipelines/samplers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2e7958d5491b03526bf6ffa1529a21f3b91e1aa0 --- /dev/null +++ b/trellis/pipelines/samplers/__init__.py @@ -0,0 +1,3 @@ +from .base import Sampler +from .flow_euler_old import FlowEulerSampler, FlowEulerCfgSampler, FlowEulerGuidanceIntervalSampler +from .flow_euler import FlowEulerSampler, FlowEulerCfgSampler, FlowEulerGuidanceIntervalSampler, LatentMatchSampler, LatentMatchGuidanceIntervalSampler \ No newline at end of file diff --git a/trellis/pipelines/samplers/__pycache__/__init__.cpython-310.pyc b/trellis/pipelines/samplers/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..050e5e93c51f8c0220ae144762d2ced1b9a9af9a Binary files /dev/null and b/trellis/pipelines/samplers/__pycache__/__init__.cpython-310.pyc differ diff --git a/trellis/pipelines/samplers/__pycache__/base.cpython-310.pyc b/trellis/pipelines/samplers/__pycache__/base.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1250f59798af8e9277b342bae5f15132d7f795af Binary files /dev/null and b/trellis/pipelines/samplers/__pycache__/base.cpython-310.pyc differ diff --git a/trellis/pipelines/samplers/__pycache__/classifier_free_guidance_mixin.cpython-310.pyc b/trellis/pipelines/samplers/__pycache__/classifier_free_guidance_mixin.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9ddc31e608eee86ba6b390eec45f586bc3718397 Binary files /dev/null and b/trellis/pipelines/samplers/__pycache__/classifier_free_guidance_mixin.cpython-310.pyc differ diff --git a/trellis/pipelines/samplers/__pycache__/flow_euler.cpython-310.pyc b/trellis/pipelines/samplers/__pycache__/flow_euler.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6cdea0a51d12ef248d67b240a74a9a8c6d699c59 Binary files /dev/null and b/trellis/pipelines/samplers/__pycache__/flow_euler.cpython-310.pyc differ diff --git a/trellis/pipelines/samplers/__pycache__/flow_euler_old.cpython-310.pyc b/trellis/pipelines/samplers/__pycache__/flow_euler_old.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d1ef50294691e3fab01f0399bb6cfbafa73a3832 Binary files /dev/null and b/trellis/pipelines/samplers/__pycache__/flow_euler_old.cpython-310.pyc differ diff --git a/trellis/pipelines/samplers/__pycache__/guidance_interval_mixin.cpython-310.pyc b/trellis/pipelines/samplers/__pycache__/guidance_interval_mixin.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e4bffb043c3992d2fbeeebb815738e8de8807e8e Binary files /dev/null and b/trellis/pipelines/samplers/__pycache__/guidance_interval_mixin.cpython-310.pyc differ diff --git a/trellis/pipelines/samplers/base.py b/trellis/pipelines/samplers/base.py new file mode 100644 index 0000000000000000000000000000000000000000..1966ce787009a5ee0c1ed06dce491525ff1dbcbf --- /dev/null +++ b/trellis/pipelines/samplers/base.py @@ -0,0 +1,20 @@ +from typing import * +from abc import ABC, abstractmethod + + +class Sampler(ABC): + """ + A base class for samplers. + """ + + @abstractmethod + def sample( + self, + model, + **kwargs + ): + """ + Sample from a model. + """ + pass + \ No newline at end of file diff --git a/trellis/pipelines/samplers/classifier_free_guidance_mixin.py b/trellis/pipelines/samplers/classifier_free_guidance_mixin.py new file mode 100644 index 0000000000000000000000000000000000000000..5701b25f5d7a2197612eb256f8ee13e8c489da1f --- /dev/null +++ b/trellis/pipelines/samplers/classifier_free_guidance_mixin.py @@ -0,0 +1,12 @@ +from typing import * + + +class ClassifierFreeGuidanceSamplerMixin: + """ + A mixin class for samplers that apply classifier-free guidance. + """ + + def _inference_model(self, model, x_t, t, cond, neg_cond, cfg_strength, **kwargs): + pred = super()._inference_model(model, x_t, t, cond, **kwargs) + neg_pred = super()._inference_model(model, x_t, t, neg_cond, **kwargs) + return (1 + cfg_strength) * pred - cfg_strength * neg_pred diff --git a/trellis/pipelines/samplers/flow_euler.py b/trellis/pipelines/samplers/flow_euler.py new file mode 100644 index 0000000000000000000000000000000000000000..56ce2794291fab862003f5b5eba9f146e7377719 --- /dev/null +++ b/trellis/pipelines/samplers/flow_euler.py @@ -0,0 +1,798 @@ +from typing import * +import torch +import numpy as np +from tqdm import tqdm +from easydict import EasyDict as edict +from .base import Sampler +from .classifier_free_guidance_mixin import ClassifierFreeGuidanceSamplerMixin +from .guidance_interval_mixin import GuidanceIntervalSamplerMixin +import math +from trellis.modules.spatial import patchify, unpatchify +from trellis.utils import render_utils, postprocessing_utils +from trellis.utils import loss_utils +import trellis.modules.sparse as sp +import torch.nn.functional as F + + +class FlowEulerSampler(Sampler): + """ + Generate samples from a flow-matching model using Euler sampling. + + Args: + sigma_min: The minimum scale of noise in flow. + """ + def __init__( + self, + sigma_min: float, + ): + self.sigma_min = sigma_min + + def _eps_to_xstart(self, x_t, t, eps): + assert x_t.shape == eps.shape + return (x_t - (self.sigma_min + (1 - self.sigma_min) * t) * eps) / (1 - t) + + def _xstart_to_x_t(self, x_0, t, eps): + assert x_0.shape == eps.shape + return (1-t) * x_0 + (self.sigma_min + (1 - self.sigma_min) * t) * eps + # return (1-t) * x_0 + t * eps + self.sigma_min * (1-t) * eps + + def _xstart_to_eps(self, x_t, t, x_0): + assert x_t.shape == x_0.shape + return (x_t - (1 - t) * x_0) / (self.sigma_min + (1 - self.sigma_min) * t) + + def _v_to_xstart_eps(self, x_t, t, v): + assert x_t.shape == v.shape + eps = (1 - t) * v + x_t + x_0 = (1 - self.sigma_min) * x_t - (self.sigma_min + (1 - self.sigma_min) * t) * v + return x_0, eps + + def _xstart_to_v(self, x_0, x_t, t): + assert x_0.shape == x_t.shape + return (x_t - (1 - self.sigma_min) * x_0) / (self.sigma_min + (1 - self.sigma_min) * t) + + + def _inference_model(self, model, x_t, t, cond=None, **kwargs): + t = torch.tensor([1000 * t] * x_t.shape[0], device=x_t.device, dtype=torch.float32) + return model(x_t, t, cond, **kwargs) + + def _get_model_prediction(self, model, x_t, t, cond=None, **kwargs): + param = kwargs.pop("parameterization", "v") + if param == "v": + pred_v = self._inference_model(model, x_t, t, cond, **kwargs) + pred_x_0, pred_eps = self._v_to_xstart_eps(x_t=x_t, t=t, v=pred_v) + elif param == "x0": + pred_x_0 = self._inference_model(model, x_t, t, cond, **kwargs) + pred_v = self._xstart_to_v(x_0=pred_x_0, x_t=x_t, t=t) + return pred_x_0, None, pred_v + + def _get_model_gt(self, x_0, t, noise): + gt_x_t = self._xstart_to_x_t(x_0, t, noise) + gt_v = self._xstart_to_v(x_0, gt_x_t, t) + return gt_x_t, gt_v + + @torch.no_grad() + def sample_once( + self, + model, + x_t, + t: float, + t_prev: float, + cond: Optional[Any] = None, + **kwargs + ): + """ + Sample x_{t-1} from the model using Euler method. + + Args: + model: The model to sample from. + x_t: The [N x C x ...] tensor of noisy inputs at time t. + t: The current timestep. + t_prev: The previous timestep. + cond: conditional information. + **kwargs: Additional arguments for model inference. + + Returns: + a dict containing the following + - 'pred_x_prev': x_{t-1}. + - 'pred_x_0': a prediction of x_0. + """ + pred_x_0, pred_eps, pred_v = self._get_model_prediction(model, x_t, t, cond, **kwargs) + pred_x_prev = x_t - (t - t_prev) * pred_v + return edict({"pred_x_prev": pred_x_prev, "pred_x_0": pred_x_0, "pred_eps": pred_eps}) + + def sample_once_opt( + self, + model, + x_t, + t: float, + t_prev: float, + cond: Optional[Any] = None, + **kwargs + ): + """ + Sample x_{t-1} from the model using Euler method. + + Args: + model: The model to sample from. + x_t: The [N x C x ...] tensor of noisy inputs at time t. + t: The current timestep. + t_prev: The previous timestep. + cond: conditional information. + **kwargs: Additional arguments for model inference. + + Returns: + a dict containing the following + - 'pred_x_prev': x_{t-1}. + - 'pred_x_0': a prediction of x_0. + """ + pred_x_0, pred_eps, pred_v = self._get_model_prediction(model, x_t, t, cond, **kwargs) + pred_x_prev = x_t - (t - t_prev) * pred_v + return edict({"pred_x_prev": pred_x_prev, "pred_x_0": pred_x_0, "pred_eps": pred_eps}) + + def sample_once_opt_delta_v( + self, + model, + slat_decoder_gs, + slat_decoder_mesh, + dreamsim_model, + learning_rate, + input_images, + extrinsics, + intrinsics, + x_t, + t: float, + t_prev: float, + cond: Optional[Any] = None, + **kwargs + ): + """ + Sample x_{t-1} from the model using Euler method. + + Args: + model: The model to sample from. + x_t: The [N x C x ...] tensor of noisy inputs at time t. + t: The current timestep. + t_prev: The previous timestep. + cond: conditional information. + **kwargs: Additional arguments for model inference. + + Returns: + a dict containing the following + - 'pred_x_prev': x_{t-1}. + - 'pred_x_0': a prediction of x_0. + """ + torch.cuda.empty_cache() + with torch.no_grad(): + pred_x_0, pred_eps, pred_v = self._get_model_prediction(model, x_t, t, cond, **kwargs) + pred_v_opt_feat = torch.nn.Parameter(pred_v.feats.detach().clone()) + optimizer = torch.optim.Adam([pred_v_opt_feat], betas=(0.5, 0.9), lr=learning_rate) + pred_v_opt = sp.SparseTensor(feats=pred_v_opt_feat, coords=pred_v.coords) + total_steps = 5 + input_images = F.interpolate(input_images, size=(259, 259), mode='bilinear', align_corners=False) + with tqdm(total=total_steps, disable=True, desc='Appearance (opt): optimizing') as pbar: + for step in range(total_steps): + optimizer.zero_grad() + pred_x_0, _ = self._v_to_xstart_eps(x_t=x_t, t=t, v=pred_v_opt) + pred_gs = slat_decoder_gs(pred_x_0) + # pred_mesh = slat_decoder_mesh(pred_x_0) + rend_gs = render_utils.render_frames(pred_gs[0], extrinsics, intrinsics, {'resolution': 259, 'bg_color': (0, 0, 0)}, need_depth=True, opt=True)['color'] + # rend_mesh = render_utils.render_frames_opt(pred_mesh[0], extrinsics, intrinsics, {'resolution': 518, 'bg_color': (0, 0, 0)}, need_depth=True)['color'] + rend_gs = torch.stack(rend_gs, dim=0) + loss_gs = loss_utils.l1_loss(rend_gs, input_images) + (1 - loss_utils.ssim(rend_gs, input_images)) + loss_utils.lpips(rend_gs, input_images) + dreamsim_model(rend_gs, input_images).mean() + # loss_gs = (1 - loss_utils.ssim(rend_gs, input_images)) + loss_utils.lpips(rend_gs, input_images) + dreamsim_model(rend_gs, input_images).mean() + # loss_mesh = loss_utils.l1_loss(rend_mesh, input_images) + 0.2 * (1 - loss_utils.ssim(rend_mesh, input_images)) + 0.2 * loss_utils.lpips(rend_mesh, input_images) + loss = loss_gs + 0.2 * loss_utils.l1_loss(pred_v_opt_feat, pred_v.feats) + loss.backward() + optimizer.step() + pbar.set_postfix({'loss': loss.item()}) + pbar.update() + + pred_x_prev = x_t - (t - t_prev) * pred_v_opt.detach() + torch.cuda.empty_cache() + return edict({"pred_x_prev": pred_x_prev, "pred_x_0": pred_x_0, "pred_eps": pred_eps}) + + def sample_opt( + self, + model, + noise, + cond: Optional[Any] = None, + steps: int = 50, + rescale_t: float = 1.0, + verbose: bool = True, + **kwargs + ): + """ + Generate samples from the model using Euler method. + + Args: + model: The model to sample from. + noise: The initial noise tensor. + cond: conditional information. + steps: The number of steps to sample. + rescale_t: The rescale factor for t. + verbose: If True, show a progress bar. + **kwargs: Additional arguments for model_inference. + + Returns: + a dict containing the following + - 'samples': the model samples. + - 'pred_x_t': a list of prediction of x_t. + - 'pred_x_0': a list of prediction of x_0. + """ + sample = noise + t_seq = np.linspace(1, 0, steps + 1) + t_seq = rescale_t * t_seq / (1 + (rescale_t - 1) * t_seq) + t_pairs = list((t_seq[i], t_seq[i + 1]) for i in range(steps)) + ret = edict({"samples": None, "pred_x_t": [], "pred_x_0": []}) + for t, t_prev in tqdm(t_pairs, desc="Sampling", disable=not verbose): + out = self.sample_once_opt(model, sample, t, t_prev, cond, **kwargs) + sample = out.pred_x_prev + ret.pred_x_t.append(out.pred_x_prev) + ret.pred_x_0.append(out.pred_x_0) + ret.samples = sample + return ret + + def sample_opt_delta_v( + self, + model, + slat_decoder_gs, + slat_decoder_mesh, + dreamsim_model, + apperance_learning_rate, + start_t, + input_images, + extrinsics, + intrinsics, + noise, + cond: Optional[Any] = None, + steps: int = 50, + rescale_t: float = 1.0, + verbose: bool = True, + **kwargs + ): + """ + Generate samples from the model using Euler method. + + Args: + model: The model to sample from. + noise: The initial noise tensor. + cond: conditional information. + steps: The number of steps to sample. + rescale_t: The rescale factor for t. + verbose: If True, show a progress bar. + **kwargs: Additional arguments for model_inference. + + Returns: + a dict containing the following + - 'samples': the model samples. + - 'pred_x_t': a list of prediction of x_t. + - 'pred_x_0': a list of prediction of x_0. + """ + sample = noise + t_seq = np.linspace(1, 0, steps + 1) + t_seq = rescale_t * t_seq / (1 + (rescale_t - 1) * t_seq) + t_pairs = list((t_seq[i], t_seq[i + 1]) for i in range(steps)) + ret = edict({"samples": None, "pred_x_t": [], "pred_x_0": []}) + # def cosine_anealing(step, total_steps, start_lr, end_lr): + # return end_lr + 0.5 * (start_lr - end_lr) * (1 + np.cos(np.pi * step / total_steps)) + for i, (t, t_prev) in enumerate(tqdm(t_pairs, desc="Sampling", disable=not verbose)): + if t > start_t: + out = self.sample_once(model, sample, t, t_prev, cond, **kwargs) + sample = out.pred_x_prev + ret.pred_x_t.append(out.pred_x_prev) + ret.pred_x_0.append(out.pred_x_0) + else: + # learning_rate = cosine_anealing(i - int(np.where(t_seq <= start_t)[0].min()), int(steps - np.where(t_seq <= start_t)[0].min()), apperance_learning_rate, 1e-5) + learning_rate = apperance_learning_rate + out = self.sample_once_opt_delta_v(model, slat_decoder_gs, slat_decoder_mesh, dreamsim_model, learning_rate, input_images, extrinsics, intrinsics, sample, t, t_prev, cond, **kwargs) + sample = out.pred_x_prev + ret.pred_x_t.append(out.pred_x_prev) + ret.pred_x_0.append(out.pred_x_0) + ret.samples = sample + return ret + + @torch.no_grad() + def sample( + self, + model, + noise, + cond: Optional[Any] = None, + steps: int = 50, + rescale_t: float = 1.0, + verbose: bool = True, + **kwargs + ): + """ + Generate samples from the model using Euler method. + + Args: + model: The model to sample from. + noise: The initial noise tensor. + cond: conditional information. + steps: The number of steps to sample. + rescale_t: The rescale factor for t. + verbose: If True, show a progress bar. + **kwargs: Additional arguments for model_inference. + + Returns: + a dict containing the following + - 'samples': the model samples. + - 'pred_x_t': a list of prediction of x_t. + - 'pred_x_0': a list of prediction of x_0. + """ + sample = noise + t_seq = np.linspace(1, 0, steps + 1) + t_seq = rescale_t * t_seq / (1 + (rescale_t - 1) * t_seq) + t_pairs = list((t_seq[i], t_seq[i + 1]) for i in range(steps)) + ret = edict({"samples": None, "pred_x_t": [], "pred_x_0": []}) + for t, t_prev in tqdm(t_pairs, desc="Sampling", disable=not verbose): + out = self.sample_once(model, sample, t, t_prev, cond, **kwargs) + sample = out.pred_x_prev + ret.pred_x_t.append(out.pred_x_prev) + ret.pred_x_0.append(out.pred_x_0) + ret.samples = sample + return ret + + + +class LatentMatchSampler(FlowEulerSampler): + """ + Generate samples from a Bridge Matching model using Euler sampling. + This sampler is designed for Latent Bridge Matching (LBM), where + the target (x_1) for training is assumed to be sampled from a Gaussian distribution, + and the source (x_0) for inference is also typically a Gaussian noise. + + Args: + sigma_bridge: The sigma parameter for the Bridge Matching stochastic interpolant. + This controls the amount of stochasticity in the SDE (LBM paper Eq 1). + + """ + def __init__( + self, + sigma_bridge: float = 0.1, + **kwargs + ): + # Call parent constructor with a dummy sigma_min. + # sigma_min is specific to Flow Matching's interpolant, which we override. + super().__init__(sigma_min=0.0, **kwargs) + self.sigma_bridge = sigma_bridge + + # Override _xstart_to_x_t for Bridge Matching's stochastic interpolant + # This method is used to generate gt_x_t for training. + def _xstart_to_x_t(self, x_0: torch.Tensor, t: float, noise: torch.Tensor, x_1: torch.Tensor) -> torch.Tensor: + """ + Calculates x_t according to the Bridge Matching stochastic interpolant. + This function is used during training to generate noisy samples x_t from + paired x_0 and x_1 samples. The 'x_1' argument is crucial for Bridge Matching. + + Args: + x_0: The source latent tensor (e.g., from a data distribution or Gaussian). + t: The current timestep (float between 0 and 1). + eps: A random noise tensor (epsilon). + x_1: The target latent tensor. Required for Bridge Matching. + + Returns: + The interpolated latent tensor x_t. + """ + # LBM interpolant formula: x_t = (1-t)x_0 + t*x_1 + sigma_bridge*sqrt(t*(1-t))*epsilon + return (1 - t) * x_0 + t * x_1 + self.sigma_bridge * math.sqrt(t * (1 - t)) * noise + + # This method is used to calculate gt_v for training. + def _xstart_to_v(self, x_0: torch.Tensor, x_t: torch.Tensor, t: float, x_1: Optional[torch.Tensor] = None) -> torch.Tensor: + """ + Calculates the ground truth drift (v) that the model should predict for Bridge Matching. + This function is used in the training objective to define the target for the model. + + Args: + x_0: The source latent tensor. + x_t: The interpolated latent tensor at time t. + t: The current timestep (float between 0 and 1). + x_1: The target latent tensor. Required for Bridge Matching. + + Returns: + The target drift tensor v. + """ + if x_1 is None: + # This branch should ideally not be hit during _get_model_gt for LBM. + raise ValueError("For Bridge Matching's target drift calculation, x_1 (target latent) must be provided.") + assert x_t.shape == x_1.shape, "x_t and x_1 shapes must match." + # LBM drift formula: v = (x_1 - x_t) / (1 - t) + # Add a small epsilon to (1-t) to prevent division by zero if t is exactly 1. + epsilon_t = 1e-5 # Small epsilon for numerical stability + return (x_t - x_0) / (t + epsilon_t) + + # Override _get_model_gt to provide ground truth for Bridge Matching training. + # In this simplified case, x_1 is sampled from a Gaussian distribution. + def _get_model_gt(self, x_0: torch.Tensor, t: float, x_1: torch.Tensor): + """ + Calculates ground truth x_t and v_target for Bridge Matching training purposes. + In this simplified case, x_1 is sampled from a Gaussian distribution. + + Args: + x_0: The source latent tensor (e.g., from a data distribution, or another Gaussian). + t: The current timestep. + noise: A random noise tensor (epsilon). + + Returns: + A tuple (gt_x_t, gt_v). + """ + # Sample x_1 from a Gaussian distribution with the same shape as x_0 + # This simulates the target distribution being Gaussian. + if isinstance(x_0, sp.SparseTensor): + noise = sp.SparseTensor( + feats=torch.randn_like(x_0.feats).to(x_0.feats.device), + coords=x_0.coords, + ) + else: + noise = torch.randn_like(x_0).to(x_0.device) + # For Bridge Matching, _xstart_to_x_t needs x_1 + gt_x_t = self._xstart_to_x_t(x_0, t, noise, x_1=x_1) + gt_v = self._xstart_to_v(x_0, gt_x_t, t, x_1=x_1) + return gt_x_t, gt_v + + # Override sample_once to include the stochastic term for SDE integration. + @torch.no_grad() + def sample_once( + self, + model, + x_t: torch.Tensor, + t: float, + t_prev: float, + cond: Optional[Any] = None, + **kwargs + ) -> edict: + """ + Performs a single Euler step to sample x_{t_next} from x_t for Bridge Matching. + The model is assumed to predict the drift 'v' as per LBM's formulation. + + Args: + model: The model to sample from (should be trained for Bridge Matching). + x_t: The [N x C x ...] tensor of current latent inputs at time t. + t: The current timestep. + t_next: The next timestep in the forward integration sequence (t+dt). + cond: conditional information. + **kwargs: Additional arguments for model inference. + + Returns: + An edict containing: + - 'pred_x_prev': The estimated latent tensor at t_next. + - 'pred_x_0': A prediction of x_0 (may be None as direct derivation is complex in LBM). + - 'pred_eps': A prediction of eps (may be None). + """ + # Get model's prediction of the drift (v) + # We use the parent's _get_model_prediction. Its _v_to_xstart_eps uses sigma_min, + # which is a dummy value here. For LBM, pred_v is the main output. + pred_x_0, pred_eps, pred_v = self._get_model_prediction(model, x_t, t, cond, **kwargs) + + # Calculate time step difference (dt) + dt = t - t_prev # This is the forward step size + + # Sample noise for the stochastic part of the SDE + # The SDE for LBM is dx_t = v(x_t, t) dt + sigma dB_t + # For Euler, dB_t approx sqrt(dt) * Z, where Z ~ N(0,I) + # noise_increment = sp.SparseTensor( + # feats=torch.randn_like(x_t.feats).to(x_t.feats.device), + # coords=x_t.coords, + # ) + # if isinstance(x_t, sp.SparseTensor): + # noise_increment = sp.SparseTensor( + # feats=torch.randn_like(x_t.feats).to(x_t.feats.device), + # coords=x_t.coords, + # ) + # else: + # noise_increment = torch.randn_like(x_t).to(x_t.device) + # noise_increment = noise_increment * self.sigma_bridge * torch.sqrt(torch.tensor(max(0.0, dt), device=x_t.device)) + # pred_x_prev = x_t - (t - t_prev) * pred_v - noise_increment + pred_x_prev = x_t - (t - t_prev) * pred_v + return edict({"pred_x_prev": pred_x_prev, "pred_x_0": pred_x_0, "pred_eps": pred_eps}) + +class FlowMatchingSampler(FlowEulerSampler): + """ + Implementation of Flow Matching using Euler sampling. + Inherits from FlowEulerSampler and modifies key methods for flow matching. + """ + def __init__(self, sigma_min: float = 0.0): + super().__init__(sigma_min=sigma_min) + + def _compute_velocity(self, x_t: torch.Tensor, x_0: torch.Tensor, t: float) -> torch.Tensor: + return ((1 - self.sigma_min) * x_t - x_0 ) / (self.sigma_min + (1 - self.sigma_min) * t) + + def _get_model_gt(self, x_1: torch.Tensor, t: float, x_0: torch.Tensor = None): + # TODO: Implement this method + pass + # """ + # Get ground truth for training. + # Args: + # x_1: Target endpoint + # t: Time point + # noise: Initial noise to use as x_0 + # """ + # x_t = (1 - t) * x_0 + t * x_1 + # v = self._compute_velocity(x_t, x_0, t) + # eps = x_t + (1 - t) * v # Convert velocity to noise + # return x_t, eps, v + + def _v_to_xstart_eps(self, x_t: torch.Tensor, t: float, v: torch.Tensor): + """Convert velocity to x_0 and noise predictions""" + eps = x_t + (1 - t) * v + x_0 = self._eps_to_xstart(x_t, t, eps) + return x_0, eps + + @torch.no_grad() + def sample( + self, + model, + x_1: torch.Tensor, + cond: Optional[Any] = None, + steps: int = 50, + rescale_t: float = 1.0, + verbose: bool = True, + **kwargs + ) -> Dict[str, torch.Tensor]: + """ + Generate samples by following the flow from noise to x_1. + Args: + model: The model to sample from + x_1: Target endpoint + cond: Conditional information + steps: Number of sampling steps + rescale_t: Time rescaling factor + verbose: Whether to show progress bar + **kwargs: Additional model arguments + Returns: + Dictionary containing sampling trajectory and predictions + """ + # Initialize with noise as x_0 + noise = torch.randn_like(x_1) + current_x = noise + + t_seq = np.linspace(1, 0, steps + 1) + t_seq = rescale_t * t_seq / (1 + (rescale_t - 1) * t_seq) + t_pairs = list(zip(t_seq[:-1], t_seq[1:])) + + ret = edict({ + "samples": None, + "pred_x_t": [], + "pred_x_0": [] + }) + + for t, t_prev in tqdm(t_pairs, desc="Sampling", disable=not verbose): + out = self.sample_once(model, current_x, t, t_prev, cond, **kwargs) + current_x = out.pred_x_prev + ret.pred_x_t.append(out.pred_x_prev) + ret.pred_x_0.append(out.pred_x_0) + + ret.samples = current_x + return ret + + def sample_once( + self, + model, + x_t: torch.Tensor, + t: float, + t_prev: float, + cond: Optional[Any] = None, + **kwargs + ) -> Dict: + """ + Sample x_{t-1} from the model using Euler method. + Args: + model: The model to sample from + x_t: Current state + t: Current time + t_prev: Next time step + cond: Conditional information + **kwargs: Additional model arguments + Returns: + Dictionary containing predictions + """ + pred_x_0, pred_eps, pred_v = self._get_model_prediction(model, x_t, t, cond, **kwargs) + pred_x_prev = x_t + (t_prev - t) * pred_v + return edict({ + "pred_x_prev": pred_x_prev, + "pred_x_0": pred_x_0, + "pred_eps": pred_eps + }) + +class FlowEulerCfgSampler(ClassifierFreeGuidanceSamplerMixin, FlowEulerSampler): + """ + Generate samples from a flow-matching model using Euler sampling with classifier-free guidance. + """ + @torch.no_grad() + def sample( + self, + model, + noise, + cond, + neg_cond, + steps: int = 50, + rescale_t: float = 1.0, + cfg_strength: float = 3.0, + verbose: bool = True, + **kwargs + ): + """ + Generate samples from the model using Euler method. + + Args: + model: The model to sample from. + noise: The initial noise tensor. + cond: conditional information. + neg_cond: negative conditional information. + steps: The number of steps to sample. + rescale_t: The rescale factor for t. + cfg_strength: The strength of classifier-free guidance. + verbose: If True, show a progress bar. + **kwargs: Additional arguments for model_inference. + + Returns: + a dict containing the following + - 'samples': the model samples. + - 'pred_x_t': a list of prediction of x_t. + - 'pred_x_0': a list of prediction of x_0. + """ + return super().sample(model, noise, cond, steps, rescale_t, verbose, neg_cond=neg_cond, cfg_strength=cfg_strength, **kwargs) + + +class FlowEulerGuidanceIntervalSampler(GuidanceIntervalSamplerMixin, FlowEulerSampler): + """ + Generate samples from a flow-matching model using Euler sampling with classifier-free guidance and interval. + """ + @torch.no_grad() + def sample( + self, + model, + noise, + cond, + neg_cond, + steps: int = 50, + rescale_t: float = 1.0, + cfg_strength: float = 3.0, + cfg_interval: Tuple[float, float] = (0.0, 1.0), + verbose: bool = True, + **kwargs + ): + """ + Generate samples from the model using Euler method. + + Args: + model: The model to sample from. + noise: The initial noise tensor. + cond: conditional information. + neg_cond: negative conditional information. + steps: The number of steps to sample. + rescale_t: The rescale factor for t. + cfg_strength: The strength of classifier-free guidance. + cfg_interval: The interval for classifier-free guidance. + verbose: If True, show a progress bar. + **kwargs: Additional arguments for model_inference. + + Returns: + a dict containing the following + - 'samples': the model samples. + - 'pred_x_t': a list of prediction of x_t. + - 'pred_x_0': a list of prediction of x_0. + """ + return super().sample(model, noise, cond, steps, rescale_t, verbose, neg_cond=neg_cond, cfg_strength=cfg_strength, cfg_interval=cfg_interval, **kwargs) + + def sample_opt( + self, + model, + noise, + cond, + neg_cond, + steps: int = 50, + rescale_t: float = 1.0, + cfg_strength: float = 3.0, + cfg_interval: Tuple[float, float] = (0.0, 1.0), + verbose: bool = True, + **kwargs + ): + """ + Generate samples from the model using Euler method. + + Args: + model: The model to sample from. + noise: The initial noise tensor. + cond: conditional information. + neg_cond: negative conditional information. + steps: The number of steps to sample. + rescale_t: The rescale factor for t. + cfg_strength: The strength of classifier-free guidance. + cfg_interval: The interval for classifier-free guidance. + verbose: If True, show a progress bar. + **kwargs: Additional arguments for model_inference. + + Returns: + a dict containing the following + - 'samples': the model samples. + - 'pred_x_t': a list of prediction of x_t. + - 'pred_x_0': a list of prediction of x_0. + """ + return super().sample_opt(model, noise, cond, steps, rescale_t, verbose, neg_cond=neg_cond, cfg_strength=cfg_strength, cfg_interval=cfg_interval, **kwargs) + + def sample_opt_delta_v( + self, + model, + slat_decoder_gs, + slat_decoder_mesh, + dreamsim_model, + apperance_learning_rate, + start_t, + input_images, + extrinsics, + intrinsics, + noise, + cond, + neg_cond, + steps: int = 50, + rescale_t: float = 1.0, + cfg_strength: float = 3.0, + cfg_interval: Tuple[float, float] = (0.0, 1.0), + verbose: bool = True, + **kwargs + ): + """ + Generate samples from the model using Euler method. + + Args: + model: The model to sample from. + noise: The initial noise tensor. + cond: conditional information. + neg_cond: negative conditional information. + steps: The number of steps to sample. + rescale_t: The rescale factor for t. + cfg_strength: The strength of classifier-free guidance. + cfg_interval: The interval for classifier-free guidance. + verbose: If True, show a progress bar. + **kwargs: Additional arguments for model_inference. + + Returns: + a dict containing the following + - 'samples': the model samples. + - 'pred_x_t': a list of prediction of x_t. + - 'pred_x_0': a list of prediction of x_0. + """ + return super().sample_opt_delta_v(model, slat_decoder_gs, slat_decoder_mesh, dreamsim_model, apperance_learning_rate, start_t, input_images, extrinsics, intrinsics,noise, cond, steps, rescale_t, verbose, neg_cond=neg_cond, cfg_strength=cfg_strength, cfg_interval=cfg_interval, **kwargs) + + +class LatentMatchGuidanceIntervalSampler(GuidanceIntervalSamplerMixin, LatentMatchSampler): + """ + Generate samples from a flow-matching model using Euler sampling with classifier-free guidance and interval. + """ + @torch.no_grad() + def sample( + self, + model, + noise, + cond, + neg_cond, + steps: int = 50, + rescale_t: float = 1.0, + cfg_strength: float = 3.0, + cfg_interval: Tuple[float, float] = (0.0, 1.0), + verbose: bool = True, + **kwargs + ): + """ + Generate samples from the model using Euler method. + + Args: + model: The model to sample from. + noise: The initial noise tensor. + cond: conditional information. + neg_cond: negative conditional information. + steps: The number of steps to sample. + rescale_t: The rescale factor for t. + cfg_strength: The strength of classifier-free guidance. + cfg_interval: The interval for classifier-free guidance. + verbose: If True, show a progress bar. + **kwargs: Additional arguments for model_inference. + + Returns: + a dict containing the following + - 'samples': the model samples. + - 'pred_x_t': a list of prediction of x_t. + - 'pred_x_0': a list of prediction of x_0. + """ + return super().sample(model, noise, cond, steps, rescale_t, verbose, neg_cond=neg_cond, cfg_strength=cfg_strength, cfg_interval=cfg_interval, **kwargs) diff --git a/trellis/pipelines/samplers/flow_euler_old.py b/trellis/pipelines/samplers/flow_euler_old.py new file mode 100644 index 0000000000000000000000000000000000000000..64a459ab7623a20e1214980d4fff656b5581be2d --- /dev/null +++ b/trellis/pipelines/samplers/flow_euler_old.py @@ -0,0 +1,471 @@ +from typing import * +import torch +import numpy as np +from tqdm import tqdm +from easydict import EasyDict as edict +from .base import Sampler +from .classifier_free_guidance_mixin import ClassifierFreeGuidanceSamplerMixin +from .guidance_interval_mixin import GuidanceIntervalSamplerMixin +import trellis.modules.sparse as sp +from trellis.modules.spatial import patchify, unpatchify + +class FlowEulerSampler(Sampler): + """ + Generate samples from a flow-matching model using Euler sampling. + + Args: + sigma_min: The minimum scale of noise in flow. + """ + def __init__( + self, + sigma_min: float, + ): + self.sigma_min = sigma_min + + def _eps_to_xstart(self, x_t, t, eps): + assert x_t.shape == eps.shape + return (x_t - (self.sigma_min + (1 - self.sigma_min) * t) * eps) / (1 - t) + + def _xstart_to_x_t(self, x_0, t, eps): + assert x_0.shape == eps.shape + return (1-t) * x_0 + (self.sigma_min + (1 - self.sigma_min) * t) * eps + + def _xstart_to_x_t(self, x_0, t, eps): + assert x_0.shape == eps.shape + return (1-t) * x_0 + (self.sigma_min + (1 - self.sigma_min) * t) * eps + + def _xstart_to_eps(self, x_t, t, x_0): + assert x_t.shape == x_0.shape + return (x_t - (1 - t) * x_0) / (self.sigma_min + (1 - self.sigma_min) * t) + + def _v_to_xstart_eps(self, x_t, t, v): + assert x_t.shape == v.shape + eps = (1 - t) * v + x_t + x_0 = (1 - self.sigma_min) * x_t - (self.sigma_min + (1 - self.sigma_min) * t) * v + return x_0, eps + + def _xstart_to_v(self, x_0, x_t, t): + assert x_0.shape == x_t.shape + return (x_t - (1 - self.sigma_min) * x_0) / (self.sigma_min + (1 - self.sigma_min) * t) + + + def _inference_model(self, model, x_t, t, cond=None, **kwargs): + t = torch.tensor([1000 * t] * x_t.shape[0], device=x_t.device, dtype=torch.float32) + return model(x_t.to(torch.float32), t, cond, **kwargs) + + def _get_model_prediction(self, model, x_t, t, cond=None, **kwargs): + param = kwargs.pop("parameterization", "v") + if param == "v": + pred_v = self._inference_model(model, x_t, t, cond, **kwargs) + pred_x_0, pred_eps = self._v_to_xstart_eps(x_t=x_t, t=t, v=pred_v) + elif param == "x0": + pred_x_0 = self._inference_model(model, x_t, t, cond, **kwargs) + pred_v = self._xstart_to_v(x_0=pred_x_0, x_t=x_t, t=t) + return pred_x_0, None, pred_v + + def _get_model_gt(self, x_0, t, noise): + gt_x_t = self._xstart_to_x_t(x_0, t, noise) + gt_v = self._xstart_to_v(x_0, gt_x_t, t) + return gt_x_t, gt_v + + @torch.no_grad() + def sample_once( + self, + model, + x_t, + t: float, + t_prev: float, + cond: Optional[Any] = None, + **kwargs + ): + """ + Sample x_{t-1} from the model using Euler method. + + Args: + model: The model to sample from. + x_t: The [N x C x ...] tensor of noisy inputs at time t. + t: The current timestep. + t_prev: The previous timestep. + cond: conditional information. + **kwargs: Additional arguments for model inference. + + Returns: + a dict containing the following + - 'pred_x_prev': x_{t-1}. + - 'pred_x_0': a prediction of x_0. + """ + pred_x_0, pred_eps, pred_v = self._get_model_prediction(model, x_t, t, cond, **kwargs) + pred_x_prev = x_t - (t - t_prev) * pred_v + return edict({"pred_x_prev": pred_x_prev, "pred_x_0": pred_x_0, "pred_eps": pred_eps}) + + @torch.no_grad() + def sample_once_featurevolume( + self, + model, + cond_model, + x_t, + t: float, + t_prev: float, + cond: Optional[Any] = None, + **kwargs + ): + """ + Sample x_{t-1} from the model using Euler method. + + Args: + model: The model to sample from. + x_t: The [N x C x ...] tensor of noisy inputs at time t. + t: The current timestep. + t_prev: The previous timestep. + cond: conditional information. + **kwargs: Additional arguments for model inference. + + Returns: + a dict containing the following + - 'pred_x_prev': x_{t-1}. + - 'pred_x_0': a prediction of x_0. + """ + if isinstance(cond, sp.SparseTensor): + t_tmp = torch.tensor([1000 * t] * x_t.shape[0], device=x_t.device, dtype=x_t.dtype) + t_embed = model.t_embedder(t_tmp).to(x_t.dtype) + for block in cond_model: + cond = block(cond, t_embed) + if model.pe_mode == "ape": + cond = cond + model.pos_embedder(cond.coords[:, 1:]).to(x_t.dtype) + if 'neg_cond' in kwargs.keys(): + neg_cond = kwargs['neg_cond'] + for block in cond_model: + neg_cond = block(neg_cond, t_embed) + if model.pe_mode == "ape": + neg_cond = neg_cond + model.pos_embedder(neg_cond.coords[:, 1:]).to(x_t.dtype) + kwargs['neg_cond'] = neg_cond + else: + for block in cond_model: + cond = block(cond) + cond = patchify(cond, model.patch_size) + cond = cond.view(*cond.shape[:2], -1).permute(0, 2, 1).contiguous() + cond = cond + model.pos_emb[None].type(model.dtype) + if 'neg_cond' in kwargs.keys(): + neg_cond = kwargs['neg_cond'] + for block in cond_model: + neg_cond = block(neg_cond) + neg_cond = patchify(neg_cond, model.patch_size) + neg_cond = neg_cond.view(*neg_cond.shape[:2], -1).permute(0, 2, 1).contiguous() + neg_cond = neg_cond + model.pos_emb[None].type(model.dtype) + kwargs['neg_cond'] = neg_cond + pred_x_0, pred_eps, pred_v = self._get_model_prediction(model, x_t, t, cond, **kwargs) + pred_x_prev = x_t - (t - t_prev) * pred_v + return edict({"pred_x_prev": pred_x_prev, "pred_x_0": pred_x_0, "pred_eps": pred_eps}) + + @torch.no_grad() + def sample_featurevolume( + self, + model, + cond_model, + noise, + cond: Optional[Any] = None, + steps: int = 50, + rescale_t: float = 1.0, + verbose: bool = True, + **kwargs + ): + """ + Generate samples from the model using Euler method. + + Args: + model: The model to sample from. + noise: The initial noise tensor. + cond: conditional information. + steps: The number of steps to sample. + rescale_t: The rescale factor for t. + verbose: If True, show a progress bar. + **kwargs: Additional arguments for model_inference. + + Returns: + a dict containing the following + - 'samples': the model samples. + - 'pred_x_t': a list of prediction of x_t. + - 'pred_x_0': a list of prediction of x_0. + """ + sample = noise + t_seq = np.linspace(1, 0, steps + 1) + t_seq = rescale_t * t_seq / (1 + (rescale_t - 1) * t_seq) + t_pairs = list((t_seq[i], t_seq[i + 1]) for i in range(steps)) + ret = edict({"samples": None, "pred_x_t": [], "pred_x_0": []}) + for t, t_prev in tqdm(t_pairs, desc="Sampling", disable=not verbose): + out = self.sample_once_featurevolume(model, cond_model, sample, t, t_prev, cond, **kwargs) + sample = out.pred_x_prev + ret.pred_x_t.append(out.pred_x_prev) + ret.pred_x_0.append(out.pred_x_0) + ret.samples = sample + return ret + + @torch.no_grad() + def sample( + self, + model, + noise, + cond: Optional[Any] = None, + steps: int = 50, + rescale_t: float = 1.0, + verbose: bool = True, + **kwargs + ): + """ + Generate samples from the model using Euler method. + + Args: + model: The model to sample from. + noise: The initial noise tensor. + cond: conditional information. + steps: The number of steps to sample. + rescale_t: The rescale factor for t. + verbose: If True, show a progress bar. + **kwargs: Additional arguments for model_inference. + + Returns: + a dict containing the following + - 'samples': the model samples. + - 'pred_x_t': a list of prediction of x_t. + - 'pred_x_0': a list of prediction of x_0. + """ + sample = noise + t_seq = np.linspace(1, 0, steps + 1) + t_seq = rescale_t * t_seq / (1 + (rescale_t - 1) * t_seq) + t_pairs = list((t_seq[i], t_seq[i + 1]) for i in range(steps)) + ret = edict({"samples": None, "pred_x_t": [], "pred_x_0": []}) + for t, t_prev in tqdm(t_pairs, desc="Sampling", disable=not verbose): + out = self.sample_once(model, sample, t, t_prev, cond, **kwargs) + sample = out.pred_x_prev + ret.pred_x_t.append(out.pred_x_prev) + ret.pred_x_0.append(out.pred_x_0) + ret.samples = sample + return ret + + +class FlowMatchingSampler(FlowEulerSampler): + """ + Implementation of Flow Matching using Euler sampling. + Inherits from FlowEulerSampler and modifies key methods for flow matching. + """ + def __init__(self, sigma_min: float = 0.0): + super().__init__(sigma_min=sigma_min) + + def _compute_velocity(self, x_t: torch.Tensor, x_0: torch.Tensor, t: float) -> torch.Tensor: + return ((1 - self.sigma_min) * x_t - x_0 ) / (self.sigma_min + (1 - self.sigma_min) * t) + + def _get_model_gt(self, x_1: torch.Tensor, t: float, x_0: torch.Tensor = None): + # TODO: Implement this method + pass + # """ + # Get ground truth for training. + # Args: + # x_1: Target endpoint + # t: Time point + # noise: Initial noise to use as x_0 + # """ + # x_t = (1 - t) * x_0 + t * x_1 + # v = self._compute_velocity(x_t, x_0, t) + # eps = x_t + (1 - t) * v # Convert velocity to noise + # return x_t, eps, v + + def _v_to_xstart_eps(self, x_t: torch.Tensor, t: float, v: torch.Tensor): + """Convert velocity to x_0 and noise predictions""" + eps = x_t + (1 - t) * v + x_0 = self._eps_to_xstart(x_t, t, eps) + return x_0, eps + + @torch.no_grad() + def sample( + self, + model, + x_1: torch.Tensor, + cond: Optional[Any] = None, + steps: int = 50, + rescale_t: float = 1.0, + verbose: bool = True, + **kwargs + ) -> Dict[str, torch.Tensor]: + """ + Generate samples by following the flow from noise to x_1. + Args: + model: The model to sample from + x_1: Target endpoint + cond: Conditional information + steps: Number of sampling steps + rescale_t: Time rescaling factor + verbose: Whether to show progress bar + **kwargs: Additional model arguments + Returns: + Dictionary containing sampling trajectory and predictions + """ + # Initialize with noise as x_0 + noise = torch.randn_like(x_1) + current_x = noise + + t_seq = np.linspace(1, 0, steps + 1) + t_seq = rescale_t * t_seq / (1 + (rescale_t - 1) * t_seq) + t_pairs = list(zip(t_seq[:-1], t_seq[1:])) + + ret = edict({ + "samples": None, + "pred_x_t": [], + "pred_x_0": [] + }) + + for t, t_prev in tqdm(t_pairs, desc="Sampling", disable=not verbose): + out = self.sample_once(model, current_x, t, t_prev, cond, **kwargs) + current_x = out.pred_x_prev + ret.pred_x_t.append(out.pred_x_prev) + ret.pred_x_0.append(out.pred_x_0) + + ret.samples = current_x + return ret + + def sample_once( + self, + model, + x_t: torch.Tensor, + t: float, + t_prev: float, + cond: Optional[Any] = None, + **kwargs + ) -> Dict: + """ + Sample x_{t-1} from the model using Euler method. + Args: + model: The model to sample from + x_t: Current state + t: Current time + t_prev: Next time step + cond: Conditional information + **kwargs: Additional model arguments + Returns: + Dictionary containing predictions + """ + pred_x_0, pred_eps, pred_v = self._get_model_prediction(model, x_t, t, cond, **kwargs) + pred_x_prev = x_t + (t_prev - t) * pred_v + return edict({ + "pred_x_prev": pred_x_prev, + "pred_x_0": pred_x_0, + "pred_eps": pred_eps + }) + +class FlowEulerCfgSampler(ClassifierFreeGuidanceSamplerMixin, FlowEulerSampler): + """ + Generate samples from a flow-matching model using Euler sampling with classifier-free guidance. + """ + @torch.no_grad() + def sample( + self, + model, + noise, + cond, + neg_cond, + steps: int = 50, + rescale_t: float = 1.0, + cfg_strength: float = 3.0, + verbose: bool = True, + **kwargs + ): + """ + Generate samples from the model using Euler method. + + Args: + model: The model to sample from. + noise: The initial noise tensor. + cond: conditional information. + neg_cond: negative conditional information. + steps: The number of steps to sample. + rescale_t: The rescale factor for t. + cfg_strength: The strength of classifier-free guidance. + verbose: If True, show a progress bar. + **kwargs: Additional arguments for model_inference. + + Returns: + a dict containing the following + - 'samples': the model samples. + - 'pred_x_t': a list of prediction of x_t. + - 'pred_x_0': a list of prediction of x_0. + """ + return super().sample(model, noise, cond, steps, rescale_t, verbose, neg_cond=neg_cond, cfg_strength=cfg_strength, **kwargs) + + +class FlowEulerGuidanceIntervalSampler(GuidanceIntervalSamplerMixin, FlowEulerSampler): + """ + Generate samples from a flow-matching model using Euler sampling with classifier-free guidance and interval. + """ + @torch.no_grad() + def sample( + self, + model, + noise, + cond, + neg_cond, + steps: int = 50, + rescale_t: float = 1.0, + cfg_strength: float = 3.0, + cfg_interval: Tuple[float, float] = (0.0, 1.0), + verbose: bool = True, + **kwargs + ): + """ + Generate samples from the model using Euler method. + + Args: + model: The model to sample from. + noise: The initial noise tensor. + cond: conditional information. + neg_cond: negative conditional information. + steps: The number of steps to sample. + rescale_t: The rescale factor for t. + cfg_strength: The strength of classifier-free guidance. + cfg_interval: The interval for classifier-free guidance. + verbose: If True, show a progress bar. + **kwargs: Additional arguments for model_inference. + + Returns: + a dict containing the following + - 'samples': the model samples. + - 'pred_x_t': a list of prediction of x_t. + - 'pred_x_0': a list of prediction of x_0. + """ + return super().sample(model, noise, cond, steps, rescale_t, verbose, neg_cond=neg_cond, cfg_strength=cfg_strength, cfg_interval=cfg_interval, **kwargs) + + @torch.no_grad() + def sample_featurevolume( + self, + model, + cond_model, + noise, + cond, + neg_cond, + steps: int = 50, + rescale_t: float = 1.0, + cfg_strength: float = 3.0, + cfg_interval: Tuple[float, float] = (0.0, 1.0), + verbose: bool = True, + **kwargs + ): + """ + Generate samples from the model using Euler method. + + Args: + model: The model to sample from. + noise: The initial noise tensor. + cond: conditional information. + neg_cond: negative conditional information. + steps: The number of steps to sample. + rescale_t: The rescale factor for t. + cfg_strength: The strength of classifier-free guidance. + cfg_interval: The interval for classifier-free guidance. + verbose: If True, show a progress bar. + **kwargs: Additional arguments for model_inference. + + Returns: + a dict containing the following + - 'samples': the model samples. + - 'pred_x_t': a list of prediction of x_t. + - 'pred_x_0': a list of prediction of x_0. + """ + return super().sample_featurevolume(model, cond_model, noise, cond, steps, rescale_t, verbose, neg_cond=neg_cond, cfg_strength=cfg_strength, cfg_interval=cfg_interval, **kwargs) diff --git a/trellis/pipelines/samplers/guidance_interval_mixin.py b/trellis/pipelines/samplers/guidance_interval_mixin.py new file mode 100644 index 0000000000000000000000000000000000000000..7074a4d5fea20a8f799416aa6571faca4f9eea06 --- /dev/null +++ b/trellis/pipelines/samplers/guidance_interval_mixin.py @@ -0,0 +1,15 @@ +from typing import * + + +class GuidanceIntervalSamplerMixin: + """ + A mixin class for samplers that apply classifier-free guidance with interval. + """ + + def _inference_model(self, model, x_t, t, cond, neg_cond, cfg_strength, cfg_interval, **kwargs): + if cfg_interval[0] <= t <= cfg_interval[1]: + pred = super()._inference_model(model, x_t, t, cond, **kwargs) + neg_pred = super()._inference_model(model, x_t, t, neg_cond, **kwargs) + return (1 + cfg_strength) * pred - cfg_strength * neg_pred + else: + return super()._inference_model(model, x_t, t, cond, **kwargs) diff --git a/trellis/pipelines/trellis_image_to_3d.py b/trellis/pipelines/trellis_image_to_3d.py new file mode 100644 index 0000000000000000000000000000000000000000..7b2f41879a80237e767a1122c7e638018fb913b4 --- /dev/null +++ b/trellis/pipelines/trellis_image_to_3d.py @@ -0,0 +1,843 @@ +from typing import * +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +from torchvision import transforms +from PIL import Image +import trimesh +import os +import random +import open3d as o3d +import trellis.modules.sparse as sp +from trellis.models.sparse_structure_vae import * +from contextlib import contextmanager + +import sys +sys.path.append("wheels/vggt") +from wheels.vggt.vggt.models.vggt import VGGT +from typing import * +from scipy.spatial.transform import Rotation + +def export_point_cloud(xyz, color): + # Convert tensors to numpy arrays if needed + if isinstance(xyz, torch.Tensor): + xyz = xyz.detach().cpu().numpy() + if isinstance(color, torch.Tensor): + color = color.detach().cpu().numpy() + + color = (color * 255).astype(np.uint8) + + # Create point cloud using trimesh + point_cloud = trimesh.PointCloud(vertices=xyz, colors=color) + + return point_cloud + +def normalize_trimesh(mesh): + # Calculate the mesh centroid and bounding box extents + centroid = mesh.centroid + # Determine the scale based on the largest extent to fit into unit cube + # Normalizing: Center and scale the vertices + mesh.vertices -= centroid + + extents = mesh.extents + scale = max(extents) + mesh.vertices /= scale + + return mesh + +def random_sample_rotation(rotation_factor: float = 1.0) -> np.ndarray: + # angle_z, angle_y, angle_x + euler = np.random.rand(3) * np.pi * 2 / rotation_factor # (0, 2 * pi / rotation_range) + rotation = Rotation.from_euler('zyx', euler).as_matrix() + return rotation + +from scipy.ndimage import binary_dilation +def voxelize_trimesh(mesh, resolution=(64, 64, 64), stride=4): + """ + Voxelize a given trimesh object with the specified resolution, incorporating 4x anti-aliasing. + First voxelizes at a 4x resolution and then downsamples to the target resolution. + + Args: + mesh (trimesh.Trimesh): The input trimesh object to be voxelized. + resolution (tuple): The voxel grid resolution as (x, y, z). Default is (64, 64, 64). + + Returns: + np.ndarray: A boolean numpy array representing the voxel grid where True indicates + the presence of the mesh in that voxel and False otherwise. + """ + target_density = max(resolution) + target_edge_length = 1.0 / target_density + max_edge_for_subdivision = target_edge_length / 2 + + # Calculate the higher resolution for 4x anti-aliasing + anti_aliasing_density = target_density * stride + anti_aliasing_edge_length = 1.0 / anti_aliasing_density + anti_aliasing_max_edge_for_subdivision = anti_aliasing_edge_length / 2 + + # Get the vertices and faces of the mesh + vertices = mesh.vertices + faces = mesh.faces + + # Subdivide the mesh for the higher resolution voxelization + try: + new_vertices, new_faces = trimesh.remesh.subdivide_to_size( + vertices, faces, anti_aliasing_max_edge_for_subdivision + ) + subdivided_mesh = trimesh.Trimesh(vertices=new_vertices, faces=new_faces) + except Exception as e: + print(f"Unexpected error during mesh subdivision for anti-aliasing: {e}") + raise + + # Voxelize the subdivided mesh at the higher resolution + try: + high_res_voxel_grid = subdivided_mesh.voxelized( + pitch=anti_aliasing_edge_length, method="binvox", exact=True + ) + except: + print("Voxelization using 'binvox' method failed for anti-aliasing") + high_res_voxel_grid = subdivided_mesh.voxelized(pitch=anti_aliasing_edge_length) + print("Falling back to default voxelization method for anti-aliasing.") + high_res_boolean_array = high_res_voxel_grid.matrix.astype(bool) + + x_stride, y_stride, z_stride = [int(anti_aliasing_density / target_density)] * 3 + downsampled_shape = ( + high_res_boolean_array.shape[0] // x_stride, + high_res_boolean_array.shape[1] // y_stride, + high_res_boolean_array.shape[2] // z_stride + ) + downsampled_array = np.zeros(downsampled_shape, dtype=bool) + + # Use NumPy's strided tricks to efficiently access sub-cubes for downsampling + shape = (downsampled_shape[0], downsampled_shape[1], downsampled_shape[2], x_stride, y_stride, z_stride) + strides = (x_stride * high_res_boolean_array.strides[0], + y_stride * high_res_boolean_array.strides[1], + z_stride * high_res_boolean_array.strides[2], + high_res_boolean_array.strides[0], + high_res_boolean_array.strides[1], + high_res_boolean_array.strides[2]) + sub_cubes = np.lib.stride_tricks.as_strided(high_res_boolean_array, shape=shape, strides=strides) + downsampled_array = np.any(sub_cubes, axis=(3, 4, 5)) + + return downsampled_array + +def get_occupied_coordinates(voxel_grid): + # Find the indices of occupied voxels + occupied_indices = np.argwhere(voxel_grid) + + coords = torch.tensor(occupied_indices, dtype=torch.int8) # Use float for scaling operations + + # Add a leading dimension for batch size or any additional data associations + coords = torch.cat([torch.zeros(coords.shape[0], 1, dtype=torch.int32), coords + 1], dim=1) + + # Move to GPU if required + coords = coords.to('cuda:0') + + return coords + +from .base import Pipeline +from . import samplers +from ..modules import sparse as sp + + +class TrellisImageTo3DPipeline(Pipeline): + """ + Pipeline for inferring Trellis image-to-3D models. + + Args: + models (dict[str, nn.Module]): The models to use in the pipeline. + sparse_structure_sampler (samplers.Sampler): The sampler for the sparse structure. + slat_sampler (samplers.Sampler): The sampler for the structured latent. + slat_normalization (dict): The normalization parameters for the structured latent. + image_cond_model (str): The name of the image conditioning model. + """ + default_image_resolution = 518 + def __init__( + self, + models: dict[str, nn.Module] = None, + sparse_structure_sampler: samplers.Sampler = None, + slat_sampler: samplers.Sampler = None, + slat_normalization: dict = None, + image_cond_model: str = None, + ): + if models is None: + return + super().__init__(models) + self.sparse_structure_sampler = sparse_structure_sampler + self.slat_sampler = slat_sampler + self.sparse_structure_sampler_params = {} + self.slat_sampler_params = {} + self.slat_normalization = slat_normalization + self._init_image_cond_model(image_cond_model) + + @staticmethod + def from_pretrained(path: str) -> "TrellisImageTo3DPipeline": + """ + Load a pretrained model. + + Args: + path (str): The path to the model. Can be either local path or a Hugging Face repository. + """ + pipeline = super(TrellisImageTo3DPipeline, TrellisImageTo3DPipeline).from_pretrained(path) + new_pipeline = TrellisImageTo3DPipeline() + new_pipeline.__dict__ = pipeline.__dict__ + args = pipeline._pretrained_args + + new_pipeline.sparse_structure_sampler = getattr(samplers, args['sparse_structure_sampler']['name'])(**args['sparse_structure_sampler']['args']) + new_pipeline.sparse_structure_sampler_params = args['sparse_structure_sampler']['params'] + + new_pipeline.slat_sampler = getattr(samplers, args['slat_sampler']['name'])(**args['slat_sampler']['args']) + new_pipeline.slat_sampler_params = args['slat_sampler']['params'] + + new_pipeline.slat_normalization = args['slat_normalization'] + + new_pipeline._init_image_cond_model(args['image_cond_model']) + + return new_pipeline + + def _init_image_cond_model(self, name: str): + """ + Initialize the image conditioning model. + """ + try: + dinov2_model = torch.hub.load(os.path.join(torch.hub.get_dir(), 'facebookresearch_dinov2_main'), name, source='local',pretrained=True) + except: + dinov2_model = torch.hub.load('facebookresearch/dinov2', name, pretrained=True) + dinov2_model.eval() + self.models['image_cond_model'] = dinov2_model + transform = transforms.Compose([ + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), + ]) + self.image_cond_model_transform = transform + + def preprocess_image(self, input: Image.Image, resolution=518, no_background=True, recenter=True) -> Image.Image: + """ + Preprocess the input image using BiRefNet for background removal. + Includes padding to maintain aspect ratio when resizing to 518x518. + """ + # if has alpha channel, use it directly + has_alpha = False + if input.mode == 'RGBA': + alpha = np.array(input)[:, :, -1] + if not np.all(alpha == 255): + has_alpha = True + + if has_alpha: + output = input + else: + input = input.convert('RGB') + max_size = max(input.size) + scale = min(1, 1024 / max_size) + if scale < 1: + input = input.resize((int(input.width * scale), int(input.height * scale)), Image.Resampling.LANCZOS) + + # Load BiRefNet model if not already loaded + if getattr(self, 'birefnet_model', None) is None: + self._lazy_load_birefnet() + + # Get mask using BiRefNet + mask = self._get_birefnet_mask(input) + + # Convert input to RGBA and apply mask + input_rgba = input.convert('RGBA') + input_array = np.array(input_rgba) + input_array[:, :, 3] = mask * 255 # Apply mask to alpha channel + output = Image.fromarray(input_array) + + # Process the output image + output_np = np.array(output) + alpha = output_np[:, :, 3] + + # Find bounding box of non-transparent pixels + bbox = np.argwhere(alpha > 0.8 * 255) + if len(bbox) == 0: # Handle case where no foreground is detected + return input.convert('RGB') + + bbox = np.min(bbox[:, 1]), np.min(bbox[:, 0]), np.max(bbox[:, 1]), np.max(bbox[:, 0]) + center = [(bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2] + size = max(bbox[2] - bbox[0], bbox[3] - bbox[1]) + size = int(size * 1.2) + # size = int(size * 1.1) + height, width = alpha.shape + if not recenter: + center = [width / 2, height / 2] + size = max(bbox[2] - bbox[0], + bbox[3] - bbox[1], + (bbox[2] - width / 2) * 2, + (width / 2 - bbox[0]) * 2, + (height / 2 - bbox[1]) * 2, + (bbox[3] - height / 2) * 2) + + + # Calculate and apply crop bbox + if not no_background: + if height > width: + center[0] = width / 2 + if center[1] < width / 2: + center[1] = width / 2 + elif center[1] > height - width / 2: + center[1] = height - width / 2 + else: + center[1] = height / 2 + if center[0] < height / 2: + center[0] = height / 2 + elif center[0] > width - height / 2: + center[0] = width - height / 2 + + size = min(center[0], center[1], input.width - center[0], input.height - center[1], size) * 2 + + bbox = ( + int(center[0] - size // 2), + int(center[1] - size // 2), + int(center[0] + size // 2), + int(center[1] + size // 2) + ) + + # Ensure bbox is within image bounds + bbox = ( + max(0, bbox[0]), + max(0, bbox[1]), + min(output.width, bbox[2]), + min(output.height, bbox[3]) + ) + + output = output.crop(bbox) + + # Add padding to maintain aspect ratio + width, height = output.size + if width > height: + new_height = width + padding = (width - height) // 2 + padded_output = Image.new('RGBA', (width, new_height), (0, 0, 0, 0)) + padded_output.paste(output, (0, padding)) + else: + new_width = height + padding = (height - width) // 2 + padded_output = Image.new('RGBA', (new_width, height), (0, 0, 0, 0)) + padded_output.paste(output, (padding, 0)) + + # Resize padded image to target size + # padded_output = padded_output.resize((resolution, resolution), Image.Resampling.LANCZOS) + padded_output = torch.from_numpy(np.array(padded_output).astype(np.float32)) / 255 + padded_output = F.interpolate(padded_output.unsqueeze(0).permute(0, 3, 1, 2), (resolution, resolution), mode='bilinear', align_corners=False)[0].permute(1, 2, 0) + + # Final processing + output = padded_output.cpu().numpy() + if no_background: + output = np.dstack(( + output[:, :, :3] * (output[:, :, 3:4] > 0.8), # RGB channels premultiplied by alpha + output[:, :, 3] # Original alpha channel + )) + output = Image.fromarray((output * 255).astype(np.uint8), mode='RGBA') + + return output + + def _lazy_load_birefnet(self): + """Lazy loading of the BiRefNet model""" + from transformers import AutoImageProcessor, Mask2FormerForUniversalSegmentation, AutoModelForImageSegmentation + self.birefnet_model = AutoModelForImageSegmentation.from_pretrained( + 'weights/BiRefNet', + trust_remote_code=True + ).to(self.device) + self.birefnet_model.eval() + + def _get_birefnet_mask(self, image: Image.Image) -> np.ndarray: + """Get object mask using BiRefNet""" + image_size = (1024, 1024) + transform_image = transforms.Compose([ + transforms.Resize(image_size), + transforms.ToTensor(), + transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) + ]) + + input_images = transform_image(image).unsqueeze(0).to(self.device) + + with torch.no_grad(): + preds = self.birefnet_model(input_images)[-1].sigmoid().cpu() + + pred = preds[0].squeeze() + pred_pil = transforms.ToPILImage()(pred) + mask = pred_pil.resize(image.size) + mask_np = np.array(mask) + + return (mask_np > 128).astype(np.uint8) + + @torch.no_grad() + def encode_image(self, image: Union[torch.Tensor, list[Image.Image]], w_layernorm=True) -> torch.Tensor: + """ + Encode the image. + + Args: + image (Union[torch.Tensor, list[Image.Image]]): The image to encode + + Returns: + torch.Tensor: The encoded features. + """ + if isinstance(image, torch.Tensor): + assert image.ndim == 4, "Image tensor should be batched (B, C, H, W)" + image = F.interpolate(image, self.default_image_resolution, mode='bilinear', align_corners=False) + elif isinstance(image, list): + assert all(isinstance(i, Image.Image) for i in image), "Image list should be list of PIL images" + image = [i.resize((self.default_image_resolution, self.default_image_resolution), Image.LANCZOS) for i in image] + image = [np.array(i.convert('RGB')).astype(np.float32) / 255 for i in image] + image = [torch.from_numpy(i).permute(2, 0, 1).float() for i in image] + image = torch.stack(image).to(self.device) + else: + raise ValueError(f"Unsupported type of image: {type(image)}") + + image = self.image_cond_model_transform(image).to(self.device) + features = self.models['image_cond_model'](image, is_training=True)['x_prenorm'] + if w_layernorm: + features = F.layer_norm(features, features.shape[-1:]) + return features + + def get_cond(self, image: Union[torch.Tensor, list[Image.Image]]) -> dict: + """ + Get the conditioning information for the model. + + Args: + image (Union[torch.Tensor, list[Image.Image]]): The image prompts. + + Returns: + dict: The conditioning information + """ + cond = self.encode_image(image) + neg_cond = torch.zeros_like(cond) + return { + 'cond': cond, + 'neg_cond': neg_cond, + } + + def sample_sparse_structure( + self, + cond: dict, + num_samples: int = 1, + sampler_params: dict = {}, + noise: torch.Tensor = None, + ) -> torch.Tensor: + """ + Sample sparse structures with the given conditioning. + + Args: + cond (dict): The conditioning information. + num_samples (int): The number of samples to generate. + sampler_params (dict): Additional parameters for the sampler. + """ + # Sample occupancy latent + flow_model = self.models['sparse_structure_flow_model'] + reso = flow_model.resolution + if noise is None: + noise = torch.randn(num_samples, flow_model.in_channels, reso, reso, reso).to(self.device) + sampler_params = {**self.sparse_structure_sampler_params, **sampler_params} + z_s = self.sparse_structure_sampler.sample( + flow_model, + noise, + **cond, + **sampler_params, + verbose=True + ).samples + + # Decode occupancy latent + decoder = self.models['sparse_structure_decoder'] + coords = torch.argwhere(decoder(z_s)>0)[:, [0, 2, 3, 4]].int() + + return coords + def encode_slat( + self, + slat: sp.SparseTensor, + ): + ret = {} + slat = self.models['slat_encoder'](slat, sample_posterior=False) + ret['slat'] = slat + return ret + + @torch.no_grad() + def decode_slat( + self, + slat: sp.SparseTensor, + formats: List[str] = ['mesh', 'gaussian', 'radiance_field'], + ) -> dict: + """ + Decode the structured latent. + + Args: + slat (sp.SparseTensor): The structured latent. + formats (List[str]): The formats to decode the structured latent to. + + Returns: + dict: The decoded structured latent. + """ + ret = {} + ret['slat'] = slat + if 'gaussian' in formats: + ret['gaussian'] = self.models['slat_decoder_gs'](slat) + if 'mesh' in formats: + ret['mesh'] = self.models['slat_decoder_mesh'](slat) + if 'radiance_field' in formats: + ret['radiance_field'] = self.models['slat_decoder_rf'](slat) + return ret + + def sample_slat( + self, + cond: dict, + coords: torch.Tensor, + sampler_params: dict = {}, + ) -> sp.SparseTensor: + """ + Sample structured latent with the given conditioning. + + Args: + cond (dict): The conditioning information. + coords (torch.Tensor): The coordinates of the sparse structure. + sampler_params (dict): Additional parameters for the sampler. + """ + # Sample structured latent + flow_model = self.models['slat_flow_model'] + noise = sp.SparseTensor( + feats=torch.randn(coords.shape[0], flow_model.in_channels).to(self.device), + coords=coords, + ) + sampler_params = {**self.slat_sampler_params, **sampler_params} + slat = self.slat_sampler.sample( + flow_model, + noise, + **cond, + **sampler_params, + verbose=True + ).samples + + std = torch.tensor(self.slat_normalization['std'])[None].to(slat.device) + mean = torch.tensor(self.slat_normalization['mean'])[None].to(slat.device) + slat = slat * std + mean + return slat + + def get_input(self, batch_data): + std = torch.tensor(self.slat_normalization['std'])[None].to(self.device) + mean = torch.tensor(self.slat_normalization['mean'])[None].to(self.device) + + images = batch_data['source_image'] + cond = self.encode_image(images) + if random.random() > 0.5: + cond = torch.zeros_like(cond) + + target_feats = batch_data['target_feats'] + target_coords = batch_data['target_coords'] + targets = sp.SparseTensor(target_feats, target_coords).to(self.device) + targets = (targets - mean) / std + + noise = sp.SparseTensor( + feats=torch.randn_like(target_feats).to(self.device), + coords=target_coords.to(self.device), + ) + return targets, cond, noise + + def forward(self, x: torch.Tensor, t: torch.Tensor, cond: torch.Tensor) -> torch.Tensor: + return self.slat_flow_model(x, t, cond) + + @contextmanager + def inject_sampler_multi_image( + self, + sampler_name: str, + num_images: int, + num_steps: int, + mode: Literal['stochastic', 'multidiffusion'] = 'stochastic', + ): + """ + Inject a sampler with multiple images as condition. + + Args: + sampler_name (str): The name of the sampler to inject. + num_images (int): The number of images to condition on. + num_steps (int): The number of steps to run the sampler for. + """ + sampler = getattr(self, sampler_name) + setattr(sampler, f'_old_inference_model', sampler._inference_model) + + if mode == 'stochastic': + if num_images > num_steps: + print(f"\033[93mWarning: number of conditioning images is greater than number of steps for {sampler_name}. " + "This may lead to performance degradation.\033[0m") + + cond_indices = (np.arange(num_steps) % num_images).tolist() + def _new_inference_model(self, model, x_t, t, cond, **kwargs): + cond_idx = cond_indices.pop(0) + cond_i = cond[cond_idx:cond_idx+1] + return self._old_inference_model(model, x_t, t, cond=cond_i, **kwargs) + + elif mode =='multidiffusion': + from .samplers import FlowEulerSampler + def _new_inference_model(self, model, x_t, t, cond, neg_cond, cfg_strength, cfg_interval, **kwargs): + if cfg_interval[0] <= t <= cfg_interval[1]: + preds = [] + for i in range(len(cond)): + preds.append(FlowEulerSampler._inference_model(self, model, x_t, t, cond[i:i+1], **kwargs)) + pred = sum(preds) / len(preds) + neg_pred = FlowEulerSampler._inference_model(self, model, x_t, t, neg_cond, **kwargs) + return (1 + cfg_strength) * pred - cfg_strength * neg_pred + else: + preds = [] + for i in range(len(cond)): + preds.append(FlowEulerSampler._inference_model(self, model, x_t, t, cond[i:i+1], **kwargs)) + pred = sum(preds) / len(preds) + return pred + + else: + raise ValueError(f"Unsupported mode: {mode}") + + sampler._inference_model = _new_inference_model.__get__(sampler, type(sampler)) + + yield + + sampler._inference_model = sampler._old_inference_model + delattr(sampler, f'_old_inference_model') + + @torch.no_grad() + def run_multi_image( + self, + images: List[Image.Image], + num_samples: int = 1, + seed: int = 42, + sparse_structure_sampler_params: dict = {}, + slat_sampler_params: dict = {}, + formats: List[str] = ['mesh', 'gaussian', 'radiance_field'], + preprocess_image: bool = True, + mode: Literal['stochastic', 'multidiffusion'] = 'stochastic', + ): + """ + Run the pipeline with multiple images as condition + + Args: + images (List[Image.Image]): The multi-view images of the assets + num_samples (int): The number of samples to generate. + sparse_structure_sampler_params (dict): Additional parameters for the sparse structure sampler. + slat_sampler_params (dict): Additional parameters for the structured latent sampler. + preprocess_image (bool): Whether to preprocess the image. + """ + if preprocess_image: + images = [self.preprocess_image(image) for image in images] + cond = self.get_cond(images) + cond['neg_cond'] = cond['neg_cond'][:1] + torch.manual_seed(seed) + flow_model = self.models['sparse_structure_flow_model'] + reso = flow_model.resolution + noise = torch.randn(num_samples, flow_model.in_channels, reso, reso, reso).to(self.device) + ss_steps = {**self.sparse_structure_sampler_params, **sparse_structure_sampler_params}.get('steps') + with self.inject_sampler_multi_image('sparse_structure_sampler', len(images), ss_steps, mode=mode): + coords = self.sample_sparse_structure(cond, num_samples, sparse_structure_sampler_params, noise) + slat_steps = {**self.slat_sampler_params, **slat_sampler_params}.get('steps') + with self.inject_sampler_multi_image('slat_sampler', len(images), slat_steps, mode=mode): + slat = self.sample_slat(cond, coords, slat_sampler_params) + return self.decode_slat(slat, formats) + + @torch.no_grad() + def run( + self, + image: Image.Image, + ref_image: Image.Image = None, + num_samples: int = 1, + seed: int = 42, + sparse_structure_sampler_params: dict = {}, + slat_sampler_params: dict = {}, + formats: List[str] = ['mesh'], + preprocess_image: bool = True, + init_mesh: trimesh.Trimesh = None, + coords: torch.Tensor = None, + normalize_init_mesh: bool = False, + init_resolution: int = 62, + init_stride: int = 4 + ) -> dict: + """ + Run the pipeline. + + Args: + image (Image.Image): The image prompt. + num_samples (int): The number of samples to generate. + sparse_structure_sampler_params (dict): Additional parameters for the sparse structure sampler. + slat_sampler_params (dict): Additional parameters for the structured latent sampler. + preprocess_image (bool): Whether to preprocess the image. + """ + if preprocess_image: + image = self.preprocess_image(image) + if ref_image is not None: + cond = self.encode_image([image, ref_image]) + neg_cond = torch.zeros_like(cond[0:1]) + sparse_cond = slat_cond = { + 'cond': 0.5 * cond[0:1] + 0.5 * cond[1:2], + 'neg_cond': neg_cond, + } + else: + sparse_cond = slat_cond = self.get_cond([image]) + + torch.manual_seed(seed) + if init_mesh is not None: + mesh_o3d = o3d.geometry.TriangleMesh() + mesh_o3d.vertices = o3d.utility.Vector3dVector(init_mesh.vertices) + mesh_o3d.triangles = o3d.utility.Vector3iVector(init_mesh.faces) + if normalize_init_mesh: + vertices = np.asarray(mesh_o3d.vertices) + init_mesh = normalize_trimesh(init_mesh) + center = (vertices.max(axis=0) + vertices.min(axis=0)) / 2 + vertices = vertices - center + diag = np.linalg.norm(vertices.max(axis=0) - vertices.min(axis=0)) + vertices = vertices / diag + mesh_o3d.vertices = o3d.utility.Vector3dVector(vertices) + + vertices = np.clip(np.asarray(mesh_o3d.vertices), -0.5 + 1e-6, 0.5 - 1e-6) + mesh_o3d.vertices = o3d.utility.Vector3dVector(vertices) + + voxel_grid = o3d.geometry.VoxelGrid.create_from_triangle_mesh_within_bounds( + mesh_o3d, + voxel_size=1/64, + min_bound=(-0.5, -0.5, -0.5), + max_bound=(0.5, 0.5, 0.5) + ) + + voxel_indices = np.array([voxel.grid_index for voxel in voxel_grid.get_voxels()]) + coords = torch.cat([torch.zeros(len(voxel_indices), 1), torch.tensor(voxel_indices)], dim=1).int().to(self.device) + elif coords is not None: + coords = coords + else: + coords = self.sample_sparse_structure(sparse_cond, num_samples, sparse_structure_sampler_params) + slat = self.sample_slat(slat_cond, coords, slat_sampler_params) + return self.decode_slat(slat, formats) + + def configure_optimizers(self): + params = list(self.slat_flow_model.parameters()) + opt = torch.optim.AdamW(params, lr=1e-4, weight_decay=0.0) + return opt + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + +class TrellisVGGTTo3DPipeline(TrellisImageTo3DPipeline): + def get_ss_cond(self, image_cond: torch.Tensor, aggregated_tokens_list: List, num_samples: int) -> dict: + """ + Get the conditioning information for the model. + + Args: + image (Union[torch.Tensor, list[Image.Image]]): The image prompts. + + Returns: + dict: The conditioning information + """ + cond = self.sparse_structure_vggt_cond(aggregated_tokens_list, image_cond) + neg_cond = torch.zeros_like(cond) + return { + 'cond': cond, + 'neg_cond': neg_cond, + } + @torch.no_grad() + def vggt_feat(self, image: Union[torch.Tensor, list[Image.Image]]) -> List: + """ + Encode the image. + + Args: + image (Union[torch.Tensor, list[Image.Image]]): The image to encode + + Returns: + torch.Tensor: The encoded features. + """ + if isinstance(image, torch.Tensor): + assert image.ndim == 4, "Image tensor should be batched (B, C, H, W)" + image = F.interpolate(image, self.default_image_resolution, mode='bilinear', align_corners=False) + elif isinstance(image, list): + assert all(isinstance(i, Image.Image) for i in image), "Image list should be list of PIL images" + image = [i.resize((self.default_image_resolution, self.default_image_resolution), Image.LANCZOS) for i in image] + image = [np.array(i.convert('RGB')).astype(np.float32) / 255 for i in image] + image = [torch.from_numpy(i).permute(2, 0, 1).float() for i in image] + image = torch.stack(image).to(self.device) + else: + raise ValueError(f"Unsupported type of image: {type(image)}") + + with torch.no_grad(): + with torch.cuda.amp.autocast(dtype=self.VGGT_dtype): + # Predict attributes including cameras, depth maps, and point maps. + aggregated_tokens_list, _ = self.VGGT_model.aggregator(image[None]) + + return aggregated_tokens_list, image + + def run( + self, + image: Union[torch.Tensor, list[Image.Image]], + coords: torch.Tensor = None, + num_samples: int = 1, + seed: int = 42, + sparse_structure_sampler_params: dict = {}, + slat_sampler_params: dict = {}, + formats: List[str] = ['mesh'], + preprocess_image: bool = True, + mode: Literal['stochastic', 'multidiffusion'] = 'stochastic', + ): + + torch.manual_seed(seed) + aggregated_tokens_list, _ = self.vggt_feat(image) + b, n, _, _ = aggregated_tokens_list[0].shape + image_cond = self.encode_image(image).reshape(b, n, -1, 1024) + + # if coords is None: + ss_flow_model = self.models['sparse_structure_flow_model'] + ss_cond = self.get_ss_cond(image_cond[:, :, 5:], aggregated_tokens_list, num_samples) + # Sample structured latent + ss_sampler_params = {**self.sparse_structure_sampler_params, **sparse_structure_sampler_params} + reso = ss_flow_model.resolution + ss_noise = torch.randn(num_samples, ss_flow_model.in_channels, reso, reso, reso).to(self.device) + ss_slat = self.sparse_structure_sampler.sample( + ss_flow_model, + ss_noise, + **ss_cond, + **ss_sampler_params, + verbose=True + ).samples + + decoder = self.models['sparse_structure_decoder'] + coords = torch.argwhere(decoder(ss_slat)>0)[:, [0, 2, 3, 4]].int() + + cond = { + 'cond': image_cond.reshape(n, -1, 1024), + 'neg_cond': torch.zeros_like(image_cond.reshape(n, -1, 1024))[:1], + } + + slat_steps = {**self.slat_sampler_params, **slat_sampler_params}.get('steps') + with self.inject_sampler_multi_image('slat_sampler', len(image), slat_steps, mode=mode): + slat = self.sample_slat(cond, coords, slat_sampler_params) + return self.decode_slat(slat, formats) + @staticmethod + def from_pretrained(path: str) -> "TrellisVGGTTo3DPipeline": + """ + Load a pretrained model. + + Args: + path (str): The path to the model. Can be either local path or a Hugging Face repository. + """ + pipeline = super(TrellisVGGTTo3DPipeline, TrellisVGGTTo3DPipeline).from_pretrained(path) + new_pipeline = TrellisVGGTTo3DPipeline() + new_pipeline.__dict__ = pipeline.__dict__ + args = pipeline._pretrained_args + new_pipeline.VGGT_dtype = torch.float32 + # VGGT_model = VGGT() + # VGGT_model_weight = torch.load("weights/VGGT_weight/object_vggt_model.pt", map_location=torch.device('cpu')) + # VGGT_model.load_state_dict(VGGT_model_weight) + VGGT_model = VGGT.from_pretrained("Stable-X/vggt-object-v0-1") + new_pipeline.VGGT_model = VGGT_model.to(new_pipeline.device) + del new_pipeline.VGGT_model.depth_head + del new_pipeline.VGGT_model.track_head + del new_pipeline.VGGT_model.camera_head + del new_pipeline.VGGT_model.point_head + new_pipeline.VGGT_model.eval() + + new_pipeline.sparse_structure_sampler = getattr(samplers, args['sparse_structure_sampler']['name'])(**args['sparse_structure_sampler']['args']) + new_pipeline.sparse_structure_sampler_params = args['sparse_structure_sampler']['params'] + + new_pipeline.slat_sampler = getattr(samplers, args['slat_sampler']['name'])(**args['slat_sampler']['args']) + new_pipeline.slat_sampler_params = args['slat_sampler']['params'] + + new_pipeline.slat_normalization = args['slat_normalization'] + + new_pipeline._init_image_cond_model(args['image_cond_model']) + + return new_pipeline \ No newline at end of file diff --git a/trellis/renderers/__init__.py b/trellis/renderers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..72f1616ed82a2ba281eb518714d66555a13b8a0f --- /dev/null +++ b/trellis/renderers/__init__.py @@ -0,0 +1,33 @@ +import importlib + +__attributes = { + 'GSplatRenderer': 'gsplat_renderer', + 'GaussianRenderer': 'gaussian_render', + 'MeshRenderer': 'mesh_renderer', + 'OctreeRenderer': 'octree_renderer' +} + +__submodules = [] + +__all__ = list(__attributes.keys()) + __submodules + +def __getattr__(name): + if name not in globals(): + if name in __attributes: + module_name = __attributes[name] + module = importlib.import_module(f".{module_name}", __name__) + globals()[name] = getattr(module, name) + elif name in __submodules: + module = importlib.import_module(f".{name}", __name__) + globals()[name] = module + else: + raise AttributeError(f"module {__name__} has no attribute {name}") + return globals()[name] + + +# For Pylance +if __name__ == '__main__': + from .mesh_renderer import MeshRenderer + from .gsplat_renderer import GSplatRenderer + from .gaussian_render import GaussianRenderer + from .octree_renderer import OctreeRenderer \ No newline at end of file diff --git a/trellis/renderers/__pycache__/__init__.cpython-310.pyc b/trellis/renderers/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b2ad76fd451195e5881a48a6ffb8ef97effca1f6 Binary files /dev/null and b/trellis/renderers/__pycache__/__init__.cpython-310.pyc differ diff --git a/trellis/renderers/__pycache__/gaussian_render.cpython-310.pyc b/trellis/renderers/__pycache__/gaussian_render.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1c77d3c9f306cbb2883bebbf48428313b892eecd Binary files /dev/null and b/trellis/renderers/__pycache__/gaussian_render.cpython-310.pyc differ diff --git a/trellis/renderers/__pycache__/mesh_renderer.cpython-310.pyc b/trellis/renderers/__pycache__/mesh_renderer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..56c182a03e1cb8448861f30b66e6cfa82968b02d Binary files /dev/null and b/trellis/renderers/__pycache__/mesh_renderer.cpython-310.pyc differ diff --git a/trellis/renderers/gaussian_render.py b/trellis/renderers/gaussian_render.py new file mode 100644 index 0000000000000000000000000000000000000000..be4d6255c1b706baf57649214edd2be7ff480420 --- /dev/null +++ b/trellis/renderers/gaussian_render.py @@ -0,0 +1,242 @@ +# +# Copyright (C) 2023, Inria +# GRAPHDECO research group, https://team.inria.fr/graphdeco +# All rights reserved. +# +# This software is free for non-commercial, research and evaluation use +# under the terms of the LICENSE.md file. +# +# For inquiries contact george.drettakis@inria.fr +# + +import torch +import math +from easydict import EasyDict as edict +import numpy as np +from ..representations.gaussian import Gaussian +import torch.nn.functional as F +from easydict import EasyDict as edict + + +def intrinsics_to_projection( + intrinsics: torch.Tensor, + near: float, + far: float, + ) -> torch.Tensor: + """ + OpenCV intrinsics to OpenGL perspective matrix + + Args: + intrinsics (torch.Tensor): [3, 3] OpenCV intrinsics matrix + near (float): near plane to clip + far (float): far plane to clip + Returns: + (torch.Tensor): [4, 4] OpenGL perspective matrix + """ + fx, fy = intrinsics[0, 0], intrinsics[1, 1] + cx, cy = intrinsics[0, 2], intrinsics[1, 2] + ret = torch.zeros((4, 4), dtype=intrinsics.dtype, device=intrinsics.device) + ret[0, 0] = 2 * fx + ret[1, 1] = 2 * fy + ret[0, 2] = 2 * cx - 1 + ret[1, 2] = - 2 * cy + 1 + ret[2, 2] = far / (far - near) + ret[2, 3] = near * far / (near - far) + ret[3, 2] = 1. + return ret + + +def render(viewpoint_camera, pc : Gaussian, pipe, bg_color : torch.Tensor, scaling_modifier = 1.0, override_color = None, need_depth = False): + """ + Render the scene. + + Background tensor (bg_color) must be on GPU! + """ + # lazy import + if 'GaussianRasterizer' not in globals(): + from diff_gaussian_rasterization import GaussianRasterizer, GaussianRasterizationSettings + + # Create zero tensor. We will use it to make pytorch return gradients of the 2D (screen-space) means + screenspace_points = torch.zeros_like(pc.get_xyz, dtype=pc.get_xyz.dtype, requires_grad=True, device="cuda") + 0 + try: + screenspace_points.retain_grad() + except: + pass + # Set up rasterization configuration + tanfovx = math.tan(viewpoint_camera.FoVx * 0.5) + tanfovy = math.tan(viewpoint_camera.FoVy * 0.5) + + kernel_size = pipe.kernel_size + subpixel_offset = torch.zeros((int(viewpoint_camera.image_height), int(viewpoint_camera.image_width), 2), dtype=torch.float32, device="cuda") + + raster_settings = GaussianRasterizationSettings( + image_height=int(viewpoint_camera.image_height), + image_width=int(viewpoint_camera.image_width), + tanfovx=tanfovx, + tanfovy=tanfovy, + kernel_size=kernel_size, + subpixel_offset=subpixel_offset, + bg=bg_color, + scale_modifier=scaling_modifier, + viewmatrix=viewpoint_camera.world_view_transform, + projmatrix=viewpoint_camera.full_proj_transform, + sh_degree=pc.active_sh_degree, + campos=viewpoint_camera.camera_center, + prefiltered=False, + debug=pipe.debug + ) + + rasterizer = GaussianRasterizer(raster_settings=raster_settings) + + means3D = pc.get_xyz + means2D = screenspace_points + opacity = pc.get_opacity + + # If precomputed 3d covariance is provided, use it. If not, then it will be computed from + # scaling / rotation by the rasterizer. + scales = None + rotations = None + cov3D_precomp = None + if pipe.compute_cov3D_python: + cov3D_precomp = pc.get_covariance(scaling_modifier) + else: + scales = pc.get_scaling + rotations = pc.get_rotation + + # If precomputed colors are provided, use them. Otherwise, if it is desired to precompute colors + # from SHs in Python, do it. If not, then SH -> RGB conversion will be done by rasterizer. + shs = pc.get_features + colors_precomp = None + + # Rasterize visible Gaussians to image, obtain their radii (on screen). + rendered_image, radii = rasterizer( + means3D = means3D, + means2D = means2D, + shs = shs, + colors_precomp = colors_precomp, + opacities = opacity, + scales = scales, + rotations = rotations, + cov3D_precomp = cov3D_precomp + ) + if need_depth: + p_hom = torch.cat([pc.get_xyz, torch.ones_like(pc.get_xyz[...,:1])], -1).unsqueeze(-1) + p_view = torch.matmul(viewpoint_camera.world_view_transform.transpose(0,1), p_hom) + p_view = p_view[...,:3,:] + depth = p_view.squeeze()[...,2:3] + depth = depth.repeat(1,3) + depth_map = rasterizer( + means3D = means3D, + means2D = means2D, + shs = None, + colors_precomp = depth, + opacities = opacity, + scales = scales, + rotations = rotations, + cov3D_precomp = cov3D_precomp)[0] + # Those Gaussians that were frustum culled or had a radius of 0 were not visible. + # They will be excluded from value updates used in the splitting criteria. + return edict({"render": rendered_image, + "depth": depth_map if need_depth else None, + "viewspace_points": screenspace_points, + "visibility_filter" : radii > 0, + "radii": radii}) + + +class GaussianRenderer: + """ + Renderer for the Voxel representation. + + Args: + rendering_options (dict): Rendering options. + """ + + def __init__(self, rendering_options={}) -> None: + self.pipe = edict({ + "kernel_size": 0.1, + "convert_SHs_python": False, + "compute_cov3D_python": False, + "scale_modifier": 1.0, + "debug": False + }) + self.rendering_options = edict({ + "resolution": None, + "near": None, + "far": None, + "ssaa": 1, + "bg_color": 'random', + }) + self.rendering_options.update(rendering_options) + self.bg_color = None + + def render( + self, + gausssian: Gaussian, + extrinsics: torch.Tensor, + intrinsics: torch.Tensor, + colors_overwrite: torch.Tensor = None, + need_depth: bool = False, + ) -> edict: + """ + Render the gausssian. + + Args: + gaussian : gaussianmodule + extrinsics (torch.Tensor): (4, 4) camera extrinsics + intrinsics (torch.Tensor): (3, 3) camera intrinsics + colors_overwrite (torch.Tensor): (N, 3) override color + + Returns: + edict containing: + color (torch.Tensor): (3, H, W) rendered color image + """ + resolution = self.rendering_options["resolution"] + near = self.rendering_options["near"] + far = self.rendering_options["far"] + ssaa = self.rendering_options["ssaa"] + + if self.rendering_options["bg_color"] == 'random': + self.bg_color = torch.zeros(3, dtype=torch.float32, device="cuda") + if np.random.rand() < 0.5: + self.bg_color += 1 + else: + self.bg_color = torch.tensor(self.rendering_options["bg_color"], dtype=torch.float32, device="cuda") + + view = extrinsics + perspective = intrinsics_to_projection(intrinsics, near, far) + camera = torch.inverse(view)[:3, 3] + focalx = intrinsics[0, 0] + focaly = intrinsics[1, 1] + fovx = 2 * torch.atan(0.5 / focalx) + fovy = 2 * torch.atan(0.5 / focaly) + + camera_dict = edict({ + "image_height": resolution * ssaa, + "image_width": resolution * ssaa, + "FoVx": fovx, + "FoVy": fovy, + "znear": near, + "zfar": far, + "world_view_transform": view.T.float().contiguous(), + "projection_matrix": perspective.T.float().contiguous(), + "full_proj_transform": (perspective @ view).T.float().contiguous(), + "camera_center": camera + }) + + # Render + render_ret = render(camera_dict, gausssian, self.pipe, self.bg_color, override_color=colors_overwrite, scaling_modifier=self.pipe.scale_modifier, need_depth=need_depth) + + if ssaa > 1: + render_ret.render = F.interpolate(render_ret.render[None], size=(resolution, resolution), mode='bilinear', align_corners=False, antialias=True).squeeze() + if need_depth: + render_ret.depth = F.interpolate(render_ret.depth[None], size=(resolution, resolution), mode='bilinear', align_corners=False, antialias=True).squeeze() + + if need_depth: + return edict({ + 'color': render_ret['render'], + 'depth': render_ret['depth'], + }) + else: + return edict({ + 'color': render_ret['render'], + }) diff --git a/trellis/renderers/gsplat_renderer.py b/trellis/renderers/gsplat_renderer.py new file mode 100644 index 0000000000000000000000000000000000000000..770089ec00fa5ddb998abb8f82cdd9f8c643dc40 --- /dev/null +++ b/trellis/renderers/gsplat_renderer.py @@ -0,0 +1,108 @@ +import gsplat as gs +import numpy as np +import torch +import torch.nn.functional as F +from easydict import EasyDict as edict + + +class GSplatRenderer: + def __init__(self, rendering_options={}) -> None: + self.pipe = edict({ + "kernel_size": 0.1, + "convert_SHs_python": False, + "compute_cov3D_python": False, + "scale_modifier": 1.0, + "debug": False, + "use_mip_gaussian": True + }) + self.rendering_options = edict({ + "resolution": None, + "near": None, + "far": None, + "ssaa": 1, + "bg_color": 'random', + }) + self.rendering_options.update(rendering_options) + self.bg_color = None + + def render( + self, + gaussian, + extrinsics: torch.Tensor, + intrinsics: torch.Tensor, + colors_overwrite: torch.Tensor = None + ) -> edict: + + resolution = self.rendering_options["resolution"] + ssaa = self.rendering_options["ssaa"] + + if self.rendering_options["bg_color"] == 'random': + self.bg_color = torch.zeros(3, dtype=torch.float32, device="cuda") + if np.random.rand() < 0.5: + self.bg_color += 1 + else: + self.bg_color = torch.tensor( + self.rendering_options["bg_color"], + dtype=torch.float32, + device="cuda" + ) + + height = resolution * ssaa + width = resolution * ssaa + + # Set up background color + if self.rendering_options["bg_color"] == 'random': + self.bg_color = torch.zeros(3, dtype=torch.float32, device="cuda") + if np.random.rand() < 0.5: + self.bg_color += 1 + else: + self.bg_color = torch.tensor( + self.rendering_options["bg_color"], + dtype=torch.float32, + device="cuda" + ) + + Ks_scaled = intrinsics.clone() + Ks_scaled[0, 0] *= width + Ks_scaled[1, 1] *= height + Ks_scaled[0, 2] *= width + Ks_scaled[1, 2] *= height + Ks_scaled = Ks_scaled.unsqueeze(0) + + near_plane = 0.01 + far_plane = 1000.0 + + # Rasterize with gsplat + render_colors, render_alphas, meta = gs.rasterization( + means=gaussian.get_xyz, + quats=F.normalize(gaussian.get_rotation, dim=-1), + scales=gaussian.get_scaling / intrinsics[0, 0], + opacities=gaussian.get_opacity.squeeze(-1), + colors=colors_overwrite.unsqueeze(0) if colors_overwrite is not None else torch.sigmoid( + gaussian.get_features.squeeze(1)).unsqueeze(0), + viewmats=extrinsics.unsqueeze(0), + Ks=Ks_scaled, + width=width, + height=height, + near_plane=near_plane, + far_plane=far_plane, + radius_clip=3.0, + eps2d=0.3, + render_mode="RGB", + backgrounds=self.bg_color.unsqueeze(0), + camera_model="pinhole" + ) + + rendered_image = render_colors[0, ..., 0:3].permute(2, 0, 1) + + # Apply supersampling if needed + if ssaa > 1: + rendered_image = F.interpolate( + rendered_image[None], + size=(resolution, resolution), + mode='bilinear', + align_corners=False, + antialias=True + ).squeeze() + + return edict({'color': rendered_image}) \ No newline at end of file diff --git a/trellis/renderers/mesh_renderer.py b/trellis/renderers/mesh_renderer.py new file mode 100644 index 0000000000000000000000000000000000000000..8a6ef3237c7f26c7b4c6a23aa7cf5c5edefa21e7 --- /dev/null +++ b/trellis/renderers/mesh_renderer.py @@ -0,0 +1,153 @@ +import torch +try: + import nvdiffrast.torch as dr +except : + print("nvdiffrast are not installed. Please install them to use the mesh renderer.") +from easydict import EasyDict as edict +from ..representations.mesh import MeshExtractResult +import torch.nn.functional as F + + +def intrinsics_to_projection( + intrinsics: torch.Tensor, + near: float, + far: float, + ) -> torch.Tensor: + """ + OpenCV intrinsics to OpenGL perspective matrix + + Args: + intrinsics (torch.Tensor): [3, 3] OpenCV intrinsics matrix + near (float): near plane to clip + far (float): far plane to clip + Returns: + (torch.Tensor): [4, 4] OpenGL perspective matrix + """ + fx, fy = intrinsics[0, 0], intrinsics[1, 1] + cx, cy = intrinsics[0, 2], intrinsics[1, 2] + ret = torch.zeros((4, 4), dtype=intrinsics.dtype, device=intrinsics.device) + ret[0, 0] = 2 * fx + ret[1, 1] = 2 * fy + ret[0, 2] = 2 * cx - 1 + ret[1, 2] = - 2 * cy + 1 + ret[2, 2] = far / (far - near) + ret[2, 3] = near * far / (near - far) + ret[3, 2] = 1. + return ret + + +class MeshRenderer: + """ + Renderer for the Mesh representation. + + Args: + rendering_options (dict): Rendering options. + glctx (nvdiffrast.torch.RasterizeGLContext): RasterizeGLContext object for CUDA/OpenGL interop. + """ + def __init__(self, rendering_options={}, device='cuda'): + self.rendering_options = edict({ + "resolution": None, + "near": None, + "far": None, + "ssaa": 1 + }) + self.rendering_options.update(rendering_options) + self.glctx = dr.RasterizeCudaContext(device=device) + self.device = device + + def render( + self, + mesh : MeshExtractResult, + extrinsics: torch.Tensor, + intrinsics: torch.Tensor, + return_types = ["color", "normal", "nocs", "depth"] + ) -> edict: + """ + Render the mesh. + + Args: + mesh : meshmodel + extrinsics (torch.Tensor): (4, 4) camera extrinsics + intrinsics (torch.Tensor): (3, 3) camera intrinsics + return_types (list): list of return types, can be "mask", "depth", "normal", "color", "nocs" + + Returns: + edict based on return_types containing: + color (torch.Tensor): [3, H, W] rendered color image + depth (torch.Tensor): [H, W] rendered depth image + normal (torch.Tensor): [3, H, W] rendered normal image in camera space + mask (torch.Tensor): [H, W] rendered mask image + nocs (torch.Tensor): [3, H, W] rendered NOCS coordinates + """ + resolution = self.rendering_options["resolution"] + near = self.rendering_options["near"] + far = self.rendering_options["far"] + ssaa = self.rendering_options["ssaa"] + + if mesh.vertices.shape[0] == 0 or mesh.faces.shape[0] == 0: + default_img = torch.zeros((1, resolution, resolution, 3), dtype=torch.float32, device=self.device) + ret_dict = {k : default_img if k in ['normal', 'normal_map', 'color'] else default_img[..., :1] for k in return_types} + return ret_dict + + perspective = intrinsics_to_projection(intrinsics, near, far) + + RT = extrinsics.unsqueeze(0) + full_proj = (perspective @ extrinsics).unsqueeze(0) + + vertices = mesh.vertices.unsqueeze(0) + + vertices_homo = torch.cat([vertices, torch.ones_like(vertices[..., :1])], dim=-1) + vertices_camera = torch.bmm(vertices_homo, RT.transpose(-1, -2)) + vertices_clip = torch.bmm(vertices_homo, full_proj.transpose(-1, -2)) + faces_int = mesh.faces.int() + rast, _ = dr.rasterize( + self.glctx, vertices_clip, faces_int, (resolution * ssaa, resolution * ssaa)) + + out_dict = edict() + for type in return_types: + img = None + try: + if type == "mask": + img = dr.antialias((rast[..., -1:] > 0).float(), rast, vertices_clip, faces_int) + elif type == "depth": + img = dr.interpolate(vertices_camera[..., 2:3].contiguous(), rast, faces_int)[0] + elif type == "normal": + # Transform face normals to camera space + rotation = RT[..., :3, :3] # [1, 3, 3] + face_normals = mesh.face_normal.view(1, -1, 3) # [1, N, 3] + camera_space_normals = torch.matmul(face_normals, rotation.transpose(-1, -2)) + camera_space_normals = F.normalize(camera_space_normals, dim=-1) + + img = dr.interpolate( + camera_space_normals.reshape(1, -1, 3), rast, + torch.arange(mesh.faces.shape[0] * 3, device=self.device, dtype=torch.int).reshape(-1, 3) + )[0] + # normalize norm pictures to [0,1] range + img = (-img + 1) / 2 + elif type == "color": + img = dr.interpolate(mesh.vertex_attrs[:, :3].contiguous(), rast, faces_int)[0] + img = dr.antialias(img, rast, vertices_clip, faces_int) + elif type == "nocs": + img = dr.interpolate(vertices[..., :3].contiguous(), rast, faces_int)[0] + img = img + 0.5 + + if ssaa > 1: + if type == 'color': + img = F.interpolate(img.permute(0, 3, 1, 2), (resolution, resolution), mode='bilinear', align_corners=False, antialias=True) + img = img.squeeze() + else: + img = F.interpolate(img.permute(0, 3, 1, 2), (resolution, resolution), mode='nearest') + img = img.squeeze() + else: + img = img.permute(0, 3, 1, 2).squeeze() + except Exception as e: + print(f"Error rendering {type}: {str(e)}") + # Return a blank image of appropriate shape in case of error + if type in ['normal', 'color', 'nocs', 'depth']: + img = torch.zeros((3, resolution, resolution), dtype=torch.float32, device=self.device) + else: + img = torch.zeros((resolution, resolution), dtype=torch.float32, device=self.device) + + out_dict[type] = img + + return out_dict diff --git a/trellis/renderers/octree_renderer.py b/trellis/renderers/octree_renderer.py new file mode 100644 index 0000000000000000000000000000000000000000..e933c9a55dbaa6177daa833c54fb70e08f81dc54 --- /dev/null +++ b/trellis/renderers/octree_renderer.py @@ -0,0 +1,301 @@ +import numpy as np +import torch +import torch.nn.functional as F +import math +import cv2 +from scipy.stats import qmc +from easydict import EasyDict as edict +from ..representations.octree import DfsOctree + + +def intrinsics_to_projection( + intrinsics: torch.Tensor, + near: float, + far: float, + ) -> torch.Tensor: + """ + OpenCV intrinsics to OpenGL perspective matrix + + Args: + intrinsics (torch.Tensor): [3, 3] OpenCV intrinsics matrix + near (float): near plane to clip + far (float): far plane to clip + Returns: + (torch.Tensor): [4, 4] OpenGL perspective matrix + """ + fx, fy = intrinsics[0, 0], intrinsics[1, 1] + cx, cy = intrinsics[0, 2], intrinsics[1, 2] + ret = torch.zeros((4, 4), dtype=intrinsics.dtype, device=intrinsics.device) + ret[0, 0] = 2 * fx + ret[1, 1] = 2 * fy + ret[0, 2] = 2 * cx - 1 + ret[1, 2] = - 2 * cy + 1 + ret[2, 2] = far / (far - near) + ret[2, 3] = near * far / (near - far) + ret[3, 2] = 1. + return ret + + +def render(viewpoint_camera, octree : DfsOctree, pipe, bg_color : torch.Tensor, scaling_modifier = 1.0, used_rank = None, colors_overwrite = None, aux=None, halton_sampler=None): + """ + Render the scene. + + Background tensor (bg_color) must be on GPU! + """ + # lazy import + if 'OctreeTrivecRasterizer' not in globals(): + from diffoctreerast import OctreeVoxelRasterizer, OctreeGaussianRasterizer, OctreeTrivecRasterizer, OctreeDecoupolyRasterizer + + # Set up rasterization configuration + tanfovx = math.tan(viewpoint_camera.FoVx * 0.5) + tanfovy = math.tan(viewpoint_camera.FoVy * 0.5) + + raster_settings = edict( + image_height=int(viewpoint_camera.image_height), + image_width=int(viewpoint_camera.image_width), + tanfovx=tanfovx, + tanfovy=tanfovy, + bg=bg_color, + scale_modifier=scaling_modifier, + viewmatrix=viewpoint_camera.world_view_transform, + projmatrix=viewpoint_camera.full_proj_transform, + sh_degree=octree.active_sh_degree, + campos=viewpoint_camera.camera_center, + with_distloss=pipe.with_distloss, + jitter=pipe.jitter, + debug=pipe.debug, + ) + + positions = octree.get_xyz + if octree.primitive == "voxel": + densities = octree.get_density + elif octree.primitive == "gaussian": + opacities = octree.get_opacity + elif octree.primitive == "trivec": + trivecs = octree.get_trivec + densities = octree.get_density + raster_settings.density_shift = octree.density_shift + elif octree.primitive == "decoupoly": + decoupolys_V, decoupolys_g = octree.get_decoupoly + densities = octree.get_density + raster_settings.density_shift = octree.density_shift + else: + raise ValueError(f"Unknown primitive {octree.primitive}") + depths = octree.get_depth + + # If precomputed colors are provided, use them. Otherwise, if it is desired to precompute colors + # from SHs in Python, do it. If not, then SH -> RGB conversion will be done by rasterizer. + colors_precomp = None + shs = octree.get_features + if octree.primitive in ["voxel", "gaussian"] and colors_overwrite is not None: + colors_precomp = colors_overwrite + shs = None + + ret = edict() + + if octree.primitive == "voxel": + renderer = OctreeVoxelRasterizer(raster_settings=raster_settings) + rgb, depth, alpha, distloss = renderer( + positions = positions, + densities = densities, + shs = shs, + colors_precomp = colors_precomp, + depths = depths, + aabb = octree.aabb, + aux = aux, + ) + ret['rgb'] = rgb + ret['depth'] = depth + ret['alpha'] = alpha + ret['distloss'] = distloss + elif octree.primitive == "gaussian": + renderer = OctreeGaussianRasterizer(raster_settings=raster_settings) + rgb, depth, alpha = renderer( + positions = positions, + opacities = opacities, + shs = shs, + colors_precomp = colors_precomp, + depths = depths, + aabb = octree.aabb, + aux = aux, + ) + ret['rgb'] = rgb + ret['depth'] = depth + ret['alpha'] = alpha + elif octree.primitive == "trivec": + raster_settings.used_rank = used_rank if used_rank is not None else trivecs.shape[1] + renderer = OctreeTrivecRasterizer(raster_settings=raster_settings) + rgb, depth, alpha, percent_depth = renderer( + positions = positions, + trivecs = trivecs, + densities = densities, + shs = shs, + colors_precomp = colors_precomp, + colors_overwrite = colors_overwrite, + depths = depths, + aabb = octree.aabb, + aux = aux, + halton_sampler = halton_sampler, + ) + ret['percent_depth'] = percent_depth + ret['rgb'] = rgb + ret['depth'] = depth + ret['alpha'] = alpha + elif octree.primitive == "decoupoly": + raster_settings.used_rank = used_rank if used_rank is not None else decoupolys_V.shape[1] + renderer = OctreeDecoupolyRasterizer(raster_settings=raster_settings) + rgb, depth, alpha = renderer( + positions = positions, + decoupolys_V = decoupolys_V, + decoupolys_g = decoupolys_g, + densities = densities, + shs = shs, + colors_precomp = colors_precomp, + depths = depths, + aabb = octree.aabb, + aux = aux, + ) + ret['rgb'] = rgb + ret['depth'] = depth + ret['alpha'] = alpha + + return ret + + +class OctreeRenderer: + """ + Renderer for the Voxel representation. + + Args: + rendering_options (dict): Rendering options. + """ + + def __init__(self, rendering_options={}) -> None: + try: + import diffoctreerast + except ImportError: + print("\033[93m[WARNING] diffoctreerast is not installed. The renderer will be disabled.\033[0m") + self.unsupported = True + else: + self.unsupported = False + + self.pipe = edict({ + "with_distloss": False, + "with_aux": False, + "scale_modifier": 1.0, + "used_rank": None, + "jitter": False, + "debug": False, + }) + self.rendering_options = edict({ + "resolution": None, + "near": None, + "far": None, + "ssaa": 1, + "bg_color": 'random', + }) + self.halton_sampler = qmc.Halton(2, scramble=False) + self.rendering_options.update(rendering_options) + self.bg_color = None + + def render( + self, + octree: DfsOctree, + extrinsics: torch.Tensor, + intrinsics: torch.Tensor, + colors_overwrite: torch.Tensor = None, + ) -> edict: + """ + Render the octree. + + Args: + octree (Octree): octree + extrinsics (torch.Tensor): (4, 4) camera extrinsics + intrinsics (torch.Tensor): (3, 3) camera intrinsics + colors_overwrite (torch.Tensor): (N, 3) override color + + Returns: + edict containing: + color (torch.Tensor): (3, H, W) rendered color + depth (torch.Tensor): (H, W) rendered depth + alpha (torch.Tensor): (H, W) rendered alpha + distloss (Optional[torch.Tensor]): (H, W) rendered distance loss + percent_depth (Optional[torch.Tensor]): (H, W) rendered percent depth + aux (Optional[edict]): auxiliary tensors + """ + resolution = self.rendering_options["resolution"] + near = self.rendering_options["near"] + far = self.rendering_options["far"] + ssaa = self.rendering_options["ssaa"] + + if self.unsupported: + image = np.zeros((512, 512, 3), dtype=np.uint8) + text_bbox = cv2.getTextSize("Unsupported", cv2.FONT_HERSHEY_SIMPLEX, 2, 3)[0] + origin = (512 - text_bbox[0]) // 2, (512 - text_bbox[1]) // 2 + image = cv2.putText(image, "Unsupported", origin, cv2.FONT_HERSHEY_SIMPLEX, 2, (255, 255, 255), 3, cv2.LINE_AA) + return { + 'color': torch.tensor(image, dtype=torch.float32).permute(2, 0, 1) / 255, + } + + if self.rendering_options["bg_color"] == 'random': + self.bg_color = torch.zeros(3, dtype=torch.float32, device="cuda") + if np.random.rand() < 0.5: + self.bg_color += 1 + else: + self.bg_color = torch.tensor(self.rendering_options["bg_color"], dtype=torch.float32, device="cuda") + + if self.pipe["with_aux"]: + aux = { + 'grad_color2': torch.zeros((octree.num_leaf_nodes, 3), dtype=torch.float32, requires_grad=True, device="cuda") + 0, + 'contributions': torch.zeros((octree.num_leaf_nodes, 1), dtype=torch.float32, requires_grad=True, device="cuda") + 0, + } + for k in aux.keys(): + aux[k].requires_grad_() + aux[k].retain_grad() + else: + aux = None + + view = extrinsics + perspective = intrinsics_to_projection(intrinsics, near, far) + camera = torch.inverse(view)[:3, 3] + focalx = intrinsics[0, 0] + focaly = intrinsics[1, 1] + fovx = 2 * torch.atan(0.5 / focalx) + fovy = 2 * torch.atan(0.5 / focaly) + + print(f"Rendering with resolution {resolution}, near {near}, far {far}, ssaa {ssaa}, bg_color {self.bg_color}, fovx {fovx}, fovy {fovy}") + camera_dict = edict({ + "image_height": resolution * ssaa, + "image_width": resolution * ssaa, + "FoVx": fovx, + "FoVy": fovy, + "znear": near, + "zfar": far, + "world_view_transform": view.T.contiguous(), + "projection_matrix": perspective.T.contiguous(), + "full_proj_transform": (perspective @ view).T.contiguous(), + "camera_center": camera + }) + + # Render + render_ret = render(camera_dict, octree, self.pipe, self.bg_color, aux=aux, colors_overwrite=colors_overwrite, scaling_modifier=self.pipe.scale_modifier, used_rank=self.pipe.used_rank, halton_sampler=self.halton_sampler) + + if ssaa > 1: + render_ret.rgb = F.interpolate(render_ret.rgb[None], size=(resolution, resolution), mode='bilinear', align_corners=False, antialias=True).squeeze() + render_ret.depth = F.interpolate(render_ret.depth[None, None], size=(resolution, resolution), mode='bilinear', align_corners=False, antialias=True).squeeze() + render_ret.alpha = F.interpolate(render_ret.alpha[None, None], size=(resolution, resolution), mode='bilinear', align_corners=False, antialias=True).squeeze() + if hasattr(render_ret, 'percent_depth'): + render_ret.percent_depth = F.interpolate(render_ret.percent_depth[None, None], size=(resolution, resolution), mode='bilinear', align_corners=False, antialias=True).squeeze() + + ret = edict({ + 'color': render_ret.rgb, + 'depth': render_ret.depth, + 'alpha': render_ret.alpha, + }) + if self.pipe["with_distloss"] and 'distloss' in render_ret: + ret['distloss'] = render_ret.distloss + if self.pipe["with_aux"]: + ret['aux'] = aux + if hasattr(render_ret, 'percent_depth'): + ret['percent_depth'] = render_ret.percent_depth + return ret \ No newline at end of file diff --git a/trellis/representations/__init__.py b/trellis/representations/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c027ee0c4e866b8f536e982399caf70e45905cc7 --- /dev/null +++ b/trellis/representations/__init__.py @@ -0,0 +1,3 @@ +from .gaussian import Gaussian +from .mesh import MeshExtractResult +from .octree import DfsOctree as Octree \ No newline at end of file diff --git a/trellis/representations/__pycache__/__init__.cpython-310.pyc b/trellis/representations/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..375729efd5bcdff2e52bb7e672e787e7a9f34e6c Binary files /dev/null and b/trellis/representations/__pycache__/__init__.cpython-310.pyc differ diff --git a/trellis/representations/gaussian/__init__.py b/trellis/representations/gaussian/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e3de6e180bd732836af876d748255595be2d4d74 --- /dev/null +++ b/trellis/representations/gaussian/__init__.py @@ -0,0 +1 @@ +from .gaussian_model import Gaussian \ No newline at end of file diff --git a/trellis/representations/gaussian/__pycache__/__init__.cpython-310.pyc b/trellis/representations/gaussian/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b7b49bf7d9175d445790ba1370a0296b8a0043c0 Binary files /dev/null and b/trellis/representations/gaussian/__pycache__/__init__.cpython-310.pyc differ diff --git a/trellis/representations/gaussian/__pycache__/gaussian_model.cpython-310.pyc b/trellis/representations/gaussian/__pycache__/gaussian_model.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..342ee620968a18e8df2da9066d78935b0f6a97c2 Binary files /dev/null and b/trellis/representations/gaussian/__pycache__/gaussian_model.cpython-310.pyc differ diff --git a/trellis/representations/gaussian/__pycache__/general_utils.cpython-310.pyc b/trellis/representations/gaussian/__pycache__/general_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2bc3af5fdc8f1ea1fbfecf0bb46ed229c56e9849 Binary files /dev/null and b/trellis/representations/gaussian/__pycache__/general_utils.cpython-310.pyc differ diff --git a/trellis/representations/gaussian/gaussian_model.py b/trellis/representations/gaussian/gaussian_model.py new file mode 100644 index 0000000000000000000000000000000000000000..1bfc51de0e79797a1a67c19dcae78dbe729f07a6 --- /dev/null +++ b/trellis/representations/gaussian/gaussian_model.py @@ -0,0 +1,215 @@ +import torch +import numpy as np +from plyfile import PlyData, PlyElement +from .general_utils import inverse_sigmoid, strip_symmetric, build_scaling_rotation + +SH_C0_0 = 0.28209479177387814 + +class Gaussian: + def __init__( + self, + aabb : list, + sh_degree : int = 0, + mininum_kernel_size : float = 0.0, + scaling_bias : float = 0.01, + opacity_bias : float = 0.1, + scaling_activation : str = "exp", + device='cuda' + ): + self.init_params = { + 'aabb': aabb, + 'sh_degree': sh_degree, + 'mininum_kernel_size': mininum_kernel_size, + 'scaling_bias': scaling_bias, + 'opacity_bias': opacity_bias, + 'scaling_activation': scaling_activation, + } + + self.sh_degree = sh_degree + self.active_sh_degree = sh_degree + self.mininum_kernel_size = mininum_kernel_size + self.scaling_bias = scaling_bias + self.opacity_bias = opacity_bias + self.scaling_activation_type = scaling_activation + self.device = device + self.aabb = torch.tensor(aabb, dtype=torch.float32, device=device) + + self.setup_functions() + + self._xyz = None + self._features_dc = None + self._features_rest = None + self._scaling = None + self._rotation = None + self._opacity = None + + def convert_to_fp32(self): + self.aabb = self.aabb.float() + if self._xyz is not None: + self._xyz = self._xyz.float() + if self._features_dc is not None: + self._features_dc = self._features_dc.float() + if self._features_rest is not None: + self._features_rest = self._features_rest.float() + if self._scaling is not None: + self._scaling = self._scaling.float() + if self._rotation is not None: + self._rotation = self._rotation.float() + if self._opacity is not None: + self._opacity = self._opacity.float() + + def setup_functions(self): + def build_covariance_from_scaling_rotation(scaling, scaling_modifier, rotation): + L = build_scaling_rotation(scaling_modifier * scaling, rotation) + actual_covariance = L @ L.transpose(1, 2) + symm = strip_symmetric(actual_covariance) + return symm + + if self.scaling_activation_type == "exp": + self.scaling_activation = torch.exp + self.inverse_scaling_activation = torch.log + elif self.scaling_activation_type == "softplus": + self.scaling_activation = torch.nn.functional.softplus + self.inverse_scaling_activation = lambda x: x + torch.log(-torch.expm1(-x)) + + self.covariance_activation = build_covariance_from_scaling_rotation + + self.opacity_activation = torch.sigmoid + self.inverse_opacity_activation = inverse_sigmoid + + self.rotation_activation = torch.nn.functional.normalize + + self.scale_bias = self.inverse_scaling_activation(torch.tensor(self.scaling_bias)).cuda() + self.rots_bias = torch.zeros((4)).cuda() + self.rots_bias[0] = 1 + self.opacity_bias = self.inverse_opacity_activation(torch.tensor(self.opacity_bias)).cuda() + + @property + def get_scaling(self): + scales = self.scaling_activation(self._scaling + self.scale_bias) + scales = torch.square(scales) + self.mininum_kernel_size ** 2 + scales = torch.sqrt(scales) + return scales + + @property + def get_rotation(self): + return self.rotation_activation(self._rotation + self.rots_bias[None, :]) + + @property + def get_xyz(self): + return self._xyz * self.aabb[None, 3:] + self.aabb[None, :3] + + @property + def get_color(self): + return (SH_C0_0 * self._features_dc.squeeze(dim=1) + 0.5).clip(0, 1) + + @property + def get_features(self): + return torch.cat((self._features_dc, self._features_rest), dim=2) if self._features_rest is not None else self._features_dc + + @property + def get_opacity(self): + return self.opacity_activation(self._opacity + self.opacity_bias) + + def get_covariance(self, scaling_modifier = 1): + return self.covariance_activation(self.get_scaling, scaling_modifier, self._rotation + self.rots_bias[None, :]) + + def from_scaling(self, scales): + scales = torch.sqrt(torch.square(scales) - self.mininum_kernel_size ** 2) + self._scaling = self.inverse_scaling_activation(scales) - self.scale_bias + + def from_rotation(self, rots): + self._rotation = rots - self.rots_bias[None, :] + + def from_xyz(self, xyz): + self._xyz = (xyz - self.aabb[None, :3]) / self.aabb[None, 3:] + + def from_features(self, features): + self._features_dc = features + + def from_opacity(self, opacities): + self._opacity = self.inverse_opacity_activation(opacities) - self.opacity_bias + + def construct_list_of_attributes(self): + l = ['x', 'y', 'z', 'nx', 'ny', 'nz'] + # All channels except the 3 DC + for i in range(self._features_dc.shape[1]*self._features_dc.shape[2]): + l.append('f_dc_{}'.format(i)) + l.append('opacity') + for i in range(self._scaling.shape[1]): + l.append('scale_{}'.format(i)) + for i in range(self._rotation.shape[1]): + l.append('rot_{}'.format(i)) + return l + + def save_ply(self, path): + xyz = self.get_xyz.detach().cpu().numpy() + normals = np.zeros_like(xyz) + f_dc = self._features_dc.detach().transpose(1, 2).flatten(start_dim=1).contiguous().cpu().numpy() + opacities = inverse_sigmoid(self.get_opacity).detach().cpu().numpy() + scale = torch.log(self.get_scaling).detach().cpu().numpy() + rotation = (self._rotation + self.rots_bias[None, :]).detach().cpu().numpy() + + dtype_full = [(attribute, 'f4') for attribute in self.construct_list_of_attributes()] + + elements = np.empty(xyz.shape[0], dtype=dtype_full) + attributes = np.concatenate((xyz, normals, f_dc, opacities, scale, rotation), axis=1) + elements[:] = list(map(tuple, attributes)) + el = PlyElement.describe(elements, 'vertex') + PlyData([el]).write(path) + + def load_ply(self, path): + plydata = PlyData.read(path) + + xyz = np.stack((np.asarray(plydata.elements[0]["x"]), + np.asarray(plydata.elements[0]["y"]), + np.asarray(plydata.elements[0]["z"])), axis=1) + opacities = np.asarray(plydata.elements[0]["opacity"])[..., np.newaxis] + + features_dc = np.zeros((xyz.shape[0], 3, 1)) + features_dc[:, 0, 0] = np.asarray(plydata.elements[0]["f_dc_0"]) + features_dc[:, 1, 0] = np.asarray(plydata.elements[0]["f_dc_1"]) + features_dc[:, 2, 0] = np.asarray(plydata.elements[0]["f_dc_2"]) + + if self.sh_degree > 0: + extra_f_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("f_rest_")] + extra_f_names = sorted(extra_f_names, key = lambda x: int(x.split('_')[-1])) + assert len(extra_f_names)==3*(self.sh_degree + 1) ** 2 - 3 + features_extra = np.zeros((xyz.shape[0], len(extra_f_names))) + for idx, attr_name in enumerate(extra_f_names): + features_extra[:, idx] = np.asarray(plydata.elements[0][attr_name]) + # Reshape (P,F*SH_coeffs) to (P, F, SH_coeffs except DC) + features_extra = features_extra.reshape((features_extra.shape[0], 3, (self.max_sh_degree + 1) ** 2 - 1)) + + scale_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("scale_")] + scale_names = sorted(scale_names, key = lambda x: int(x.split('_')[-1])) + scales = np.zeros((xyz.shape[0], len(scale_names))) + for idx, attr_name in enumerate(scale_names): + scales[:, idx] = np.asarray(plydata.elements[0][attr_name]) + + rot_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("rot")] + rot_names = sorted(rot_names, key = lambda x: int(x.split('_')[-1])) + rots = np.zeros((xyz.shape[0], len(rot_names))) + for idx, attr_name in enumerate(rot_names): + rots[:, idx] = np.asarray(plydata.elements[0][attr_name]) + + # convert to actual gaussian attributes + xyz = torch.tensor(xyz, dtype=torch.float, device=self.device) + features_dc = torch.tensor(features_dc, dtype=torch.float, device=self.device).transpose(1, 2).contiguous() + if self.sh_degree > 0: + features_extra = torch.tensor(features_extra, dtype=torch.float, device=self.device).transpose(1, 2).contiguous() + opacities = torch.sigmoid(torch.tensor(opacities, dtype=torch.float, device=self.device)) + scales = torch.exp(torch.tensor(scales, dtype=torch.float, device=self.device)) + rots = torch.tensor(rots, dtype=torch.float, device=self.device) + + # convert to _hidden attributes + self._xyz = (xyz - self.aabb[None, :3]) / self.aabb[None, 3:] + self._features_dc = features_dc + if self.sh_degree > 0: + self._features_rest = features_extra + else: + self._features_rest = None + self._opacity = self.inverse_opacity_activation(opacities) - self.opacity_bias + self._scaling = self.inverse_scaling_activation(torch.sqrt(torch.square(scales) - self.mininum_kernel_size ** 2)) - self.scale_bias + self._rotation = rots - self.rots_bias[None, :] + \ No newline at end of file diff --git a/trellis/representations/gaussian/general_utils.py b/trellis/representations/gaussian/general_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..541c0825229a2d86e84460b765879f86f724a59d --- /dev/null +++ b/trellis/representations/gaussian/general_utils.py @@ -0,0 +1,133 @@ +# +# Copyright (C) 2023, Inria +# GRAPHDECO research group, https://team.inria.fr/graphdeco +# All rights reserved. +# +# This software is free for non-commercial, research and evaluation use +# under the terms of the LICENSE.md file. +# +# For inquiries contact george.drettakis@inria.fr +# + +import torch +import sys +from datetime import datetime +import numpy as np +import random + +def inverse_sigmoid(x): + return torch.log(x/(1-x)) + +def PILtoTorch(pil_image, resolution): + resized_image_PIL = pil_image.resize(resolution) + resized_image = torch.from_numpy(np.array(resized_image_PIL)) / 255.0 + if len(resized_image.shape) == 3: + return resized_image.permute(2, 0, 1) + else: + return resized_image.unsqueeze(dim=-1).permute(2, 0, 1) + +def get_expon_lr_func( + lr_init, lr_final, lr_delay_steps=0, lr_delay_mult=1.0, max_steps=1000000 +): + """ + Copied from Plenoxels + + Continuous learning rate decay function. Adapted from JaxNeRF + The returned rate is lr_init when step=0 and lr_final when step=max_steps, and + is log-linearly interpolated elsewhere (equivalent to exponential decay). + If lr_delay_steps>0 then the learning rate will be scaled by some smooth + function of lr_delay_mult, such that the initial learning rate is + lr_init*lr_delay_mult at the beginning of optimization but will be eased back + to the normal learning rate when steps>lr_delay_steps. + :param conf: config subtree 'lr' or similar + :param max_steps: int, the number of steps during optimization. + :return HoF which takes step as input + """ + + def helper(step): + if step < 0 or (lr_init == 0.0 and lr_final == 0.0): + # Disable this parameter + return 0.0 + if lr_delay_steps > 0: + # A kind of reverse cosine decay. + delay_rate = lr_delay_mult + (1 - lr_delay_mult) * np.sin( + 0.5 * np.pi * np.clip(step / lr_delay_steps, 0, 1) + ) + else: + delay_rate = 1.0 + t = np.clip(step / max_steps, 0, 1) + log_lerp = np.exp(np.log(lr_init) * (1 - t) + np.log(lr_final) * t) + return delay_rate * log_lerp + + return helper + +def strip_lowerdiag(L): + uncertainty = torch.zeros((L.shape[0], 6), dtype=torch.float, device="cuda") + + uncertainty[:, 0] = L[:, 0, 0] + uncertainty[:, 1] = L[:, 0, 1] + uncertainty[:, 2] = L[:, 0, 2] + uncertainty[:, 3] = L[:, 1, 1] + uncertainty[:, 4] = L[:, 1, 2] + uncertainty[:, 5] = L[:, 2, 2] + return uncertainty + +def strip_symmetric(sym): + return strip_lowerdiag(sym) + +def build_rotation(r): + norm = torch.sqrt(r[:,0]*r[:,0] + r[:,1]*r[:,1] + r[:,2]*r[:,2] + r[:,3]*r[:,3]) + + q = r / norm[:, None] + + R = torch.zeros((q.size(0), 3, 3), device='cuda') + + r = q[:, 0] + x = q[:, 1] + y = q[:, 2] + z = q[:, 3] + + R[:, 0, 0] = 1 - 2 * (y*y + z*z) + R[:, 0, 1] = 2 * (x*y - r*z) + R[:, 0, 2] = 2 * (x*z + r*y) + R[:, 1, 0] = 2 * (x*y + r*z) + R[:, 1, 1] = 1 - 2 * (x*x + z*z) + R[:, 1, 2] = 2 * (y*z - r*x) + R[:, 2, 0] = 2 * (x*z - r*y) + R[:, 2, 1] = 2 * (y*z + r*x) + R[:, 2, 2] = 1 - 2 * (x*x + y*y) + return R + +def build_scaling_rotation(s, r): + L = torch.zeros((s.shape[0], 3, 3), dtype=torch.float, device="cuda") + R = build_rotation(r) + + L[:,0,0] = s[:,0] + L[:,1,1] = s[:,1] + L[:,2,2] = s[:,2] + + L = R @ L + return L + +def safe_state(silent): + old_f = sys.stdout + class F: + def __init__(self, silent): + self.silent = silent + + def write(self, x): + if not self.silent: + if x.endswith("\n"): + old_f.write(x.replace("\n", " [{}]\n".format(str(datetime.now().strftime("%d/%m %H:%M:%S"))))) + else: + old_f.write(x) + + def flush(self): + old_f.flush() + + sys.stdout = F(silent) + + random.seed(0) + np.random.seed(0) + torch.manual_seed(0) + torch.cuda.set_device(torch.device("cuda:0")) diff --git a/trellis/representations/mesh/__init__.py b/trellis/representations/mesh/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..bf6068e8bad549a7625d8de8b23ed1659814a655 --- /dev/null +++ b/trellis/representations/mesh/__init__.py @@ -0,0 +1,5 @@ +from .cube2mesh import SparseFeatures2Mesh, MeshExtractResult +try: + from .mc2mesh import SparseFeatures2MCMesh +except ImportError: + pass \ No newline at end of file diff --git a/trellis/representations/mesh/__pycache__/__init__.cpython-310.pyc b/trellis/representations/mesh/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..15addeb8a3d731f62e6cf2b68da2a8b5f3f952b9 Binary files /dev/null and b/trellis/representations/mesh/__pycache__/__init__.cpython-310.pyc differ diff --git a/trellis/representations/mesh/__pycache__/cube2mesh.cpython-310.pyc b/trellis/representations/mesh/__pycache__/cube2mesh.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fb331fd6fccc0ce8510ba2261134b29722dbe9af Binary files /dev/null and b/trellis/representations/mesh/__pycache__/cube2mesh.cpython-310.pyc differ diff --git a/trellis/representations/mesh/__pycache__/flexicube.cpython-310.pyc b/trellis/representations/mesh/__pycache__/flexicube.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..460812611dfbe0bef6900741bba44122cfd5d78a Binary files /dev/null and b/trellis/representations/mesh/__pycache__/flexicube.cpython-310.pyc differ diff --git a/trellis/representations/mesh/__pycache__/mc2mesh.cpython-310.pyc b/trellis/representations/mesh/__pycache__/mc2mesh.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..df1471dc56f41acb01e4ea455b920e80c23fbedd Binary files /dev/null and b/trellis/representations/mesh/__pycache__/mc2mesh.cpython-310.pyc differ diff --git a/trellis/representations/mesh/__pycache__/tables.cpython-310.pyc b/trellis/representations/mesh/__pycache__/tables.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1117f4450a5ea87b2a2a6240edb6c9563b4e9a30 Binary files /dev/null and b/trellis/representations/mesh/__pycache__/tables.cpython-310.pyc differ diff --git a/trellis/representations/mesh/__pycache__/utils_cube.cpython-310.pyc b/trellis/representations/mesh/__pycache__/utils_cube.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..89876dfe8a48829eb8f9d3b9fc00b9db2521ba24 Binary files /dev/null and b/trellis/representations/mesh/__pycache__/utils_cube.cpython-310.pyc differ diff --git a/trellis/representations/mesh/cube2mesh.py b/trellis/representations/mesh/cube2mesh.py new file mode 100644 index 0000000000000000000000000000000000000000..c420ddccc9f87aef3932c66bf24a4e1c196b578b --- /dev/null +++ b/trellis/representations/mesh/cube2mesh.py @@ -0,0 +1,268 @@ +# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. +import torch +from ...modules.sparse import SparseTensor +from easydict import EasyDict as edict +from .utils_cube import * +from .flexicube import FlexiCubes + +import torch +import trimesh +import numpy as np + +# Dependency for mesh cleaning and hole fix +# from ...utils.random_utils import sphere_hammersley_sequence +# import utils3d +# import igraph +# from tqdm import tqdm +# from pymeshfix import _meshfix + +class MeshExtractResult: + def __init__(self, + vertices, + faces, + vertex_attrs=None, + res=64 + ): + self.vertices = vertices + self.faces = faces.long() + self.vertex_attrs = vertex_attrs + self.face_normal = self.comput_face_normals(vertices, faces) + self.vertex_normal = self.comput_v_normals(vertices, faces) + self.res = res + self.success = (vertices.shape[0] != 0 and faces.shape[0] != 0) + + # training only + self.tsdf_v = None + self.tsdf_s = None + self.reg_loss = None + + def comput_face_normals(self, verts, faces): + i0 = faces[..., 0].long() + i1 = faces[..., 1].long() + i2 = faces[..., 2].long() + + v0 = verts[i0, :] + v1 = verts[i1, :] + v2 = verts[i2, :] + face_normals = torch.cross(v1 - v0, v2 - v0, dim=-1) + face_normals = torch.nn.functional.normalize(face_normals, dim=1) + return face_normals[:, None, :].repeat(1, 3, 1) + + def comput_v_normals(self, verts, faces): + i0 = faces[..., 0].long() + i1 = faces[..., 1].long() + i2 = faces[..., 2].long() + + v0 = verts[i0, :] + v1 = verts[i1, :] + v2 = verts[i2, :] + face_normals = torch.cross(v1 - v0, v2 - v0, dim=-1) + v_normals = torch.zeros_like(verts) + v_normals.scatter_add_(0, i0[..., None].repeat(1, 3), face_normals) + v_normals.scatter_add_(0, i1[..., None].repeat(1, 3), face_normals) + v_normals.scatter_add_(0, i2[..., None].repeat(1, 3), face_normals) + + v_normals = torch.nn.functional.normalize(v_normals, dim=1) + return v_normals + + def to_trimesh(self, transform_pose=False): + """ + Convert the mesh to a trimesh.Trimesh object. + Args: + transform_pose (bool): If True, transform the vertices to change coordinate system + Returns: + trimesh.Trimesh: The converted mesh + """ + # Convert vertices and faces to numpy arrays + vertices = self.vertices.detach().cpu().numpy() + faces = self.faces.detach().cpu().numpy() + + # Apply coordinate transformation if requested + if transform_pose: + transform_matrix = np.array([ + [1, 0, 0], + [0, 0, -1], + [0, 1, 0] + ]) + vertices = vertices @ transform_matrix + + # Also transform the normals if they exist + vertex_normals = self.vertex_normal.detach().cpu().numpy() @ transform_matrix + else: + vertex_normals = self.vertex_normal.detach().cpu().numpy() + + # Prepare vertex colors if they exist + vertex_colors = None + if self.vertex_attrs is not None: + vertex_colors = self.vertex_attrs[:, :3].detach().cpu().numpy() + # Ensure colors are in [0, 255] range + if vertex_colors.max() <= 1.0: + vertex_colors = (vertex_colors * 255).astype(np.uint8) + + # Create the trimesh mesh + mesh = trimesh.Trimesh( + vertices=vertices, + faces=faces, + vertex_colors=vertex_colors + ) + + return mesh + + @staticmethod + def from_trimesh(mesh, device='cuda'): + # Convert scene to mesh if necessary + if hasattr(mesh, 'geometry'): + # If it's a scene, get the first mesh + # Assuming the scene has at least one mesh + mesh_name = list(mesh.geometry.keys())[0] + mesh = mesh.geometry[mesh_name] + + vertices = torch.tensor(mesh.vertices, dtype=torch.float32) + faces = torch.tensor(mesh.faces, dtype=torch.int64) + + vertex_attrs = None + if mesh.visual.vertex_colors is not None: + vertex_attrs = torch.tensor(mesh.visual.vertex_colors, dtype=torch.float32) / 255.0 + print(vertex_attrs) + vertex_attrs = vertex_attrs[:, :3] + else: + vertex_attrs = torch.zeros((vertices.shape[0], 3), dtype=torch.float32) + return MeshExtractResult(vertices, faces, vertex_attrs) + + def to(self, device): + self.vertices = self.vertices.to(device) + self.faces = self.faces.to(device) + if self.vertex_attrs is not None: + self.vertex_attrs = self.vertex_attrs.to(device) + self.face_normal = self.face_normal.to(device) + self.vertex_normal = self.vertex_normal.to(device) + return self + + def subdivide(self): + """ + Subdivide the mesh by splitting each triangle into four smaller triangles. + """ + new_vertices = [] + new_faces = [] + vertex_map = {} + + def get_midpoint(v1, v2): + edge = tuple(sorted((v1, v2))) + if edge not in vertex_map: + midpoint = (self.vertices[v1] + self.vertices[v2]) / 2 + vertex_map[edge] = len(new_vertices) + new_vertices.append(midpoint) + return vertex_map[edge] + + for face in self.faces: + v0, v1, v2 = face + a = get_midpoint(v0.item(), v1.item()) + b = get_midpoint(v1.item(), v2.item()) + c = get_midpoint(v2.item(), v0.item()) + + new_faces.append([v0.item(), a, c]) + new_faces.append([v1.item(), b, a]) + new_faces.append([v2.item(), c, b]) + new_faces.append([a, b, c]) + + new_vertices = torch.stack(new_vertices) + new_faces = torch.tensor(new_faces, dtype=torch.long) + + self.vertices = torch.cat([self.vertices, new_vertices], dim=0) + self.faces = new_faces + self.face_normal = self.comput_face_normals(self.vertices, self.faces) + self.vertex_normal = self.comput_v_normals(self.vertices, self.faces) + self.vertex_attrs = torch.zeros((self.vertices.shape[0], 3), dtype=torch.float32) + +class SparseFeatures2Mesh: + def __init__(self, device="cuda", res=64, use_color=True): + ''' + a model to generate a mesh from sparse features structures using flexicube + ''' + super().__init__() + self.device=device + self.res = res + self.mesh_extractor = FlexiCubes(device=device) + self.sdf_bias = -1.0 / res + verts, cube = construct_dense_grid(self.res, self.device) + self.reg_c = cube.to(self.device) + self.reg_v = verts.to(self.device) + self.use_color = use_color + self._calc_layout() + + def _calc_layout(self): + LAYOUTS = { + 'sdf': {'shape': (8, 1), 'size': 8}, + 'deform': {'shape': (8, 3), 'size': 8 * 3}, + 'weights': {'shape': (21,), 'size': 21} + } + if self.use_color: + ''' + 6 channel color including normal map + ''' + LAYOUTS['color'] = {'shape': (8, 6,), 'size': 8 * 6} + self.layouts = edict(LAYOUTS) + start = 0 + for k, v in self.layouts.items(): + v['range'] = (start, start + v['size']) + start += v['size'] + self.feats_channels = start + + def get_layout(self, feats : torch.Tensor, name : str): + if name not in self.layouts: + return None + return feats[:, self.layouts[name]['range'][0]:self.layouts[name]['range'][1]].reshape(-1, *self.layouts[name]['shape']) + + def __call__(self, cubefeats : SparseTensor, training=False): + """ + Generates a mesh based on the specified sparse voxel structures. + Args: + cube_attrs [Nx21] : Sparse Tensor attrs about cube weights + verts_attrs [Nx10] : [0:1] SDF [1:4] deform [4:7] color [7:10] normal + Returns: + return the success tag and ni you loss, + """ + # add sdf bias to verts_attrs + coords = cubefeats.coords[:, 1:] + feats = cubefeats.feats + + sdf, deform, color, weights = [self.get_layout(feats, name) for name in ['sdf', 'deform', 'color', 'weights']] + sdf += self.sdf_bias + v_attrs = [sdf, deform, color] if self.use_color else [sdf, deform] + v_pos, v_attrs, reg_loss = sparse_cube2verts(coords, torch.cat(v_attrs, dim=-1), training=training) + v_attrs_d = get_dense_attrs(v_pos, v_attrs, res=self.res+1, sdf_init=True) + weights_d = get_dense_attrs(coords, weights, res=self.res, sdf_init=False) + if self.use_color: + sdf_d, deform_d, colors_d = v_attrs_d[..., 0], v_attrs_d[..., 1:4], v_attrs_d[..., 4:] + else: + sdf_d, deform_d = v_attrs_d[..., 0], v_attrs_d[..., 1:4] + colors_d = None + + x_nx3 = get_defomed_verts(self.reg_v, deform_d, self.res) + + vertices, faces, L_dev, colors = self.mesh_extractor( + voxelgrid_vertices=x_nx3, + scalar_field=sdf_d, + cube_idx=self.reg_c, + resolution=self.res, + beta=weights_d[:, :12], + alpha=weights_d[:, 12:20], + gamma_f=weights_d[:, 20], + voxelgrid_colors=colors_d, + training=training) + + mesh = MeshExtractResult(vertices=vertices, faces=faces, vertex_attrs=colors, res=self.res) + if training: + if mesh.success: + reg_loss += L_dev.mean() * 0.5 + reg_loss += (weights[:,:20]).abs().mean() * 0.2 + mesh.reg_loss = reg_loss + mesh.tsdf_v = get_defomed_verts(v_pos, v_attrs[:, 1:4], self.res) + mesh.tsdf_s = v_attrs[:, 0] + return mesh diff --git a/trellis/representations/mesh/flexicube.py b/trellis/representations/mesh/flexicube.py new file mode 100644 index 0000000000000000000000000000000000000000..c8b1177a7301e113a027826ded8cca01f64f22d8 --- /dev/null +++ b/trellis/representations/mesh/flexicube.py @@ -0,0 +1,362 @@ +# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. + +import torch +from .tables import * + +__all__ = [ + 'FlexiCubes' +] + + +class FlexiCubes: + def __init__(self, device="cuda"): + + self.device = device + self.dmc_table = torch.tensor(dmc_table, dtype=torch.long, device=device, requires_grad=False) + self.num_vd_table = torch.tensor(num_vd_table, + dtype=torch.long, device=device, requires_grad=False) + self.check_table = torch.tensor( + check_table, + dtype=torch.long, device=device, requires_grad=False) + + self.tet_table = torch.tensor(tet_table, dtype=torch.long, device=device, requires_grad=False) + self.quad_split_1 = torch.tensor([0, 1, 2, 0, 2, 3], dtype=torch.long, device=device, requires_grad=False) + self.quad_split_2 = torch.tensor([0, 1, 3, 3, 1, 2], dtype=torch.long, device=device, requires_grad=False) + self.quad_split_train = torch.tensor( + [0, 1, 1, 2, 2, 3, 3, 0], dtype=torch.long, device=device, requires_grad=False) + + self.cube_corners = torch.tensor([[0, 0, 0], [1, 0, 0], [0, 1, 0], [1, 1, 0], [0, 0, 1], [ + 1, 0, 1], [0, 1, 1], [1, 1, 1]], dtype=torch.float, device=device) + self.cube_corners_idx = torch.pow(2, torch.arange(8, requires_grad=False)) + self.cube_edges = torch.tensor([0, 1, 1, 5, 4, 5, 0, 4, 2, 3, 3, 7, 6, 7, 2, 6, + 2, 0, 3, 1, 7, 5, 6, 4], dtype=torch.long, device=device, requires_grad=False) + + self.edge_dir_table = torch.tensor([0, 2, 0, 2, 0, 2, 0, 2, 1, 1, 1, 1], + dtype=torch.long, device=device) + self.dir_faces_table = torch.tensor([ + [[5, 4], [3, 2], [4, 5], [2, 3]], + [[5, 4], [1, 0], [4, 5], [0, 1]], + [[3, 2], [1, 0], [2, 3], [0, 1]] + ], dtype=torch.long, device=device) + self.adj_pairs = torch.tensor([0, 1, 1, 3, 3, 2, 2, 0], dtype=torch.long, device=device) + + def __call__(self, voxelgrid_vertices, scalar_field, cube_idx, resolution, qef_reg_scale=1e-3, + weight_scale=0.99, beta=None, alpha=None, gamma_f=None, voxelgrid_colors=None, training=False): + num_vertices = voxelgrid_vertices.shape[0] + num_cubes = cube_idx.shape[0] + + surf_cubes, occ_fx8 = self._identify_surf_cubes(scalar_field, cube_idx) + if surf_cubes.sum() == 0: + return ( + torch.zeros((0, 3), device=self.device), + torch.zeros((0, 3), dtype=torch.long, device=self.device), + torch.zeros((0), device=self.device), + torch.zeros((0, voxelgrid_colors.shape[-1]), device=self.device) if voxelgrid_colors is not None else None + ) + beta, alpha, gamma_f = self._normalize_weights( + beta, alpha, gamma_f, surf_cubes, weight_scale) + + if voxelgrid_colors is not None: + voxelgrid_colors = torch.sigmoid(voxelgrid_colors) + + case_ids = self._get_case_id(occ_fx8, surf_cubes, resolution) + + surf_edges, idx_map, edge_counts, surf_edges_mask = self._identify_surf_edges( + scalar_field, cube_idx, surf_cubes + ) + + vd, L_dev, vd_gamma, vd_idx_map, vd_color = self._compute_vd( + voxelgrid_vertices, cube_idx[surf_cubes], surf_edges, scalar_field, + case_ids, beta, alpha, gamma_f, idx_map, qef_reg_scale, voxelgrid_colors) + vertices, faces, s_edges, edge_indices, vertices_color = self._triangulate( + scalar_field, surf_edges, vd, vd_gamma, edge_counts, idx_map, + vd_idx_map, surf_edges_mask, training, vd_color) + return vertices, faces, L_dev, vertices_color + + def _compute_reg_loss(self, vd, ue, edge_group_to_vd, vd_num_edges): + """ + Regularizer L_dev as in Equation 8 + """ + dist = torch.norm(ue - torch.index_select(input=vd, index=edge_group_to_vd, dim=0), dim=-1) + mean_l2 = torch.zeros_like(vd[:, 0]) + mean_l2 = (mean_l2).index_add_(0, edge_group_to_vd, dist) / vd_num_edges.squeeze(1).float() + mad = (dist - torch.index_select(input=mean_l2, index=edge_group_to_vd, dim=0)).abs() + return mad + + def _normalize_weights(self, beta, alpha, gamma_f, surf_cubes, weight_scale): + """ + Normalizes the given weights to be non-negative. If input weights are None, it creates and returns a set of weights of ones. + """ + n_cubes = surf_cubes.shape[0] + + if beta is not None: + beta = (torch.tanh(beta) * weight_scale + 1) + else: + beta = torch.ones((n_cubes, 12), dtype=torch.float, device=self.device) + + if alpha is not None: + alpha = (torch.tanh(alpha) * weight_scale + 1) + else: + alpha = torch.ones((n_cubes, 8), dtype=torch.float, device=self.device) + + if gamma_f is not None: + gamma_f = torch.sigmoid(gamma_f) * weight_scale + (1 - weight_scale) / 2 + else: + gamma_f = torch.ones((n_cubes), dtype=torch.float, device=self.device) + + return beta[surf_cubes], alpha[surf_cubes], gamma_f[surf_cubes] + + @torch.no_grad() + def _get_case_id(self, occ_fx8, surf_cubes, res): + """ + Obtains the ID of topology cases based on cell corner occupancy. This function resolves the + ambiguity in the Dual Marching Cubes (DMC) configurations as described in Section 1.3 of the + supplementary material. It should be noted that this function assumes a regular grid. + """ + case_ids = (occ_fx8[surf_cubes] * self.cube_corners_idx.to(self.device).unsqueeze(0)).sum(-1) + + problem_config = self.check_table.to(self.device)[case_ids] + to_check = problem_config[..., 0] == 1 + problem_config = problem_config[to_check] + if not isinstance(res, (list, tuple)): + res = [res, res, res] + + # The 'problematic_configs' only contain configurations for surface cubes. Next, we construct a 3D array, + # 'problem_config_full', to store configurations for all cubes (with default config for non-surface cubes). + # This allows efficient checking on adjacent cubes. + problem_config_full = torch.zeros(list(res) + [5], device=self.device, dtype=torch.long) + vol_idx = torch.nonzero(problem_config_full[..., 0] == 0) # N, 3 + vol_idx_problem = vol_idx[surf_cubes][to_check] + problem_config_full[vol_idx_problem[..., 0], vol_idx_problem[..., 1], vol_idx_problem[..., 2]] = problem_config + vol_idx_problem_adj = vol_idx_problem + problem_config[..., 1:4] + + within_range = ( + vol_idx_problem_adj[..., 0] >= 0) & ( + vol_idx_problem_adj[..., 0] < res[0]) & ( + vol_idx_problem_adj[..., 1] >= 0) & ( + vol_idx_problem_adj[..., 1] < res[1]) & ( + vol_idx_problem_adj[..., 2] >= 0) & ( + vol_idx_problem_adj[..., 2] < res[2]) + + vol_idx_problem = vol_idx_problem[within_range] + vol_idx_problem_adj = vol_idx_problem_adj[within_range] + problem_config = problem_config[within_range] + problem_config_adj = problem_config_full[vol_idx_problem_adj[..., 0], + vol_idx_problem_adj[..., 1], vol_idx_problem_adj[..., 2]] + # If two cubes with cases C16 and C19 share an ambiguous face, both cases are inverted. + to_invert = (problem_config_adj[..., 0] == 1) + idx = torch.arange(case_ids.shape[0], device=self.device)[to_check][within_range][to_invert] + case_ids.index_put_((idx,), problem_config[to_invert][..., -1]) + return case_ids + + @torch.no_grad() + def _identify_surf_edges(self, scalar_field, cube_idx, surf_cubes): + """ + Identifies grid edges that intersect with the underlying surface by checking for opposite signs. As each edge + can be shared by multiple cubes, this function also assigns a unique index to each surface-intersecting edge + and marks the cube edges with this index. + """ + occ_n = scalar_field < 0 + all_edges = cube_idx[surf_cubes][:, self.cube_edges].reshape(-1, 2) + unique_edges, _idx_map, counts = torch.unique(all_edges, dim=0, return_inverse=True, return_counts=True) + + unique_edges = unique_edges.long() + mask_edges = occ_n[unique_edges.reshape(-1)].reshape(-1, 2).sum(-1) == 1 + + surf_edges_mask = mask_edges[_idx_map] + counts = counts[_idx_map] + + mapping = torch.ones((unique_edges.shape[0]), dtype=torch.long, device=cube_idx.device) * -1 + mapping[mask_edges] = torch.arange(mask_edges.sum(), device=cube_idx.device) + # Shaped as [number of cubes x 12 edges per cube]. This is later used to map a cube edge to the unique index + # for a surface-intersecting edge. Non-surface-intersecting edges are marked with -1. + idx_map = mapping[_idx_map] + surf_edges = unique_edges[mask_edges] + return surf_edges, idx_map, counts, surf_edges_mask + + @torch.no_grad() + def _identify_surf_cubes(self, scalar_field, cube_idx): + """ + Identifies grid cubes that intersect with the underlying surface by checking if the signs at + all corners are not identical. + """ + occ_n = scalar_field < 0 + occ_fx8 = occ_n[cube_idx.reshape(-1)].reshape(-1, 8) + _occ_sum = torch.sum(occ_fx8, -1) + surf_cubes = (_occ_sum > 0) & (_occ_sum < 8) + return surf_cubes, occ_fx8 + + def _linear_interp(self, edges_weight, edges_x): + """ + Computes the location of zero-crossings on 'edges_x' using linear interpolation with 'edges_weight'. + """ + edge_dim = edges_weight.dim() - 2 + assert edges_weight.shape[edge_dim] == 2 + edges_weight = torch.cat([torch.index_select(input=edges_weight, index=torch.tensor(1, device=self.device), dim=edge_dim), - + torch.index_select(input=edges_weight, index=torch.tensor(0, device=self.device), dim=edge_dim)] + , edge_dim) + denominator = edges_weight.sum(edge_dim) + ue = (edges_x * edges_weight).sum(edge_dim) / denominator + return ue + + def _solve_vd_QEF(self, p_bxnx3, norm_bxnx3, c_bx3, qef_reg_scale): + p_bxnx3 = p_bxnx3.reshape(-1, 7, 3) + norm_bxnx3 = norm_bxnx3.reshape(-1, 7, 3) + c_bx3 = c_bx3.reshape(-1, 3) + A = norm_bxnx3 + B = ((p_bxnx3) * norm_bxnx3).sum(-1, keepdims=True) + + A_reg = (torch.eye(3, device=p_bxnx3.device) * qef_reg_scale).unsqueeze(0).repeat(p_bxnx3.shape[0], 1, 1) + B_reg = (qef_reg_scale * c_bx3).unsqueeze(-1) + A = torch.cat([A, A_reg], 1) + B = torch.cat([B, B_reg], 1) + dual_verts = torch.linalg.lstsq(A, B).solution.squeeze(-1) + return dual_verts + + def _compute_vd(self, voxelgrid_vertices, surf_cubes_fx8, surf_edges, scalar_field, + case_ids, beta, alpha, gamma_f, idx_map, qef_reg_scale, voxelgrid_colors): + """ + Computes the location of dual vertices as described in Section 4.2 + """ + alpha_nx12x2 = torch.index_select(input=alpha, index=self.cube_edges, dim=1).reshape(-1, 12, 2) + surf_edges_x = torch.index_select(input=voxelgrid_vertices, index=surf_edges.reshape(-1), dim=0).reshape(-1, 2, 3) + surf_edges_s = torch.index_select(input=scalar_field, index=surf_edges.reshape(-1), dim=0).reshape(-1, 2, 1) + zero_crossing = self._linear_interp(surf_edges_s, surf_edges_x) + + if voxelgrid_colors is not None: + C = voxelgrid_colors.shape[-1] + surf_edges_c = torch.index_select(input=voxelgrid_colors, index=surf_edges.reshape(-1), dim=0).reshape(-1, 2, C) + + idx_map = idx_map.reshape(-1, 12) + num_vd = torch.index_select(input=self.num_vd_table, index=case_ids, dim=0) + edge_group, edge_group_to_vd, edge_group_to_cube, vd_num_edges, vd_gamma = [], [], [], [], [] + + # if color is not None: + # vd_color = [] + + total_num_vd = 0 + vd_idx_map = torch.zeros((case_ids.shape[0], 12), dtype=torch.long, device=self.device, requires_grad=False) + + for num in torch.unique(num_vd): + cur_cubes = (num_vd == num) # consider cubes with the same numbers of vd emitted (for batching) + curr_num_vd = cur_cubes.sum() * num + curr_edge_group = self.dmc_table[case_ids[cur_cubes], :num].reshape(-1, num * 7) + curr_edge_group_to_vd = torch.arange( + curr_num_vd, device=self.device).unsqueeze(-1).repeat(1, 7) + total_num_vd + total_num_vd += curr_num_vd + curr_edge_group_to_cube = torch.arange(idx_map.shape[0], device=self.device)[ + cur_cubes].unsqueeze(-1).repeat(1, num * 7).reshape_as(curr_edge_group) + + curr_mask = (curr_edge_group != -1) + edge_group.append(torch.masked_select(curr_edge_group, curr_mask)) + edge_group_to_vd.append(torch.masked_select(curr_edge_group_to_vd.reshape_as(curr_edge_group), curr_mask)) + edge_group_to_cube.append(torch.masked_select(curr_edge_group_to_cube, curr_mask)) + vd_num_edges.append(curr_mask.reshape(-1, 7).sum(-1, keepdims=True)) + vd_gamma.append(torch.masked_select(gamma_f, cur_cubes).unsqueeze(-1).repeat(1, num).reshape(-1)) + # if color is not None: + # vd_color.append(color[cur_cubes].unsqueeze(1).repeat(1, num, 1).reshape(-1, 3)) + + edge_group = torch.cat(edge_group) + edge_group_to_vd = torch.cat(edge_group_to_vd) + edge_group_to_cube = torch.cat(edge_group_to_cube) + vd_num_edges = torch.cat(vd_num_edges) + vd_gamma = torch.cat(vd_gamma) + # if color is not None: + # vd_color = torch.cat(vd_color) + # else: + # vd_color = None + + vd = torch.zeros((total_num_vd, 3), device=self.device) + beta_sum = torch.zeros((total_num_vd, 1), device=self.device) + + idx_group = torch.gather(input=idx_map.reshape(-1), dim=0, index=edge_group_to_cube * 12 + edge_group) + + x_group = torch.index_select(input=surf_edges_x, index=idx_group.reshape(-1), dim=0).reshape(-1, 2, 3) + s_group = torch.index_select(input=surf_edges_s, index=idx_group.reshape(-1), dim=0).reshape(-1, 2, 1) + + + zero_crossing_group = torch.index_select( + input=zero_crossing, index=idx_group.reshape(-1), dim=0).reshape(-1, 3) + + alpha_group = torch.index_select(input=alpha_nx12x2.reshape(-1, 2), dim=0, + index=edge_group_to_cube * 12 + edge_group).reshape(-1, 2, 1) + ue_group = self._linear_interp(s_group * alpha_group, x_group) + + beta_group = torch.gather(input=beta.reshape(-1), dim=0, + index=edge_group_to_cube * 12 + edge_group).reshape(-1, 1) + beta_sum = beta_sum.index_add_(0, index=edge_group_to_vd, source=beta_group) + vd = vd.index_add_(0, index=edge_group_to_vd, source=ue_group * beta_group) / beta_sum + + ''' + interpolate colors use the same method as dual vertices + ''' + if voxelgrid_colors is not None: + vd_color = torch.zeros((total_num_vd, C), device=self.device) + c_group = torch.index_select(input=surf_edges_c, index=idx_group.reshape(-1), dim=0).reshape(-1, 2, C) + uc_group = self._linear_interp(s_group * alpha_group, c_group) + vd_color = vd_color.index_add_(0, index=edge_group_to_vd, source=uc_group * beta_group) / beta_sum + else: + vd_color = None + + L_dev = self._compute_reg_loss(vd, zero_crossing_group, edge_group_to_vd, vd_num_edges) + + v_idx = torch.arange(vd.shape[0], device=self.device) # + total_num_vd + + vd_idx_map = (vd_idx_map.reshape(-1)).scatter(dim=0, index=edge_group_to_cube * + 12 + edge_group, src=v_idx[edge_group_to_vd]) + + return vd, L_dev, vd_gamma, vd_idx_map, vd_color + + def _triangulate(self, scalar_field, surf_edges, vd, vd_gamma, edge_counts, idx_map, vd_idx_map, surf_edges_mask, training, vd_color): + """ + Connects four neighboring dual vertices to form a quadrilateral. The quadrilaterals are then split into + triangles based on the gamma parameter, as described in Section 4.3. + """ + with torch.no_grad(): + group_mask = (edge_counts == 4) & surf_edges_mask # surface edges shared by 4 cubes. + group = idx_map.reshape(-1)[group_mask] + vd_idx = vd_idx_map[group_mask] + edge_indices, indices = torch.sort(group, stable=True) + quad_vd_idx = vd_idx[indices].reshape(-1, 4) + + # Ensure all face directions point towards the positive SDF to maintain consistent winding. + s_edges = scalar_field[surf_edges[edge_indices.reshape(-1, 4)[:, 0]].reshape(-1)].reshape(-1, 2) + flip_mask = s_edges[:, 0] > 0 + quad_vd_idx = torch.cat((quad_vd_idx[flip_mask][:, [0, 1, 3, 2]], + quad_vd_idx[~flip_mask][:, [2, 3, 1, 0]])) + + quad_gamma = torch.index_select(input=vd_gamma, index=quad_vd_idx.reshape(-1), dim=0).reshape(-1, 4) + gamma_02 = quad_gamma[:, 0] * quad_gamma[:, 2] + gamma_13 = quad_gamma[:, 1] * quad_gamma[:, 3] + if not training: + mask = (gamma_02 > gamma_13) + faces = torch.zeros((quad_gamma.shape[0], 6), dtype=torch.long, device=quad_vd_idx.device) + faces[mask] = quad_vd_idx[mask][:, self.quad_split_1] + faces[~mask] = quad_vd_idx[~mask][:, self.quad_split_2] + faces = faces.reshape(-1, 3) + else: + vd_quad = torch.index_select(input=vd, index=quad_vd_idx.reshape(-1), dim=0).reshape(-1, 4, 3) + vd_02 = (vd_quad[:, 0] + vd_quad[:, 2]) / 2 + vd_13 = (vd_quad[:, 1] + vd_quad[:, 3]) / 2 + weight_sum = (gamma_02 + gamma_13) + 1e-8 + vd_center = (vd_02 * gamma_02.unsqueeze(-1) + vd_13 * gamma_13.unsqueeze(-1)) / weight_sum.unsqueeze(-1) + + if vd_color is not None: + color_quad = torch.index_select(input=vd_color, index=quad_vd_idx.reshape(-1), dim=0).reshape(-1, 4, vd_color.shape[-1]) + color_02 = (color_quad[:, 0] + color_quad[:, 2]) / 2 + color_13 = (color_quad[:, 1] + color_quad[:, 3]) / 2 + color_center = (color_02 * gamma_02.unsqueeze(-1) + color_13 * gamma_13.unsqueeze(-1)) / weight_sum.unsqueeze(-1) + vd_color = torch.cat([vd_color, color_center]) + + + vd_center_idx = torch.arange(vd_center.shape[0], device=self.device) + vd.shape[0] + vd = torch.cat([vd, vd_center]) + faces = quad_vd_idx[:, self.quad_split_train].reshape(-1, 4, 2) + faces = torch.cat([faces, vd_center_idx.reshape(-1, 1, 1).repeat(1, 4, 1)], -1).reshape(-1, 3) + return vd, faces, s_edges, edge_indices, vd_color diff --git a/trellis/representations/mesh/mc2mesh.py b/trellis/representations/mesh/mc2mesh.py new file mode 100644 index 0000000000000000000000000000000000000000..7f88be01f938f664eb35d9d62c1bd32846fca2e8 --- /dev/null +++ b/trellis/representations/mesh/mc2mesh.py @@ -0,0 +1,216 @@ +import torch +from easydict import EasyDict as edict +from typing import Tuple, Optional +from diso import DiffDMC +from .cube2mesh import MeshExtractResult +from .utils_cube import * +from ...modules.sparse import SparseTensor + +class EnhancedMarchingCubes: + def __init__(self, device="cuda"): + self.device = device + self.diffdmc = DiffDMC(dtype=torch.float32) + + def __call__(self, + voxelgrid_vertices: torch.Tensor, + scalar_field: torch.Tensor, + voxelgrid_colors: Optional[torch.Tensor] = None, + training: bool = False + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + """ + Enhanced Marching Cubes implementation using DiffDMC that handles deformations and colors + """ + if scalar_field.dim() == 1: + grid_size = int(round(scalar_field.shape[0] ** (1 / 3))) + scalar_field = scalar_field.reshape(grid_size, grid_size, grid_size) + elif scalar_field.dim() > 3: + scalar_field = scalar_field.squeeze() + + if scalar_field.dim() != 3: + raise ValueError(f"Expected 3D array, got shape {scalar_field.shape}") + + # Normalize coordinates for DiffDMC + scalar_field = scalar_field.to(self.device) + + # Get deformation field if provided + deform_field = None + if voxelgrid_vertices is not None: + if voxelgrid_vertices.dim() == 2: + grid_size = int(round(voxelgrid_vertices.shape[0] ** (1 / 3))) + voxelgrid_vertices = voxelgrid_vertices.reshape(grid_size, grid_size, grid_size, 3) + deform_field = voxelgrid_vertices.to(self.device) + + # Run DiffDMC + vertices, faces = self.diffdmc( + scalar_field, + deform_field, + isovalue=0.0 + ) + + # Handle colors if provided + colors = None + if voxelgrid_colors is not None: + voxelgrid_colors = torch.sigmoid(voxelgrid_colors) + if voxelgrid_colors.dim() == 2: + grid_size = int(round(voxelgrid_colors.shape[0] ** (1/3))) + voxelgrid_colors = voxelgrid_colors.reshape(grid_size, grid_size, grid_size, -1) + + grid_positions = vertices.clone() * grid_size + grid_coords = grid_positions.long() + local_coords = grid_positions - grid_coords.float() + + # Clamp coordinates to grid bounds + grid_coords = torch.clamp(grid_coords, 0, voxelgrid_colors.shape[0] - 1) + + # Trilinear interpolation for colors + colors = self._interpolate_color(grid_coords, local_coords, voxelgrid_colors) + + vertices = vertices * 2 - 1 # Normalize vertices to [-1, 1] + vertices /= 2.0 # Normalize vertices to [-0.5, 0.5] + + # Compute deviation loss for training + deviation_loss = torch.tensor(0.0, device=self.device) + if training and deform_field is not None: + # Compute deviation between original and deformed vertices + deviation_loss = self._compute_deviation_loss(vertices, deform_field) + + # faces = faces.flip(dims=[1]) # Maintain consistent face orientation + + return vertices, faces, deviation_loss, colors + + def _interpolate_color(self, grid_coords: torch.Tensor, + local_coords: torch.Tensor, + color_field: torch.Tensor) -> torch.Tensor: + """ + Interpolate colors using trilinear interpolation + Args: + grid_coords: (N, 3) integer grid coordinates + local_coords: (N, 3) fractional positions within grid cells + color_field: (res, res, res, C) color values + """ + x, y, z = local_coords[:, 0], local_coords[:, 1], local_coords[:, 2] + + # Get corner values for each vertex + c000 = color_field[grid_coords[:, 0], grid_coords[:, 1], grid_coords[:, 2]] + c001 = color_field[grid_coords[:, 0], grid_coords[:, 1], + torch.clamp(grid_coords[:, 2] + 1, 0, color_field.shape[2] - 1)] + c010 = color_field[grid_coords[:, 0], + torch.clamp(grid_coords[:, 1] + 1, 0, color_field.shape[1] - 1), + grid_coords[:, 2]] + c011 = color_field[grid_coords[:, 0], + torch.clamp(grid_coords[:, 1] + 1, 0, color_field.shape[1] - 1), + torch.clamp(grid_coords[:, 2] + 1, 0, color_field.shape[2] - 1)] + c100 = color_field[torch.clamp(grid_coords[:, 0] + 1, 0, color_field.shape[0] - 1), + grid_coords[:, 1], grid_coords[:, 2]] + c101 = color_field[torch.clamp(grid_coords[:, 0] + 1, 0, color_field.shape[0] - 1), + grid_coords[:, 1], + torch.clamp(grid_coords[:, 2] + 1, 0, color_field.shape[2] - 1)] + c110 = color_field[torch.clamp(grid_coords[:, 0] + 1, 0, color_field.shape[0] - 1), + torch.clamp(grid_coords[:, 1] + 1, 0, color_field.shape[1] - 1), + grid_coords[:, 2]] + c111 = color_field[torch.clamp(grid_coords[:, 0] + 1, 0, color_field.shape[0] - 1), + torch.clamp(grid_coords[:, 1] + 1, 0, color_field.shape[1] - 1), + torch.clamp(grid_coords[:, 2] + 1, 0, color_field.shape[2] - 1)] + + # Interpolate along x axis + c00 = c000 * (1 - x)[:, None] + c100 * x[:, None] + c01 = c001 * (1 - x)[:, None] + c101 * x[:, None] + c10 = c010 * (1 - x)[:, None] + c110 * x[:, None] + c11 = c011 * (1 - x)[:, None] + c111 * x[:, None] + + # Interpolate along y axis + c0 = c00 * (1 - y)[:, None] + c10 * y[:, None] + c1 = c01 * (1 - y)[:, None] + c11 * y[:, None] + + # Interpolate along z axis + colors = c0 * (1 - z)[:, None] + c1 * z[:, None] + + return colors + + def _compute_deviation_loss(self, vertices: torch.Tensor, + deform_field: torch.Tensor) -> torch.Tensor: + """Compute deviation loss for training""" + # Since DiffDMC already handles the deformation, we compute the loss + # based on the magnitude of the deformation field + return torch.mean(deform_field ** 2) + +class SparseFeatures2MCMesh: + def __init__(self, device="cuda", res=128, use_color=True): + super().__init__() + self.device = device + + self.res = res + + self.mesh_extractor = EnhancedMarchingCubes(device=device) + self.sdf_bias = -1.0 / res + verts, cube = construct_dense_grid(self.res, self.device) + self.reg_c = cube.to(self.device) + self.reg_v = verts.to(self.device) + self.use_color = use_color + self._calc_layout() + + def _calc_layout(self): + LAYOUTS = { + 'sdf': {'shape': (8, 1), 'size': 8}, + 'deform': {'shape': (8, 3), 'size': 8 * 3}, + 'weights': {'shape': (21,), 'size': 21} + } + if self.use_color: + ''' + 6 channel color including normal map + ''' + LAYOUTS['color'] = {'shape': (8, 6,), 'size': 8 * 6} + self.layouts = edict(LAYOUTS) + start = 0 + for k, v in self.layouts.items(): + v['range'] = (start, start + v['size']) + start += v['size'] + self.feats_channels = start + + def get_layout(self, feats: torch.Tensor, name: str): + if name not in self.layouts: + return None + return feats[:, self.layouts[name]['range'][0]:self.layouts[name]['range'][1]].reshape(-1, *self.layouts[name][ + 'shape']) + + def __call__(self, cubefeats: SparseTensor, training=False): + coords = cubefeats.coords[:, 1:] + feats = cubefeats.feats + + sdf, deform, color, weights = [self.get_layout(feats, name) + for name in ['sdf', 'deform', 'color', 'weights']] + sdf += self.sdf_bias + v_attrs = [sdf, deform, color] if self.use_color else [sdf, deform] + v_pos, v_attrs, reg_loss = sparse_cube2verts(coords, torch.cat(v_attrs, dim=-1), + training=training) + + v_attrs_d = get_dense_attrs(v_pos, v_attrs, res=self.res + 1, sdf_init=True) + + if self.use_color: + sdf_d, deform_d, colors_d = (v_attrs_d[..., 0], v_attrs_d[..., 1:4], + v_attrs_d[..., 4:]) + else: + sdf_d, deform_d = v_attrs_d[..., 0], v_attrs_d[..., 1:4] + colors_d = None + + x_nx3 = get_defomed_verts(self.reg_v, deform_d, self.res) + + vertices, faces, L_dev, colors = self.mesh_extractor( + voxelgrid_vertices=x_nx3, + scalar_field=sdf_d, + voxelgrid_colors=colors_d, + training=training + ) + + mesh = MeshExtractResult(vertices=vertices, faces=faces, + vertex_attrs=colors, res=self.res) + + if training: + if mesh.success: + reg_loss += L_dev.mean() * 0.5 + reg_loss += (weights[:, :20]).abs().mean() * 0.2 + mesh.reg_loss = reg_loss + mesh.tsdf_v = get_defomed_verts(v_pos, v_attrs[:, 1:4], self.res) + mesh.tsdf_s = v_attrs[:, 0] + + return mesh \ No newline at end of file diff --git a/trellis/representations/mesh/tables.py b/trellis/representations/mesh/tables.py new file mode 100644 index 0000000000000000000000000000000000000000..7c02dd7f4133aef487f623c02b11e3075cab0916 --- /dev/null +++ b/trellis/representations/mesh/tables.py @@ -0,0 +1,791 @@ +# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. +dmc_table = [ +[[-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 8, 9, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 4, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 9, -1, -1, -1, -1], [4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 4, 7, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 8, -1, -1, -1, -1], [4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 4, 5, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 4, 5, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[5, 7, 8, 9, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 5, 7, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 5, 7, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 5, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 8, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 9, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 8, 9, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 7, 8, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 4, 7, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 9, -1, -1, -1, -1], [4, 7, 8, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 4, 7, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 5, 9, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 8, 11, -1, -1, -1], [4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 4, 5, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 4, 5, 8, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[5, 7, 8, 9, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 5, 7, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 5, 7, 8, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 5, 7, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 8, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 9, 10, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 8, 9, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 7, 8, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 4, 7, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 9, 10, -1, -1, -1], [4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 4, 7, 9, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 5, 9, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 8, -1, -1, -1, -1], [4, 5, 9, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 4, 5, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 4, 5, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[5, 7, 8, 9, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 5, 7, 9, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 5, 7, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 5, 7, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 8, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 9, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[8, 9, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 7, 8, -1, -1, -1, -1], [1, 3, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 4, 7, 10, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 9, 10, 11, -1, -1], [4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 7, 9, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 5, 9, -1, -1, -1, -1], [1, 3, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 8, 10, 11, -1, -1], [4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 4, 5, 10, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 5, 8, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[5, 7, 8, 9, -1, -1, -1], [1, 3, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 5, 7, 9, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 5, 7, 8, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[5, 7, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 8, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 9, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 8, 9, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 6, 8, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 4, 6, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 9, -1, -1, -1, -1], [4, 6, 8, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 4, 6, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 5, 9, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 8, -1, -1, -1, -1], [4, 5, 9, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 4, 5, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 4, 5, 8, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[5, 6, 8, 9, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 5, 6, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 5, 6, 8, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 5, 6, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 6, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 6, 7, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 9, -1, -1, -1, -1], [2, 3, 6, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 6, 7, 8, 9, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 4, 6, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 4, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 9, -1, -1, -1, -1], [2, 3, 4, 6, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 4, 6, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 5, 9, -1, -1, -1, -1], [2, 3, 6, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 6, 7, 8, -1, -1], [4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 4, 5, -1, -1, -1], [2, 3, 6, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 4, 5, 6, 7, 8], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 5, 6, 8, 9, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 5, 6, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 2, 3, 5, 6, 8], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 5, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 10, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 8, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 9, 10, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 8, 9, 10, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 6, 8, 11, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 4, 6, 11, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 9, 10, -1, -1, -1], [4, 6, 8, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 4, 6, 9, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 5, 9, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 8, -1, -1, -1, -1], [4, 5, 9, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1]], +[[0, 2, 4, 5, 10, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 4, 5, 8, 10, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[5, 6, 8, 9, 11, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 5, 6, 9, 11, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 5, 6, 8, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 5, 6, 10, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 6, 7, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 6, 7, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 6, 7, 9, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[6, 7, 8, 9, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 4, 6, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 4, 6, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 4, 6, 8, 9, 10], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 6, 9, 10, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 5, 9, -1, -1, -1, -1], [1, 3, 6, 7, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 6, 7, 8, 10, -1], [4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 4, 5, 6, 7, 10], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 5, 6, 7, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 5, 6, 8, 9, 10], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 5, 6, 9, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 8, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 8, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 9, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 8, 9, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 7, 8, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 4, 7, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 9, -1, -1, -1, -1], [4, 7, 8, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 4, 7, 9, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 6, 9, 10, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 8, -1, -1, -1, -1], [4, 6, 9, 10, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 4, 6, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 4, 6, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[6, 7, 8, 9, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 6, 7, 9, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 6, 7, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 6, 7, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 11, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 8, 11, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 9, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 8, 9, 11, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 7, 8, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 4, 7, 11, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 9, -1, -1, -1, -1], [4, 7, 8, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1]], +[[1, 2, 4, 7, 9, 11, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 6, 9, 10, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 8, 11, -1, -1, -1], [4, 6, 9, 10, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 4, 6, 10, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 4, 6, 8, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[6, 7, 8, 9, 10, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 6, 7, 9, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 6, 7, 8, 10, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 6, 7, 10, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 5, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 8, -1, -1, -1, -1], [1, 2, 5, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 5, 6, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 5, 6, 8, 9, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 7, 8, -1, -1, -1, -1], [1, 2, 5, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 4, 7, -1, -1, -1], [1, 2, 5, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 5, 6, 9, -1, -1], [4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 4, 5, 6, 7, 9], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 4, 6, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 8, -1, -1, -1, -1], [1, 2, 4, 6, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 4, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 4, 6, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 6, 7, 8, 9, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 2, 3, 6, 7, 9], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 6, 7, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 6, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 5, 6, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 5, 6, 8, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 5, 6, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[5, 6, 8, 9, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 7, 8, -1, -1, -1, -1], [1, 3, 5, 6, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 4, 5, 6, 7, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 5, 6, 9, 11, -1], [4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 5, 6, 7, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 4, 6, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 4, 6, 8, 9, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 4, 6, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 6, 8, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 6, 7, 8, 9, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 9, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 6, 7, 8, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[5, 7, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 8, -1, -1, -1, -1], [5, 7, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 9, -1, -1, -1, -1], [5, 7, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 8, 9, -1, -1, -1], [5, 7, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 5, 8, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 4, 5, 10, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 9, -1, -1, -1, -1], [4, 5, 8, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 4, 5, 9, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 7, 9, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 8, -1, -1, -1, -1], [4, 7, 9, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 4, 7, 10, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 4, 7, 8, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[8, 9, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 9, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 8, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 5, 7, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 5, 7, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 9, -1, -1, -1, -1], [2, 3, 5, 7, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 5, 7, 8, 9, 10], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 4, 5, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 4, 5, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 9, -1, -1, -1, -1], [2, 3, 4, 5, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 4, 5, 9, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 4, 7, 9, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 4, 7, 8, 9, 10], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 2, 3, 4, 7, 10], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 7, 8, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 8, 9, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 9, 10, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 2, 3, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 5, 7, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 8, -1, -1, -1, -1], [1, 2, 5, 7, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 5, 7, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 5, 7, 8, 9, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 4, 5, 8, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 2, 3, 4, 5, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 4, 5, 8, 9, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 5, 9, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 4, 7, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 8, -1, -1, -1, -1], [1, 2, 4, 7, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 4, 7, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 4, 7, 8, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 8, 9, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 2, 3, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 8, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 5, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 5, 7, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 5, 7, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[5, 7, 8, 9, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 4, 5, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 4, 5, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 4, 5, 8, 9, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 4, 7, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 4, 7, 8, 9, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 4, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 8, 9, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]] +] +num_vd_table = [0, 1, 1, 1, 1, 1, 2, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 2, 1, 3, 1, 2, 2, +2, 1, 2, 1, 2, 1, 1, 2, 1, 1, 2, 2, 2, 1, 2, 3, 1, 1, 2, 2, 1, 1, 1, 1, 1, 1, 2, +1, 2, 1, 2, 2, 1, 1, 2, 1, 1, 1, 1, 2, 2, 2, 1, 1, 2, 1, 2, 3, 2, 2, 1, 1, 1, 1, +1, 1, 2, 1, 1, 1, 2, 1, 2, 2, 2, 1, 1, 1, 1, 1, 2, 3, 2, 2, 2, 2, 2, 1, 3, 4, 2, +2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 1, 1, 1, 1, 2, 1, 1, 2, 2, 2, 2, 2, +3, 2, 1, 2, 1, 1, 1, 1, 1, 1, 2, 2, 3, 2, 3, 2, 4, 2, 2, 2, 2, 1, 2, 1, 2, 1, 1, +2, 1, 1, 2, 2, 2, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 2, 1, 1, 1, 1, 1, +1, 2, 1, 1, 1, 2, 2, 2, 1, 1, 2, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 2, +1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 2, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, +1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0] +check_table = [ +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, 1, 0, 0, 194], +[1, -1, 0, 0, 193], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, 0, 1, 0, 164], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, 0, -1, 0, 161], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, 0, 0, 1, 152], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, 0, 0, 1, 145], +[1, 0, 0, 1, 144], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, 0, 0, -1, 137], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, 0, 1, 0, 133], +[1, 0, 1, 0, 132], +[1, 1, 0, 0, 131], +[1, 1, 0, 0, 130], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, 0, 0, 1, 100], +[0, 0, 0, 0, 0], +[1, 0, 0, 1, 98], +[0, 0, 0, 0, 0], +[1, 0, 0, 1, 96], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, 0, 1, 0, 88], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, 0, -1, 0, 82], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, 0, 1, 0, 74], +[0, 0, 0, 0, 0], +[1, 0, 1, 0, 72], +[0, 0, 0, 0, 0], +[1, 0, 0, -1, 70], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, -1, 0, 0, 67], +[0, 0, 0, 0, 0], +[1, -1, 0, 0, 65], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, 1, 0, 0, 56], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, -1, 0, 0, 52], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, 1, 0, 0, 44], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, 1, 0, 0, 40], +[0, 0, 0, 0, 0], +[1, 0, 0, -1, 38], +[1, 0, -1, 0, 37], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, 0, -1, 0, 33], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, -1, 0, 0, 28], +[0, 0, 0, 0, 0], +[1, 0, -1, 0, 26], +[1, 0, 0, -1, 25], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, -1, 0, 0, 20], +[0, 0, 0, 0, 0], +[1, 0, -1, 0, 18], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, 0, 0, -1, 9], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, 0, 0, -1, 6], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0] +] +tet_table = [ +[-1, -1, -1, -1, -1, -1], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[1, 1, 1, 1, 1, 1], +[4, 4, 4, 4, 4, 4], +[0, 0, 0, 0, 0, 0], +[4, 0, 0, 4, 4, -1], +[1, 1, 1, 1, 1, 1], +[4, 4, 4, 4, 4, 4], +[0, 4, 0, 4, 4, -1], +[0, 0, 0, 0, 0, 0], +[1, 1, 1, 1, 1, 1], +[5, 5, 5, 5, 5, 5], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[1, 1, 1, 1, 1, 1], +[2, 2, 2, 2, 2, 2], +[0, 0, 0, 0, 0, 0], +[2, 0, 2, -1, 0, 2], +[1, 1, 1, 1, 1, 1], +[2, -1, 2, 4, 4, 2], +[0, 0, 0, 0, 0, 0], +[2, 0, 2, 4, 4, 2], +[1, 1, 1, 1, 1, 1], +[2, 4, 2, 4, 4, 2], +[0, 4, 0, 4, 4, 0], +[2, 0, 2, 0, 0, 2], +[1, 1, 1, 1, 1, 1], +[2, 5, 2, 5, 5, 2], +[0, 0, 0, 0, 0, 0], +[2, 0, 2, 0, 0, 2], +[1, 1, 1, 1, 1, 1], +[1, 1, 1, 1, 1, 1], +[0, 1, 1, -1, 0, 1], +[0, 0, 0, 0, 0, 0], +[2, 2, 2, 2, 2, 2], +[4, 1, 1, 4, 4, 1], +[0, 1, 1, 0, 0, 1], +[4, 0, 0, 4, 4, 0], +[2, 2, 2, 2, 2, 2], +[-1, 1, 1, 4, 4, 1], +[0, 1, 1, 4, 4, 1], +[0, 0, 0, 0, 0, 0], +[2, 2, 2, 2, 2, 2], +[5, 1, 1, 5, 5, 1], +[0, 1, 1, 0, 0, 1], +[0, 0, 0, 0, 0, 0], +[2, 2, 2, 2, 2, 2], +[1, 1, 1, 1, 1, 1], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[8, 8, 8, 8, 8, 8], +[1, 1, 1, 4, 4, 1], +[0, 0, 0, 0, 0, 0], +[4, 0, 0, 4, 4, 0], +[4, 4, 4, 4, 4, 4], +[1, 1, 1, 4, 4, 1], +[0, 4, 0, 4, 4, 0], +[0, 0, 0, 0, 0, 0], +[4, 4, 4, 4, 4, 4], +[1, 1, 1, 5, 5, 1], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[5, 5, 5, 5, 5, 5], +[6, 6, 6, 6, 6, 6], +[6, -1, 0, 6, 0, 6], +[6, 0, 0, 6, 0, 6], +[6, 1, 1, 6, 1, 6], +[4, 4, 4, 4, 4, 4], +[0, 0, 0, 0, 0, 0], +[4, 0, 0, 4, 4, 4], +[1, 1, 1, 1, 1, 1], +[6, 4, -1, 6, 4, 6], +[6, 4, 0, 6, 4, 6], +[6, 0, 0, 6, 0, 6], +[6, 1, 1, 6, 1, 6], +[5, 5, 5, 5, 5, 5], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[1, 1, 1, 1, 1, 1], +[2, 2, 2, 2, 2, 2], +[0, 0, 0, 0, 0, 0], +[2, 0, 2, 2, 0, 2], +[1, 1, 1, 1, 1, 1], +[2, 2, 2, 2, 2, 2], +[0, 0, 0, 0, 0, 0], +[2, 0, 2, 2, 2, 2], +[1, 1, 1, 1, 1, 1], +[2, 4, 2, 2, 4, 2], +[0, 4, 0, 4, 4, 0], +[2, 0, 2, 2, 0, 2], +[1, 1, 1, 1, 1, 1], +[2, 2, 2, 2, 2, 2], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[1, 1, 1, 1, 1, 1], +[6, 1, 1, 6, -1, 6], +[6, 1, 1, 6, 0, 6], +[6, 0, 0, 6, 0, 6], +[6, 2, 2, 6, 2, 6], +[4, 1, 1, 4, 4, 1], +[0, 1, 1, 0, 0, 1], +[4, 0, 0, 4, 4, 4], +[2, 2, 2, 2, 2, 2], +[6, 1, 1, 6, 4, 6], +[6, 1, 1, 6, 4, 6], +[6, 0, 0, 6, 0, 6], +[6, 2, 2, 6, 2, 6], +[5, 1, 1, 5, 5, 1], +[0, 1, 1, 0, 0, 1], +[0, 0, 0, 0, 0, 0], +[2, 2, 2, 2, 2, 2], +[1, 1, 1, 1, 1, 1], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[6, 6, 6, 6, 6, 6], +[1, 1, 1, 1, 1, 1], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[4, 4, 4, 4, 4, 4], +[1, 1, 1, 1, 4, 1], +[0, 4, 0, 4, 4, 0], +[0, 0, 0, 0, 0, 0], +[4, 4, 4, 4, 4, 4], +[1, 1, 1, 1, 1, 1], +[0, 0, 0, 0, 0, 0], +[0, 5, 0, 5, 0, 5], +[5, 5, 5, 5, 5, 5], +[5, 5, 5, 5, 5, 5], +[0, 5, 0, 5, 0, 5], +[-1, 5, 0, 5, 0, 5], +[1, 5, 1, 5, 1, 5], +[4, 5, -1, 5, 4, 5], +[0, 5, 0, 5, 0, 5], +[4, 5, 0, 5, 4, 5], +[1, 5, 1, 5, 1, 5], +[4, 4, 4, 4, 4, 4], +[0, 4, 0, 4, 4, 4], +[0, 0, 0, 0, 0, 0], +[1, 1, 1, 1, 1, 1], +[6, 6, 6, 6, 6, 6], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[1, 1, 1, 1, 1, 1], +[2, 5, 2, 5, -1, 5], +[0, 5, 0, 5, 0, 5], +[2, 5, 2, 5, 0, 5], +[1, 5, 1, 5, 1, 5], +[2, 5, 2, 5, 4, 5], +[0, 5, 0, 5, 0, 5], +[2, 5, 2, 5, 4, 5], +[1, 5, 1, 5, 1, 5], +[2, 4, 2, 4, 4, 2], +[0, 4, 0, 4, 4, 4], +[2, 0, 2, 0, 0, 2], +[1, 1, 1, 1, 1, 1], +[2, 6, 2, 6, 6, 2], +[0, 0, 0, 0, 0, 0], +[2, 0, 2, 0, 0, 2], +[1, 1, 1, 1, 1, 1], +[1, 1, 1, 1, 1, 1], +[0, 1, 1, 1, 0, 1], +[0, 0, 0, 0, 0, 0], +[2, 2, 2, 2, 2, 2], +[4, 1, 1, 1, 4, 1], +[0, 1, 1, 1, 0, 1], +[4, 0, 0, 4, 4, 0], +[2, 2, 2, 2, 2, 2], +[1, 1, 1, 1, 1, 1], +[0, 1, 1, 1, 1, 1], +[0, 0, 0, 0, 0, 0], +[2, 2, 2, 2, 2, 2], +[1, 1, 1, 1, 1, 1], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[2, 2, 2, 2, 2, 2], +[1, 1, 1, 1, 1, 1], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[5, 5, 5, 5, 5, 5], +[1, 1, 1, 1, 4, 1], +[0, 0, 0, 0, 0, 0], +[4, 0, 0, 4, 4, 0], +[4, 4, 4, 4, 4, 4], +[1, 1, 1, 1, 1, 1], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[4, 4, 4, 4, 4, 4], +[1, 1, 1, 1, 1, 1], +[6, 0, 0, 6, 0, 6], +[0, 0, 0, 0, 0, 0], +[6, 6, 6, 6, 6, 6], +[5, 5, 5, 5, 5, 5], +[5, 5, 0, 5, 0, 5], +[5, 5, 0, 5, 0, 5], +[5, 5, 1, 5, 1, 5], +[4, 4, 4, 4, 4, 4], +[0, 0, 0, 0, 0, 0], +[4, 4, 0, 4, 4, 4], +[1, 1, 1, 1, 1, 1], +[4, 4, 4, 4, 4, 4], +[4, 4, 0, 4, 4, 4], +[0, 0, 0, 0, 0, 0], +[1, 1, 1, 1, 1, 1], +[8, 8, 8, 8, 8, 8], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[1, 1, 1, 1, 1, 1], +[2, 2, 2, 2, 2, 2], +[0, 0, 0, 0, 0, 0], +[2, 2, 2, 2, 0, 2], +[1, 1, 1, 1, 1, 1], +[2, 2, 2, 2, 2, 2], +[0, 0, 0, 0, 0, 0], +[2, 2, 2, 2, 2, 2], +[1, 1, 1, 1, 1, 1], +[2, 2, 2, 2, 2, 2], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[4, 1, 1, 4, 4, 1], +[2, 2, 2, 2, 2, 2], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[1, 1, 1, 1, 1, 1], +[1, 1, 1, 1, 1, 1], +[1, 1, 1, 1, 0, 1], +[0, 0, 0, 0, 0, 0], +[2, 2, 2, 2, 2, 2], +[1, 1, 1, 1, 1, 1], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[2, 4, 2, 4, 4, 2], +[1, 1, 1, 1, 1, 1], +[1, 1, 1, 1, 1, 1], +[0, 0, 0, 0, 0, 0], +[2, 2, 2, 2, 2, 2], +[1, 1, 1, 1, 1, 1], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[2, 2, 2, 2, 2, 2], +[1, 1, 1, 1, 1, 1], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[5, 5, 5, 5, 5, 5], +[1, 1, 1, 1, 1, 1], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[4, 4, 4, 4, 4, 4], +[1, 1, 1, 1, 1, 1], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[4, 4, 4, 4, 4, 4], +[1, 1, 1, 1, 1, 1], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[12, 12, 12, 12, 12, 12] +] \ No newline at end of file diff --git a/trellis/representations/mesh/utils_cube.py b/trellis/representations/mesh/utils_cube.py new file mode 100644 index 0000000000000000000000000000000000000000..23913c97bb2d57dfa0384667c69f9860ea0a4155 --- /dev/null +++ b/trellis/representations/mesh/utils_cube.py @@ -0,0 +1,61 @@ +import torch +cube_corners = torch.tensor([[0, 0, 0], [1, 0, 0], [0, 1, 0], [1, 1, 0], [0, 0, 1], [ + 1, 0, 1], [0, 1, 1], [1, 1, 1]], dtype=torch.int) +cube_neighbor = torch.tensor([[1, 0, 0], [-1, 0, 0], [0, 1, 0], [0, -1, 0], [0, 0, 1], [0, 0, -1]]) +cube_edges = torch.tensor([0, 1, 1, 5, 4, 5, 0, 4, 2, 3, 3, 7, 6, 7, 2, 6, + 2, 0, 3, 1, 7, 5, 6, 4], dtype=torch.long, requires_grad=False) + +def construct_dense_grid(res, device='cuda'): + '''construct a dense grid based on resolution''' + res_v = res + 1 + vertsid = torch.arange(res_v ** 3, device=device) + coordsid = vertsid.reshape(res_v, res_v, res_v)[:res, :res, :res].flatten() + cube_corners_bias = (cube_corners[:, 0] * res_v + cube_corners[:, 1]) * res_v + cube_corners[:, 2] + cube_fx8 = (coordsid.unsqueeze(1) + cube_corners_bias.unsqueeze(0).to(device)) + verts = torch.stack([vertsid // (res_v ** 2), (vertsid // res_v) % res_v, vertsid % res_v], dim=1) + return verts, cube_fx8 + + +def construct_voxel_grid(coords): + verts = (cube_corners.unsqueeze(0).to(coords) + coords.unsqueeze(1)).reshape(-1, 3) + verts_unique, inverse_indices = torch.unique(verts, dim=0, return_inverse=True) + cubes = inverse_indices.reshape(-1, 8) + return verts_unique, cubes + + +def cubes_to_verts(num_verts, cubes, value, reduce='mean'): + """ + Args: + cubes [Vx8] verts index for each cube + value [Vx8xM] value to be scattered + Operation: + reduced[cubes[i][j]][k] += value[i][k] + """ + M = value.shape[2] # number of channels + reduced = torch.zeros(num_verts, M, device=cubes.device) + return torch.scatter_reduce(reduced, 0, + cubes.unsqueeze(-1).expand(-1, -1, M).flatten(0, 1), + value.flatten(0, 1), reduce=reduce, include_self=False) + +def sparse_cube2verts(coords, feats, training=True): + new_coords, cubes = construct_voxel_grid(coords) + new_feats = cubes_to_verts(new_coords.shape[0], cubes, feats) + if training: + con_loss = torch.mean((feats - new_feats[cubes]) ** 2) + else: + con_loss = 0.0 + return new_coords, new_feats, con_loss + + +def get_dense_attrs(coords : torch.Tensor, feats : torch.Tensor, res : int, sdf_init=True): + F = feats.shape[-1] + dense_attrs = torch.zeros([res] * 3 + [F], device=feats.device) + if sdf_init: + dense_attrs[..., 0] = 1 # initial outside sdf value + dense_attrs[coords[:, 0], coords[:, 1], coords[:, 2], :] = feats + return dense_attrs.reshape(-1, F) + + +def get_defomed_verts(v_pos : torch.Tensor, deform : torch.Tensor, res): + return v_pos / res - 0.5 + (1 - 1e-8) / (res * 2) * torch.tanh(deform) + \ No newline at end of file diff --git a/trellis/representations/octree/__init__.py b/trellis/representations/octree/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f66a39a5a7498e2e99fe9d94d663796b3bc157b5 --- /dev/null +++ b/trellis/representations/octree/__init__.py @@ -0,0 +1 @@ +from .octree_dfs import DfsOctree \ No newline at end of file diff --git a/trellis/representations/octree/__pycache__/__init__.cpython-310.pyc b/trellis/representations/octree/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6aae217a4c1ccdf3ffb8ce5cae03d4818363aff7 Binary files /dev/null and b/trellis/representations/octree/__pycache__/__init__.cpython-310.pyc differ diff --git a/trellis/representations/octree/__pycache__/octree_dfs.cpython-310.pyc b/trellis/representations/octree/__pycache__/octree_dfs.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c9d902ff45e85e9f26a4362e4f44bf7f0a6ed18b Binary files /dev/null and b/trellis/representations/octree/__pycache__/octree_dfs.cpython-310.pyc differ diff --git a/trellis/representations/octree/octree_dfs.py b/trellis/representations/octree/octree_dfs.py new file mode 100644 index 0000000000000000000000000000000000000000..648f90299a7121b8fc65fd8e282ff90ff8191765 --- /dev/null +++ b/trellis/representations/octree/octree_dfs.py @@ -0,0 +1,362 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +DEFAULT_TRIVEC_CONFIG = { + 'dim': 8, + 'rank': 8, +} + +DEFAULT_VOXEL_CONFIG = { + 'solid': False, +} + +DEFAULT_DECOPOLY_CONFIG = { + 'degree': 8, + 'rank': 16, +} + + +class DfsOctree: + """ + Sparse Voxel Octree (SVO) implementation for PyTorch. + Using Depth-First Search (DFS) order to store the octree. + DFS order suits rendering and ray tracing. + + The structure and data are separatedly stored. + Structure is stored as a continuous array, each element is a 3*32 bits descriptor. + |-----------------------------------------| + | 0:3 bits | 4:31 bits | + | leaf num | unused | + |-----------------------------------------| + | 0:31 bits | + | child ptr | + |-----------------------------------------| + | 0:31 bits | + | data ptr | + |-----------------------------------------| + Each element represents a non-leaf node in the octree. + The valid mask is used to indicate whether the children are valid. + The leaf mask is used to indicate whether the children are leaf nodes. + The child ptr is used to point to the first non-leaf child. Non-leaf children descriptors are stored continuously from the child ptr. + The data ptr is used to point to the data of leaf children. Leaf children data are stored continuously from the data ptr. + + There are also auxiliary arrays to store the additional structural information to facilitate parallel processing. + - Position: the position of the octree nodes. + - Depth: the depth of the octree nodes. + + Args: + depth (int): the depth of the octree. + """ + + def __init__( + self, + depth, + aabb=[0,0,0,1,1,1], + sh_degree=2, + primitive='voxel', + primitive_config={}, + device='cuda', + ): + self.max_depth = depth + self.aabb = torch.tensor(aabb, dtype=torch.float32, device=device) + self.device = device + self.sh_degree = sh_degree + self.active_sh_degree = sh_degree + self.primitive = primitive + self.primitive_config = primitive_config + + self.structure = torch.tensor([[8, 1, 0]], dtype=torch.int32, device=self.device) + self.position = torch.zeros((8, 3), dtype=torch.float32, device=self.device) + self.depth = torch.zeros((8, 1), dtype=torch.uint8, device=self.device) + self.position[:, 0] = torch.tensor([0.25, 0.75, 0.25, 0.75, 0.25, 0.75, 0.25, 0.75], device=self.device) + self.position[:, 1] = torch.tensor([0.25, 0.25, 0.75, 0.75, 0.25, 0.25, 0.75, 0.75], device=self.device) + self.position[:, 2] = torch.tensor([0.25, 0.25, 0.25, 0.25, 0.75, 0.75, 0.75, 0.75], device=self.device) + self.depth[:, 0] = 1 + + self.data = ['position', 'depth'] + self.param_names = [] + + if primitive == 'voxel': + self.features_dc = torch.zeros((8, 1, 3), dtype=torch.float32, device=self.device) + self.features_ac = torch.zeros((8, (sh_degree+1)**2-1, 3), dtype=torch.float32, device=self.device) + self.data += ['features_dc', 'features_ac'] + self.param_names += ['features_dc', 'features_ac'] + if not primitive_config.get('solid', False): + self.density = torch.zeros((8, 1), dtype=torch.float32, device=self.device) + self.data.append('density') + self.param_names.append('density') + elif primitive == 'gaussian': + self.features_dc = torch.zeros((8, 1, 3), dtype=torch.float32, device=self.device) + self.features_ac = torch.zeros((8, (sh_degree+1)**2-1, 3), dtype=torch.float32, device=self.device) + self.opacity = torch.zeros((8, 1), dtype=torch.float32, device=self.device) + self.data += ['features_dc', 'features_ac', 'opacity'] + self.param_names += ['features_dc', 'features_ac', 'opacity'] + elif primitive == 'trivec': + self.trivec = torch.zeros((8, primitive_config['rank'], 3, primitive_config['dim']), dtype=torch.float32, device=self.device) + self.density = torch.zeros((8, primitive_config['rank']), dtype=torch.float32, device=self.device) + self.features_dc = torch.zeros((8, primitive_config['rank'], 1, 3), dtype=torch.float32, device=self.device) + self.features_ac = torch.zeros((8, primitive_config['rank'], (sh_degree+1)**2-1, 3), dtype=torch.float32, device=self.device) + self.density_shift = 0 + self.data += ['trivec', 'density', 'features_dc', 'features_ac'] + self.param_names += ['trivec', 'density', 'features_dc', 'features_ac'] + elif primitive == 'decoupoly': + self.decoupoly_V = torch.zeros((8, primitive_config['rank'], 3), dtype=torch.float32, device=self.device) + self.decoupoly_g = torch.zeros((8, primitive_config['rank'], primitive_config['degree']), dtype=torch.float32, device=self.device) + self.density = torch.zeros((8, primitive_config['rank']), dtype=torch.float32, device=self.device) + self.features_dc = torch.zeros((8, primitive_config['rank'], 1, 3), dtype=torch.float32, device=self.device) + self.features_ac = torch.zeros((8, primitive_config['rank'], (sh_degree+1)**2-1, 3), dtype=torch.float32, device=self.device) + self.density_shift = 0 + self.data += ['decoupoly_V', 'decoupoly_g', 'density', 'features_dc', 'features_ac'] + self.param_names += ['decoupoly_V', 'decoupoly_g', 'density', 'features_dc', 'features_ac'] + + self.setup_functions() + + def setup_functions(self): + self.density_activation = (lambda x: torch.exp(x - 2)) if self.primitive != 'trivec' else (lambda x: x) + self.opacity_activation = lambda x: torch.sigmoid(x - 6) + self.inverse_opacity_activation = lambda x: torch.log(x / (1 - x)) + 6 + self.color_activation = lambda x: torch.sigmoid(x) + + @property + def num_non_leaf_nodes(self): + return self.structure.shape[0] + + @property + def num_leaf_nodes(self): + return self.depth.shape[0] + + @property + def cur_depth(self): + return self.depth.max().item() + + @property + def occupancy(self): + return self.num_leaf_nodes / 8 ** self.cur_depth + + @property + def get_xyz(self): + return self.position + + @property + def get_depth(self): + return self.depth + + @property + def get_density(self): + if self.primitive == 'voxel' and self.primitive_config['solid']: + return torch.full((self.position.shape[0], 1), 1000, dtype=torch.float32, device=self.device) + return self.density_activation(self.density) + + @property + def get_opacity(self): + return self.opacity_activation(self.density) + + @property + def get_trivec(self): + return self.trivec + + @property + def get_decoupoly(self): + return F.normalize(self.decoupoly_V, dim=-1), self.decoupoly_g + + @property + def get_color(self): + return self.color_activation(self.colors) + + @property + def get_features(self): + if self.sh_degree == 0: + return self.features_dc + return torch.cat([self.features_dc, self.features_ac], dim=-2) + + def state_dict(self): + ret = {'structure': self.structure, 'position': self.position, 'depth': self.depth, 'sh_degree': self.sh_degree, 'active_sh_degree': self.active_sh_degree, 'trivec_config': self.trivec_config, 'voxel_config': self.voxel_config, 'primitive': self.primitive} + if hasattr(self, 'density_shift'): + ret['density_shift'] = self.density_shift + for data in set(self.data + self.param_names): + if not isinstance(getattr(self, data), nn.Module): + ret[data] = getattr(self, data) + else: + ret[data] = getattr(self, data).state_dict() + return ret + + def load_state_dict(self, state_dict): + keys = list(set(self.data + self.param_names + list(state_dict.keys()) + ['structure', 'position', 'depth'])) + for key in keys: + if key not in state_dict: + print(f"Warning: key {key} not found in the state_dict.") + continue + try: + if not isinstance(getattr(self, key), nn.Module): + setattr(self, key, state_dict[key]) + else: + getattr(self, key).load_state_dict(state_dict[key]) + except Exception as e: + print(e) + raise ValueError(f"Error loading key {key}.") + + def gather_from_leaf_children(self, data): + """ + Gather the data from the leaf children. + + Args: + data (torch.Tensor): the data to gather. The first dimension should be the number of leaf nodes. + """ + leaf_cnt = self.structure[:, 0] + leaf_cnt_masks = [leaf_cnt == i for i in range(1, 9)] + ret = torch.zeros((self.num_non_leaf_nodes,), dtype=data.dtype, device=self.device) + for i in range(8): + if leaf_cnt_masks[i].sum() == 0: + continue + start = self.structure[leaf_cnt_masks[i], 2] + for j in range(i+1): + ret[leaf_cnt_masks[i]] += data[start + j] + return ret + + def gather_from_non_leaf_children(self, data): + """ + Gather the data from the non-leaf children. + + Args: + data (torch.Tensor): the data to gather. The first dimension should be the number of leaf nodes. + """ + non_leaf_cnt = 8 - self.structure[:, 0] + non_leaf_cnt_masks = [non_leaf_cnt == i for i in range(1, 9)] + ret = torch.zeros_like(data, device=self.device) + for i in range(8): + if non_leaf_cnt_masks[i].sum() == 0: + continue + start = self.structure[non_leaf_cnt_masks[i], 1] + for j in range(i+1): + ret[non_leaf_cnt_masks[i]] += data[start + j] + return ret + + def structure_control(self, mask): + """ + Control the structure of the octree. + + Args: + mask (torch.Tensor): the mask to control the structure. 1 for subdivide, -1 for merge, 0 for keep. + """ + # Dont subdivide when the depth is the maximum. + mask[self.depth.squeeze() == self.max_depth] = torch.clamp_max(mask[self.depth.squeeze() == self.max_depth], 0) + # Dont merge when the depth is the minimum. + mask[self.depth.squeeze() == 1] = torch.clamp_min(mask[self.depth.squeeze() == 1], 0) + + # Gather control mask + structre_ctrl = self.gather_from_leaf_children(mask) + structre_ctrl[structre_ctrl==-8] = -1 + + new_leaf_num = self.structure[:, 0].clone() + # Modify the leaf num. + structre_valid = structre_ctrl >= 0 + new_leaf_num[structre_valid] -= structre_ctrl[structre_valid] # Add the new nodes. + structre_delete = structre_ctrl < 0 + merged_nodes = self.gather_from_non_leaf_children(structre_delete.int()) + new_leaf_num += merged_nodes # Delete the merged nodes. + + # Update the structure array to allocate new nodes. + mem_offset = torch.zeros((self.num_non_leaf_nodes + 1,), dtype=torch.int32, device=self.device) + mem_offset.index_add_(0, self.structure[structre_valid, 1], structre_ctrl[structre_valid]) # Add the new nodes. + mem_offset[:-1] -= structre_delete.int() # Delete the merged nodes. + new_structre_idx = torch.arange(0, self.num_non_leaf_nodes + 1, dtype=torch.int32, device=self.device) + mem_offset.cumsum(0) + new_structure_length = new_structre_idx[-1].item() + new_structre_idx = new_structre_idx[:-1] + new_structure = torch.empty((new_structure_length, 3), dtype=torch.int32, device=self.device) + new_structure[new_structre_idx[structre_valid], 0] = new_leaf_num[structre_valid] + + # Initialize the new nodes. + new_node_mask = torch.ones((new_structure_length,), dtype=torch.bool, device=self.device) + new_node_mask[new_structre_idx[structre_valid]] = False + new_structure[new_node_mask, 0] = 8 # Initialize to all leaf nodes. + new_node_num = new_node_mask.sum().item() + + # Rebuild child ptr. + non_leaf_cnt = 8 - new_structure[:, 0] + new_child_ptr = torch.cat([torch.zeros((1,), dtype=torch.int32, device=self.device), non_leaf_cnt.cumsum(0)[:-1]]) + new_structure[:, 1] = new_child_ptr + 1 + + # Rebuild data ptr with old data. + leaf_cnt = torch.zeros((new_structure_length,), dtype=torch.int32, device=self.device) + leaf_cnt.index_add_(0, new_structre_idx, self.structure[:, 0]) + old_data_ptr = torch.cat([torch.zeros((1,), dtype=torch.int32, device=self.device), leaf_cnt.cumsum(0)[:-1]]) + + # Update the data array + subdivide_mask = mask == 1 + merge_mask = mask == -1 + data_valid = ~(subdivide_mask | merge_mask) + mem_offset = torch.zeros((self.num_leaf_nodes + 1,), dtype=torch.int32, device=self.device) + mem_offset.index_add_(0, old_data_ptr[new_node_mask], torch.full((new_node_num,), 8, dtype=torch.int32, device=self.device)) # Add data array for new nodes + mem_offset[:-1] -= subdivide_mask.int() # Delete data elements for subdivide nodes + mem_offset[:-1] -= merge_mask.int() # Delete data elements for merge nodes + mem_offset.index_add_(0, self.structure[structre_valid, 2], merged_nodes[structre_valid]) # Add data elements for merge nodes + new_data_idx = torch.arange(0, self.num_leaf_nodes + 1, dtype=torch.int32, device=self.device) + mem_offset.cumsum(0) + new_data_length = new_data_idx[-1].item() + new_data_idx = new_data_idx[:-1] + new_data = {data: torch.empty((new_data_length,) + getattr(self, data).shape[1:], dtype=getattr(self, data).dtype, device=self.device) for data in self.data} + for data in self.data: + new_data[data][new_data_idx[data_valid]] = getattr(self, data)[data_valid] + + # Rebuild data ptr + leaf_cnt = new_structure[:, 0] + new_data_ptr = torch.cat([torch.zeros((1,), dtype=torch.int32, device=self.device), leaf_cnt.cumsum(0)[:-1]]) + new_structure[:, 2] = new_data_ptr + + # Initialize the new data array + ## For subdivide nodes + if subdivide_mask.sum() > 0: + subdivide_data_ptr = new_structure[new_node_mask, 2] + for data in self.data: + for i in range(8): + if data == 'position': + offset = torch.tensor([i // 4, (i // 2) % 2, i % 2], dtype=torch.float32, device=self.device) - 0.5 + scale = 2 ** (-1.0 - self.depth[subdivide_mask]) + new_data['position'][subdivide_data_ptr + i] = self.position[subdivide_mask] + offset * scale + elif data == 'depth': + new_data['depth'][subdivide_data_ptr + i] = self.depth[subdivide_mask] + 1 + elif data == 'opacity': + new_data['opacity'][subdivide_data_ptr + i] = self.inverse_opacity_activation(torch.sqrt(self.opacity_activation(self.opacity[subdivide_mask]))) + elif data == 'trivec': + offset = torch.tensor([i // 4, (i // 2) % 2, i % 2], dtype=torch.float32, device=self.device) * 0.5 + coord = (torch.linspace(0, 0.5, self.trivec.shape[-1], dtype=torch.float32, device=self.device)[None] + offset[:, None]).reshape(1, 3, self.trivec.shape[-1], 1) + axis = torch.linspace(0, 1, 3, dtype=torch.float32, device=self.device).reshape(1, 3, 1, 1).repeat(1, 1, self.trivec.shape[-1], 1) + coord = torch.stack([coord, axis], dim=3).reshape(1, 3, self.trivec.shape[-1], 2).expand(self.trivec[subdivide_mask].shape[0], -1, -1, -1) * 2 - 1 + new_data['trivec'][subdivide_data_ptr + i] = F.grid_sample(self.trivec[subdivide_mask], coord, align_corners=True) + else: + new_data[data][subdivide_data_ptr + i] = getattr(self, data)[subdivide_mask] + ## For merge nodes + if merge_mask.sum() > 0: + merge_data_ptr = torch.empty((merged_nodes.sum().item(),), dtype=torch.int32, device=self.device) + merge_nodes_cumsum = torch.cat([torch.zeros((1,), dtype=torch.int32, device=self.device), merged_nodes.cumsum(0)[:-1]]) + for i in range(8): + merge_data_ptr[merge_nodes_cumsum[merged_nodes > i] + i] = new_structure[new_structre_idx[merged_nodes > i], 2] + i + old_merge_data_ptr = self.structure[structre_delete, 2] + for data in self.data: + if data == 'position': + scale = 2 ** (1.0 - self.depth[old_merge_data_ptr]) + new_data['position'][merge_data_ptr] = ((self.position[old_merge_data_ptr] + 0.5) / scale).floor() * scale + 0.5 * scale - 0.5 + elif data == 'depth': + new_data['depth'][merge_data_ptr] = self.depth[old_merge_data_ptr] - 1 + elif data == 'opacity': + new_data['opacity'][subdivide_data_ptr + i] = self.inverse_opacity_activation(self.opacity_activation(self.opacity[subdivide_mask])**2) + elif data == 'trivec': + new_data['trivec'][merge_data_ptr] = self.trivec[old_merge_data_ptr] + else: + new_data[data][merge_data_ptr] = getattr(self, data)[old_merge_data_ptr] + + # Update the structure and data array + self.structure = new_structure + for data in self.data: + setattr(self, data, new_data[data]) + + # Save data array control temp variables + self.data_rearrange_buffer = { + 'subdivide_mask': subdivide_mask, + 'merge_mask': merge_mask, + 'data_valid': data_valid, + 'new_data_idx': new_data_idx, + 'new_data_length': new_data_length, + 'new_data': new_data + } diff --git a/trellis/utils/__init__.py b/trellis/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/trellis/utils/__pycache__/__init__.cpython-310.pyc b/trellis/utils/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6e0e70298e509d245b1136414c365f875e61e067 Binary files /dev/null and b/trellis/utils/__pycache__/__init__.cpython-310.pyc differ diff --git a/trellis/utils/__pycache__/general_utils.cpython-310.pyc b/trellis/utils/__pycache__/general_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6bebee9d6fc3393c74bb6fa8c1edc2c85bb6f75f Binary files /dev/null and b/trellis/utils/__pycache__/general_utils.cpython-310.pyc differ diff --git a/trellis/utils/__pycache__/loss_utils.cpython-310.pyc b/trellis/utils/__pycache__/loss_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..715fb0164872d83570fd0181b37c7e0ad098e1ee Binary files /dev/null and b/trellis/utils/__pycache__/loss_utils.cpython-310.pyc differ diff --git a/trellis/utils/__pycache__/postprocessing_utils.cpython-310.pyc b/trellis/utils/__pycache__/postprocessing_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6491f77d28ed35075e19b70078649f30ed68ba25 Binary files /dev/null and b/trellis/utils/__pycache__/postprocessing_utils.cpython-310.pyc differ diff --git a/trellis/utils/__pycache__/random_utils.cpython-310.pyc b/trellis/utils/__pycache__/random_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1c073c3a39622357e677185757e72e1e8b95f0f6 Binary files /dev/null and b/trellis/utils/__pycache__/random_utils.cpython-310.pyc differ diff --git a/trellis/utils/__pycache__/render_utils.cpython-310.pyc b/trellis/utils/__pycache__/render_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a6eeeaca4eabb83d4cf565cdbcbec649cf3268b1 Binary files /dev/null and b/trellis/utils/__pycache__/render_utils.cpython-310.pyc differ diff --git a/trellis/utils/general_utils.py b/trellis/utils/general_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..3d630f1607d8d84ad1a392407a0a7e74dc066916 --- /dev/null +++ b/trellis/utils/general_utils.py @@ -0,0 +1,412 @@ +import numpy as np +import cv2 +import torch +from scipy.spatial.transform import Rotation as R +import torch.nn.functional as F +# Dictionary utils +def _dict_merge(dicta, dictb, prefix=''): + """ + Merge two dictionaries. + """ + assert isinstance(dicta, dict), 'input must be a dictionary' + assert isinstance(dictb, dict), 'input must be a dictionary' + dict_ = {} + all_keys = set(dicta.keys()).union(set(dictb.keys())) + for key in all_keys: + if key in dicta.keys() and key in dictb.keys(): + if isinstance(dicta[key], dict) and isinstance(dictb[key], dict): + dict_[key] = _dict_merge(dicta[key], dictb[key], prefix=f'{prefix}.{key}') + else: + raise ValueError(f'Duplicate key {prefix}.{key} found in both dictionaries. Types: {type(dicta[key])}, {type(dictb[key])}') + elif key in dicta.keys(): + dict_[key] = dicta[key] + else: + dict_[key] = dictb[key] + return dict_ + + +def dict_merge(dicta, dictb): + """ + Merge two dictionaries. + """ + return _dict_merge(dicta, dictb, prefix='') + + +def dict_foreach(dic, func, special_func={}): + """ + Recursively apply a function to all non-dictionary leaf values in a dictionary. + """ + assert isinstance(dic, dict), 'input must be a dictionary' + for key in dic.keys(): + if isinstance(dic[key], dict): + dic[key] = dict_foreach(dic[key], func) + else: + if key in special_func.keys(): + dic[key] = special_func[key](dic[key]) + else: + dic[key] = func(dic[key]) + return dic + + +def dict_reduce(dicts, func, special_func={}): + """ + Reduce a list of dictionaries. Leaf values must be scalars. + """ + assert isinstance(dicts, list), 'input must be a list of dictionaries' + assert all([isinstance(d, dict) for d in dicts]), 'input must be a list of dictionaries' + assert len(dicts) > 0, 'input must be a non-empty list of dictionaries' + all_keys = set([key for dict_ in dicts for key in dict_.keys()]) + reduced_dict = {} + for key in all_keys: + vlist = [dict_[key] for dict_ in dicts if key in dict_.keys()] + if isinstance(vlist[0], dict): + reduced_dict[key] = dict_reduce(vlist, func, special_func) + else: + if key in special_func.keys(): + reduced_dict[key] = special_func[key](vlist) + else: + reduced_dict[key] = func(vlist) + return reduced_dict + + +def dict_any(dic, func): + """ + Recursively apply a function to all non-dictionary leaf values in a dictionary. + """ + assert isinstance(dic, dict), 'input must be a dictionary' + for key in dic.keys(): + if isinstance(dic[key], dict): + if dict_any(dic[key], func): + return True + else: + if func(dic[key]): + return True + return False + + +def dict_all(dic, func): + """ + Recursively apply a function to all non-dictionary leaf values in a dictionary. + """ + assert isinstance(dic, dict), 'input must be a dictionary' + for key in dic.keys(): + if isinstance(dic[key], dict): + if not dict_all(dic[key], func): + return False + else: + if not func(dic[key]): + return False + return True + + +def dict_flatten(dic, sep='.'): + """ + Flatten a nested dictionary into a dictionary with no nested dictionaries. + """ + assert isinstance(dic, dict), 'input must be a dictionary' + flat_dict = {} + for key in dic.keys(): + if isinstance(dic[key], dict): + sub_dict = dict_flatten(dic[key], sep=sep) + for sub_key in sub_dict.keys(): + flat_dict[str(key) + sep + str(sub_key)] = sub_dict[sub_key] + else: + flat_dict[key] = dic[key] + return flat_dict + + +def make_grid(images, nrow=None, ncol=None, aspect_ratio=None): + num_images = len(images) + if nrow is None and ncol is None: + if aspect_ratio is not None: + nrow = int(np.round(np.sqrt(num_images / aspect_ratio))) + else: + nrow = int(np.sqrt(num_images)) + ncol = (num_images + nrow - 1) // nrow + elif nrow is None and ncol is not None: + nrow = (num_images + ncol - 1) // ncol + elif nrow is not None and ncol is None: + ncol = (num_images + nrow - 1) // nrow + else: + assert nrow * ncol >= num_images, 'nrow * ncol must be greater than or equal to the number of images' + + grid = np.zeros((nrow * images[0].shape[0], ncol * images[0].shape[1], images[0].shape[2]), dtype=images[0].dtype) + for i, img in enumerate(images): + row = i // ncol + col = i % ncol + grid[row * img.shape[0]:(row + 1) * img.shape[0], col * img.shape[1]:(col + 1) * img.shape[1]] = img + return grid + + +def notes_on_image(img, notes=None): + img = np.pad(img, ((0, 32), (0, 0), (0, 0)), 'constant', constant_values=0) + img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) + if notes is not None: + img = cv2.putText(img, notes, (0, img.shape[0] - 4), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 1) + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + return img + + +def save_image_with_notes(img, path, notes=None): + """ + Save an image with notes. + """ + if isinstance(img, torch.Tensor): + img = img.cpu().numpy().transpose(1, 2, 0) + if img.dtype == np.float32 or img.dtype == np.float64: + img = np.clip(img * 255, 0, 255).astype(np.uint8) + img = notes_on_image(img, notes) + cv2.imwrite(path, cv2.cvtColor(img, cv2.COLOR_RGB2BGR)) + + +# debug utils + +def atol(x, y): + """ + Absolute tolerance. + """ + return torch.abs(x - y) + + +def rtol(x, y): + """ + Relative tolerance. + """ + return torch.abs(x - y) / torch.clamp_min(torch.maximum(torch.abs(x), torch.abs(y)), 1e-12) + + +# print utils +def indent(s, n=4): + """ + Indent a string. + """ + lines = s.split('\n') + for i in range(1, len(lines)): + lines[i] = ' ' * n + lines[i] + return '\n'.join(lines) + +def rotation2quad(matrix: torch.Tensor) -> torch.Tensor: + """ + Convert rotations given as rotation matrices to quaternions. + + Args: + matrix: Rotation matrices as tensor of shape (..., 3, 3). + + Returns: + quaternions with real part first, as tensor of shape (..., 4). + Source: https://pytorch3d.readthedocs.io/en/latest/_modules/pytorch3d/transforms/rotation_conversions.html#matrix_to_quaternion + """ + if matrix.size(-1) != 3 or matrix.size(-2) != 3: + raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.") + + if not isinstance(matrix, torch.Tensor): + matrix = torch.tensor(matrix).cuda() + + batch_dim = matrix.shape[:-2] + m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind( + matrix.reshape(batch_dim + (9,)), dim=-1 + ) + + q_abs = _sqrt_positive_part( + torch.stack( + [ + 1.0 + m00 + m11 + m22, + 1.0 + m00 - m11 - m22, + 1.0 - m00 + m11 - m22, + 1.0 - m00 - m11 + m22, + ], + dim=-1, + ) + ) + + # we produce the desired quaternion multiplied by each of r, i, j, k + quat_by_rijk = torch.stack( + [ + # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and + # `int`. + torch.stack([q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1), + # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and + # `int`. + torch.stack([m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1), + # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and + # `int`. + torch.stack([m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1), + # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and + # `int`. + torch.stack([m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1), + ], + dim=-2, + ) + + # We floor here at 0.1 but the exact level is not important; if q_abs is small, + # the candidate won't be picked. + flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device) + quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr)) + + # if not for numerical problems, quat_candidates[i] should be same (up to a sign), + # forall i; we pick the best-conditioned one (with the largest denominator) + + return quat_candidates[ + F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, : + ].reshape(batch_dim + (4,)) + +def quad2rotation(q): + """ + Convert quaternion to rotation in batch. Since all operation in pytorch, support gradient passing. + + Args: + quad (tensor, batch_size*4): quaternion. + + Returns: + rot_mat (tensor, batch_size*3*3): rotation. + """ + # bs = quad.shape[0] + # qr, qi, qj, qk = quad[:, 0], quad[:, 1], quad[:, 2], quad[:, 3] + # two_s = 2.0 / (quad * quad).sum(-1) + # rot_mat = torch.zeros(bs, 3, 3).to(quad.get_device()) + # rot_mat[:, 0, 0] = 1 - two_s * (qj**2 + qk**2) + # rot_mat[:, 0, 1] = two_s * (qi * qj - qk * qr) + # rot_mat[:, 0, 2] = two_s * (qi * qk + qj * qr) + # rot_mat[:, 1, 0] = two_s * (qi * qj + qk * qr) + # rot_mat[:, 1, 1] = 1 - two_s * (qi**2 + qk**2) + # rot_mat[:, 1, 2] = two_s * (qj * qk - qi * qr) + # rot_mat[:, 2, 0] = two_s * (qi * qk - qj * qr) + # rot_mat[:, 2, 1] = two_s * (qj * qk + qi * qr) + # rot_mat[:, 2, 2] = 1 - two_s * (qi**2 + qj**2) + # return rot_mat + if not isinstance(q, torch.Tensor): + q = torch.tensor(q).cuda() + + norm = torch.sqrt( + q[:, 0] * q[:, 0] + q[:, 1] * q[:, 1] + q[:, 2] * q[:, 2] + q[:, 3] * q[:, 3] + ) + q = q / norm[:, None] + rot = torch.zeros((q.size(0), 3, 3)).to(q) + r = q[:, 0] + x = q[:, 1] + y = q[:, 2] + z = q[:, 3] + rot[:, 0, 0] = 1 - 2 * (y * y + z * z) + rot[:, 0, 1] = 2 * (x * y - r * z) + rot[:, 0, 2] = 2 * (x * z + r * y) + rot[:, 1, 0] = 2 * (x * y + r * z) + rot[:, 1, 1] = 1 - 2 * (x * x + z * z) + rot[:, 1, 2] = 2 * (y * z - r * x) + rot[:, 2, 0] = 2 * (x * z - r * y) + rot[:, 2, 1] = 2 * (y * z + r * x) + rot[:, 2, 2] = 1 - 2 * (x * x + y * y) + return rot + +def perform_rodrigues_transformation(rvec): + try: + R, _ = cv2.Rodrigues(rvec) + return R + except cv2.error as e: + return False + +def euler2rot(euler): + r = R.from_euler('xyz', euler, degrees=True) + rotation_matrix = r.as_matrix() + return rotation_matrix + +def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor: + """ + Returns torch.sqrt(torch.max(0, x)) + but with a zero subgradient where x is 0. + """ + ret = torch.zeros_like(x) + positive_mask = x > 0 + ret[positive_mask] = torch.sqrt(x[positive_mask]) + return ret + +def matrix_to_quaternion(matrix: torch.Tensor) -> torch.Tensor: + """ + Convert rotations given as rotation matrices to quaternions. + + Args: + matrix: Rotation matrices as tensor of shape (..., 3, 3). + + Returns: + quaternions with real part first, as tensor of shape (..., 4). + """ + if matrix.size(-1) != 3 or matrix.size(-2) != 3: + raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.") + + batch_dim = matrix.shape[:-2] + m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind( + matrix.reshape(batch_dim + (9,)), dim=-1 + ) + + q_abs = _sqrt_positive_part( + torch.stack( + [ + 1.0 + m00 + m11 + m22, + 1.0 + m00 - m11 - m22, + 1.0 - m00 + m11 - m22, + 1.0 - m00 - m11 + m22, + ], + dim=-1, + ) + ) + + # we produce the desired quaternion multiplied by each of r, i, j, k + quat_by_rijk = torch.stack( + [ + # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and + # `int`. + torch.stack([q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1), + # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and + # `int`. + torch.stack([m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1), + # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and + # `int`. + torch.stack([m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1), + # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and + # `int`. + torch.stack([m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1), + ], + dim=-2, + ) + + # We floor here at 0.1 but the exact level is not important; if q_abs is small, + # the candidate won't be picked. + flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device) + quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr)) + + # if not for numerical problems, quat_candidates[i] should be same (up to a sign), + # forall i; we pick the best-conditioned one (with the largest denominator) + + return quat_candidates[ + F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, : + ].reshape(batch_dim + (4,)) + +def quaternion_to_matrix(quaternions: torch.Tensor) -> torch.Tensor: + """ + Convert rotations given as quaternions to rotation matrices. + + Args: + quaternions: quaternions with real part first, + as tensor of shape (..., 4). + + Returns: + Rotation matrices as tensor of shape (..., 3, 3). + """ + r, i, j, k = torch.unbind(quaternions, -1) + # pyre-fixme[58]: `/` is not supported for operand types `float` and `Tensor`. + two_s = 2.0 / (quaternions * quaternions).sum(-1) + + o = torch.stack( + ( + 1 - two_s * (j * j + k * k), + two_s * (i * j - k * r), + two_s * (i * k + j * r), + two_s * (i * j + k * r), + 1 - two_s * (i * i + k * k), + two_s * (j * k - i * r), + two_s * (i * k - j * r), + two_s * (j * k + i * r), + 1 - two_s * (i * i + j * j), + ), + -1, + ) + return o.reshape(quaternions.shape[:-1] + (3, 3)) diff --git a/trellis/utils/loss_utils.py b/trellis/utils/loss_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..52049f69543f2700bc5525b09cbf2fb25c08aa9e --- /dev/null +++ b/trellis/utils/loss_utils.py @@ -0,0 +1,92 @@ +import torch +import torch.nn.functional as F +from torch.autograd import Variable +from math import exp +from lpips import LPIPS + + +def smooth_l1_loss(pred, target, beta=1.0): + diff = torch.abs(pred - target) + loss = torch.where(diff < beta, 0.5 * diff ** 2 / beta, diff - 0.5 * beta) + return loss.mean() + + +def l1_loss(network_output, gt): + return torch.abs((network_output - gt)).mean() + + +def l2_loss(network_output, gt): + return ((network_output - gt) ** 2).mean() + + +def gaussian(window_size, sigma): + gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)]) + return gauss / gauss.sum() + + +def create_window(window_size, channel): + _1D_window = gaussian(window_size, 1.5).unsqueeze(1) + _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) + window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous()) + return window + + +def psnr(img1, img2, max_val=1.0): + mse = F.mse_loss(img1, img2) + return 20 * torch.log10(max_val / torch.sqrt(mse)) + + +def ssim(img1, img2, window_size=11, size_average=True): + channel = img1.size(-3) + window = create_window(window_size, channel) + + if img1.is_cuda: + window = window.cuda(img1.get_device()) + window = window.type_as(img1) + + return _ssim(img1, img2, window, window_size, channel, size_average) + +def _ssim(img1, img2, window, window_size, channel, size_average=True): + mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel) + mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel) + + mu1_sq = mu1.pow(2) + mu2_sq = mu2.pow(2) + mu1_mu2 = mu1 * mu2 + + sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq + sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq + sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2 + + C1 = 0.01 ** 2 + C2 = 0.03 ** 2 + + ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)) + + if size_average: + return ssim_map.mean() + else: + return ssim_map.mean(1).mean(1).mean(1) + + +loss_fn_vgg = None +def lpips(img1, img2, value_range=(0, 1)): + global loss_fn_vgg + if loss_fn_vgg is None: + loss_fn_vgg = LPIPS(net='vgg').cuda().eval() + # normalize to [-1, 1] + img1 = (img1 - value_range[0]) / (value_range[1] - value_range[0]) * 2 - 1 + img2 = (img2 - value_range[0]) / (value_range[1] - value_range[0]) * 2 - 1 + return loss_fn_vgg(img1, img2).mean() + + +def normal_angle(pred, gt): + pred = pred * 2.0 - 1.0 + gt = gt * 2.0 - 1.0 + norms = pred.norm(dim=-1) * gt.norm(dim=-1) + cos_sim = (pred * gt).sum(-1) / (norms + 1e-9) + cos_sim = torch.clamp(cos_sim, -1.0, 1.0) + ang = torch.rad2deg(torch.acos(cos_sim[norms > 1e-9])).mean() + if ang.isnan(): + return -1 + return ang diff --git a/trellis/utils/postprocessing_utils.py b/trellis/utils/postprocessing_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..cf59ad4e6af3fe85ff615b7f15ad3642f558d56f --- /dev/null +++ b/trellis/utils/postprocessing_utils.py @@ -0,0 +1,587 @@ +from typing import * +import numpy as np +import torch +import utils3d +import nvdiffrast.torch as dr +from tqdm import tqdm +import trimesh +import trimesh.visual +import xatlas +import pyvista as pv +from pymeshfix import _meshfix +import igraph +import cv2 +from PIL import Image +from .random_utils import sphere_hammersley_sequence +from .render_utils import render_multiview +from ..renderers import GaussianRenderer +from ..representations import Gaussian, MeshExtractResult + + +@torch.no_grad() +def _fill_holes( + verts, + faces, + max_hole_size=0.04, + max_hole_nbe=32, + resolution=128, + num_views=500, + debug=False, + verbose=False +): + """ + Rasterize a mesh from multiple views and remove invisible faces. + Also includes postprocessing to: + 1. Remove connected components that are have low visibility. + 2. Mincut to remove faces at the inner side of the mesh connected to the outer side with a small hole. + + Args: + verts (torch.Tensor): Vertices of the mesh. Shape (V, 3). + faces (torch.Tensor): Faces of the mesh. Shape (F, 3). + max_hole_size (float): Maximum area of a hole to fill. + resolution (int): Resolution of the rasterization. + num_views (int): Number of views to rasterize the mesh. + verbose (bool): Whether to print progress. + """ + # Construct cameras + yaws = [] + pitchs = [] + for i in range(num_views): + y, p = sphere_hammersley_sequence(i, num_views) + yaws.append(y) + pitchs.append(p) + yaws = torch.tensor(yaws).cuda() + pitchs = torch.tensor(pitchs).cuda() + radius = 2.0 + fov = torch.deg2rad(torch.tensor(40)).cuda() + projection = utils3d.torch.perspective_from_fov_xy(fov, fov, 1, 3) + views = [] + for (yaw, pitch) in zip(yaws, pitchs): + orig = torch.tensor([ + torch.sin(yaw) * torch.cos(pitch), + torch.cos(yaw) * torch.cos(pitch), + torch.sin(pitch), + ]).cuda().float() * radius + view = utils3d.torch.view_look_at(orig, torch.tensor([0, 0, 0]).float().cuda(), torch.tensor([0, 0, 1]).float().cuda()) + views.append(view) + views = torch.stack(views, dim=0) + + # Rasterize + visblity = torch.zeros(faces.shape[0], dtype=torch.int32, device=verts.device) + rastctx = utils3d.torch.RastContext(backend='cuda') + for i in tqdm(range(views.shape[0]), total=views.shape[0], disable=not verbose, desc='Rasterizing'): + view = views[i] + buffers = utils3d.torch.rasterize_triangle_faces( + rastctx, verts[None], faces, resolution, resolution, view=view, projection=projection + ) + face_id = buffers['face_id'][0][buffers['mask'][0] > 0.95] - 1 + face_id = torch.unique(face_id).long() + visblity[face_id] += 1 + visblity = visblity.float() / num_views + + # Mincut + ## construct outer faces + edges, face2edge, edge_degrees = utils3d.torch.compute_edges(faces) + boundary_edge_indices = torch.nonzero(edge_degrees == 1).reshape(-1) + connected_components = utils3d.torch.compute_connected_components(faces, edges, face2edge) + outer_face_indices = torch.zeros(faces.shape[0], dtype=torch.bool, device=faces.device) + for i in range(len(connected_components)): + outer_face_indices[connected_components[i]] = visblity[connected_components[i]] > min(max(visblity[connected_components[i]].quantile(0.75).item(), 0.25), 0.5) + outer_face_indices = outer_face_indices.nonzero().reshape(-1) + + ## construct inner faces + inner_face_indices = torch.nonzero(visblity == 0).reshape(-1) + if verbose: + tqdm.write(f'Found {inner_face_indices.shape[0]} invisible faces') + if inner_face_indices.shape[0] == 0: + return verts, faces + + ## Construct dual graph (faces as nodes, edges as edges) + dual_edges, dual_edge2edge = utils3d.torch.compute_dual_graph(face2edge) + dual_edge2edge = edges[dual_edge2edge] + dual_edges_weights = torch.norm(verts[dual_edge2edge[:, 0]] - verts[dual_edge2edge[:, 1]], dim=1) + if verbose: + tqdm.write(f'Dual graph: {dual_edges.shape[0]} edges') + + ## solve mincut problem + ### construct main graph + g = igraph.Graph() + g.add_vertices(faces.shape[0]) + g.add_edges(dual_edges.cpu().numpy()) + g.es['weight'] = dual_edges_weights.cpu().numpy() + + ### source and target + g.add_vertex('s') + g.add_vertex('t') + + ### connect invisible faces to source + g.add_edges([(f, 's') for f in inner_face_indices], attributes={'weight': torch.ones(inner_face_indices.shape[0], dtype=torch.float32).cpu().numpy()}) + + ### connect outer faces to target + g.add_edges([(f, 't') for f in outer_face_indices], attributes={'weight': torch.ones(outer_face_indices.shape[0], dtype=torch.float32).cpu().numpy()}) + + ### solve mincut + cut = g.mincut('s', 't', (np.array(g.es['weight']) * 1000).tolist()) + remove_face_indices = torch.tensor([v for v in cut.partition[0] if v < faces.shape[0]], dtype=torch.long, device=faces.device) + if verbose: + tqdm.write(f'Mincut solved, start checking the cut') + + ### check if the cut is valid with each connected component + to_remove_cc = utils3d.torch.compute_connected_components(faces[remove_face_indices]) + if debug: + tqdm.write(f'Number of connected components of the cut: {len(to_remove_cc)}') + valid_remove_cc = [] + cutting_edges = [] + for cc in to_remove_cc: + #### check if the connected component has low visibility + visblity_median = visblity[remove_face_indices[cc]].median() + if debug: + tqdm.write(f'visblity_median: {visblity_median}') + if visblity_median > 0.25: + continue + + #### check if the cuting loop is small enough + cc_edge_indices, cc_edges_degree = torch.unique(face2edge[remove_face_indices[cc]], return_counts=True) + cc_boundary_edge_indices = cc_edge_indices[cc_edges_degree == 1] + cc_new_boundary_edge_indices = cc_boundary_edge_indices[~torch.isin(cc_boundary_edge_indices, boundary_edge_indices)] + if len(cc_new_boundary_edge_indices) > 0: + cc_new_boundary_edge_cc = utils3d.torch.compute_edge_connected_components(edges[cc_new_boundary_edge_indices]) + cc_new_boundary_edges_cc_center = [verts[edges[cc_new_boundary_edge_indices[edge_cc]]].mean(dim=1).mean(dim=0) for edge_cc in cc_new_boundary_edge_cc] + cc_new_boundary_edges_cc_area = [] + for i, edge_cc in enumerate(cc_new_boundary_edge_cc): + _e1 = verts[edges[cc_new_boundary_edge_indices[edge_cc]][:, 0]] - cc_new_boundary_edges_cc_center[i] + _e2 = verts[edges[cc_new_boundary_edge_indices[edge_cc]][:, 1]] - cc_new_boundary_edges_cc_center[i] + cc_new_boundary_edges_cc_area.append(torch.norm(torch.cross(_e1, _e2, dim=-1), dim=1).sum() * 0.5) + if debug: + cutting_edges.append(cc_new_boundary_edge_indices) + tqdm.write(f'Area of the cutting loop: {cc_new_boundary_edges_cc_area}') + if any([l > max_hole_size for l in cc_new_boundary_edges_cc_area]): + continue + + valid_remove_cc.append(cc) + + if debug: + face_v = verts[faces].mean(dim=1).cpu().numpy() + vis_dual_edges = dual_edges.cpu().numpy() + vis_colors = np.zeros((faces.shape[0], 3), dtype=np.uint8) + vis_colors[inner_face_indices.cpu().numpy()] = [0, 0, 255] + vis_colors[outer_face_indices.cpu().numpy()] = [0, 255, 0] + vis_colors[remove_face_indices.cpu().numpy()] = [255, 0, 255] + if len(valid_remove_cc) > 0: + vis_colors[remove_face_indices[torch.cat(valid_remove_cc)].cpu().numpy()] = [255, 0, 0] + utils3d.io.write_ply('dbg_dual.ply', face_v, edges=vis_dual_edges, vertex_colors=vis_colors) + + vis_verts = verts.cpu().numpy() + vis_edges = edges[torch.cat(cutting_edges)].cpu().numpy() + utils3d.io.write_ply('dbg_cut.ply', vis_verts, edges=vis_edges) + + + if len(valid_remove_cc) > 0: + remove_face_indices = remove_face_indices[torch.cat(valid_remove_cc)] + mask = torch.ones(faces.shape[0], dtype=torch.bool, device=faces.device) + mask[remove_face_indices] = 0 + faces = faces[mask] + faces, verts = utils3d.torch.remove_unreferenced_vertices(faces, verts) + if verbose: + tqdm.write(f'Removed {(~mask).sum()} faces by mincut') + else: + if verbose: + tqdm.write(f'Removed 0 faces by mincut') + + mesh = _meshfix.PyTMesh() + mesh.load_array(verts.cpu().numpy(), faces.cpu().numpy()) + mesh.fill_small_boundaries(nbe=max_hole_nbe, refine=True) + verts, faces = mesh.return_arrays() + verts, faces = torch.tensor(verts, device='cuda', dtype=torch.float32), torch.tensor(faces, device='cuda', dtype=torch.int32) + + return verts, faces + + +def postprocess_mesh( + vertices: np.array, + faces: np.array, + simplify: bool = True, + simplify_ratio: float = 0.9, + fill_holes: bool = True, + fill_holes_max_hole_size: float = 0.04, + fill_holes_max_hole_nbe: int = 32, + fill_holes_resolution: int = 1024, + fill_holes_num_views: int = 1000, + debug: bool = False, + verbose: bool = False, +): + """ + Postprocess a mesh by simplifying, removing invisible faces, and removing isolated pieces. + + Args: + vertices (np.array): Vertices of the mesh. Shape (V, 3). + faces (np.array): Faces of the mesh. Shape (F, 3). + simplify (bool): Whether to simplify the mesh, using quadric edge collapse. + simplify_ratio (float): Ratio of faces to keep after simplification. + fill_holes (bool): Whether to fill holes in the mesh. + fill_holes_max_hole_size (float): Maximum area of a hole to fill. + fill_holes_max_hole_nbe (int): Maximum number of boundary edges of a hole to fill. + fill_holes_resolution (int): Resolution of the rasterization. + fill_holes_num_views (int): Number of views to rasterize the mesh. + verbose (bool): Whether to print progress. + """ + + if verbose: + tqdm.write(f'Before postprocess: {vertices.shape[0]} vertices, {faces.shape[0]} faces') + + # Simplify + if simplify and simplify_ratio > 0: + mesh = pv.PolyData(vertices, np.concatenate([np.full((faces.shape[0], 1), 3), faces], axis=1)) + mesh = mesh.decimate(simplify_ratio, progress_bar=verbose) + vertices, faces = mesh.points, mesh.faces.reshape(-1, 4)[:, 1:] + if verbose: + tqdm.write(f'After decimate: {vertices.shape[0]} vertices, {faces.shape[0]} faces') + + # Remove invisible faces + if fill_holes: + vertices, faces = torch.tensor(vertices).cuda(), torch.tensor(faces.astype(np.int32)).cuda() + vertices, faces = _fill_holes( + vertices, faces, + max_hole_size=fill_holes_max_hole_size, + max_hole_nbe=fill_holes_max_hole_nbe, + resolution=fill_holes_resolution, + num_views=fill_holes_num_views, + debug=debug, + verbose=verbose, + ) + vertices, faces = vertices.cpu().numpy(), faces.cpu().numpy() + if verbose: + tqdm.write(f'After remove invisible faces: {vertices.shape[0]} vertices, {faces.shape[0]} faces') + + return vertices, faces + + +def parametrize_mesh(vertices: np.array, faces: np.array): + """ + Parametrize a mesh to a texture space, using xatlas. + + Args: + vertices (np.array): Vertices of the mesh. Shape (V, 3). + faces (np.array): Faces of the mesh. Shape (F, 3). + """ + + vmapping, indices, uvs = xatlas.parametrize(vertices, faces) + + vertices = vertices[vmapping] + faces = indices + + return vertices, faces, uvs + + +def bake_texture( + vertices: np.array, + faces: np.array, + uvs: np.array, + observations: List[np.array], + masks: List[np.array], + extrinsics: List[np.array], + intrinsics: List[np.array], + texture_size: int = 2048, + near: float = 0.1, + far: float = 10.0, + mode: Literal['fast', 'opt'] = 'opt', + lambda_tv: float = 1e-2, + verbose: bool = False, +): + """ + Bake texture to a mesh from multiple observations. + + Args: + vertices (np.array): Vertices of the mesh. Shape (V, 3). + faces (np.array): Faces of the mesh. Shape (F, 3). + uvs (np.array): UV coordinates of the mesh. Shape (V, 2). + observations (List[np.array]): List of observations. Each observation is a 2D image. Shape (H, W, 3). + masks (List[np.array]): List of masks. Each mask is a 2D image. Shape (H, W). + extrinsics (List[np.array]): List of extrinsics. Shape (4, 4). + intrinsics (List[np.array]): List of intrinsics. Shape (3, 3). + texture_size (int): Size of the texture. + near (float): Near plane of the camera. + far (float): Far plane of the camera. + mode (Literal['fast', 'opt']): Mode of texture baking. + lambda_tv (float): Weight of total variation loss in optimization. + verbose (bool): Whether to print progress. + """ + vertices = torch.tensor(vertices).cuda() + faces = torch.tensor(faces.astype(np.int32)).cuda() + uvs = torch.tensor(uvs).cuda() + observations = [torch.tensor(obs / 255.0).float().cuda() for obs in observations] + masks = [torch.tensor(m>0).bool().cuda() for m in masks] + views = [utils3d.torch.extrinsics_to_view(torch.tensor(extr).cuda()) for extr in extrinsics] + projections = [utils3d.torch.intrinsics_to_perspective(torch.tensor(intr).cuda(), near, far) for intr in intrinsics] + + if mode == 'fast': + texture = torch.zeros((texture_size * texture_size, 3), dtype=torch.float32).cuda() + texture_weights = torch.zeros((texture_size * texture_size), dtype=torch.float32).cuda() + rastctx = utils3d.torch.RastContext(backend='cuda') + for observation, view, projection in tqdm(zip(observations, views, projections), total=len(observations), disable=not verbose, desc='Texture baking (fast)'): + with torch.no_grad(): + rast = utils3d.torch.rasterize_triangle_faces( + rastctx, vertices[None], faces, observation.shape[1], observation.shape[0], uv=uvs[None], view=view, projection=projection + ) + uv_map = rast['uv'][0].detach().flip(0) + mask = rast['mask'][0].detach().bool() & masks[0] + + # nearest neighbor interpolation + uv_map = (uv_map * texture_size).floor().long() + obs = observation[mask] + uv_map = uv_map[mask] + idx = uv_map[:, 0] + (texture_size - uv_map[:, 1] - 1) * texture_size + texture = texture.scatter_add(0, idx.view(-1, 1).expand(-1, 3), obs) + texture_weights = texture_weights.scatter_add(0, idx, torch.ones((obs.shape[0]), dtype=torch.float32, device=texture.device)) + + mask = texture_weights > 0 + texture[mask] /= texture_weights[mask][:, None] + texture = np.clip(texture.reshape(texture_size, texture_size, 3).cpu().numpy() * 255, 0, 255).astype(np.uint8) + + # inpaint + mask = (texture_weights == 0).cpu().numpy().astype(np.uint8).reshape(texture_size, texture_size) + texture = cv2.inpaint(texture, mask, 3, cv2.INPAINT_TELEA) + + elif mode == 'opt': + rastctx = utils3d.torch.RastContext(backend='cuda') + observations = [observations.flip(0) for observations in observations] + masks = [m.flip(0) for m in masks] + _uv = [] + _uv_dr = [] + for observation, view, projection in tqdm(zip(observations, views, projections), total=len(views), disable=not verbose, desc='Texture baking (opt): UV'): + with torch.no_grad(): + rast = utils3d.torch.rasterize_triangle_faces( + rastctx, vertices[None], faces, observation.shape[1], observation.shape[0], uv=uvs[None], view=view, projection=projection + ) + _uv.append(rast['uv'].detach()) + _uv_dr.append(rast['uv_dr'].detach()) + + texture = torch.nn.Parameter(torch.zeros((1, texture_size, texture_size, 3), dtype=torch.float32).cuda()) + optimizer = torch.optim.Adam([texture], betas=(0.5, 0.9), lr=1e-2) + + def exp_anealing(optimizer, step, total_steps, start_lr, end_lr): + return start_lr * (end_lr / start_lr) ** (step / total_steps) + + def cosine_anealing(optimizer, step, total_steps, start_lr, end_lr): + return end_lr + 0.5 * (start_lr - end_lr) * (1 + np.cos(np.pi * step / total_steps)) + + def tv_loss(texture): + return torch.nn.functional.l1_loss(texture[:, :-1, :, :], texture[:, 1:, :, :]) + \ + torch.nn.functional.l1_loss(texture[:, :, :-1, :], texture[:, :, 1:, :]) + + total_steps = 2500 + with tqdm(total=total_steps, disable=not verbose, desc='Texture baking (opt): optimizing') as pbar: + for step in range(total_steps): + optimizer.zero_grad() + selected = np.random.randint(0, len(views)) + uv, uv_dr, observation, mask = _uv[selected], _uv_dr[selected], observations[selected], masks[selected] + render = dr.texture(texture, uv, uv_dr)[0] + loss = torch.nn.functional.l1_loss(render[mask], observation[mask]) + if lambda_tv > 0: + loss += lambda_tv * tv_loss(texture) + loss.backward() + optimizer.step() + # annealing + optimizer.param_groups[0]['lr'] = cosine_anealing(optimizer, step, total_steps, 1e-2, 1e-5) + pbar.set_postfix({'loss': loss.item()}) + pbar.update() + texture = np.clip(texture[0].flip(0).detach().cpu().numpy() * 255, 0, 255).astype(np.uint8) + mask = 1 - utils3d.torch.rasterize_triangle_faces( + rastctx, (uvs * 2 - 1)[None], faces, texture_size, texture_size + )['mask'][0].detach().cpu().numpy().astype(np.uint8) + texture = cv2.inpaint(texture, mask, 3, cv2.INPAINT_TELEA) + else: + raise ValueError(f'Unknown mode: {mode}') + + return texture + + +def to_glb( + app_rep: Union[Gaussian], + mesh: MeshExtractResult, + simplify: float = 0.95, + fill_holes: bool = True, + fill_holes_max_size: float = 0.04, + texture_size: int = 1024, + debug: bool = False, + verbose: bool = True, +) -> trimesh.Trimesh: + """ + Convert a generated asset to a glb file. + + Args: + app_rep (Union[Gaussian]): Appearance representation. + mesh (MeshExtractResult): Extracted mesh. + simplify (float): Ratio of faces to remove in simplification. + fill_holes (bool): Whether to fill holes in the mesh. + fill_holes_max_size (float): Maximum area of a hole to fill. + texture_size (int): Size of the texture. + debug (bool): Whether to print debug information. + verbose (bool): Whether to print progress. + """ + vertices = mesh.vertices.cpu().numpy() + faces = mesh.faces.cpu().numpy() + + # mesh postprocess + vertices, faces = postprocess_mesh( + vertices, faces, + simplify=simplify > 0, + simplify_ratio=simplify, + fill_holes=fill_holes, + fill_holes_max_hole_size=fill_holes_max_size, + fill_holes_max_hole_nbe=int(250 * np.sqrt(1-simplify)), + fill_holes_resolution=1024, + fill_holes_num_views=1000, + debug=debug, + verbose=verbose, + ) + + # parametrize mesh + vertices, faces, uvs = parametrize_mesh(vertices, faces) + + # bake texture + observations, extrinsics, intrinsics = render_multiview(app_rep, resolution=1024, nviews=100, only_color=True) + masks = [np.any(observation > 0, axis=-1) for observation in observations] + extrinsics = [extrinsics[i].cpu().numpy() for i in range(len(extrinsics))] + intrinsics = [intrinsics[i].cpu().numpy() for i in range(len(intrinsics))] + texture = bake_texture( + vertices, faces, uvs, + observations, masks, extrinsics, intrinsics, + texture_size=texture_size, mode='opt', + lambda_tv=0.01, + verbose=verbose + ) + texture = Image.fromarray(texture) + + # rotate mesh (from z-up to y-up) + vertices = vertices @ np.array([[1, 0, 0], [0, 0, -1], [0, 1, 0]]) + material = trimesh.visual.material.PBRMaterial( + roughnessFactor=1.0, + baseColorTexture=texture, + baseColorFactor=np.array([255, 255, 255, 255], dtype=np.uint8) + ) + mesh = trimesh.Trimesh(vertices, faces, visual=trimesh.visual.TextureVisuals(uv=uvs, material=material)) + return mesh + + +def simplify_gs( + gs: Gaussian, + simplify: float = 0.95, + verbose: bool = True, +): + """ + Simplify 3D Gaussians + NOTE: this function is not used in the current implementation for the unsatisfactory performance. + + Args: + gs (Gaussian): 3D Gaussian. + simplify (float): Ratio of Gaussians to remove in simplification. + """ + if simplify <= 0: + return gs + + # simplify + observations, extrinsics, intrinsics = render_multiview(gs, resolution=1024, nviews=100, only_color=True) + observations = [torch.tensor(obs / 255.0).float().cuda().permute(2, 0, 1) for obs in observations] + + # Following https://arxiv.org/pdf/2411.06019 + renderer = GaussianRenderer({ + "resolution": 1024, + "near": 0.8, + "far": 1.6, + "ssaa": 1, + "bg_color": (0,0,0), + }) + new_gs = Gaussian(**gs.init_params) + new_gs._features_dc = gs._features_dc.clone() + new_gs._features_rest = gs._features_rest.clone() if gs._features_rest is not None else None + new_gs._opacity = torch.nn.Parameter(gs._opacity.clone()) + new_gs._rotation = torch.nn.Parameter(gs._rotation.clone()) + new_gs._scaling = torch.nn.Parameter(gs._scaling.clone()) + new_gs._xyz = torch.nn.Parameter(gs._xyz.clone()) + + start_lr = [1e-4, 1e-3, 5e-3, 0.025] + end_lr = [1e-6, 1e-5, 5e-5, 0.00025] + optimizer = torch.optim.Adam([ + {"params": new_gs._xyz, "lr": start_lr[0]}, + {"params": new_gs._rotation, "lr": start_lr[1]}, + {"params": new_gs._scaling, "lr": start_lr[2]}, + {"params": new_gs._opacity, "lr": start_lr[3]}, + ], lr=start_lr[0]) + + def exp_anealing(optimizer, step, total_steps, start_lr, end_lr): + return start_lr * (end_lr / start_lr) ** (step / total_steps) + + def cosine_anealing(optimizer, step, total_steps, start_lr, end_lr): + return end_lr + 0.5 * (start_lr - end_lr) * (1 + np.cos(np.pi * step / total_steps)) + + _zeta = new_gs.get_opacity.clone().detach().squeeze() + _lambda = torch.zeros_like(_zeta) + _delta = 1e-7 + _interval = 10 + num_target = int((1 - simplify) * _zeta.shape[0]) + + with tqdm(total=2500, disable=not verbose, desc='Simplifying Gaussian') as pbar: + for i in range(2500): + # prune + if i % 100 == 0: + mask = new_gs.get_opacity.squeeze() > 0.05 + mask = torch.nonzero(mask).squeeze() + new_gs._xyz = torch.nn.Parameter(new_gs._xyz[mask]) + new_gs._rotation = torch.nn.Parameter(new_gs._rotation[mask]) + new_gs._scaling = torch.nn.Parameter(new_gs._scaling[mask]) + new_gs._opacity = torch.nn.Parameter(new_gs._opacity[mask]) + new_gs._features_dc = new_gs._features_dc[mask] + new_gs._features_rest = new_gs._features_rest[mask] if new_gs._features_rest is not None else None + _zeta = _zeta[mask] + _lambda = _lambda[mask] + # update optimizer state + for param_group, new_param in zip(optimizer.param_groups, [new_gs._xyz, new_gs._rotation, new_gs._scaling, new_gs._opacity]): + stored_state = optimizer.state[param_group['params'][0]] + if 'exp_avg' in stored_state: + stored_state['exp_avg'] = stored_state['exp_avg'][mask] + stored_state['exp_avg_sq'] = stored_state['exp_avg_sq'][mask] + del optimizer.state[param_group['params'][0]] + param_group['params'][0] = new_param + optimizer.state[param_group['params'][0]] = stored_state + + opacity = new_gs.get_opacity.squeeze() + + # sparisfy + if i % _interval == 0: + _zeta = _lambda + opacity.detach() + if opacity.shape[0] > num_target: + index = _zeta.topk(num_target)[1] + _m = torch.ones_like(_zeta, dtype=torch.bool) + _m[index] = 0 + _zeta[_m] = 0 + _lambda = _lambda + opacity.detach() - _zeta + + # sample a random view + view_idx = np.random.randint(len(observations)) + observation = observations[view_idx] + extrinsic = extrinsics[view_idx] + intrinsic = intrinsics[view_idx] + + color = renderer.render(new_gs, extrinsic, intrinsic)['color'] + rgb_loss = torch.nn.functional.l1_loss(color, observation) + loss = rgb_loss + \ + _delta * torch.sum(torch.pow(_lambda + opacity - _zeta, 2)) + + optimizer.zero_grad() + loss.backward() + optimizer.step() + + # update lr + for j in range(len(optimizer.param_groups)): + optimizer.param_groups[j]['lr'] = cosine_anealing(optimizer, i, 2500, start_lr[j], end_lr[j]) + + pbar.set_postfix({'loss': rgb_loss.item(), 'num': opacity.shape[0], 'lambda': _lambda.mean().item()}) + pbar.update() + + new_gs._xyz = new_gs._xyz.data + new_gs._rotation = new_gs._rotation.data + new_gs._scaling = new_gs._scaling.data + new_gs._opacity = new_gs._opacity.data + + return new_gs diff --git a/trellis/utils/random_utils.py b/trellis/utils/random_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..5b668c277b51f4930991912a80573adc79364028 --- /dev/null +++ b/trellis/utils/random_utils.py @@ -0,0 +1,30 @@ +import numpy as np + +PRIMES = [2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53] + +def radical_inverse(base, n): + val = 0 + inv_base = 1.0 / base + inv_base_n = inv_base + while n > 0: + digit = n % base + val += digit * inv_base_n + n //= base + inv_base_n *= inv_base + return val + +def halton_sequence(dim, n): + return [radical_inverse(PRIMES[dim], n) for dim in range(dim)] + +def hammersley_sequence(dim, n, num_samples): + return [n / num_samples] + halton_sequence(dim - 1, n) + +def sphere_hammersley_sequence(n, num_samples, offset=(0, 0), remap=False): + u, v = hammersley_sequence(2, n, num_samples) + u += offset[0] / num_samples + v += offset[1] + if remap: + u = 2 * u if u < 0.25 else 2 / 3 * u + 1 / 3 + theta = np.arccos(1 - 2 * u) - np.pi / 2 + phi = v * 2 * np.pi + return [phi, theta] \ No newline at end of file diff --git a/trellis/utils/render_utils.py b/trellis/utils/render_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..f0df07ae40f51d4fbbe185082c1221fce9c0759d --- /dev/null +++ b/trellis/utils/render_utils.py @@ -0,0 +1,263 @@ +import torch +import numpy as np +from tqdm import tqdm +import utils3d +from PIL import Image + +from ..renderers import MeshRenderer +from ..representations import Octree, Gaussian, MeshExtractResult +from .random_utils import sphere_hammersley_sequence + + +def yaw_pitch_r_fov_to_extrinsics_intrinsics(yaws, pitchs, rs, fovs, device='cuda'): + is_list = isinstance(yaws, list) + if not is_list: + yaws = [yaws] + pitchs = [pitchs] + if not isinstance(rs, list): + rs = [rs] * len(yaws) + if not isinstance(fovs, list): + fovs = [fovs] * len(yaws) + extrinsics = [] + intrinsics = [] + for yaw, pitch, r, fov in zip(yaws, pitchs, rs, fovs): + fov = torch.deg2rad(torch.tensor(float(fov))).to(device) + yaw = torch.tensor(float(yaw)).to(device) + pitch = torch.tensor(float(pitch)).to(device) + orig = torch.tensor([ + torch.sin(yaw) * torch.cos(pitch), + torch.cos(yaw) * torch.cos(pitch), + torch.sin(pitch), + ]).to(device) * r + extr = utils3d.torch.extrinsics_look_at(orig, torch.tensor([0, 0, 0]).float().to(device), torch.tensor([0, 0, 1]).float().to(device)) + intr = utils3d.torch.intrinsics_from_fov_xy(fov, fov) + extrinsics.append(extr) + intrinsics.append(intr) + if not is_list: + extrinsics = extrinsics[0] + intrinsics = intrinsics[0] + return extrinsics, intrinsics + + +def render_frames(sample, extrinsics, intrinsics, options={}, colors_overwrite=None, verbose=True, need_depth=False, opt=False, **kwargs): + if isinstance(sample, MeshExtractResult): + renderer = MeshRenderer() + renderer.rendering_options.resolution = options.get('resolution', 1024) + renderer.rendering_options.near = options.get('near', 1) + renderer.rendering_options.far = options.get('far', 100) + renderer.rendering_options.ssaa = options.get('ssaa', 4) + elif isinstance(sample, Gaussian): + # from ..renderers import GSplatRenderer, GaussianRenderer + # renderer = GSplatRenderer() + from ..renderers import GaussianRenderer + renderer = GaussianRenderer() + renderer.rendering_options.resolution = options.get('resolution', 1024) + renderer.rendering_options.near = options.get('near', 0.8) + renderer.rendering_options.far = options.get('far', 1.6) + renderer.rendering_options.bg_color = options.get('bg_color', (0, 0, 0)) + renderer.rendering_options.ssaa = options.get('ssaa', 1) + renderer.pipe.kernel_size = kwargs.get('kernel_size', 0.1) + renderer.pipe.use_mip_gaussian = True + elif isinstance(sample, Octree): + from ..renderers import OctreeRenderer + renderer = OctreeRenderer() + renderer.rendering_options.resolution = options.get('resolution', 512) + renderer.rendering_options.near = options.get('near', 0.8) + renderer.rendering_options.far = options.get('far', 1.6) + renderer.rendering_options.bg_color = options.get('bg_color', (0, 0, 0)) + renderer.rendering_options.ssaa = options.get('ssaa', 4) + renderer.pipe.primitive = sample.primitive + else: + raise ValueError(f'Unsupported sample type: {type(sample)}') + + rets = {} + for j, (extr, intr) in tqdm(enumerate(zip(extrinsics, intrinsics)), desc='Rendering', disable=not verbose): + if not isinstance(sample, MeshExtractResult): + res = renderer.render(sample, extr, intr, colors_overwrite=colors_overwrite, need_depth=need_depth) + if 'color' not in rets: rets['color'] = [] + if 'depth' not in rets: rets['depth'] = [] + rets['color'].append(res['color'].clamp(0, 1) if opt else \ + np.clip(res['color'].detach().cpu().numpy().transpose(1, 2, 0) * 255, 0, 255).astype(np.uint8)) + if 'percent_depth' in res: + rets['depth'].append(res['percent_depth'] if opt else res['percent_depth'].detach().cpu().numpy()) + elif 'depth' in res: + rets['depth'].append(res['depth'] if opt else res['depth'].detach().cpu().numpy()) + else: + rets['depth'].append(None) + else: + return_types = kwargs.get('return_types', ["color", "normal", "nocs", "depth", "mask"]) + res = renderer.render(sample, extr, intr, return_types = return_types) + if 'normal' not in rets: rets['normal'] = [] + if 'color' not in rets: rets['color'] = [] + if 'nocs' not in rets: rets['nocs'] = [] + if 'depth' not in rets: rets['depth'] = [] + if 'mask' not in rets: rets['mask'] = [] + if 'color' in return_types: + rets['color'].append(res['color'].clamp(0,1) if opt else \ + np.clip(res['color'].detach().cpu().numpy().transpose(1, 2, 0) * 255, 0, 255).astype(np.uint8)) + rets['normal'].append(res['normal'].clamp(0,1) if opt else \ + np.clip(res['normal'].detach().cpu().numpy().transpose(1, 2, 0) * 255, 0, 255).astype(np.uint8)) + rets['nocs'].append(res['nocs'].clamp(0,1) if opt else \ + np.clip(res['nocs'].detach().cpu().numpy().transpose(1, 2, 0) * 255, 0, 255).astype(np.uint8)) + rets['depth'].append(res['depth'] if opt else \ + res['depth'].detach().cpu().numpy()) + rets['mask'].append(res['mask'].detach().cpu().numpy().astype(np.uint8)) + return rets + +def render_orth_frames(sample, extrinsics, projections, options={}, colors_overwrite=None, verbose=True, **kwargs): + # Select renderer according to sample type + if isinstance(sample, MeshExtractResult): + renderer = MeshRenderer() + renderer.rendering_options.resolution = options.get('resolution', 1024) + renderer.rendering_options.ssaa = options.get('ssaa', 4) + else: + raise ValueError(f'Unsupported sample type: {type(sample)}') + + rets = {} + for j, extr in tqdm(enumerate(extrinsics), desc='Rendering Orthographic', disable=not verbose): + res = renderer.render(sample, extr, None, perspective=projections[j], return_types=["normal", "nocs", "depth"]) + if 'normal' not in rets: + rets['normal'] = [] + if 'color' not in rets: + rets['color'] = [] + if 'nocs' not in rets: + rets['nocs'] = [] + if 'depth' not in rets: + rets['depth'] = [] + rets['normal'].append(np.clip( + res['normal'].detach().cpu().numpy().transpose(1, 2, 0) * 255, 0, 255 + ).astype(np.uint8)) + rets['nocs'].append(np.clip( + res['nocs'].detach().cpu().numpy().transpose(1, 2, 0) * 255, 0, 255 + ).astype(np.uint8)) + rets['depth'].append(res['depth'].detach().cpu().numpy()) + return rets + +def get_ortho_projection_matrix(left, right, bottom, top, near, far): + """ + 使用 torch 创建正交投影矩阵, 使用标准的正交投影矩阵公式: + [ 2/(r-l) 0 0 -(r+l)/(r-l) ] + [ 0 2/(t-b) 0 -(t+b)/(t-b) ] + [ 0 0 -2/(f-n) -(f+n)/(f-n) ] + [ 0 0 0 1 ] + """ + projection_matrix = torch.zeros((4, 4), dtype=torch.float32) + + projection_matrix[0, 0] = 2.0 / (right - left) + projection_matrix[1, 1] = 2.0 / (top - bottom) + projection_matrix[2, 2] = -2.0 / (far - near) + projection_matrix[3, 3] = 1.0 + + projection_matrix[0, 3] = -(right + left) / (right - left) + projection_matrix[1, 3] = -(top + bottom) / (top - bottom) + projection_matrix[2, 3] = (far + near) / (far - near) + + return projection_matrix + + +def intrinsics_to_projection( + intrinsics: torch.Tensor, + near: float, + far: float, + ) -> torch.Tensor: + """ + OpenCV intrinsics to OpenGL perspective matrix + + Args: + intrinsics (torch.Tensor): [3, 3] OpenCV intrinsics matrix + near (float): near plane to clip + far (float): far plane to clip + Returns: + (torch.Tensor): [4, 4] OpenGL perspective matrix + """ + fx, fy = intrinsics[0, 0], intrinsics[1, 1] + cx, cy = intrinsics[0, 2], intrinsics[1, 2] + ret = torch.zeros((4, 4), dtype=intrinsics.dtype, device=intrinsics.device) + ret[0, 0] = 2 * fx + ret[1, 1] = 2 * fy + ret[0, 2] = 2 * cx - 1 + ret[1, 2] = - 2 * cy + 1 + ret[2, 2] = far / (far - near) + ret[2, 3] = near * far / (near - far) + ret[3, 2] = 1. + return ret + +def render_ortho_video(sample, resolution=512, ssaa=4, bg_color=(0, 0, 0), num_frames=300, r=2, inverse_direction=False, pitch=-1, **kwargs): + if inverse_direction: + yaws = torch.linspace(3.1415, -3.1415, num_frames) + else: + yaws = torch.linspace(0, 2 * 3.1415, num_frames) + if pitch != -1: + pitch = pitch * torch.ones(num_frames) + else: + pitch = 0.25 + 0.5 * torch.sin(torch.linspace(0, 2 * 3.1415, num_frames)) + yaws = yaws.tolist() + pitchs = pitch.tolist() + + ortho_scale = 0.6 + extrinsics, intrinsics = yaw_pitch_r_fov_to_extrinsics_intrinsics(yaws, pitchs, r, 40) + + projection = get_ortho_projection_matrix(-ortho_scale, ortho_scale, -ortho_scale, ortho_scale, 1e-6, 100).to(extrinsics[0].device) + projections = [projection] * num_frames + render_results = render_orth_frames(sample, extrinsics, projections, {'resolution': resolution, 'bg_color': bg_color, 'ssaa': ssaa}, **kwargs) + render_results.update({'extrinsics': extrinsics, 'intrinsics': None, 'projections': projections}) + return render_results + + +def render_multiview(sample, resolution=518, ssaa=4, bg_color=(0, 0, 0), num_frames=30, r = 2, fov = 40, random_offset=False, only_color=False, **kwargs): + if random_offset: + yaws = [] + pitchs = [] + offset = (np.random.rand(), np.random.rand()) + for i in range(num_frames): + y, p = sphere_hammersley_sequence(i, num_frames, offset) + yaws.append(y) + pitchs.append(p) + else: + cams = [sphere_hammersley_sequence(i, num_frames) for i in range(num_frames)] + yaws = [cam[0] for cam in cams] + pitchs = [cam[1] for cam in cams] + extrinsics, intrinsics = yaw_pitch_r_fov_to_extrinsics_intrinsics(yaws, pitchs, r, fov) + res = render_frames(sample, extrinsics, intrinsics, {'resolution': resolution, 'bg_color': bg_color, 'ssaa': ssaa}, **kwargs) + return res['color'] if only_color else res, extrinsics, intrinsics + +def render_video(sample, resolution=512, ssaa=4, bg_color=(0, 0, 0), num_frames=300, r=2, fov=40, + inverse_direction=False, pitch=-1, **kwargs): + if inverse_direction: + yaws = torch.linspace(3.1415, -3.1415, num_frames) + # pitch = 0.25 + 0.5 * torch.sin(torch.linspace(2 * 3.1415, 0, num_frames)) + else: + yaws = torch.linspace(0, 2 * 3.1415, num_frames) + if pitch != -1: + pitch = pitch * torch.ones(num_frames) + else: + pitch = 0.25 + 0.5 * torch.sin(torch.linspace(0, 2 * 3.1415, num_frames)) + yaws = yaws.tolist() + pitch = pitch.tolist() + extrinsics, intrinsics = yaw_pitch_r_fov_to_extrinsics_intrinsics(yaws, pitch, r, fov) + + res = render_frames(sample, extrinsics, intrinsics, {'resolution': resolution, 'bg_color': bg_color, 'ssaa': ssaa}, **kwargs) + res.update({'extrinsics': extrinsics, 'intrinsics': intrinsics}) + return res + +def render_condition_images(sample, resolution=512, ssaa=4, bg_color=(0, 0, 0), num_frames=300, r=2, fov=40, **kwargs): + yaws = [] + pitchs = [] + offset = (np.random.rand(), np.random.rand()) + for i in range(num_frames): + y, p = sphere_hammersley_sequence(i, num_frames, offset) + yaws.append(y) + pitchs.append(p) + + fov_min, fov_max = 10, 70 + radius_min = np.sqrt(3) / 2 / np.sin(fov_max / 360 * np.pi) + radius_max = np.sqrt(3) / 2 / np.sin(fov_min / 360 * np.pi) + k_min = 1 / radius_max**2 + k_max = 1 / radius_min**2 + ks = np.random.uniform(k_min, k_max, (1000000,)) + radius = [1 / np.sqrt(k) for k in ks] + fov = [2 * np.arcsin(np.sqrt(3) / 2 / r) for r in radius] + fov = [value_in_radians * 180 / np.pi for value_in_radians in fov] + + extrinsics, intrinsics = yaw_pitch_r_fov_to_extrinsics_intrinsics(yaws, pitchs, radius, fov) + return render_frames(sample, extrinsics, intrinsics, {'resolution': resolution, 'bg_color': bg_color, 'ssaa': ssaa}, **kwargs), extrinsics, intrinsics diff --git a/wheels/vggt/vggt/heads/__pycache__/camera_head.cpython-310.pyc b/wheels/vggt/vggt/heads/__pycache__/camera_head.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8e3eb07bb164560ef041dd6f7e8052b57a3790dc Binary files /dev/null and b/wheels/vggt/vggt/heads/__pycache__/camera_head.cpython-310.pyc differ diff --git a/wheels/vggt/vggt/heads/__pycache__/dpt_head.cpython-310.pyc b/wheels/vggt/vggt/heads/__pycache__/dpt_head.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..395461170fbd7f9a30fbca5e079bc98ff6a0c19d Binary files /dev/null and b/wheels/vggt/vggt/heads/__pycache__/dpt_head.cpython-310.pyc differ diff --git a/wheels/vggt/vggt/heads/__pycache__/head_act.cpython-310.pyc b/wheels/vggt/vggt/heads/__pycache__/head_act.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f68afd32f24ee5ebd9854658cca189078b7fa7ca Binary files /dev/null and b/wheels/vggt/vggt/heads/__pycache__/head_act.cpython-310.pyc differ diff --git a/wheels/vggt/vggt/heads/__pycache__/track_head.cpython-310.pyc b/wheels/vggt/vggt/heads/__pycache__/track_head.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e12ff55e5e669523e6f21a7d13ee445eb7cc9162 Binary files /dev/null and b/wheels/vggt/vggt/heads/__pycache__/track_head.cpython-310.pyc differ diff --git a/wheels/vggt/vggt/heads/__pycache__/utils.cpython-310.pyc b/wheels/vggt/vggt/heads/__pycache__/utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..09d6d85815cc6ed0df6db51183c2e742bdcae0fb Binary files /dev/null and b/wheels/vggt/vggt/heads/__pycache__/utils.cpython-310.pyc differ diff --git a/wheels/vggt/vggt/heads/camera_head.py b/wheels/vggt/vggt/heads/camera_head.py new file mode 100644 index 0000000000000000000000000000000000000000..176d76fb5baeb3a42fa3675a1d1fb14010f2904d --- /dev/null +++ b/wheels/vggt/vggt/heads/camera_head.py @@ -0,0 +1,162 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import math +import numpy as np + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from vggt.layers import Mlp +from vggt.layers.block import Block +from vggt.heads.head_act import activate_pose + + +class CameraHead(nn.Module): + """ + CameraHead predicts camera parameters from token representations using iterative refinement. + + It applies a series of transformer blocks (the "trunk") to dedicated camera tokens. + """ + + def __init__( + self, + dim_in: int = 2048, + trunk_depth: int = 4, + pose_encoding_type: str = "absT_quaR_FoV", + num_heads: int = 16, + mlp_ratio: int = 4, + init_values: float = 0.01, + trans_act: str = "linear", + quat_act: str = "linear", + fl_act: str = "relu", # Field of view activations: ensures FOV values are positive. + ): + super().__init__() + + if pose_encoding_type == "absT_quaR_FoV": + self.target_dim = 9 + else: + raise ValueError(f"Unsupported camera encoding type: {pose_encoding_type}") + + self.trans_act = trans_act + self.quat_act = quat_act + self.fl_act = fl_act + self.trunk_depth = trunk_depth + + # Build the trunk using a sequence of transformer blocks. + self.trunk = nn.Sequential( + *[ + Block( + dim=dim_in, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + init_values=init_values, + ) + for _ in range(trunk_depth) + ] + ) + + # Normalizations for camera token and trunk output. + self.token_norm = nn.LayerNorm(dim_in) + self.trunk_norm = nn.LayerNorm(dim_in) + + # Learnable empty camera pose token. + self.empty_pose_tokens = nn.Parameter(torch.zeros(1, 1, self.target_dim)) + self.embed_pose = nn.Linear(self.target_dim, dim_in) + + # Module for producing modulation parameters: shift, scale, and a gate. + self.poseLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(dim_in, 3 * dim_in, bias=True)) + + # Adaptive layer normalization without affine parameters. + self.adaln_norm = nn.LayerNorm(dim_in, elementwise_affine=False, eps=1e-6) + self.pose_branch = Mlp( + in_features=dim_in, + hidden_features=dim_in // 2, + out_features=self.target_dim, + drop=0, + ) + + def forward(self, aggregated_tokens_list: list, num_iterations: int = 4) -> list: + """ + Forward pass to predict camera parameters. + + Args: + aggregated_tokens_list (list): List of token tensors from the network; + the last tensor is used for prediction. + num_iterations (int, optional): Number of iterative refinement steps. Defaults to 4. + + Returns: + list: A list of predicted camera encodings (post-activation) from each iteration. + """ + # Use tokens from the last block for camera prediction. + tokens = aggregated_tokens_list[-1] + + # Extract the camera tokens + pose_tokens = tokens[:, :, 0] + pose_tokens = self.token_norm(pose_tokens) + + pred_pose_enc_list = self.trunk_fn(pose_tokens, num_iterations) + return pred_pose_enc_list + + def trunk_fn(self, pose_tokens: torch.Tensor, num_iterations: int) -> list: + """ + Iteratively refine camera pose predictions. + + Args: + pose_tokens (torch.Tensor): Normalized camera tokens with shape [B, 1, C]. + num_iterations (int): Number of refinement iterations. + + Returns: + list: List of activated camera encodings from each iteration. + """ + B, S, C = pose_tokens.shape # S is expected to be 1. + pred_pose_enc = None + pred_pose_enc_list = [] + + for _ in range(num_iterations): + # Use a learned empty pose for the first iteration. + if pred_pose_enc is None: + module_input = self.embed_pose(self.empty_pose_tokens.expand(B, S, -1)) + else: + # Detach the previous prediction to avoid backprop through time. + pred_pose_enc = pred_pose_enc.detach() + module_input = self.embed_pose(pred_pose_enc) + + # Generate modulation parameters and split them into shift, scale, and gate components. + shift_msa, scale_msa, gate_msa = self.poseLN_modulation(module_input).chunk(3, dim=-1) + + # Adaptive layer normalization and modulation. + pose_tokens_modulated = gate_msa * modulate(self.adaln_norm(pose_tokens), shift_msa, scale_msa) + pose_tokens_modulated = pose_tokens_modulated + pose_tokens + + pose_tokens_modulated = self.trunk(pose_tokens_modulated) + # Compute the delta update for the pose encoding. + pred_pose_enc_delta = self.pose_branch(self.trunk_norm(pose_tokens_modulated)) + + if pred_pose_enc is None: + pred_pose_enc = pred_pose_enc_delta + else: + pred_pose_enc = pred_pose_enc + pred_pose_enc_delta + + # Apply final activation functions for translation, quaternion, and field-of-view. + activated_pose = activate_pose( + pred_pose_enc, + trans_act=self.trans_act, + quat_act=self.quat_act, + fl_act=self.fl_act, + ) + pred_pose_enc_list.append(activated_pose) + + return pred_pose_enc_list + + +def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor) -> torch.Tensor: + """ + Modulate the input tensor using scaling and shifting parameters. + """ + # modified from https://github.com/facebookresearch/DiT/blob/796c29e532f47bba17c5b9c5eb39b9354b8b7c64/models.py#L19 + return x * (1 + scale) + shift diff --git a/wheels/vggt/vggt/heads/dpt_head.py b/wheels/vggt/vggt/heads/dpt_head.py new file mode 100644 index 0000000000000000000000000000000000000000..c8c8af9a499eab6a715971947cda587af8f66960 --- /dev/null +++ b/wheels/vggt/vggt/heads/dpt_head.py @@ -0,0 +1,497 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + + +# Inspired by https://github.com/DepthAnything/Depth-Anything-V2 + + +import os +from typing import List, Dict, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from .head_act import activate_head +from .utils import create_uv_grid, position_grid_to_embed + + +class DPTHead(nn.Module): + """ + DPT Head for dense prediction tasks. + + This implementation follows the architecture described in "Vision Transformers for Dense Prediction" + (https://arxiv.org/abs/2103.13413). The DPT head processes features from a vision transformer + backbone and produces dense predictions by fusing multi-scale features. + + Args: + dim_in (int): Input dimension (channels). + patch_size (int, optional): Patch size. Default is 14. + output_dim (int, optional): Number of output channels. Default is 4. + activation (str, optional): Activation type. Default is "inv_log". + conf_activation (str, optional): Confidence activation type. Default is "expp1". + features (int, optional): Feature channels for intermediate representations. Default is 256. + out_channels (List[int], optional): Output channels for each intermediate layer. + intermediate_layer_idx (List[int], optional): Indices of layers from aggregated tokens used for DPT. + pos_embed (bool, optional): Whether to use positional embedding. Default is True. + feature_only (bool, optional): If True, return features only without the last several layers and activation head. Default is False. + down_ratio (int, optional): Downscaling factor for the output resolution. Default is 1. + """ + + def __init__( + self, + dim_in: int, + patch_size: int = 14, + output_dim: int = 4, + activation: str = "inv_log", + conf_activation: str = "expp1", + features: int = 256, + out_channels: List[int] = [256, 512, 1024, 1024], + intermediate_layer_idx: List[int] = [4, 11, 17, 23], + pos_embed: bool = True, + feature_only: bool = False, + down_ratio: int = 1, + ) -> None: + super(DPTHead, self).__init__() + self.patch_size = patch_size + self.activation = activation + self.conf_activation = conf_activation + self.pos_embed = pos_embed + self.feature_only = feature_only + self.down_ratio = down_ratio + self.intermediate_layer_idx = intermediate_layer_idx + + self.norm = nn.LayerNorm(dim_in) + + # Projection layers for each output channel from tokens. + self.projects = nn.ModuleList( + [ + nn.Conv2d( + in_channels=dim_in, + out_channels=oc, + kernel_size=1, + stride=1, + padding=0, + ) + for oc in out_channels + ] + ) + + # Resize layers for upsampling feature maps. + self.resize_layers = nn.ModuleList( + [ + nn.ConvTranspose2d( + in_channels=out_channels[0], out_channels=out_channels[0], kernel_size=4, stride=4, padding=0 + ), + nn.ConvTranspose2d( + in_channels=out_channels[1], out_channels=out_channels[1], kernel_size=2, stride=2, padding=0 + ), + nn.Identity(), + nn.Conv2d( + in_channels=out_channels[3], out_channels=out_channels[3], kernel_size=3, stride=2, padding=1 + ), + ] + ) + + self.scratch = _make_scratch( + out_channels, + features, + expand=False, + ) + + # Attach additional modules to scratch. + self.scratch.stem_transpose = None + self.scratch.refinenet1 = _make_fusion_block(features) + self.scratch.refinenet2 = _make_fusion_block(features) + self.scratch.refinenet3 = _make_fusion_block(features) + self.scratch.refinenet4 = _make_fusion_block(features, has_residual=False) + + head_features_1 = features + head_features_2 = 32 + + if feature_only: + self.scratch.output_conv1 = nn.Conv2d(head_features_1, head_features_1, kernel_size=3, stride=1, padding=1) + else: + self.scratch.output_conv1 = nn.Conv2d( + head_features_1, head_features_1 // 2, kernel_size=3, stride=1, padding=1 + ) + conv2_in_channels = head_features_1 // 2 + + self.scratch.output_conv2 = nn.Sequential( + nn.Conv2d(conv2_in_channels, head_features_2, kernel_size=3, stride=1, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(head_features_2, output_dim, kernel_size=1, stride=1, padding=0), + ) + + def forward( + self, + aggregated_tokens_list: List[torch.Tensor], + images: torch.Tensor, + patch_start_idx: int, + frames_chunk_size: int = 8, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + """ + Forward pass through the DPT head, supports processing by chunking frames. + Args: + aggregated_tokens_list (List[Tensor]): List of token tensors from different transformer layers. + images (Tensor): Input images with shape [B, S, 3, H, W], in range [0, 1]. + patch_start_idx (int): Starting index for patch tokens in the token sequence. + Used to separate patch tokens from other tokens (e.g., camera or register tokens). + frames_chunk_size (int, optional): Number of frames to process in each chunk. + If None or larger than S, all frames are processed at once. Default: 8. + + Returns: + Tensor or Tuple[Tensor, Tensor]: + - If feature_only=True: Feature maps with shape [B, S, C, H, W] + - Otherwise: Tuple of (predictions, confidence) both with shape [B, S, 1, H, W] + """ + B, S, _, H, W = images.shape + + # If frames_chunk_size is not specified or greater than S, process all frames at once + if frames_chunk_size is None or frames_chunk_size >= S: + return self._forward_impl(aggregated_tokens_list, images, patch_start_idx) + + # Otherwise, process frames in chunks to manage memory usage + assert frames_chunk_size > 0 + + # Process frames in batches + all_preds = [] + all_conf = [] + + for frames_start_idx in range(0, S, frames_chunk_size): + frames_end_idx = min(frames_start_idx + frames_chunk_size, S) + + # Process batch of frames + if self.feature_only: + chunk_output = self._forward_impl( + aggregated_tokens_list, images, patch_start_idx, frames_start_idx, frames_end_idx + ) + all_preds.append(chunk_output) + else: + chunk_preds, chunk_conf = self._forward_impl( + aggregated_tokens_list, images, patch_start_idx, frames_start_idx, frames_end_idx + ) + all_preds.append(chunk_preds) + all_conf.append(chunk_conf) + + # Concatenate results along the sequence dimension + if self.feature_only: + return torch.cat(all_preds, dim=1) + else: + return torch.cat(all_preds, dim=1), torch.cat(all_conf, dim=1) + + def _forward_impl( + self, + aggregated_tokens_list: List[torch.Tensor], + images: torch.Tensor, + patch_start_idx: int, + frames_start_idx: int = None, + frames_end_idx: int = None, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + """ + Implementation of the forward pass through the DPT head. + + This method processes a specific chunk of frames from the sequence. + + Args: + aggregated_tokens_list (List[Tensor]): List of token tensors from different transformer layers. + images (Tensor): Input images with shape [B, S, 3, H, W]. + patch_start_idx (int): Starting index for patch tokens. + frames_start_idx (int, optional): Starting index for frames to process. + frames_end_idx (int, optional): Ending index for frames to process. + + Returns: + Tensor or Tuple[Tensor, Tensor]: Feature maps or (predictions, confidence). + """ + if frames_start_idx is not None and frames_end_idx is not None: + images = images[:, frames_start_idx:frames_end_idx] + + B, S, _, H, W = images.shape + + patch_h, patch_w = H // self.patch_size, W // self.patch_size + + out = [] + dpt_idx = 0 + + for layer_idx in self.intermediate_layer_idx: + x = aggregated_tokens_list[layer_idx][:, :, patch_start_idx:] + + # Select frames if processing a chunk + if frames_start_idx is not None and frames_end_idx is not None: + x = x[:, frames_start_idx:frames_end_idx] + + x = x.view(B * S, -1, x.shape[-1]) + + x = self.norm(x) + + x = x.permute(0, 2, 1).reshape((x.shape[0], x.shape[-1], patch_h, patch_w)) + + x = self.projects[dpt_idx](x) + if self.pos_embed: + x = self._apply_pos_embed(x, W, H) + x = self.resize_layers[dpt_idx](x) + + out.append(x) + dpt_idx += 1 + + # Fuse features from multiple layers. + out = self.scratch_forward(out) + # Interpolate fused output to match target image resolution. + out = custom_interpolate( + out, + (int(patch_h * self.patch_size / self.down_ratio), int(patch_w * self.patch_size / self.down_ratio)), + mode="bilinear", + align_corners=True, + ) + + if self.pos_embed: + out = self._apply_pos_embed(out, W, H) + + if self.feature_only: + return out.view(B, S, *out.shape[1:]) + + out = self.scratch.output_conv2(out) + preds, conf = activate_head(out, activation=self.activation, conf_activation=self.conf_activation) + + preds = preds.view(B, S, *preds.shape[1:]) + conf = conf.view(B, S, *conf.shape[1:]) + return preds, conf + + def _apply_pos_embed(self, x: torch.Tensor, W: int, H: int, ratio: float = 0.1) -> torch.Tensor: + """ + Apply positional embedding to tensor x. + """ + patch_w = x.shape[-1] + patch_h = x.shape[-2] + pos_embed = create_uv_grid(patch_w, patch_h, aspect_ratio=W / H, dtype=x.dtype, device=x.device) + pos_embed = position_grid_to_embed(pos_embed, x.shape[1]) + pos_embed = pos_embed * ratio + pos_embed = pos_embed.permute(2, 0, 1)[None].expand(x.shape[0], -1, -1, -1) + return x + pos_embed + + def scratch_forward(self, features: List[torch.Tensor]) -> torch.Tensor: + """ + Forward pass through the fusion blocks. + + Args: + features (List[Tensor]): List of feature maps from different layers. + + Returns: + Tensor: Fused feature map. + """ + layer_1, layer_2, layer_3, layer_4 = features + + layer_1_rn = self.scratch.layer1_rn(layer_1) + layer_2_rn = self.scratch.layer2_rn(layer_2) + layer_3_rn = self.scratch.layer3_rn(layer_3) + layer_4_rn = self.scratch.layer4_rn(layer_4) + + out = self.scratch.refinenet4(layer_4_rn, size=layer_3_rn.shape[2:]) + del layer_4_rn, layer_4 + + out = self.scratch.refinenet3(out, layer_3_rn, size=layer_2_rn.shape[2:]) + del layer_3_rn, layer_3 + + out = self.scratch.refinenet2(out, layer_2_rn, size=layer_1_rn.shape[2:]) + del layer_2_rn, layer_2 + + out = self.scratch.refinenet1(out, layer_1_rn) + del layer_1_rn, layer_1 + + out = self.scratch.output_conv1(out) + return out + + +################################################################################ +# Modules +################################################################################ + + +def _make_fusion_block(features: int, size: int = None, has_residual: bool = True, groups: int = 1) -> nn.Module: + return FeatureFusionBlock( + features, + nn.ReLU(inplace=True), + deconv=False, + bn=False, + expand=False, + align_corners=True, + size=size, + has_residual=has_residual, + groups=groups, + ) + + +def _make_scratch(in_shape: List[int], out_shape: int, groups: int = 1, expand: bool = False) -> nn.Module: + scratch = nn.Module() + out_shape1 = out_shape + out_shape2 = out_shape + out_shape3 = out_shape + if len(in_shape) >= 4: + out_shape4 = out_shape + + if expand: + out_shape1 = out_shape + out_shape2 = out_shape * 2 + out_shape3 = out_shape * 4 + if len(in_shape) >= 4: + out_shape4 = out_shape * 8 + + scratch.layer1_rn = nn.Conv2d( + in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups + ) + scratch.layer2_rn = nn.Conv2d( + in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups + ) + scratch.layer3_rn = nn.Conv2d( + in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups + ) + if len(in_shape) >= 4: + scratch.layer4_rn = nn.Conv2d( + in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups + ) + return scratch + + +class ResidualConvUnit(nn.Module): + """Residual convolution module.""" + + def __init__(self, features, activation, bn, groups=1): + """Init. + + Args: + features (int): number of features + """ + super().__init__() + + self.bn = bn + self.groups = groups + self.conv1 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups) + self.conv2 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups) + + self.norm1 = None + self.norm2 = None + + self.activation = activation + self.skip_add = nn.quantized.FloatFunctional() + + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input + + Returns: + tensor: output + """ + + out = self.activation(x) + out = self.conv1(out) + if self.norm1 is not None: + out = self.norm1(out) + + out = self.activation(out) + out = self.conv2(out) + if self.norm2 is not None: + out = self.norm2(out) + + return self.skip_add.add(out, x) + + +class FeatureFusionBlock(nn.Module): + """Feature fusion block.""" + + def __init__( + self, + features, + activation, + deconv=False, + bn=False, + expand=False, + align_corners=True, + size=None, + has_residual=True, + groups=1, + ): + """Init. + + Args: + features (int): number of features + """ + super(FeatureFusionBlock, self).__init__() + + self.deconv = deconv + self.align_corners = align_corners + self.groups = groups + self.expand = expand + out_features = features + if self.expand == True: + out_features = features // 2 + + self.out_conv = nn.Conv2d( + features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=self.groups + ) + + if has_residual: + self.resConfUnit1 = ResidualConvUnit(features, activation, bn, groups=self.groups) + + self.has_residual = has_residual + self.resConfUnit2 = ResidualConvUnit(features, activation, bn, groups=self.groups) + + self.skip_add = nn.quantized.FloatFunctional() + self.size = size + + def forward(self, *xs, size=None): + """Forward pass. + + Returns: + tensor: output + """ + output = xs[0] + + if self.has_residual: + res = self.resConfUnit1(xs[1]) + output = self.skip_add.add(output, res) + + output = self.resConfUnit2(output) + + if (size is None) and (self.size is None): + modifier = {"scale_factor": 2} + elif size is None: + modifier = {"size": self.size} + else: + modifier = {"size": size} + + output = custom_interpolate(output, **modifier, mode="bilinear", align_corners=self.align_corners) + output = self.out_conv(output) + + return output + + +def custom_interpolate( + x: torch.Tensor, + size: Tuple[int, int] = None, + scale_factor: float = None, + mode: str = "bilinear", + align_corners: bool = True, +) -> torch.Tensor: + """ + Custom interpolate to avoid INT_MAX issues in nn.functional.interpolate. + """ + if size is None: + size = (int(x.shape[-2] * scale_factor), int(x.shape[-1] * scale_factor)) + + INT_MAX = 1610612736 + + input_elements = size[0] * size[1] * x.shape[0] * x.shape[1] + + if input_elements > INT_MAX: + chunks = torch.chunk(x, chunks=(input_elements // INT_MAX) + 1, dim=0) + interpolated_chunks = [ + nn.functional.interpolate(chunk, size=size, mode=mode, align_corners=align_corners) for chunk in chunks + ] + x = torch.cat(interpolated_chunks, dim=0) + return x.contiguous() + else: + return nn.functional.interpolate(x, size=size, mode=mode, align_corners=align_corners) diff --git a/wheels/vggt/vggt/heads/head_act.py b/wheels/vggt/vggt/heads/head_act.py new file mode 100644 index 0000000000000000000000000000000000000000..2dedfcf1180a653dddc99623e60df625e5897489 --- /dev/null +++ b/wheels/vggt/vggt/heads/head_act.py @@ -0,0 +1,125 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + + +import torch +import torch.nn.functional as F + + +def activate_pose(pred_pose_enc, trans_act="linear", quat_act="linear", fl_act="linear"): + """ + Activate pose parameters with specified activation functions. + + Args: + pred_pose_enc: Tensor containing encoded pose parameters [translation, quaternion, focal length] + trans_act: Activation type for translation component + quat_act: Activation type for quaternion component + fl_act: Activation type for focal length component + + Returns: + Activated pose parameters tensor + """ + T = pred_pose_enc[..., :3] + quat = pred_pose_enc[..., 3:7] + fl = pred_pose_enc[..., 7:] # or fov + + T = base_pose_act(T, trans_act) + quat = base_pose_act(quat, quat_act) + fl = base_pose_act(fl, fl_act) # or fov + + pred_pose_enc = torch.cat([T, quat, fl], dim=-1) + + return pred_pose_enc + + +def base_pose_act(pose_enc, act_type="linear"): + """ + Apply basic activation function to pose parameters. + + Args: + pose_enc: Tensor containing encoded pose parameters + act_type: Activation type ("linear", "inv_log", "exp", "relu") + + Returns: + Activated pose parameters + """ + if act_type == "linear": + return pose_enc + elif act_type == "inv_log": + return inverse_log_transform(pose_enc) + elif act_type == "exp": + return torch.exp(pose_enc) + elif act_type == "relu": + return F.relu(pose_enc) + else: + raise ValueError(f"Unknown act_type: {act_type}") + + +def activate_head(out, activation="norm_exp", conf_activation="expp1"): + """ + Process network output to extract 3D points and confidence values. + + Args: + out: Network output tensor (B, C, H, W) + activation: Activation type for 3D points + conf_activation: Activation type for confidence values + + Returns: + Tuple of (3D points tensor, confidence tensor) + """ + # Move channels from last dim to the 4th dimension => (B, H, W, C) + fmap = out.permute(0, 2, 3, 1) # B,H,W,C expected + + # Split into xyz (first C-1 channels) and confidence (last channel) + xyz = fmap[:, :, :, :-1] + conf = fmap[:, :, :, -1] + + if activation == "norm_exp": + d = xyz.norm(dim=-1, keepdim=True).clamp(min=1e-8) + xyz_normed = xyz / d + pts3d = xyz_normed * torch.expm1(d) + elif activation == "norm": + pts3d = xyz / xyz.norm(dim=-1, keepdim=True) + elif activation == "exp": + pts3d = torch.exp(xyz) + elif activation == "relu": + pts3d = F.relu(xyz) + elif activation == "inv_log": + pts3d = inverse_log_transform(xyz) + elif activation == "xy_inv_log": + xy, z = xyz.split([2, 1], dim=-1) + z = inverse_log_transform(z) + pts3d = torch.cat([xy * z, z], dim=-1) + elif activation == "sigmoid": + pts3d = torch.sigmoid(xyz) + elif activation == "linear": + pts3d = xyz + else: + raise ValueError(f"Unknown activation: {activation}") + + if conf_activation == "expp1": + conf_out = 1 + conf.exp() + elif conf_activation == "expp0": + conf_out = conf.exp() + elif conf_activation == "sigmoid": + conf_out = torch.sigmoid(conf) + else: + raise ValueError(f"Unknown conf_activation: {conf_activation}") + + return pts3d, conf_out + + +def inverse_log_transform(y): + """ + Apply inverse log transform: sign(y) * (exp(|y|) - 1) + + Args: + y: Input tensor + + Returns: + Transformed tensor + """ + return torch.sign(y) * (torch.expm1(torch.abs(y))) diff --git a/wheels/vggt/vggt/heads/track_head.py b/wheels/vggt/vggt/heads/track_head.py new file mode 100644 index 0000000000000000000000000000000000000000..9ec7199bd185060989c236997f93b93f4fc77825 --- /dev/null +++ b/wheels/vggt/vggt/heads/track_head.py @@ -0,0 +1,108 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch.nn as nn +from .dpt_head import DPTHead +from .track_modules.base_track_predictor import BaseTrackerPredictor + + +class TrackHead(nn.Module): + """ + Track head that uses DPT head to process tokens and BaseTrackerPredictor for tracking. + The tracking is performed iteratively, refining predictions over multiple iterations. + """ + + def __init__( + self, + dim_in, + patch_size=14, + features=128, + iters=4, + predict_conf=True, + stride=2, + corr_levels=7, + corr_radius=4, + hidden_size=384, + ): + """ + Initialize the TrackHead module. + + Args: + dim_in (int): Input dimension of tokens from the backbone. + patch_size (int): Size of image patches used in the vision transformer. + features (int): Number of feature channels in the feature extractor output. + iters (int): Number of refinement iterations for tracking predictions. + predict_conf (bool): Whether to predict confidence scores for tracked points. + stride (int): Stride value for the tracker predictor. + corr_levels (int): Number of correlation pyramid levels + corr_radius (int): Radius for correlation computation, controlling the search area. + hidden_size (int): Size of hidden layers in the tracker network. + """ + super().__init__() + + self.patch_size = patch_size + + # Feature extractor based on DPT architecture + # Processes tokens into feature maps for tracking + self.feature_extractor = DPTHead( + dim_in=dim_in, + patch_size=patch_size, + features=features, + feature_only=True, # Only output features, no activation + down_ratio=2, # Reduces spatial dimensions by factor of 2 + pos_embed=False, + ) + + # Tracker module that predicts point trajectories + # Takes feature maps and predicts coordinates and visibility + self.tracker = BaseTrackerPredictor( + latent_dim=features, # Match the output_dim of feature extractor + predict_conf=predict_conf, + stride=stride, + corr_levels=corr_levels, + corr_radius=corr_radius, + hidden_size=hidden_size, + ) + + self.iters = iters + + def forward(self, aggregated_tokens_list, images, patch_start_idx, query_points=None, iters=None): + """ + Forward pass of the TrackHead. + + Args: + aggregated_tokens_list (list): List of aggregated tokens from the backbone. + images (torch.Tensor): Input images of shape (B, S, C, H, W) where: + B = batch size, S = sequence length. + patch_start_idx (int): Starting index for patch tokens. + query_points (torch.Tensor, optional): Initial query points to track. + If None, points are initialized by the tracker. + iters (int, optional): Number of refinement iterations. If None, uses self.iters. + + Returns: + tuple: + - coord_preds (torch.Tensor): Predicted coordinates for tracked points. + - vis_scores (torch.Tensor): Visibility scores for tracked points. + - conf_scores (torch.Tensor): Confidence scores for tracked points (if predict_conf=True). + """ + B, S, _, H, W = images.shape + + # Extract features from tokens + # feature_maps has shape (B, S, C, H//2, W//2) due to down_ratio=2 + feature_maps = self.feature_extractor(aggregated_tokens_list, images, patch_start_idx) + + # Use default iterations if not specified + if iters is None: + iters = self.iters + + # Perform tracking using the extracted features + coord_preds, vis_scores, conf_scores = self.tracker( + query_points=query_points, + fmaps=feature_maps, + iters=iters, + ) + + return coord_preds, vis_scores, conf_scores diff --git a/wheels/vggt/vggt/heads/track_modules/__init__.py b/wheels/vggt/vggt/heads/track_modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0952fcc3f57e34b3747962e9ebd6fc57aeea63fa --- /dev/null +++ b/wheels/vggt/vggt/heads/track_modules/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. diff --git a/wheels/vggt/vggt/heads/track_modules/__pycache__/__init__.cpython-310.pyc b/wheels/vggt/vggt/heads/track_modules/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e74fc6d03876084a4baf6c8f4268450f90287056 Binary files /dev/null and b/wheels/vggt/vggt/heads/track_modules/__pycache__/__init__.cpython-310.pyc differ diff --git a/wheels/vggt/vggt/heads/track_modules/__pycache__/base_track_predictor.cpython-310.pyc b/wheels/vggt/vggt/heads/track_modules/__pycache__/base_track_predictor.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0e73f47b75db7569a26f10b822ce03b01b6fe654 Binary files /dev/null and b/wheels/vggt/vggt/heads/track_modules/__pycache__/base_track_predictor.cpython-310.pyc differ diff --git a/wheels/vggt/vggt/heads/track_modules/__pycache__/blocks.cpython-310.pyc b/wheels/vggt/vggt/heads/track_modules/__pycache__/blocks.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..926ab5a68ef51b0acd1bfcc17943063aef9e5fb7 Binary files /dev/null and b/wheels/vggt/vggt/heads/track_modules/__pycache__/blocks.cpython-310.pyc differ diff --git a/wheels/vggt/vggt/heads/track_modules/__pycache__/modules.cpython-310.pyc b/wheels/vggt/vggt/heads/track_modules/__pycache__/modules.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e02959f2cd92c662be3469b57aae189a65d78501 Binary files /dev/null and b/wheels/vggt/vggt/heads/track_modules/__pycache__/modules.cpython-310.pyc differ diff --git a/wheels/vggt/vggt/heads/track_modules/__pycache__/utils.cpython-310.pyc b/wheels/vggt/vggt/heads/track_modules/__pycache__/utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ef634fe8e79e288d3a55a99415b0896e6e3e70cc Binary files /dev/null and b/wheels/vggt/vggt/heads/track_modules/__pycache__/utils.cpython-310.pyc differ diff --git a/wheels/vggt/vggt/heads/track_modules/base_track_predictor.py b/wheels/vggt/vggt/heads/track_modules/base_track_predictor.py new file mode 100644 index 0000000000000000000000000000000000000000..3ce8ec4b66fff236e015d1bcaf85c8237a52be7a --- /dev/null +++ b/wheels/vggt/vggt/heads/track_modules/base_track_predictor.py @@ -0,0 +1,209 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn +from einops import rearrange, repeat + + +from .blocks import EfficientUpdateFormer, CorrBlock +from .utils import sample_features4d, get_2d_embedding, get_2d_sincos_pos_embed +from .modules import Mlp + + +class BaseTrackerPredictor(nn.Module): + def __init__( + self, + stride=1, + corr_levels=5, + corr_radius=4, + latent_dim=128, + hidden_size=384, + use_spaceatt=True, + depth=6, + max_scale=518, + predict_conf=True, + ): + super(BaseTrackerPredictor, self).__init__() + """ + The base template to create a track predictor + + Modified from https://github.com/facebookresearch/co-tracker/ + and https://github.com/facebookresearch/vggsfm + """ + + self.stride = stride + self.latent_dim = latent_dim + self.corr_levels = corr_levels + self.corr_radius = corr_radius + self.hidden_size = hidden_size + self.max_scale = max_scale + self.predict_conf = predict_conf + + self.flows_emb_dim = latent_dim // 2 + + self.corr_mlp = Mlp( + in_features=self.corr_levels * (self.corr_radius * 2 + 1) ** 2, + hidden_features=self.hidden_size, + out_features=self.latent_dim, + ) + + self.transformer_dim = self.latent_dim + self.latent_dim + self.latent_dim + 4 + + self.query_ref_token = nn.Parameter(torch.randn(1, 2, self.transformer_dim)) + + space_depth = depth if use_spaceatt else 0 + time_depth = depth + + self.updateformer = EfficientUpdateFormer( + space_depth=space_depth, + time_depth=time_depth, + input_dim=self.transformer_dim, + hidden_size=self.hidden_size, + output_dim=self.latent_dim + 2, + mlp_ratio=4.0, + add_space_attn=use_spaceatt, + ) + + self.fmap_norm = nn.LayerNorm(self.latent_dim) + self.ffeat_norm = nn.GroupNorm(1, self.latent_dim) + + # A linear layer to update track feats at each iteration + self.ffeat_updater = nn.Sequential(nn.Linear(self.latent_dim, self.latent_dim), nn.GELU()) + + self.vis_predictor = nn.Sequential(nn.Linear(self.latent_dim, 1)) + + if predict_conf: + self.conf_predictor = nn.Sequential(nn.Linear(self.latent_dim, 1)) + + def forward(self, query_points, fmaps=None, iters=6, return_feat=False, down_ratio=1, apply_sigmoid=True): + """ + query_points: B x N x 2, the number of batches, tracks, and xy + fmaps: B x S x C x HH x WW, the number of batches, frames, and feature dimension. + note HH and WW is the size of feature maps instead of original images + """ + B, N, D = query_points.shape + B, S, C, HH, WW = fmaps.shape + + assert D == 2, "Input points must be 2D coordinates" + + # apply a layernorm to fmaps here + fmaps = self.fmap_norm(fmaps.permute(0, 1, 3, 4, 2)) + fmaps = fmaps.permute(0, 1, 4, 2, 3) + + # Scale the input query_points because we may downsample the images + # by down_ratio or self.stride + # e.g., if a 3x1024x1024 image is processed to a 128x256x256 feature map + # its query_points should be query_points/4 + if down_ratio > 1: + query_points = query_points / float(down_ratio) + + query_points = query_points / float(self.stride) + + # Init with coords as the query points + # It means the search will start from the position of query points at the reference frames + coords = query_points.clone().reshape(B, 1, N, 2).repeat(1, S, 1, 1) + + # Sample/extract the features of the query points in the query frame + query_track_feat = sample_features4d(fmaps[:, 0], coords[:, 0]) + + # init track feats by query feats + track_feats = query_track_feat.unsqueeze(1).repeat(1, S, 1, 1) # B, S, N, C + # back up the init coords + coords_backup = coords.clone() + + fcorr_fn = CorrBlock(fmaps, num_levels=self.corr_levels, radius=self.corr_radius) + + coord_preds = [] + + # Iterative Refinement + for _ in range(iters): + # Detach the gradients from the last iteration + # (in my experience, not very important for performance) + coords = coords.detach() + + fcorrs = fcorr_fn.corr_sample(track_feats, coords) + + corr_dim = fcorrs.shape[3] + fcorrs_ = fcorrs.permute(0, 2, 1, 3).reshape(B * N, S, corr_dim) + fcorrs_ = self.corr_mlp(fcorrs_) + + # Movement of current coords relative to query points + flows = (coords - coords[:, 0:1]).permute(0, 2, 1, 3).reshape(B * N, S, 2) + + flows_emb = get_2d_embedding(flows, self.flows_emb_dim, cat_coords=False) + + # (In my trials, it is also okay to just add the flows_emb instead of concat) + flows_emb = torch.cat([flows_emb, flows / self.max_scale, flows / self.max_scale], dim=-1) + + track_feats_ = track_feats.permute(0, 2, 1, 3).reshape(B * N, S, self.latent_dim) + + # Concatenate them as the input for the transformers + transformer_input = torch.cat([flows_emb, fcorrs_, track_feats_], dim=2) + + # 2D positional embed + # TODO: this can be much simplified + pos_embed = get_2d_sincos_pos_embed(self.transformer_dim, grid_size=(HH, WW)).to(query_points.device) + sampled_pos_emb = sample_features4d(pos_embed.expand(B, -1, -1, -1), coords[:, 0]) + + sampled_pos_emb = rearrange(sampled_pos_emb, "b n c -> (b n) c").unsqueeze(1) + + x = transformer_input + sampled_pos_emb + + # Add the query ref token to the track feats + query_ref_token = torch.cat( + [self.query_ref_token[:, 0:1], self.query_ref_token[:, 1:2].expand(-1, S - 1, -1)], dim=1 + ) + x = x + query_ref_token.to(x.device).to(x.dtype) + + # B, N, S, C + x = rearrange(x, "(b n) s d -> b n s d", b=B) + + # Compute the delta coordinates and delta track features + delta, _ = self.updateformer(x) + + # BN, S, C + delta = rearrange(delta, " b n s d -> (b n) s d", b=B) + delta_coords_ = delta[:, :, :2] + delta_feats_ = delta[:, :, 2:] + + track_feats_ = track_feats_.reshape(B * N * S, self.latent_dim) + delta_feats_ = delta_feats_.reshape(B * N * S, self.latent_dim) + + # Update the track features + track_feats_ = self.ffeat_updater(self.ffeat_norm(delta_feats_)) + track_feats_ + + track_feats = track_feats_.reshape(B, N, S, self.latent_dim).permute(0, 2, 1, 3) # BxSxNxC + + # B x S x N x 2 + coords = coords + delta_coords_.reshape(B, N, S, 2).permute(0, 2, 1, 3) + + # Force coord0 as query + # because we assume the query points should not be changed + coords[:, 0] = coords_backup[:, 0] + + # The predicted tracks are in the original image scale + if down_ratio > 1: + coord_preds.append(coords * self.stride * down_ratio) + else: + coord_preds.append(coords * self.stride) + + # B, S, N + vis_e = self.vis_predictor(track_feats.reshape(B * S * N, self.latent_dim)).reshape(B, S, N) + if apply_sigmoid: + vis_e = torch.sigmoid(vis_e) + + if self.predict_conf: + conf_e = self.conf_predictor(track_feats.reshape(B * S * N, self.latent_dim)).reshape(B, S, N) + if apply_sigmoid: + conf_e = torch.sigmoid(conf_e) + else: + conf_e = None + + if return_feat: + return coord_preds, vis_e, track_feats, query_track_feat, conf_e + else: + return coord_preds, vis_e, conf_e diff --git a/wheels/vggt/vggt/heads/track_modules/blocks.py b/wheels/vggt/vggt/heads/track_modules/blocks.py new file mode 100644 index 0000000000000000000000000000000000000000..8e7763f4fd8f515662421db192594380dbb574e5 --- /dev/null +++ b/wheels/vggt/vggt/heads/track_modules/blocks.py @@ -0,0 +1,246 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + + +# Modified from https://github.com/facebookresearch/co-tracker/ + +import math +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .utils import bilinear_sampler +from .modules import Mlp, AttnBlock, CrossAttnBlock, ResidualBlock + + +class EfficientUpdateFormer(nn.Module): + """ + Transformer model that updates track estimates. + """ + + def __init__( + self, + space_depth=6, + time_depth=6, + input_dim=320, + hidden_size=384, + num_heads=8, + output_dim=130, + mlp_ratio=4.0, + add_space_attn=True, + num_virtual_tracks=64, + ): + super().__init__() + + self.out_channels = 2 + self.num_heads = num_heads + self.hidden_size = hidden_size + self.add_space_attn = add_space_attn + + # Add input LayerNorm before linear projection + self.input_norm = nn.LayerNorm(input_dim) + self.input_transform = torch.nn.Linear(input_dim, hidden_size, bias=True) + + # Add output LayerNorm before final projection + self.output_norm = nn.LayerNorm(hidden_size) + self.flow_head = torch.nn.Linear(hidden_size, output_dim, bias=True) + self.num_virtual_tracks = num_virtual_tracks + + if self.add_space_attn: + self.virual_tracks = nn.Parameter(torch.randn(1, num_virtual_tracks, 1, hidden_size)) + else: + self.virual_tracks = None + + self.time_blocks = nn.ModuleList( + [ + AttnBlock( + hidden_size, + num_heads, + mlp_ratio=mlp_ratio, + attn_class=nn.MultiheadAttention, + ) + for _ in range(time_depth) + ] + ) + + if add_space_attn: + self.space_virtual_blocks = nn.ModuleList( + [ + AttnBlock( + hidden_size, + num_heads, + mlp_ratio=mlp_ratio, + attn_class=nn.MultiheadAttention, + ) + for _ in range(space_depth) + ] + ) + self.space_point2virtual_blocks = nn.ModuleList( + [CrossAttnBlock(hidden_size, hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(space_depth)] + ) + self.space_virtual2point_blocks = nn.ModuleList( + [CrossAttnBlock(hidden_size, hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(space_depth)] + ) + assert len(self.time_blocks) >= len(self.space_virtual2point_blocks) + self.initialize_weights() + + def initialize_weights(self): + def _basic_init(module): + if isinstance(module, nn.Linear): + torch.nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + torch.nn.init.trunc_normal_(self.flow_head.weight, std=0.001) + + self.apply(_basic_init) + + def forward(self, input_tensor, mask=None): + # Apply input LayerNorm + input_tensor = self.input_norm(input_tensor) + tokens = self.input_transform(input_tensor) + + init_tokens = tokens + + B, _, T, _ = tokens.shape + + if self.add_space_attn: + virtual_tokens = self.virual_tracks.repeat(B, 1, T, 1) + tokens = torch.cat([tokens, virtual_tokens], dim=1) + + _, N, _, _ = tokens.shape + + j = 0 + for i in range(len(self.time_blocks)): + time_tokens = tokens.contiguous().view(B * N, T, -1) # B N T C -> (B N) T C + + time_tokens = self.time_blocks[i](time_tokens) + + tokens = time_tokens.view(B, N, T, -1) # (B N) T C -> B N T C + if self.add_space_attn and (i % (len(self.time_blocks) // len(self.space_virtual_blocks)) == 0): + space_tokens = tokens.permute(0, 2, 1, 3).contiguous().view(B * T, N, -1) # B N T C -> (B T) N C + point_tokens = space_tokens[:, : N - self.num_virtual_tracks] + virtual_tokens = space_tokens[:, N - self.num_virtual_tracks :] + + virtual_tokens = self.space_virtual2point_blocks[j](virtual_tokens, point_tokens, mask=mask) + virtual_tokens = self.space_virtual_blocks[j](virtual_tokens) + point_tokens = self.space_point2virtual_blocks[j](point_tokens, virtual_tokens, mask=mask) + + space_tokens = torch.cat([point_tokens, virtual_tokens], dim=1) + tokens = space_tokens.view(B, T, N, -1).permute(0, 2, 1, 3) # (B T) N C -> B N T C + j += 1 + + if self.add_space_attn: + tokens = tokens[:, : N - self.num_virtual_tracks] + + tokens = tokens + init_tokens + + # Apply output LayerNorm before final projection + tokens = self.output_norm(tokens) + flow = self.flow_head(tokens) + + return flow, None + + +class CorrBlock: + def __init__(self, fmaps, num_levels=4, radius=4, multiple_track_feats=False, padding_mode="zeros"): + """ + Build a pyramid of feature maps from the input. + + fmaps: Tensor (B, S, C, H, W) + num_levels: number of pyramid levels (each downsampled by factor 2) + radius: search radius for sampling correlation + multiple_track_feats: if True, split the target features per pyramid level + padding_mode: passed to grid_sample / bilinear_sampler + """ + B, S, C, H, W = fmaps.shape + self.S, self.C, self.H, self.W = S, C, H, W + self.num_levels = num_levels + self.radius = radius + self.padding_mode = padding_mode + self.multiple_track_feats = multiple_track_feats + + # Build pyramid: each level is half the spatial resolution of the previous + self.fmaps_pyramid = [fmaps] # level 0 is full resolution + current_fmaps = fmaps + for i in range(num_levels - 1): + B, S, C, H, W = current_fmaps.shape + # Merge batch & sequence dimensions + current_fmaps = current_fmaps.reshape(B * S, C, H, W) + # Avg pool down by factor 2 + current_fmaps = F.avg_pool2d(current_fmaps, kernel_size=2, stride=2) + _, _, H_new, W_new = current_fmaps.shape + current_fmaps = current_fmaps.reshape(B, S, C, H_new, W_new) + self.fmaps_pyramid.append(current_fmaps) + + # Precompute a delta grid (of shape (2r+1, 2r+1, 2)) for sampling. + # This grid is added to the (scaled) coordinate centroids. + r = self.radius + dx = torch.linspace(-r, r, 2 * r + 1, device=fmaps.device, dtype=fmaps.dtype) + dy = torch.linspace(-r, r, 2 * r + 1, device=fmaps.device, dtype=fmaps.dtype) + # delta: for every (dy,dx) displacement (i.e. Δx, Δy) + self.delta = torch.stack(torch.meshgrid(dy, dx, indexing="ij"), dim=-1) # shape: (2r+1, 2r+1, 2) + + def corr_sample(self, targets, coords): + """ + Instead of storing the entire correlation pyramid, we compute each level's correlation + volume, sample it immediately, then discard it. This saves GPU memory. + + Args: + targets: Tensor (B, S, N, C) — features for the current targets. + coords: Tensor (B, S, N, 2) — coordinates at full resolution. + + Returns: + Tensor (B, S, N, L) where L = num_levels * (2*radius+1)**2 (concatenated sampled correlations) + """ + B, S, N, C = targets.shape + + # If you have multiple track features, split them per level. + if self.multiple_track_feats: + targets_split = torch.split(targets, C // self.num_levels, dim=-1) + + out_pyramid = [] + for i, fmaps in enumerate(self.fmaps_pyramid): + # Get current spatial resolution H, W for this pyramid level. + B, S, C, H, W = fmaps.shape + # Reshape feature maps for correlation computation: + # fmap2s: (B, S, C, H*W) + fmap2s = fmaps.view(B, S, C, H * W) + # Choose appropriate target features. + fmap1 = targets_split[i] if self.multiple_track_feats else targets # shape: (B, S, N, C) + + # Compute correlation directly + corrs = compute_corr_level(fmap1, fmap2s, C) + corrs = corrs.view(B, S, N, H, W) + + # Prepare sampling grid: + # Scale down the coordinates for the current level. + centroid_lvl = coords.reshape(B * S * N, 1, 1, 2) / (2**i) + # Make sure our precomputed delta grid is on the same device/dtype. + delta_lvl = self.delta.to(coords.device).to(coords.dtype) + # Now the grid for grid_sample is: + # coords_lvl = centroid_lvl + delta_lvl (broadcasted over grid) + coords_lvl = centroid_lvl + delta_lvl.view(1, 2 * self.radius + 1, 2 * self.radius + 1, 2) + + # Sample from the correlation volume using bilinear interpolation. + # We reshape corrs to (B * S * N, 1, H, W) so grid_sample acts over each target. + corrs_sampled = bilinear_sampler( + corrs.reshape(B * S * N, 1, H, W), coords_lvl, padding_mode=self.padding_mode + ) + # The sampled output is (B * S * N, 1, 2r+1, 2r+1). Flatten the last two dims. + corrs_sampled = corrs_sampled.view(B, S, N, -1) # Now shape: (B, S, N, (2r+1)^2) + out_pyramid.append(corrs_sampled) + + # Concatenate all levels along the last dimension. + out = torch.cat(out_pyramid, dim=-1).contiguous() + return out + + +def compute_corr_level(fmap1, fmap2s, C): + # fmap1: (B, S, N, C) + # fmap2s: (B, S, C, H*W) + corrs = torch.matmul(fmap1, fmap2s) # (B, S, N, H*W) + corrs = corrs.view(fmap1.shape[0], fmap1.shape[1], fmap1.shape[2], -1) # (B, S, N, H*W) + return corrs / math.sqrt(C) diff --git a/wheels/vggt/vggt/heads/track_modules/modules.py b/wheels/vggt/vggt/heads/track_modules/modules.py new file mode 100644 index 0000000000000000000000000000000000000000..4b090ddc4a9db01c8dd3564f9053e1ca9cdde93a --- /dev/null +++ b/wheels/vggt/vggt/heads/track_modules/modules.py @@ -0,0 +1,218 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + + +import torch +import torch.nn as nn +import torch.nn.functional as F +from functools import partial +from typing import Callable +import collections +from torch import Tensor +from itertools import repeat + + +# From PyTorch internals +def _ntuple(n): + def parse(x): + if isinstance(x, collections.abc.Iterable) and not isinstance(x, str): + return tuple(x) + return tuple(repeat(x, n)) + + return parse + + +def exists(val): + return val is not None + + +def default(val, d): + return val if exists(val) else d + + +to_2tuple = _ntuple(2) + + +class ResidualBlock(nn.Module): + """ + ResidualBlock: construct a block of two conv layers with residual connections + """ + + def __init__(self, in_planes, planes, norm_fn="group", stride=1, kernel_size=3): + super(ResidualBlock, self).__init__() + + self.conv1 = nn.Conv2d( + in_planes, + planes, + kernel_size=kernel_size, + padding=1, + stride=stride, + padding_mode="zeros", + ) + self.conv2 = nn.Conv2d( + planes, + planes, + kernel_size=kernel_size, + padding=1, + padding_mode="zeros", + ) + self.relu = nn.ReLU(inplace=True) + + num_groups = planes // 8 + + if norm_fn == "group": + self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + if not stride == 1: + self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + + elif norm_fn == "batch": + self.norm1 = nn.BatchNorm2d(planes) + self.norm2 = nn.BatchNorm2d(planes) + if not stride == 1: + self.norm3 = nn.BatchNorm2d(planes) + + elif norm_fn == "instance": + self.norm1 = nn.InstanceNorm2d(planes) + self.norm2 = nn.InstanceNorm2d(planes) + if not stride == 1: + self.norm3 = nn.InstanceNorm2d(planes) + + elif norm_fn == "none": + self.norm1 = nn.Sequential() + self.norm2 = nn.Sequential() + if not stride == 1: + self.norm3 = nn.Sequential() + else: + raise NotImplementedError + + if stride == 1: + self.downsample = None + else: + self.downsample = nn.Sequential( + nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), + self.norm3, + ) + + def forward(self, x): + y = x + y = self.relu(self.norm1(self.conv1(y))) + y = self.relu(self.norm2(self.conv2(y))) + + if self.downsample is not None: + x = self.downsample(x) + + return self.relu(x + y) + + +class Mlp(nn.Module): + """MLP as used in Vision Transformer, MLP-Mixer and related networks""" + + def __init__( + self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.GELU, + norm_layer=None, + bias=True, + drop=0.0, + use_conv=False, + ): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + bias = to_2tuple(bias) + drop_probs = to_2tuple(drop) + linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear + + self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0]) + self.act = act_layer() + self.drop1 = nn.Dropout(drop_probs[0]) + self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1]) + self.drop2 = nn.Dropout(drop_probs[1]) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop1(x) + x = self.fc2(x) + x = self.drop2(x) + return x + + +class AttnBlock(nn.Module): + def __init__( + self, + hidden_size, + num_heads, + attn_class: Callable[..., nn.Module] = nn.MultiheadAttention, + mlp_ratio=4.0, + **block_kwargs + ): + """ + Self attention block + """ + super().__init__() + + self.norm1 = nn.LayerNorm(hidden_size) + self.norm2 = nn.LayerNorm(hidden_size) + + self.attn = attn_class(embed_dim=hidden_size, num_heads=num_heads, batch_first=True, **block_kwargs) + + mlp_hidden_dim = int(hidden_size * mlp_ratio) + + self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, drop=0) + + def forward(self, x, mask=None): + # Prepare the mask for PyTorch's attention (it expects a different format) + # attn_mask = mask if mask is not None else None + # Normalize before attention + x = self.norm1(x) + + # PyTorch's MultiheadAttention returns attn_output, attn_output_weights + # attn_output, _ = self.attn(x, x, x, attn_mask=attn_mask) + + attn_output, _ = self.attn(x, x, x) + + # Add & Norm + x = x + attn_output + x = x + self.mlp(self.norm2(x)) + return x + + +class CrossAttnBlock(nn.Module): + def __init__(self, hidden_size, context_dim, num_heads=1, mlp_ratio=4.0, **block_kwargs): + """ + Cross attention block + """ + super().__init__() + + self.norm1 = nn.LayerNorm(hidden_size) + self.norm_context = nn.LayerNorm(hidden_size) + self.norm2 = nn.LayerNorm(hidden_size) + + self.cross_attn = nn.MultiheadAttention( + embed_dim=hidden_size, num_heads=num_heads, batch_first=True, **block_kwargs + ) + + mlp_hidden_dim = int(hidden_size * mlp_ratio) + + self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, drop=0) + + def forward(self, x, context, mask=None): + # Normalize inputs + x = self.norm1(x) + context = self.norm_context(context) + + # Apply cross attention + # Note: nn.MultiheadAttention returns attn_output, attn_output_weights + attn_output, _ = self.cross_attn(x, context, context, attn_mask=mask) + + # Add & Norm + x = x + attn_output + x = x + self.mlp(self.norm2(x)) + return x diff --git a/wheels/vggt/vggt/heads/track_modules/utils.py b/wheels/vggt/vggt/heads/track_modules/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..51d01d39cdc10388a04dab5db7cf409b31bde766 --- /dev/null +++ b/wheels/vggt/vggt/heads/track_modules/utils.py @@ -0,0 +1,226 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# Modified from https://github.com/facebookresearch/vggsfm +# and https://github.com/facebookresearch/co-tracker/tree/main + + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from typing import Optional, Tuple, Union + + +def get_2d_sincos_pos_embed(embed_dim: int, grid_size: Union[int, Tuple[int, int]], return_grid=False) -> torch.Tensor: + """ + This function initializes a grid and generates a 2D positional embedding using sine and cosine functions. + It is a wrapper of get_2d_sincos_pos_embed_from_grid. + Args: + - embed_dim: The embedding dimension. + - grid_size: The grid size. + Returns: + - pos_embed: The generated 2D positional embedding. + """ + if isinstance(grid_size, tuple): + grid_size_h, grid_size_w = grid_size + else: + grid_size_h = grid_size_w = grid_size + grid_h = torch.arange(grid_size_h, dtype=torch.float) + grid_w = torch.arange(grid_size_w, dtype=torch.float) + grid = torch.meshgrid(grid_w, grid_h, indexing="xy") + grid = torch.stack(grid, dim=0) + grid = grid.reshape([2, 1, grid_size_h, grid_size_w]) + pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) + if return_grid: + return ( + pos_embed.reshape(1, grid_size_h, grid_size_w, -1).permute(0, 3, 1, 2), + grid, + ) + return pos_embed.reshape(1, grid_size_h, grid_size_w, -1).permute(0, 3, 1, 2) + + +def get_2d_sincos_pos_embed_from_grid(embed_dim: int, grid: torch.Tensor) -> torch.Tensor: + """ + This function generates a 2D positional embedding from a given grid using sine and cosine functions. + + Args: + - embed_dim: The embedding dimension. + - grid: The grid to generate the embedding from. + + Returns: + - emb: The generated 2D positional embedding. + """ + assert embed_dim % 2 == 0 + + # use half of dimensions to encode grid_h + emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) + emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) + + emb = torch.cat([emb_h, emb_w], dim=2) # (H*W, D) + return emb + + +def get_1d_sincos_pos_embed_from_grid(embed_dim: int, pos: torch.Tensor) -> torch.Tensor: + """ + This function generates a 1D positional embedding from a given grid using sine and cosine functions. + + Args: + - embed_dim: The embedding dimension. + - pos: The position to generate the embedding from. + + Returns: + - emb: The generated 1D positional embedding. + """ + assert embed_dim % 2 == 0 + omega = torch.arange(embed_dim // 2, dtype=torch.double) + omega /= embed_dim / 2.0 + omega = 1.0 / 10000**omega # (D/2,) + + pos = pos.reshape(-1) # (M,) + out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product + + emb_sin = torch.sin(out) # (M, D/2) + emb_cos = torch.cos(out) # (M, D/2) + + emb = torch.cat([emb_sin, emb_cos], dim=1) # (M, D) + return emb[None].float() + + +def get_2d_embedding(xy: torch.Tensor, C: int, cat_coords: bool = True) -> torch.Tensor: + """ + This function generates a 2D positional embedding from given coordinates using sine and cosine functions. + + Args: + - xy: The coordinates to generate the embedding from. + - C: The size of the embedding. + - cat_coords: A flag to indicate whether to concatenate the original coordinates to the embedding. + + Returns: + - pe: The generated 2D positional embedding. + """ + B, N, D = xy.shape + assert D == 2 + + x = xy[:, :, 0:1] + y = xy[:, :, 1:2] + div_term = (torch.arange(0, C, 2, device=xy.device, dtype=torch.float32) * (1000.0 / C)).reshape(1, 1, int(C / 2)) + + pe_x = torch.zeros(B, N, C, device=xy.device, dtype=torch.float32) + pe_y = torch.zeros(B, N, C, device=xy.device, dtype=torch.float32) + + pe_x[:, :, 0::2] = torch.sin(x * div_term) + pe_x[:, :, 1::2] = torch.cos(x * div_term) + + pe_y[:, :, 0::2] = torch.sin(y * div_term) + pe_y[:, :, 1::2] = torch.cos(y * div_term) + + pe = torch.cat([pe_x, pe_y], dim=2) # (B, N, C*3) + if cat_coords: + pe = torch.cat([xy, pe], dim=2) # (B, N, C*3+3) + return pe + + +def bilinear_sampler(input, coords, align_corners=True, padding_mode="border"): + r"""Sample a tensor using bilinear interpolation + + `bilinear_sampler(input, coords)` samples a tensor :attr:`input` at + coordinates :attr:`coords` using bilinear interpolation. It is the same + as `torch.nn.functional.grid_sample()` but with a different coordinate + convention. + + The input tensor is assumed to be of shape :math:`(B, C, H, W)`, where + :math:`B` is the batch size, :math:`C` is the number of channels, + :math:`H` is the height of the image, and :math:`W` is the width of the + image. The tensor :attr:`coords` of shape :math:`(B, H_o, W_o, 2)` is + interpreted as an array of 2D point coordinates :math:`(x_i,y_i)`. + + Alternatively, the input tensor can be of size :math:`(B, C, T, H, W)`, + in which case sample points are triplets :math:`(t_i,x_i,y_i)`. Note + that in this case the order of the components is slightly different + from `grid_sample()`, which would expect :math:`(x_i,y_i,t_i)`. + + If `align_corners` is `True`, the coordinate :math:`x` is assumed to be + in the range :math:`[0,W-1]`, with 0 corresponding to the center of the + left-most image pixel :math:`W-1` to the center of the right-most + pixel. + + If `align_corners` is `False`, the coordinate :math:`x` is assumed to + be in the range :math:`[0,W]`, with 0 corresponding to the left edge of + the left-most pixel :math:`W` to the right edge of the right-most + pixel. + + Similar conventions apply to the :math:`y` for the range + :math:`[0,H-1]` and :math:`[0,H]` and to :math:`t` for the range + :math:`[0,T-1]` and :math:`[0,T]`. + + Args: + input (Tensor): batch of input images. + coords (Tensor): batch of coordinates. + align_corners (bool, optional): Coordinate convention. Defaults to `True`. + padding_mode (str, optional): Padding mode. Defaults to `"border"`. + + Returns: + Tensor: sampled points. + """ + coords = coords.detach().clone() + ############################################################ + # IMPORTANT: + coords = coords.to(input.device).to(input.dtype) + ############################################################ + + sizes = input.shape[2:] + + assert len(sizes) in [2, 3] + + if len(sizes) == 3: + # t x y -> x y t to match dimensions T H W in grid_sample + coords = coords[..., [1, 2, 0]] + + if align_corners: + scale = torch.tensor( + [2 / max(size - 1, 1) for size in reversed(sizes)], device=coords.device, dtype=coords.dtype + ) + else: + scale = torch.tensor([2 / size for size in reversed(sizes)], device=coords.device, dtype=coords.dtype) + + coords.mul_(scale) # coords = coords * scale + coords.sub_(1) # coords = coords - 1 + + return F.grid_sample(input, coords, align_corners=align_corners, padding_mode=padding_mode) + + +def sample_features4d(input, coords): + r"""Sample spatial features + + `sample_features4d(input, coords)` samples the spatial features + :attr:`input` represented by a 4D tensor :math:`(B, C, H, W)`. + + The field is sampled at coordinates :attr:`coords` using bilinear + interpolation. :attr:`coords` is assumed to be of shape :math:`(B, R, + 2)`, where each sample has the format :math:`(x_i, y_i)`. This uses the + same convention as :func:`bilinear_sampler` with `align_corners=True`. + + The output tensor has one feature per point, and has shape :math:`(B, + R, C)`. + + Args: + input (Tensor): spatial features. + coords (Tensor): points. + + Returns: + Tensor: sampled features. + """ + + B, _, _, _ = input.shape + + # B R 2 -> B R 1 2 + coords = coords.unsqueeze(2) + + # B C R 1 + feats = bilinear_sampler(input, coords) + + return feats.permute(0, 2, 1, 3).view(B, -1, feats.shape[1] * feats.shape[3]) # B C R 1 -> B R C diff --git a/wheels/vggt/vggt/heads/utils.py b/wheels/vggt/vggt/heads/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..d7af1f68fa0ce0a48d11a708d53aa20aa8f78ba2 --- /dev/null +++ b/wheels/vggt/vggt/heads/utils.py @@ -0,0 +1,108 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn + + +def position_grid_to_embed(pos_grid: torch.Tensor, embed_dim: int, omega_0: float = 100) -> torch.Tensor: + """ + Convert 2D position grid (HxWx2) to sinusoidal embeddings (HxWxC) + + Args: + pos_grid: Tensor of shape (H, W, 2) containing 2D coordinates + embed_dim: Output channel dimension for embeddings + + Returns: + Tensor of shape (H, W, embed_dim) with positional embeddings + """ + H, W, grid_dim = pos_grid.shape + assert grid_dim == 2 + pos_flat = pos_grid.reshape(-1, grid_dim) # Flatten to (H*W, 2) + + # Process x and y coordinates separately + emb_x = make_sincos_pos_embed(embed_dim // 2, pos_flat[:, 0], omega_0=omega_0) # [1, H*W, D/2] + emb_y = make_sincos_pos_embed(embed_dim // 2, pos_flat[:, 1], omega_0=omega_0) # [1, H*W, D/2] + + # Combine and reshape + emb = torch.cat([emb_x, emb_y], dim=-1) # [1, H*W, D] + + return emb.view(H, W, embed_dim) # [H, W, D] + + +def make_sincos_pos_embed(embed_dim: int, pos: torch.Tensor, omega_0: float = 100) -> torch.Tensor: + """ + This function generates a 1D positional embedding from a given grid using sine and cosine functions. + + Args: + - embed_dim: The embedding dimension. + - pos: The position to generate the embedding from. + + Returns: + - emb: The generated 1D positional embedding. + """ + assert embed_dim % 2 == 0 + omega = torch.arange(embed_dim // 2, dtype=torch.double, device=pos.device) + omega /= embed_dim / 2.0 + omega = 1.0 / omega_0**omega # (D/2,) + + pos = pos.reshape(-1) # (M,) + out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product + + emb_sin = torch.sin(out) # (M, D/2) + emb_cos = torch.cos(out) # (M, D/2) + + emb = torch.cat([emb_sin, emb_cos], dim=1) # (M, D) + return emb.float() + + +# Inspired by https://github.com/microsoft/moge + + +def create_uv_grid( + width: int, height: int, aspect_ratio: float = None, dtype: torch.dtype = None, device: torch.device = None +) -> torch.Tensor: + """ + Create a normalized UV grid of shape (width, height, 2). + + The grid spans horizontally and vertically according to an aspect ratio, + ensuring the top-left corner is at (-x_span, -y_span) and the bottom-right + corner is at (x_span, y_span), normalized by the diagonal of the plane. + + Args: + width (int): Number of points horizontally. + height (int): Number of points vertically. + aspect_ratio (float, optional): Width-to-height ratio. Defaults to width/height. + dtype (torch.dtype, optional): Data type of the resulting tensor. + device (torch.device, optional): Device on which the tensor is created. + + Returns: + torch.Tensor: A (width, height, 2) tensor of UV coordinates. + """ + # Derive aspect ratio if not explicitly provided + if aspect_ratio is None: + aspect_ratio = float(width) / float(height) + + # Compute normalized spans for X and Y + diag_factor = (aspect_ratio**2 + 1.0) ** 0.5 + span_x = aspect_ratio / diag_factor + span_y = 1.0 / diag_factor + + # Establish the linspace boundaries + left_x = -span_x * (width - 1) / width + right_x = span_x * (width - 1) / width + top_y = -span_y * (height - 1) / height + bottom_y = span_y * (height - 1) / height + + # Generate 1D coordinates + x_coords = torch.linspace(left_x, right_x, steps=width, dtype=dtype, device=device) + y_coords = torch.linspace(top_y, bottom_y, steps=height, dtype=dtype, device=device) + + # Create 2D meshgrid (width x height) and stack into UV + uu, vv = torch.meshgrid(x_coords, y_coords, indexing="xy") + uv_grid = torch.stack((uu, vv), dim=-1) + + return uv_grid diff --git a/wheels/vggt/vggt/layers/__init__.py b/wheels/vggt/vggt/layers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8120f4bc83066cb3f825ce32daa3b437f88486f1 --- /dev/null +++ b/wheels/vggt/vggt/layers/__init__.py @@ -0,0 +1,11 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from .mlp import Mlp +from .patch_embed import PatchEmbed +from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused +from .block import NestedTensorBlock +from .attention import MemEffAttention diff --git a/wheels/vggt/vggt/layers/__pycache__/__init__.cpython-310.pyc b/wheels/vggt/vggt/layers/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..39acdb9a3f4d1a3d2b20cefaebe3cc4d265e85fb Binary files /dev/null and b/wheels/vggt/vggt/layers/__pycache__/__init__.cpython-310.pyc differ diff --git a/wheels/vggt/vggt/layers/__pycache__/attention.cpython-310.pyc b/wheels/vggt/vggt/layers/__pycache__/attention.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bbac8a7443686a017c11d4c6e824e0ebe8fd4932 Binary files /dev/null and b/wheels/vggt/vggt/layers/__pycache__/attention.cpython-310.pyc differ diff --git a/wheels/vggt/vggt/layers/__pycache__/block.cpython-310.pyc b/wheels/vggt/vggt/layers/__pycache__/block.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0f0c046ae8f1fe274049762fcb9fec9c231162df Binary files /dev/null and b/wheels/vggt/vggt/layers/__pycache__/block.cpython-310.pyc differ diff --git a/wheels/vggt/vggt/layers/__pycache__/drop_path.cpython-310.pyc b/wheels/vggt/vggt/layers/__pycache__/drop_path.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..57614c082d757d8bd4b67e71e130482317eb934a Binary files /dev/null and b/wheels/vggt/vggt/layers/__pycache__/drop_path.cpython-310.pyc differ diff --git a/wheels/vggt/vggt/layers/__pycache__/layer_scale.cpython-310.pyc b/wheels/vggt/vggt/layers/__pycache__/layer_scale.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e87bfe1e9b030c2b5a8e680e9d7d6887c2c302c3 Binary files /dev/null and b/wheels/vggt/vggt/layers/__pycache__/layer_scale.cpython-310.pyc differ diff --git a/wheels/vggt/vggt/layers/__pycache__/mlp.cpython-310.pyc b/wheels/vggt/vggt/layers/__pycache__/mlp.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c8a8fbb3a6abd33aa9e3dcf969676510a4cf6bd8 Binary files /dev/null and b/wheels/vggt/vggt/layers/__pycache__/mlp.cpython-310.pyc differ diff --git a/wheels/vggt/vggt/layers/__pycache__/patch_embed.cpython-310.pyc b/wheels/vggt/vggt/layers/__pycache__/patch_embed.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..89b41b2e2e69804376f66042c6cc3b6d6eaba0f5 Binary files /dev/null and b/wheels/vggt/vggt/layers/__pycache__/patch_embed.cpython-310.pyc differ diff --git a/wheels/vggt/vggt/layers/__pycache__/rope.cpython-310.pyc b/wheels/vggt/vggt/layers/__pycache__/rope.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d22c4f512354ba846c76286d8bd35095b5cb3696 Binary files /dev/null and b/wheels/vggt/vggt/layers/__pycache__/rope.cpython-310.pyc differ diff --git a/wheels/vggt/vggt/layers/__pycache__/swiglu_ffn.cpython-310.pyc b/wheels/vggt/vggt/layers/__pycache__/swiglu_ffn.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..00b818234235900aba39fa4201d4aed19aba332f Binary files /dev/null and b/wheels/vggt/vggt/layers/__pycache__/swiglu_ffn.cpython-310.pyc differ diff --git a/wheels/vggt/vggt/layers/__pycache__/vision_transformer.cpython-310.pyc b/wheels/vggt/vggt/layers/__pycache__/vision_transformer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..91d4c380c58566558b0922549a290c0f73e97505 Binary files /dev/null and b/wheels/vggt/vggt/layers/__pycache__/vision_transformer.cpython-310.pyc differ diff --git a/wheels/vggt/vggt/layers/attention.py b/wheels/vggt/vggt/layers/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..ab3089ce0c7493342ef0cf373dfe74a1df2b9563 --- /dev/null +++ b/wheels/vggt/vggt/layers/attention.py @@ -0,0 +1,98 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py + +import logging +import os +import warnings + +from torch import Tensor +from torch import nn +import torch.nn.functional as F + +XFORMERS_AVAILABLE = False + + +class Attention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = True, + proj_bias: bool = True, + attn_drop: float = 0.0, + proj_drop: float = 0.0, + norm_layer: nn.Module = nn.LayerNorm, + qk_norm: bool = False, + fused_attn: bool = True, # use F.scaled_dot_product_attention or not + rope=None, + ) -> None: + super().__init__() + assert dim % num_heads == 0, "dim should be divisible by num_heads" + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.scale = self.head_dim**-0.5 + self.fused_attn = fused_attn + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim, bias=proj_bias) + self.proj_drop = nn.Dropout(proj_drop) + self.rope = rope + + def forward(self, x: Tensor, pos=None) -> Tensor: + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) + q, k, v = qkv.unbind(0) + q, k = self.q_norm(q), self.k_norm(k) + + if self.rope is not None: + q = self.rope(q, pos) + k = self.rope(k, pos) + + if self.fused_attn: + x = F.scaled_dot_product_attention( + q, + k, + v, + dropout_p=self.attn_drop.p if self.training else 0.0, + ) + else: + q = q * self.scale + attn = q @ k.transpose(-2, -1) + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + x = attn @ v + + x = x.transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class MemEffAttention(Attention): + def forward(self, x: Tensor, attn_bias=None, pos=None) -> Tensor: + assert pos is None + if not XFORMERS_AVAILABLE: + if attn_bias is not None: + raise AssertionError("xFormers is required for using nested tensors") + return super().forward(x) + + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads) + + q, k, v = unbind(qkv, 2) + + x = memory_efficient_attention(q, k, v, attn_bias=attn_bias) + x = x.reshape([B, N, C]) + + x = self.proj(x) + x = self.proj_drop(x) + return x diff --git a/wheels/vggt/vggt/layers/block.py b/wheels/vggt/vggt/layers/block.py new file mode 100644 index 0000000000000000000000000000000000000000..5f89e4da7121effca97151d1d8429586e422346e --- /dev/null +++ b/wheels/vggt/vggt/layers/block.py @@ -0,0 +1,259 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py + +import logging +import os +from typing import Callable, List, Any, Tuple, Dict +import warnings + +import torch +from torch import nn, Tensor + +from .attention import Attention +from .drop_path import DropPath +from .layer_scale import LayerScale +from .mlp import Mlp + + +XFORMERS_AVAILABLE = False + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + mlp_ratio: float = 4.0, + qkv_bias: bool = True, + proj_bias: bool = True, + ffn_bias: bool = True, + drop: float = 0.0, + attn_drop: float = 0.0, + init_values=None, + drop_path: float = 0.0, + act_layer: Callable[..., nn.Module] = nn.GELU, + norm_layer: Callable[..., nn.Module] = nn.LayerNorm, + attn_class: Callable[..., nn.Module] = Attention, + ffn_layer: Callable[..., nn.Module] = Mlp, + qk_norm: bool = False, + fused_attn: bool = True, # use F.scaled_dot_product_attention or not + rope=None, + ) -> None: + super().__init__() + + self.norm1 = norm_layer(dim) + + self.attn = attn_class( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + proj_bias=proj_bias, + attn_drop=attn_drop, + proj_drop=drop, + qk_norm=qk_norm, + fused_attn=fused_attn, + rope=rope, + ) + + self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = ffn_layer( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=drop, + bias=ffn_bias, + ) + self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.sample_drop_ratio = drop_path + + def forward(self, x: Tensor, pos=None) -> Tensor: + def attn_residual_func(x: Tensor, pos=None) -> Tensor: + return self.ls1(self.attn(self.norm1(x), pos=pos)) + + def ffn_residual_func(x: Tensor) -> Tensor: + return self.ls2(self.mlp(self.norm2(x))) + + if self.training and self.sample_drop_ratio > 0.1: + # the overhead is compensated only for a drop path rate larger than 0.1 + x = drop_add_residual_stochastic_depth( + x, + pos=pos, + residual_func=attn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + ) + x = drop_add_residual_stochastic_depth( + x, + residual_func=ffn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + ) + elif self.training and self.sample_drop_ratio > 0.0: + x = x + self.drop_path1(attn_residual_func(x, pos=pos)) + x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2 + else: + x = x + attn_residual_func(x, pos=pos) + x = x + ffn_residual_func(x) + return x + + +def drop_add_residual_stochastic_depth( + x: Tensor, + residual_func: Callable[[Tensor], Tensor], + sample_drop_ratio: float = 0.0, + pos=None, +) -> Tensor: + # 1) extract subset using permutation + b, n, d = x.shape + sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1) + brange = (torch.randperm(b, device=x.device))[:sample_subset_size] + x_subset = x[brange] + + # 2) apply residual_func to get residual + if pos is not None: + # if necessary, apply rope to the subset + pos = pos[brange] + residual = residual_func(x_subset, pos=pos) + else: + residual = residual_func(x_subset) + + x_flat = x.flatten(1) + residual = residual.flatten(1) + + residual_scale_factor = b / sample_subset_size + + # 3) add the residual + x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor) + return x_plus_residual.view_as(x) + + +def get_branges_scales(x, sample_drop_ratio=0.0): + b, n, d = x.shape + sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1) + brange = (torch.randperm(b, device=x.device))[:sample_subset_size] + residual_scale_factor = b / sample_subset_size + return brange, residual_scale_factor + + +def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None): + if scaling_vector is None: + x_flat = x.flatten(1) + residual = residual.flatten(1) + x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor) + else: + x_plus_residual = scaled_index_add( + x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor + ) + return x_plus_residual + + +attn_bias_cache: Dict[Tuple, Any] = {} + + +def get_attn_bias_and_cat(x_list, branges=None): + """ + this will perform the index select, cat the tensors, and provide the attn_bias from cache + """ + batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list] + all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list)) + if all_shapes not in attn_bias_cache.keys(): + seqlens = [] + for b, x in zip(batch_sizes, x_list): + for _ in range(b): + seqlens.append(x.shape[1]) + attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens) + attn_bias._batch_sizes = batch_sizes + attn_bias_cache[all_shapes] = attn_bias + + if branges is not None: + cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1]) + else: + tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list) + cat_tensors = torch.cat(tensors_bs1, dim=1) + + return attn_bias_cache[all_shapes], cat_tensors + + +def drop_add_residual_stochastic_depth_list( + x_list: List[Tensor], + residual_func: Callable[[Tensor, Any], Tensor], + sample_drop_ratio: float = 0.0, + scaling_vector=None, +) -> Tensor: + # 1) generate random set of indices for dropping samples in the batch + branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list] + branges = [s[0] for s in branges_scales] + residual_scale_factors = [s[1] for s in branges_scales] + + # 2) get attention bias and index+concat the tensors + attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges) + + # 3) apply residual_func to get residual, and split the result + residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore + + outputs = [] + for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors): + outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x)) + return outputs + + +class NestedTensorBlock(Block): + def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]: + """ + x_list contains a list of tensors to nest together and run + """ + assert isinstance(self.attn, MemEffAttention) + + if self.training and self.sample_drop_ratio > 0.0: + + def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.attn(self.norm1(x), attn_bias=attn_bias) + + def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.mlp(self.norm2(x)) + + x_list = drop_add_residual_stochastic_depth_list( + x_list, + residual_func=attn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None, + ) + x_list = drop_add_residual_stochastic_depth_list( + x_list, + residual_func=ffn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None, + ) + return x_list + else: + + def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias)) + + def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.ls2(self.mlp(self.norm2(x))) + + attn_bias, x = get_attn_bias_and_cat(x_list) + x = x + attn_residual_func(x, attn_bias=attn_bias) + x = x + ffn_residual_func(x) + return attn_bias.split(x) + + def forward(self, x_or_x_list): + if isinstance(x_or_x_list, Tensor): + return super().forward(x_or_x_list) + elif isinstance(x_or_x_list, list): + if not XFORMERS_AVAILABLE: + raise AssertionError("xFormers is required for using nested tensors") + return self.forward_nested(x_or_x_list) + else: + raise AssertionError diff --git a/wheels/vggt/vggt/layers/drop_path.py b/wheels/vggt/vggt/layers/drop_path.py new file mode 100644 index 0000000000000000000000000000000000000000..1d640e0b969b8dcba96260243473700b4e5b24b5 --- /dev/null +++ b/wheels/vggt/vggt/layers/drop_path.py @@ -0,0 +1,34 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py + + +from torch import nn + + +def drop_path(x, drop_prob: float = 0.0, training: bool = False): + if drop_prob == 0.0 or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = x.new_empty(shape).bernoulli_(keep_prob) + if keep_prob > 0.0: + random_tensor.div_(keep_prob) + output = x * random_tensor + return output + + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" + + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) diff --git a/wheels/vggt/vggt/layers/layer_scale.py b/wheels/vggt/vggt/layers/layer_scale.py new file mode 100644 index 0000000000000000000000000000000000000000..51df0d7ce61f2b41fa9e6369f52391dd7fe7d386 --- /dev/null +++ b/wheels/vggt/vggt/layers/layer_scale.py @@ -0,0 +1,27 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110 + +from typing import Union + +import torch +from torch import Tensor +from torch import nn + + +class LayerScale(nn.Module): + def __init__( + self, + dim: int, + init_values: Union[float, Tensor] = 1e-5, + inplace: bool = False, + ) -> None: + super().__init__() + self.inplace = inplace + self.gamma = nn.Parameter(init_values * torch.ones(dim)) + + def forward(self, x: Tensor) -> Tensor: + return x.mul_(self.gamma) if self.inplace else x * self.gamma diff --git a/wheels/vggt/vggt/layers/mlp.py b/wheels/vggt/vggt/layers/mlp.py new file mode 100644 index 0000000000000000000000000000000000000000..bbf9432aae9258612caeae910a7bde17999e328e --- /dev/null +++ b/wheels/vggt/vggt/layers/mlp.py @@ -0,0 +1,40 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py + + +from typing import Callable, Optional + +from torch import Tensor, nn + + +class Mlp(nn.Module): + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: Callable[..., nn.Module] = nn.GELU, + drop: float = 0.0, + bias: bool = True, + ) -> None: + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features, bias=bias) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features, bias=bias) + self.drop = nn.Dropout(drop) + + def forward(self, x: Tensor) -> Tensor: + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x diff --git a/wheels/vggt/vggt/layers/patch_embed.py b/wheels/vggt/vggt/layers/patch_embed.py new file mode 100644 index 0000000000000000000000000000000000000000..8b7c0804784a42cf80c0297d110dcc68cc85b339 --- /dev/null +++ b/wheels/vggt/vggt/layers/patch_embed.py @@ -0,0 +1,88 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py + +from typing import Callable, Optional, Tuple, Union + +from torch import Tensor +import torch.nn as nn + + +def make_2tuple(x): + if isinstance(x, tuple): + assert len(x) == 2 + return x + + assert isinstance(x, int) + return (x, x) + + +class PatchEmbed(nn.Module): + """ + 2D image to patch embedding: (B,C,H,W) -> (B,N,D) + + Args: + img_size: Image size. + patch_size: Patch token size. + in_chans: Number of input image channels. + embed_dim: Number of linear projection output channels. + norm_layer: Normalization layer. + """ + + def __init__( + self, + img_size: Union[int, Tuple[int, int]] = 224, + patch_size: Union[int, Tuple[int, int]] = 16, + in_chans: int = 3, + embed_dim: int = 768, + norm_layer: Optional[Callable] = None, + flatten_embedding: bool = True, + ) -> None: + super().__init__() + + image_HW = make_2tuple(img_size) + patch_HW = make_2tuple(patch_size) + patch_grid_size = ( + image_HW[0] // patch_HW[0], + image_HW[1] // patch_HW[1], + ) + + self.img_size = image_HW + self.patch_size = patch_HW + self.patches_resolution = patch_grid_size + self.num_patches = patch_grid_size[0] * patch_grid_size[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + self.flatten_embedding = flatten_embedding + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW) + self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + + def forward(self, x: Tensor) -> Tensor: + _, _, H, W = x.shape + patch_H, patch_W = self.patch_size + + assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}" + assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}" + + x = self.proj(x) # B C H W + H, W = x.size(2), x.size(3) + x = x.flatten(2).transpose(1, 2) # B HW C + x = self.norm(x) + if not self.flatten_embedding: + x = x.reshape(-1, H, W, self.embed_dim) # B H W C + return x + + def flops(self) -> float: + Ho, Wo = self.patches_resolution + flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) + if self.norm is not None: + flops += Ho * Wo * self.embed_dim + return flops diff --git a/wheels/vggt/vggt/layers/rope.py b/wheels/vggt/vggt/layers/rope.py new file mode 100644 index 0000000000000000000000000000000000000000..4d5d33304e55dbd05687bd86752a47a80e5f82df --- /dev/null +++ b/wheels/vggt/vggt/layers/rope.py @@ -0,0 +1,188 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + + +# Implementation of 2D Rotary Position Embeddings (RoPE). + +# This module provides a clean implementation of 2D Rotary Position Embeddings, +# which extends the original RoPE concept to handle 2D spatial positions. + +# Inspired by: +# https://github.com/meta-llama/codellama/blob/main/llama/model.py +# https://github.com/naver-ai/rope-vit + + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from typing import Dict, Tuple + + +class PositionGetter: + """Generates and caches 2D spatial positions for patches in a grid. + + This class efficiently manages the generation of spatial coordinates for patches + in a 2D grid, caching results to avoid redundant computations. + + Attributes: + position_cache: Dictionary storing precomputed position tensors for different + grid dimensions. + """ + + def __init__(self): + """Initializes the position generator with an empty cache.""" + self.position_cache: Dict[Tuple[int, int], torch.Tensor] = {} + + def __call__(self, batch_size: int, height: int, width: int, device: torch.device) -> torch.Tensor: + """Generates spatial positions for a batch of patches. + + Args: + batch_size: Number of samples in the batch. + height: Height of the grid in patches. + width: Width of the grid in patches. + device: Target device for the position tensor. + + Returns: + Tensor of shape (batch_size, height*width, 2) containing y,x coordinates + for each position in the grid, repeated for each batch item. + """ + if (height, width) not in self.position_cache: + y_coords = torch.arange(height, device=device) + x_coords = torch.arange(width, device=device) + positions = torch.cartesian_prod(y_coords, x_coords) + self.position_cache[height, width] = positions + + cached_positions = self.position_cache[height, width] + return cached_positions.view(1, height * width, 2).expand(batch_size, -1, -1).clone() + + +class RotaryPositionEmbedding2D(nn.Module): + """2D Rotary Position Embedding implementation. + + This module applies rotary position embeddings to input tokens based on their + 2D spatial positions. It handles the position-dependent rotation of features + separately for vertical and horizontal dimensions. + + Args: + frequency: Base frequency for the position embeddings. Default: 100.0 + scaling_factor: Scaling factor for frequency computation. Default: 1.0 + + Attributes: + base_frequency: Base frequency for computing position embeddings. + scaling_factor: Factor to scale the computed frequencies. + frequency_cache: Cache for storing precomputed frequency components. + """ + + def __init__(self, frequency: float = 100.0, scaling_factor: float = 1.0): + """Initializes the 2D RoPE module.""" + super().__init__() + self.base_frequency = frequency + self.scaling_factor = scaling_factor + self.frequency_cache: Dict[Tuple, Tuple[torch.Tensor, torch.Tensor]] = {} + + def _compute_frequency_components( + self, dim: int, seq_len: int, device: torch.device, dtype: torch.dtype + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Computes frequency components for rotary embeddings. + + Args: + dim: Feature dimension (must be even). + seq_len: Maximum sequence length. + device: Target device for computations. + dtype: Data type for the computed tensors. + + Returns: + Tuple of (cosine, sine) tensors for frequency components. + """ + cache_key = (dim, seq_len, device, dtype) + if cache_key not in self.frequency_cache: + # Compute frequency bands + exponents = torch.arange(0, dim, 2, device=device).float() / dim + inv_freq = 1.0 / (self.base_frequency**exponents) + + # Generate position-dependent frequencies + positions = torch.arange(seq_len, device=device, dtype=inv_freq.dtype) + angles = torch.einsum("i,j->ij", positions, inv_freq) + + # Compute and cache frequency components + angles = angles.to(dtype) + angles = torch.cat((angles, angles), dim=-1) + cos_components = angles.cos().to(dtype) + sin_components = angles.sin().to(dtype) + self.frequency_cache[cache_key] = (cos_components, sin_components) + + return self.frequency_cache[cache_key] + + @staticmethod + def _rotate_features(x: torch.Tensor) -> torch.Tensor: + """Performs feature rotation by splitting and recombining feature dimensions. + + Args: + x: Input tensor to rotate. + + Returns: + Rotated feature tensor. + """ + feature_dim = x.shape[-1] + x1, x2 = x[..., : feature_dim // 2], x[..., feature_dim // 2 :] + return torch.cat((-x2, x1), dim=-1) + + def _apply_1d_rope( + self, tokens: torch.Tensor, positions: torch.Tensor, cos_comp: torch.Tensor, sin_comp: torch.Tensor + ) -> torch.Tensor: + """Applies 1D rotary position embeddings along one dimension. + + Args: + tokens: Input token features. + positions: Position indices. + cos_comp: Cosine components for rotation. + sin_comp: Sine components for rotation. + + Returns: + Tokens with applied rotary position embeddings. + """ + # Embed positions with frequency components + cos = F.embedding(positions, cos_comp)[:, None, :, :] + sin = F.embedding(positions, sin_comp)[:, None, :, :] + + # Apply rotation + return (tokens * cos) + (self._rotate_features(tokens) * sin) + + def forward(self, tokens: torch.Tensor, positions: torch.Tensor) -> torch.Tensor: + """Applies 2D rotary position embeddings to input tokens. + + Args: + tokens: Input tensor of shape (batch_size, n_heads, n_tokens, dim). + The feature dimension (dim) must be divisible by 4. + positions: Position tensor of shape (batch_size, n_tokens, 2) containing + the y and x coordinates for each token. + + Returns: + Tensor of same shape as input with applied 2D rotary position embeddings. + + Raises: + AssertionError: If input dimensions are invalid or positions are malformed. + """ + # Validate inputs + assert tokens.size(-1) % 2 == 0, "Feature dimension must be even" + assert positions.ndim == 3 and positions.shape[-1] == 2, "Positions must have shape (batch_size, n_tokens, 2)" + + # Compute feature dimension for each spatial direction + feature_dim = tokens.size(-1) // 2 + + # Get frequency components + max_position = int(positions.max()) + 1 + cos_comp, sin_comp = self._compute_frequency_components(feature_dim, max_position, tokens.device, tokens.dtype) + + # Split features for vertical and horizontal processing + vertical_features, horizontal_features = tokens.chunk(2, dim=-1) + + # Apply RoPE separately for each dimension + vertical_features = self._apply_1d_rope(vertical_features, positions[..., 0], cos_comp, sin_comp) + horizontal_features = self._apply_1d_rope(horizontal_features, positions[..., 1], cos_comp, sin_comp) + + # Combine processed features + return torch.cat((vertical_features, horizontal_features), dim=-1) diff --git a/wheels/vggt/vggt/layers/swiglu_ffn.py b/wheels/vggt/vggt/layers/swiglu_ffn.py new file mode 100644 index 0000000000000000000000000000000000000000..54fe8e90b7bedf6fbdbf09c6215844e3cc63f857 --- /dev/null +++ b/wheels/vggt/vggt/layers/swiglu_ffn.py @@ -0,0 +1,72 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import os +from typing import Callable, Optional +import warnings + +from torch import Tensor, nn +import torch.nn.functional as F + + +class SwiGLUFFN(nn.Module): + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: Callable[..., nn.Module] = None, + drop: float = 0.0, + bias: bool = True, + ) -> None: + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias) + self.w3 = nn.Linear(hidden_features, out_features, bias=bias) + + def forward(self, x: Tensor) -> Tensor: + x12 = self.w12(x) + x1, x2 = x12.chunk(2, dim=-1) + hidden = F.silu(x1) * x2 + return self.w3(hidden) + + +XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None +# try: +# if XFORMERS_ENABLED: +# from xformers.ops import SwiGLU + +# XFORMERS_AVAILABLE = True +# warnings.warn("xFormers is available (SwiGLU)") +# else: +# warnings.warn("xFormers is disabled (SwiGLU)") +# raise ImportError +# except ImportError: +SwiGLU = SwiGLUFFN +XFORMERS_AVAILABLE = False + +# warnings.warn("xFormers is not available (SwiGLU)") + + +class SwiGLUFFNFused(SwiGLU): + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: Callable[..., nn.Module] = None, + drop: float = 0.0, + bias: bool = True, + ) -> None: + out_features = out_features or in_features + hidden_features = hidden_features or in_features + hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8 + super().__init__( + in_features=in_features, + hidden_features=hidden_features, + out_features=out_features, + bias=bias, + ) diff --git a/wheels/vggt/vggt/layers/vision_transformer.py b/wheels/vggt/vggt/layers/vision_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..120cbe6c26650d212e50aefc497669abdc937467 --- /dev/null +++ b/wheels/vggt/vggt/layers/vision_transformer.py @@ -0,0 +1,407 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/main/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py + +from functools import partial +import math +import logging +from typing import Sequence, Tuple, Union, Callable + +import torch +import torch.nn as nn +from torch.utils.checkpoint import checkpoint +from torch.nn.init import trunc_normal_ +from . import Mlp, PatchEmbed, SwiGLUFFNFused, MemEffAttention, NestedTensorBlock as Block + +logger = logging.getLogger("dinov2") + + +def named_apply(fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False) -> nn.Module: + if not depth_first and include_root: + fn(module=module, name=name) + for child_name, child_module in module.named_children(): + child_name = ".".join((name, child_name)) if name else child_name + named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True) + if depth_first and include_root: + fn(module=module, name=name) + return module + + +class BlockChunk(nn.ModuleList): + def forward(self, x): + for b in self: + x = b(x) + return x + + +class DinoVisionTransformer(nn.Module): + def __init__( + self, + img_size=224, + patch_size=16, + in_chans=3, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4.0, + qkv_bias=True, + ffn_bias=True, + proj_bias=True, + drop_path_rate=0.0, + drop_path_uniform=False, + init_values=None, # for layerscale: None or 0 => no layerscale + embed_layer=PatchEmbed, + act_layer=nn.GELU, + block_fn=Block, + ffn_layer="mlp", + block_chunks=1, + num_register_tokens=0, + interpolate_antialias=False, + interpolate_offset=0.1, + qk_norm=False, + ): + """ + Args: + img_size (int, tuple): input image size + patch_size (int, tuple): patch size + in_chans (int): number of input channels + embed_dim (int): embedding dimension + depth (int): depth of transformer + num_heads (int): number of attention heads + mlp_ratio (int): ratio of mlp hidden dim to embedding dim + qkv_bias (bool): enable bias for qkv if True + proj_bias (bool): enable bias for proj in attn if True + ffn_bias (bool): enable bias for ffn if True + drop_path_rate (float): stochastic depth rate + drop_path_uniform (bool): apply uniform drop rate across blocks + weight_init (str): weight init scheme + init_values (float): layer-scale init values + embed_layer (nn.Module): patch embedding layer + act_layer (nn.Module): MLP activation layer + block_fn (nn.Module): transformer block class + ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity" + block_chunks: (int) split block sequence into block_chunks units for FSDP wrap + num_register_tokens: (int) number of extra cls tokens (so-called "registers") + interpolate_antialias: (str) flag to apply anti-aliasing when interpolating positional embeddings + interpolate_offset: (float) work-around offset to apply when interpolating positional embeddings + """ + super().__init__() + norm_layer = partial(nn.LayerNorm, eps=1e-6) + + # tricky but makes it work + self.use_checkpoint = False + # + + self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models + self.num_tokens = 1 + self.n_blocks = depth + self.num_heads = num_heads + self.patch_size = patch_size + self.num_register_tokens = num_register_tokens + self.interpolate_antialias = interpolate_antialias + self.interpolate_offset = interpolate_offset + + self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) + num_patches = self.patch_embed.num_patches + + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim)) + assert num_register_tokens >= 0 + self.register_tokens = ( + nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim)) if num_register_tokens else None + ) + + if drop_path_uniform is True: + dpr = [drop_path_rate] * depth + else: + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule + + if ffn_layer == "mlp": + logger.info("using MLP layer as FFN") + ffn_layer = Mlp + elif ffn_layer == "swiglufused" or ffn_layer == "swiglu": + logger.info("using SwiGLU layer as FFN") + ffn_layer = SwiGLUFFNFused + elif ffn_layer == "identity": + logger.info("using Identity layer as FFN") + + def f(*args, **kwargs): + return nn.Identity() + + ffn_layer = f + else: + raise NotImplementedError + + blocks_list = [ + block_fn( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + proj_bias=proj_bias, + ffn_bias=ffn_bias, + drop_path=dpr[i], + norm_layer=norm_layer, + act_layer=act_layer, + ffn_layer=ffn_layer, + init_values=init_values, + qk_norm=qk_norm, + ) + for i in range(depth) + ] + if block_chunks > 0: + self.chunked_blocks = True + chunked_blocks = [] + chunksize = depth // block_chunks + for i in range(0, depth, chunksize): + # this is to keep the block index consistent if we chunk the block list + chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize]) + self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks]) + else: + self.chunked_blocks = False + self.blocks = nn.ModuleList(blocks_list) + + self.norm = norm_layer(embed_dim) + self.head = nn.Identity() + + self.mask_token = nn.Parameter(torch.zeros(1, embed_dim)) + + self.init_weights() + + def init_weights(self): + trunc_normal_(self.pos_embed, std=0.02) + nn.init.normal_(self.cls_token, std=1e-6) + if self.register_tokens is not None: + nn.init.normal_(self.register_tokens, std=1e-6) + named_apply(init_weights_vit_timm, self) + + def interpolate_pos_encoding(self, x, w, h): + previous_dtype = x.dtype + npatch = x.shape[1] - 1 + N = self.pos_embed.shape[1] - 1 + if npatch == N and w == h: + return self.pos_embed + pos_embed = self.pos_embed.float() + class_pos_embed = pos_embed[:, 0] + patch_pos_embed = pos_embed[:, 1:] + dim = x.shape[-1] + w0 = w // self.patch_size + h0 = h // self.patch_size + M = int(math.sqrt(N)) # Recover the number of patches in each dimension + assert N == M * M + kwargs = {} + if self.interpolate_offset: + # Historical kludge: add a small number to avoid floating point error in the interpolation, see https://github.com/facebookresearch/dino/issues/8 + # Note: still needed for backward-compatibility, the underlying operators are using both output size and scale factors + sx = float(w0 + self.interpolate_offset) / M + sy = float(h0 + self.interpolate_offset) / M + kwargs["scale_factor"] = (sx, sy) + else: + # Simply specify an output size instead of a scale factor + kwargs["size"] = (w0, h0) + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed.reshape(1, M, M, dim).permute(0, 3, 1, 2), + mode="bicubic", + antialias=self.interpolate_antialias, + **kwargs, + ) + assert (w0, h0) == patch_pos_embed.shape[-2:] + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype) + + def prepare_tokens_with_masks(self, x, masks=None): + B, nc, w, h = x.shape + x = self.patch_embed(x) + if masks is not None: + x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x) + + x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) + x = x + self.interpolate_pos_encoding(x, w, h) + + if self.register_tokens is not None: + x = torch.cat( + ( + x[:, :1], + self.register_tokens.expand(x.shape[0], -1, -1), + x[:, 1:], + ), + dim=1, + ) + + return x + + def forward_features_list(self, x_list, masks_list): + x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)] + + for blk in self.blocks: + if self.use_checkpoint: + x = checkpoint(blk, x, use_reentrant=self.use_reentrant) + else: + x = blk(x) + + all_x = x + output = [] + for x, masks in zip(all_x, masks_list): + x_norm = self.norm(x) + output.append( + { + "x_norm_clstoken": x_norm[:, 0], + "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1], + "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :], + "x_prenorm": x, + "masks": masks, + } + ) + return output + + def forward_features(self, x, masks=None): + if isinstance(x, list): + return self.forward_features_list(x, masks) + + x = self.prepare_tokens_with_masks(x, masks) + + for blk in self.blocks: + if self.use_checkpoint: + x = checkpoint(blk, x, use_reentrant=self.use_reentrant) + else: + x = blk(x) + + x_norm = self.norm(x) + return { + "x_norm_clstoken": x_norm[:, 0], + "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1], + "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :], + "x_prenorm": x, + "masks": masks, + } + + def _get_intermediate_layers_not_chunked(self, x, n=1): + x = self.prepare_tokens_with_masks(x) + # If n is an int, take the n last blocks. If it's a list, take them + output, total_block_len = [], len(self.blocks) + blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n + for i, blk in enumerate(self.blocks): + x = blk(x) + if i in blocks_to_take: + output.append(x) + assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found" + return output + + def _get_intermediate_layers_chunked(self, x, n=1): + x = self.prepare_tokens_with_masks(x) + output, i, total_block_len = [], 0, len(self.blocks[-1]) + # If n is an int, take the n last blocks. If it's a list, take them + blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n + for block_chunk in self.blocks: + for blk in block_chunk[i:]: # Passing the nn.Identity() + x = blk(x) + if i in blocks_to_take: + output.append(x) + i += 1 + assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found" + return output + + def get_intermediate_layers( + self, + x: torch.Tensor, + n: Union[int, Sequence] = 1, # Layers or n last layers to take + reshape: bool = False, + return_class_token: bool = False, + norm=True, + ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]: + if self.chunked_blocks: + outputs = self._get_intermediate_layers_chunked(x, n) + else: + outputs = self._get_intermediate_layers_not_chunked(x, n) + if norm: + outputs = [self.norm(out) for out in outputs] + class_tokens = [out[:, 0] for out in outputs] + outputs = [out[:, 1 + self.num_register_tokens :] for out in outputs] + if reshape: + B, _, w, h = x.shape + outputs = [ + out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous() + for out in outputs + ] + if return_class_token: + return tuple(zip(outputs, class_tokens)) + return tuple(outputs) + + def forward(self, *args, is_training=True, **kwargs): + ret = self.forward_features(*args, **kwargs) + if is_training: + return ret + else: + return self.head(ret["x_norm_clstoken"]) + + +def init_weights_vit_timm(module: nn.Module, name: str = ""): + """ViT weight initialization, original timm impl (for reproducibility)""" + if isinstance(module, nn.Linear): + trunc_normal_(module.weight, std=0.02) + if module.bias is not None: + nn.init.zeros_(module.bias) + + +def vit_small(patch_size=16, num_register_tokens=0, **kwargs): + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=384, + depth=12, + num_heads=6, + mlp_ratio=4, + block_fn=partial(Block, attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + **kwargs, + ) + return model + + +def vit_base(patch_size=16, num_register_tokens=0, **kwargs): + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4, + block_fn=partial(Block, attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + **kwargs, + ) + return model + + +def vit_large(patch_size=16, num_register_tokens=0, **kwargs): + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=1024, + depth=24, + num_heads=16, + mlp_ratio=4, + block_fn=partial(Block, attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + **kwargs, + ) + return model + + +def vit_giant2(patch_size=16, num_register_tokens=0, **kwargs): + """ + Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64 + """ + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=1536, + depth=40, + num_heads=24, + mlp_ratio=4, + block_fn=partial(Block, attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + **kwargs, + ) + return model diff --git a/wheels/vggt/vggt/models/__pycache__/aggregator.cpython-310.pyc b/wheels/vggt/vggt/models/__pycache__/aggregator.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..86616472c10ec4bd6f9bf6eb75e53553555f219e Binary files /dev/null and b/wheels/vggt/vggt/models/__pycache__/aggregator.cpython-310.pyc differ diff --git a/wheels/vggt/vggt/models/__pycache__/vggt.cpython-310.pyc b/wheels/vggt/vggt/models/__pycache__/vggt.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e60b577117d29848ddf84ab8ab4a4e56c83c397e Binary files /dev/null and b/wheels/vggt/vggt/models/__pycache__/vggt.cpython-310.pyc differ diff --git a/wheels/vggt/vggt/models/aggregator.py b/wheels/vggt/vggt/models/aggregator.py new file mode 100644 index 0000000000000000000000000000000000000000..393f9920a24b05eca3eb82f7db8bd024f9c1636e --- /dev/null +++ b/wheels/vggt/vggt/models/aggregator.py @@ -0,0 +1,331 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import logging +import torch +import torch.nn as nn +import torch.nn.functional as F +from typing import Optional, Tuple, Union, List, Dict, Any + +from vggt.layers import PatchEmbed +from vggt.layers.block import Block +from vggt.layers.rope import RotaryPositionEmbedding2D, PositionGetter +from vggt.layers.vision_transformer import vit_small, vit_base, vit_large, vit_giant2 + +logger = logging.getLogger(__name__) + +_RESNET_MEAN = [0.485, 0.456, 0.406] +_RESNET_STD = [0.229, 0.224, 0.225] + + +class Aggregator(nn.Module): + """ + The Aggregator applies alternating-attention over input frames, + as described in VGGT: Visual Geometry Grounded Transformer. + + + Args: + img_size (int): Image size in pixels. + patch_size (int): Size of each patch for PatchEmbed. + embed_dim (int): Dimension of the token embeddings. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + mlp_ratio (float): Ratio of MLP hidden dim to embedding dim. + num_register_tokens (int): Number of register tokens. + block_fn (nn.Module): The block type used for attention (Block by default). + qkv_bias (bool): Whether to include bias in QKV projections. + proj_bias (bool): Whether to include bias in the output projection. + ffn_bias (bool): Whether to include bias in MLP layers. + patch_embed (str): Type of patch embed. e.g., "conv" or "dinov2_vitl14_reg". + aa_order (list[str]): The order of alternating attention, e.g. ["frame", "global"]. + aa_block_size (int): How many blocks to group under each attention type before switching. If not necessary, set to 1. + qk_norm (bool): Whether to apply QK normalization. + rope_freq (int): Base frequency for rotary embedding. -1 to disable. + init_values (float): Init scale for layer scale. + """ + + def __init__( + self, + img_size=518, + patch_size=14, + embed_dim=1024, + depth=24, + num_heads=16, + mlp_ratio=4.0, + num_register_tokens=4, + block_fn=Block, + qkv_bias=True, + proj_bias=True, + ffn_bias=True, + patch_embed="dinov2_vitl14_reg", + aa_order=["frame", "global"], + aa_block_size=1, + qk_norm=True, + rope_freq=100, + init_values=0.01, + ): + super().__init__() + + self.__build_patch_embed__(patch_embed, img_size, patch_size, num_register_tokens, embed_dim=embed_dim) + + # Initialize rotary position embedding if frequency > 0 + self.rope = RotaryPositionEmbedding2D(frequency=rope_freq) if rope_freq > 0 else None + self.position_getter = PositionGetter() if self.rope is not None else None + + self.frame_blocks = nn.ModuleList( + [ + block_fn( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + proj_bias=proj_bias, + ffn_bias=ffn_bias, + init_values=init_values, + qk_norm=qk_norm, + rope=self.rope, + ) + for _ in range(depth) + ] + ) + + self.global_blocks = nn.ModuleList( + [ + block_fn( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + proj_bias=proj_bias, + ffn_bias=ffn_bias, + init_values=init_values, + qk_norm=qk_norm, + rope=self.rope, + ) + for _ in range(depth) + ] + ) + + self.depth = depth + self.aa_order = aa_order + self.patch_size = patch_size + self.aa_block_size = aa_block_size + + # Validate that depth is divisible by aa_block_size + if self.depth % self.aa_block_size != 0: + raise ValueError(f"depth ({depth}) must be divisible by aa_block_size ({aa_block_size})") + + self.aa_block_num = self.depth // self.aa_block_size + + # Note: We have two camera tokens, one for the first frame and one for the rest + # The same applies for register tokens + self.camera_token = nn.Parameter(torch.randn(1, 2, 1, embed_dim)) + self.register_token = nn.Parameter(torch.randn(1, 2, num_register_tokens, embed_dim)) + + # The patch tokens start after the camera and register tokens + self.patch_start_idx = 1 + num_register_tokens + + # Initialize parameters with small values + nn.init.normal_(self.camera_token, std=1e-6) + nn.init.normal_(self.register_token, std=1e-6) + + # Register normalization constants as buffers + for name, value in ( + ("_resnet_mean", _RESNET_MEAN), + ("_resnet_std", _RESNET_STD), + ): + self.register_buffer( + name, + torch.FloatTensor(value).view(1, 1, 3, 1, 1), + persistent=False, + ) + + def __build_patch_embed__( + self, + patch_embed, + img_size, + patch_size, + num_register_tokens, + interpolate_antialias=True, + interpolate_offset=0.0, + block_chunks=0, + init_values=1.0, + embed_dim=1024, + ): + """ + Build the patch embed layer. If 'conv', we use a + simple PatchEmbed conv layer. Otherwise, we use a vision transformer. + """ + + if "conv" in patch_embed: + self.patch_embed = PatchEmbed(img_size=img_size, patch_size=patch_size, in_chans=3, embed_dim=embed_dim) + else: + vit_models = { + "dinov2_vitl14_reg": vit_large, + "dinov2_vitb14_reg": vit_base, + "dinov2_vits14_reg": vit_small, + "dinov2_vitg2_reg": vit_giant2, + } + + self.patch_embed = vit_models[patch_embed]( + img_size=img_size, + patch_size=patch_size, + num_register_tokens=num_register_tokens, + interpolate_antialias=interpolate_antialias, + interpolate_offset=interpolate_offset, + block_chunks=block_chunks, + init_values=init_values, + ) + + # Disable gradient updates for mask token + if hasattr(self.patch_embed, "mask_token"): + self.patch_embed.mask_token.requires_grad_(False) + + def forward( + self, + images: torch.Tensor, + ) -> Tuple[List[torch.Tensor], int]: + """ + Args: + images (torch.Tensor): Input images with shape [B, S, 3, H, W], in range [0, 1]. + B: batch size, S: sequence length, 3: RGB channels, H: height, W: width + + Returns: + (list[torch.Tensor], int): + The list of outputs from the attention blocks, + and the patch_start_idx indicating where patch tokens begin. + """ + B, S, C_in, H, W = images.shape + + if C_in != 3: + raise ValueError(f"Expected 3 input channels, got {C_in}") + + # Normalize images and reshape for patch embed + images = (images - self._resnet_mean) / self._resnet_std + + # Reshape to [B*S, C, H, W] for patch embedding + images = images.view(B * S, C_in, H, W) + patch_tokens = self.patch_embed(images) + + if isinstance(patch_tokens, dict): + patch_tokens = patch_tokens["x_norm_patchtokens"] + + _, P, C = patch_tokens.shape + + # Expand camera and register tokens to match batch size and sequence length + camera_token = slice_expand_and_flatten(self.camera_token, B, S) + register_token = slice_expand_and_flatten(self.register_token, B, S) + + # Concatenate special tokens with patch tokens + tokens = torch.cat([camera_token, register_token, patch_tokens], dim=1) + + pos = None + if self.rope is not None: + pos = self.position_getter(B * S, H // self.patch_size, W // self.patch_size, device=images.device) + + if self.patch_start_idx > 0: + # do not use position embedding for special tokens (camera and register tokens) + # so set pos to 0 for the special tokens + pos = pos + 1 + pos_special = torch.zeros(B * S, self.patch_start_idx, 2).to(images.device).to(pos.dtype) + pos = torch.cat([pos_special, pos], dim=1) + + # update P because we added special tokens + _, P, C = tokens.shape + + frame_idx = 0 + global_idx = 0 + output_list = [] + + for _ in range(self.aa_block_num): + for attn_type in self.aa_order: + if attn_type == "frame": + tokens, frame_idx, frame_intermediates = self._process_frame_attention( + tokens, B, S, P, C, frame_idx, pos=pos + ) + elif attn_type == "global": + tokens, global_idx, global_intermediates = self._process_global_attention( + tokens, B, S, P, C, global_idx, pos=pos + ) + else: + raise ValueError(f"Unknown attention type: {attn_type}") + + for i in range(len(frame_intermediates)): + # concat frame and global intermediates, [B x S x P x 2C] + concat_inter = torch.cat([frame_intermediates[i], global_intermediates[i]], dim=-1) + output_list.append(concat_inter) + + del concat_inter + del frame_intermediates + del global_intermediates + return output_list, self.patch_start_idx + + def _process_frame_attention(self, tokens, B, S, P, C, frame_idx, pos=None): + """ + Process frame attention blocks. We keep tokens in shape (B*S, P, C). + """ + # If needed, reshape tokens or positions: + if tokens.shape != (B * S, P, C): + tokens = tokens.view(B, S, P, C).view(B * S, P, C) + + if pos is not None and pos.shape != (B * S, P, 2): + pos = pos.view(B, S, P, 2).view(B * S, P, 2) + + intermediates = [] + + # by default, self.aa_block_size=1, which processes one block at a time + for _ in range(self.aa_block_size): + tokens = self.frame_blocks[frame_idx](tokens, pos=pos) + frame_idx += 1 + intermediates.append(tokens.view(B, S, P, C)) + + return tokens, frame_idx, intermediates + + def _process_global_attention(self, tokens, B, S, P, C, global_idx, pos=None): + """ + Process global attention blocks. We keep tokens in shape (B, S*P, C). + """ + if tokens.shape != (B, S * P, C): + tokens = tokens.view(B, S, P, C).view(B, S * P, C) + + if pos is not None and pos.shape != (B, S * P, 2): + pos = pos.view(B, S, P, 2).view(B, S * P, 2) + + intermediates = [] + + # by default, self.aa_block_size=1, which processes one block at a time + for _ in range(self.aa_block_size): + tokens = self.global_blocks[global_idx](tokens, pos=pos) + global_idx += 1 + intermediates.append(tokens.view(B, S, P, C)) + + return tokens, global_idx, intermediates + + +def slice_expand_and_flatten(token_tensor, B, S): + """ + Processes specialized tokens with shape (1, 2, X, C) for multi-frame processing: + 1) Uses the first position (index=0) for the first frame only + 2) Uses the second position (index=1) for all remaining frames (S-1 frames) + 3) Expands both to match batch size B + 4) Concatenates to form (B, S, X, C) where each sequence has 1 first-position token + followed by (S-1) second-position tokens + 5) Flattens to (B*S, X, C) for processing + + Returns: + torch.Tensor: Processed tokens with shape (B*S, X, C) + """ + + # Slice out the "query" tokens => shape (1, 1, ...) + query = token_tensor[:, 0:1, ...].expand(B, 1, *token_tensor.shape[2:]) + # Slice out the "other" tokens => shape (1, S-1, ...) + others = token_tensor[:, 1:, ...].expand(B, S - 1, *token_tensor.shape[2:]) + # Concatenate => shape (B, S, ...) + combined = torch.cat([query, others], dim=1) + + # Finally flatten => shape (B*S, ...) + combined = combined.view(B * S, *combined.shape[2:]) + return combined diff --git a/wheels/vggt/vggt/models/vggt.py b/wheels/vggt/vggt/models/vggt.py new file mode 100644 index 0000000000000000000000000000000000000000..1b4dccd7400dc0ab32d2fb4b6457bbf22e81a82c --- /dev/null +++ b/wheels/vggt/vggt/models/vggt.py @@ -0,0 +1,95 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn +from huggingface_hub import PyTorchModelHubMixin # used for model hub +from vggt.models.aggregator import Aggregator +from vggt.heads.camera_head import CameraHead +from vggt.heads.dpt_head import DPTHead +from vggt.heads.track_head import TrackHead + + +class VGGT(nn.Module, PyTorchModelHubMixin): + def __init__(self, img_size=518, patch_size=14, embed_dim=1024): + super().__init__() + + self.aggregator = Aggregator(img_size=img_size, patch_size=patch_size, embed_dim=embed_dim) + self.camera_head = CameraHead(dim_in=2 * embed_dim) + self.point_head = DPTHead(dim_in=2 * embed_dim, output_dim=4, activation="inv_log", conf_activation="expp1") + self.depth_head = DPTHead(dim_in=2 * embed_dim, output_dim=2, activation="exp", conf_activation="expp1") + self.track_head = TrackHead(dim_in=2 * embed_dim, patch_size=patch_size) + + def forward( + self, + images: torch.Tensor, + query_points: torch.Tensor = None, + ): + """ + Forward pass of the VGGT model. + + Args: + images (torch.Tensor): Input images with shape [S, 3, H, W] or [B, S, 3, H, W], in range [0, 1]. + B: batch size, S: sequence length, 3: RGB channels, H: height, W: width + query_points (torch.Tensor, optional): Query points for tracking, in pixel coordinates. + Shape: [N, 2] or [B, N, 2], where N is the number of query points. + Default: None + + Returns: + dict: A dictionary containing the following predictions: + - pose_enc (torch.Tensor): Camera pose encoding with shape [B, S, 9] (from the last iteration) + - depth (torch.Tensor): Predicted depth maps with shape [B, S, H, W, 1] + - depth_conf (torch.Tensor): Confidence scores for depth predictions with shape [B, S, H, W] + - world_points (torch.Tensor): 3D world coordinates for each pixel with shape [B, S, H, W, 3] + - world_points_conf (torch.Tensor): Confidence scores for world points with shape [B, S, H, W] + - images (torch.Tensor): Original input images, preserved for visualization + + If query_points is provided, also includes: + - track (torch.Tensor): Point tracks with shape [B, S, N, 2] (from the last iteration), in pixel coordinates + - vis (torch.Tensor): Visibility scores for tracked points with shape [B, S, N] + - conf (torch.Tensor): Confidence scores for tracked points with shape [B, S, N] + """ + + # If without batch dimension, add it + if len(images.shape) == 4: + images = images.unsqueeze(0) + if query_points is not None and len(query_points.shape) == 2: + query_points = query_points.unsqueeze(0) + + aggregated_tokens_list, patch_start_idx = self.aggregator(images) + + predictions = {} + + with torch.cuda.amp.autocast(enabled=False): + if self.camera_head is not None: + pose_enc_list = self.camera_head(aggregated_tokens_list) + predictions["pose_enc"] = pose_enc_list[-1] # pose encoding of the last iteration + + if self.depth_head is not None: + depth, depth_conf = self.depth_head( + aggregated_tokens_list, images=images, patch_start_idx=patch_start_idx + ) + predictions["depth"] = depth + predictions["depth_conf"] = depth_conf + + if self.point_head is not None: + pts3d, pts3d_conf = self.point_head( + aggregated_tokens_list, images=images, patch_start_idx=patch_start_idx + ) + predictions["world_points"] = pts3d + predictions["world_points_conf"] = pts3d_conf + + if self.track_head is not None and query_points is not None: + track_list, vis, conf = self.track_head( + aggregated_tokens_list, images=images, patch_start_idx=patch_start_idx, query_points=query_points + ) + predictions["track"] = track_list[-1] # track of the last iteration + predictions["vis"] = vis + predictions["conf"] = conf + + predictions["images"] = images + + return predictions diff --git a/wheels/vggt/vggt/utils/__pycache__/geometry.cpython-310.pyc b/wheels/vggt/vggt/utils/__pycache__/geometry.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..68e61c82b54fa383fac3dade527ff78dc11d4a00 Binary files /dev/null and b/wheels/vggt/vggt/utils/__pycache__/geometry.cpython-310.pyc differ diff --git a/wheels/vggt/vggt/utils/__pycache__/load_fn.cpython-310.pyc b/wheels/vggt/vggt/utils/__pycache__/load_fn.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ee6e5f6e226891834ce517213440a967ccde1b8e Binary files /dev/null and b/wheels/vggt/vggt/utils/__pycache__/load_fn.cpython-310.pyc differ diff --git a/wheels/vggt/vggt/utils/__pycache__/pose_enc.cpython-310.pyc b/wheels/vggt/vggt/utils/__pycache__/pose_enc.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3a60ebb45c236b69fda43f8824e386e6d6a4e42d Binary files /dev/null and b/wheels/vggt/vggt/utils/__pycache__/pose_enc.cpython-310.pyc differ diff --git a/wheels/vggt/vggt/utils/__pycache__/rotation.cpython-310.pyc b/wheels/vggt/vggt/utils/__pycache__/rotation.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..df292acd61983a3e6e10eae4df14ba2b0f799568 Binary files /dev/null and b/wheels/vggt/vggt/utils/__pycache__/rotation.cpython-310.pyc differ diff --git a/wheels/vggt/vggt/utils/geometry.py b/wheels/vggt/vggt/utils/geometry.py new file mode 100644 index 0000000000000000000000000000000000000000..e9c3fdedf7fa644ba8d06f32f76de12b422169e7 --- /dev/null +++ b/wheels/vggt/vggt/utils/geometry.py @@ -0,0 +1,236 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import os +import torch +import numpy as np + + +def unproject_depth_map_to_point_map( + depth_map: np.ndarray, extrinsics_cam: np.ndarray, intrinsics_cam: np.ndarray +) -> np.ndarray: + """ + Unproject a batch of depth maps to 3D world coordinates. + + Args: + depth_map (np.ndarray): Batch of depth maps of shape (S, H, W, 1) or (S, H, W) + extrinsics_cam (np.ndarray): Batch of camera extrinsic matrices of shape (S, 3, 4) + intrinsics_cam (np.ndarray): Batch of camera intrinsic matrices of shape (S, 3, 3) + + Returns: + np.ndarray: Batch of 3D world coordinates of shape (S, H, W, 3) + """ + if isinstance(depth_map, torch.Tensor): + depth_map = depth_map.cpu().numpy() + if isinstance(extrinsics_cam, torch.Tensor): + extrinsics_cam = extrinsics_cam.cpu().numpy() + if isinstance(intrinsics_cam, torch.Tensor): + intrinsics_cam = intrinsics_cam.cpu().numpy() + + world_points_list = [] + for frame_idx in range(depth_map.shape[0]): + cur_world_points, _, _ = depth_to_world_coords_points( + depth_map[frame_idx].squeeze(-1), extrinsics_cam[frame_idx], intrinsics_cam[frame_idx] + ) + world_points_list.append(cur_world_points) + world_points_array = np.stack(world_points_list, axis=0) + + return world_points_array + + +def depth_to_world_coords_points( + depth_map: np.ndarray, + extrinsic: np.ndarray, + intrinsic: np.ndarray, + eps=1e-8, +) -> tuple[np.ndarray, np.ndarray, np.ndarray]: + """ + Convert a depth map to world coordinates. + + Args: + depth_map (np.ndarray): Depth map of shape (H, W). + intrinsic (np.ndarray): Camera intrinsic matrix of shape (3, 3). + extrinsic (np.ndarray): Camera extrinsic matrix of shape (3, 4). OpenCV camera coordinate convention, cam from world. + + Returns: + tuple[np.ndarray, np.ndarray]: World coordinates (H, W, 3) and valid depth mask (H, W). + """ + if depth_map is None: + return None, None, None + + # Valid depth mask + point_mask = depth_map > eps + + # Convert depth map to camera coordinates + cam_coords_points = depth_to_cam_coords_points(depth_map, intrinsic) + + # Multiply with the inverse of extrinsic matrix to transform to world coordinates + # extrinsic_inv is 4x4 (note closed_form_inverse_OpenCV is batched, the output is (N, 4, 4)) + cam_to_world_extrinsic = closed_form_inverse_se3(extrinsic[None])[0] + + R_cam_to_world = cam_to_world_extrinsic[:3, :3] + t_cam_to_world = cam_to_world_extrinsic[:3, 3] + + # Apply the rotation and translation to the camera coordinates + world_coords_points = np.dot(cam_coords_points, R_cam_to_world.T) + t_cam_to_world # HxWx3, 3x3 -> HxWx3 + # world_coords_points = np.einsum("ij,hwj->hwi", R_cam_to_world, cam_coords_points) + t_cam_to_world + + return world_coords_points, cam_coords_points, point_mask + + +def depth_to_cam_coords_points(depth_map: np.ndarray, intrinsic: np.ndarray) -> tuple[np.ndarray, np.ndarray]: + """ + Convert a depth map to camera coordinates. + + Args: + depth_map (np.ndarray): Depth map of shape (H, W). + intrinsic (np.ndarray): Camera intrinsic matrix of shape (3, 3). + + Returns: + tuple[np.ndarray, np.ndarray]: Camera coordinates (H, W, 3) + """ + H, W = depth_map.shape + assert intrinsic.shape == (3, 3), "Intrinsic matrix must be 3x3" + assert intrinsic[0, 1] == 0 and intrinsic[1, 0] == 0, "Intrinsic matrix must have zero skew" + + # Intrinsic parameters + fu, fv = intrinsic[0, 0], intrinsic[1, 1] + cu, cv = intrinsic[0, 2], intrinsic[1, 2] + + # Generate grid of pixel coordinates + u, v = np.meshgrid(np.arange(W), np.arange(H)) + + # Unproject to camera coordinates + x_cam = (u - cu) * depth_map / fu + y_cam = (v - cv) * depth_map / fv + z_cam = depth_map + + # Stack to form camera coordinates + cam_coords = np.stack((x_cam, y_cam, z_cam), axis=-1).astype(np.float32) + + return cam_coords + + +def closed_form_inverse_se3(se3, R=None, T=None): + """ + Compute the inverse of each 4x4 (or 3x4) SE3 matrix in a batch. + + If `R` and `T` are provided, they must correspond to the rotation and translation + components of `se3`. Otherwise, they will be extracted from `se3`. + + Args: + se3: Nx4x4 or Nx3x4 array or tensor of SE3 matrices. + R (optional): Nx3x3 array or tensor of rotation matrices. + T (optional): Nx3x1 array or tensor of translation vectors. + + Returns: + Inverted SE3 matrices with the same type and device as `se3`. + + Shapes: + se3: (N, 4, 4) + R: (N, 3, 3) + T: (N, 3, 1) + """ + # Check if se3 is a numpy array or a torch tensor + is_numpy = isinstance(se3, np.ndarray) + + # Validate shapes + if se3.shape[-2:] != (4, 4) and se3.shape[-2:] != (3, 4): + raise ValueError(f"se3 must be of shape (N,4,4), got {se3.shape}.") + + # Extract R and T if not provided + if R is None: + R = se3[:, :3, :3] # (N,3,3) + if T is None: + T = se3[:, :3, 3:] # (N,3,1) + + # Transpose R + if is_numpy: + # Compute the transpose of the rotation for NumPy + R_transposed = np.transpose(R, (0, 2, 1)) + # -R^T t for NumPy + top_right = -np.matmul(R_transposed, T) + inverted_matrix = np.tile(np.eye(4), (len(R), 1, 1)) + else: + R_transposed = R.permute(0, 2, 1) # (N,3,3) + top_right = -torch.bmm(R_transposed, T) # (N,3,1) + inverted_matrix = torch.eye(4, 4)[None].repeat(len(R), 1, 1) + inverted_matrix = inverted_matrix.to(R.dtype).to(R.device) + + inverted_matrix[:, :3, :3] = R_transposed + inverted_matrix[:, :3, 3:] = top_right + + return inverted_matrix + +def depth_to_cam_coords_points_tensor(depth_map: torch.Tensor, intrinsic: torch.Tensor) -> torch.Tensor: + """ + Convert a depth map to camera coordinates. + + Args: + depth_map (torch.Tensor): Depth map of shape (B, H, W). + intrinsic (torch.Tensor): Camera intrinsic matrix of shape (B, 3, 3). + + Returns: + torch.Tensor: Camera coordinates (B, H, W, 3) + """ + B, H, W = depth_map.shape + + # Intrinsic parameters + fu, fv = intrinsic[:, 0, 0], intrinsic[:, 1, 1] + cu, cv = intrinsic[:, 0, 2], intrinsic[:, 1, 2] + + # Generate grid of pixel coordinates + v, u = torch.meshgrid(torch.arange(W).to(depth_map.device), torch.arange(H).to(depth_map.device)) + + # Unproject to camera coordinates + x_cam = (u[None] - cu[:, None, None]) * depth_map / fu[:, None, None] + y_cam = (v[None] - cv[:, None, None]) * depth_map / fv[:, None, None] + z_cam = depth_map + + # Stack to form camera coordinates + cam_coords = torch.stack((x_cam, y_cam, z_cam), dim=-1).float() + + return cam_coords + +def depth_to_world_coords_points_tensor( + depth_map: torch.Tensor, + extrinsic: torch.Tensor, + intrinsic: torch.Tensor, + eps=1e-8, +) -> torch.Tensor: + """ + Convert a depth map to world coordinates. + + Args: + depth_map (torch.Tensor): Depth map of shape (B, H, W, 1). + intrinsic (torch.Tensor): Camera intrinsic matrix of shape (B, 3, 3). + extrinsic (torch.Tensor): Camera extrinsic matrix of shape (B, 3, 4). OpenCV camera coordinate convention, cam from world. + + Returns: + torch.Tensor: World coordinates (B, H, W, 3). + """ + if depth_map is None: + return None + + # Valid depth mask + point_mask = depth_map > eps + + # Convert depth map to camera coordinates + cam_coords_points = depth_to_cam_coords_points_tensor(depth_map, intrinsic) + + # Multiply with the inverse of extrinsic matrix to transform to world coordinates + # extrinsic_inv is 4x4 (note closed_form_inverse_OpenCV is batched, the output is (N, 4, 4)) + cam_to_world_extrinsic = closed_form_inverse_se3(extrinsic) + + R_cam_to_world = cam_to_world_extrinsic[:, :3, :3] + t_cam_to_world = cam_to_world_extrinsic[:, :3, 3] + + B, H, W, _ = cam_coords_points.shape + # Apply the rotation and translation to the camera coordinates + world_coords_points = torch.matmul(cam_coords_points.reshape(B, -1, 3), R_cam_to_world.float().permute(0,2,1)) + t_cam_to_world[:,None] # BxHWx3, Bx3x3 -> HxWx3 + # world_coords_points = np.einsum("ij,hwj->hwi", R_cam_to_world, cam_coords_points) + t_cam_to_world + + return world_coords_points.reshape(B, H, W, 3) \ No newline at end of file diff --git a/wheels/vggt/vggt/utils/load_fn.py b/wheels/vggt/vggt/utils/load_fn.py new file mode 100644 index 0000000000000000000000000000000000000000..1fa765a5dba0892573ab0d111dd1abea3f923109 --- /dev/null +++ b/wheels/vggt/vggt/utils/load_fn.py @@ -0,0 +1,118 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from PIL import Image +from torchvision import transforms as TF + + +def load_and_preprocess_images(image_path_list): + """ + A quick start function to load and preprocess images for model input. + This assumes the images should have the same shape for easier batching, but our model can also work well with different shapes. + + Args: + image_path_list (list): List of paths to image files + + Returns: + torch.Tensor: Batched tensor of preprocessed images with shape (N, 3, H, W) + + Raises: + ValueError: If the input list is empty + + Notes: + - Images with different dimensions will be padded with white (value=1.0) + - A warning is printed when images have different shapes + - The function ensures width=518px while maintaining aspect ratio + - Height is adjusted to be divisible by 14 for compatibility with model requirements + """ + # Check for empty list + if len(image_path_list) == 0: + raise ValueError("At least 1 image is required") + + images = [] + alphas = [] + shapes = set() + to_tensor = TF.ToTensor() + + # First process all images and collect their shapes + # for image_path in image_path_list: + for img in image_path_list: + + # Open image + # img = Image.open(image_path) + img = img[0] + + # If there's an alpha channel, blend onto white background: + if img.mode == "RGBA": + # Create white background + alphas.append(to_tensor(img)[3:]) + # background = Image.new("RGBA", img.size, (255, 255, 255, 255)) + # Alpha composite onto the white background + # img = Image.alpha_composite(background, img) + + + # Now convert to "RGB" (this step assigns white for transparent areas) + img = img.convert("RGB") + + width, height = img.size + new_width = 518 + + # Calculate height maintaining aspect ratio, divisible by 14 + new_height = round(height * (new_width / width) / 14) * 14 + + # Resize with new dimensions (width, height) + + img = img.resize((new_width, new_height), Image.Resampling.BICUBIC) + img = to_tensor(img) # Convert to tensor (0, 1) + + # Center crop height if it's larger than 518 + + if new_height > 518: + start_y = (new_height - 518) // 2 + img = img[:, start_y : start_y + 518, :] + + shapes.add((img.shape[1], img.shape[2])) + images.append(img) + + # Check if we have different shapes + # In theory our model can also work well with different shapes + + if len(shapes) > 1: + print(f"Warning: Found images with different shapes: {shapes}") + # Find maximum dimensions + max_height = max(shape[0] for shape in shapes) + max_width = max(shape[1] for shape in shapes) + + # Pad images if necessary + padded_images = [] + for img in images: + h_padding = max_height - img.shape[1] + w_padding = max_width - img.shape[2] + + if h_padding > 0 or w_padding > 0: + pad_top = h_padding // 2 + pad_bottom = h_padding - pad_top + pad_left = w_padding // 2 + pad_right = w_padding - pad_left + + img = torch.nn.functional.pad( + img, (pad_left, pad_right, pad_top, pad_bottom), mode="constant", value=1.0 + ) + padded_images.append(img) + images = padded_images + + images = torch.stack(images) # concatenate images + alphas = torch.stack(alphas) # concatenate images + + # Ensure correct shape when single image + if len(image_path_list) == 1: + # Verify shape is (1, C, H, W) + if images.dim() == 3: + images = images.unsqueeze(0) + alphas = alphas.unsqueeze(0) + + return images, alphas diff --git a/wheels/vggt/vggt/utils/pose_enc.py b/wheels/vggt/vggt/utils/pose_enc.py new file mode 100644 index 0000000000000000000000000000000000000000..2f98b0878cb13451b8cdb80074349cbf2644c5fa --- /dev/null +++ b/wheels/vggt/vggt/utils/pose_enc.py @@ -0,0 +1,130 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from .rotation import quat_to_mat, mat_to_quat + + +def extri_intri_to_pose_encoding( + extrinsics, + intrinsics, + image_size_hw=None, # e.g., (256, 512) + pose_encoding_type="absT_quaR_FoV", +): + """Convert camera extrinsics and intrinsics to a compact pose encoding. + + This function transforms camera parameters into a unified pose encoding format, + which can be used for various downstream tasks like pose prediction or representation. + + Args: + extrinsics (torch.Tensor): Camera extrinsic parameters with shape BxSx3x4, + where B is batch size and S is sequence length. + In OpenCV coordinate system (x-right, y-down, z-forward), representing camera from world transformation. + The format is [R|t] where R is a 3x3 rotation matrix and t is a 3x1 translation vector. + intrinsics (torch.Tensor): Camera intrinsic parameters with shape BxSx3x3. + Defined in pixels, with format: + [[fx, 0, cx], + [0, fy, cy], + [0, 0, 1]] + where fx, fy are focal lengths and (cx, cy) is the principal point + image_size_hw (tuple): Tuple of (height, width) of the image in pixels. + Required for computing field of view values. For example: (256, 512). + pose_encoding_type (str): Type of pose encoding to use. Currently only + supports "absT_quaR_FoV" (absolute translation, quaternion rotation, field of view). + + Returns: + torch.Tensor: Encoded camera pose parameters with shape BxSx9. + For "absT_quaR_FoV" type, the 9 dimensions are: + - [:3] = absolute translation vector T (3D) + - [3:7] = rotation as quaternion quat (4D) + - [7:] = field of view (2D) + """ + + # extrinsics: BxSx3x4 + # intrinsics: BxSx3x3 + + if pose_encoding_type == "absT_quaR_FoV": + R = extrinsics[:, :, :3, :3] # BxSx3x3 + T = extrinsics[:, :, :3, 3] # BxSx3 + + quat = mat_to_quat(R) + # Note the order of h and w here + H, W = image_size_hw + fov_h = 2 * torch.atan((H / 2) / intrinsics[..., 1, 1]) + fov_w = 2 * torch.atan((W / 2) / intrinsics[..., 0, 0]) + pose_encoding = torch.cat([T, quat, fov_h[..., None], fov_w[..., None]], dim=-1).float() + else: + raise NotImplementedError + + return pose_encoding + + +def pose_encoding_to_extri_intri( + pose_encoding, + image_size_hw=None, # e.g., (256, 512) + pose_encoding_type="absT_quaR_FoV", + build_intrinsics=True, +): + """Convert a pose encoding back to camera extrinsics and intrinsics. + + This function performs the inverse operation of extri_intri_to_pose_encoding, + reconstructing the full camera parameters from the compact encoding. + + Args: + pose_encoding (torch.Tensor): Encoded camera pose parameters with shape BxSx9, + where B is batch size and S is sequence length. + For "absT_quaR_FoV" type, the 9 dimensions are: + - [:3] = absolute translation vector T (3D) + - [3:7] = rotation as quaternion quat (4D) + - [7:] = field of view (2D) + image_size_hw (tuple): Tuple of (height, width) of the image in pixels. + Required for reconstructing intrinsics from field of view values. + For example: (256, 512). + pose_encoding_type (str): Type of pose encoding used. Currently only + supports "absT_quaR_FoV" (absolute translation, quaternion rotation, field of view). + build_intrinsics (bool): Whether to reconstruct the intrinsics matrix. + If False, only extrinsics are returned and intrinsics will be None. + + Returns: + tuple: (extrinsics, intrinsics) + - extrinsics (torch.Tensor): Camera extrinsic parameters with shape BxSx3x4. + In OpenCV coordinate system (x-right, y-down, z-forward), representing camera from world + transformation. The format is [R|t] where R is a 3x3 rotation matrix and t is + a 3x1 translation vector. + - intrinsics (torch.Tensor or None): Camera intrinsic parameters with shape BxSx3x3, + or None if build_intrinsics is False. Defined in pixels, with format: + [[fx, 0, cx], + [0, fy, cy], + [0, 0, 1]] + where fx, fy are focal lengths and (cx, cy) is the principal point, + assumed to be at the center of the image (W/2, H/2). + """ + + intrinsics = None + + if pose_encoding_type == "absT_quaR_FoV": + T = pose_encoding[..., :3] + quat = pose_encoding[..., 3:7] + fov_h = pose_encoding[..., 7] + fov_w = pose_encoding[..., 8] + + R = quat_to_mat(quat) + extrinsics = torch.cat([R, T[..., None]], dim=-1) + + if build_intrinsics: + H, W = image_size_hw + fy = (H / 2.0) / torch.tan(fov_h / 2.0) + fx = (W / 2.0) / torch.tan(fov_w / 2.0) + intrinsics = torch.zeros(pose_encoding.shape[:2] + (3, 3), device=pose_encoding.device) + intrinsics[..., 0, 0] = fx + intrinsics[..., 1, 1] = fy + intrinsics[..., 0, 2] = W / 2 + intrinsics[..., 1, 2] = H / 2 + intrinsics[..., 2, 2] = 1.0 # Set the homogeneous coordinate to 1 + else: + raise NotImplementedError + + return extrinsics, intrinsics diff --git a/wheels/vggt/vggt/utils/rotation.py b/wheels/vggt/vggt/utils/rotation.py new file mode 100644 index 0000000000000000000000000000000000000000..657583e6915437c824c192d51939990b589a14fa --- /dev/null +++ b/wheels/vggt/vggt/utils/rotation.py @@ -0,0 +1,138 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# Modified from PyTorch3D, https://github.com/facebookresearch/pytorch3d + +import torch +import numpy as np +import torch.nn.functional as F + + +def quat_to_mat(quaternions: torch.Tensor) -> torch.Tensor: + """ + Quaternion Order: XYZW or say ijkr, scalar-last + + Convert rotations given as quaternions to rotation matrices. + Args: + quaternions: quaternions with real part last, + as tensor of shape (..., 4). + + Returns: + Rotation matrices as tensor of shape (..., 3, 3). + """ + i, j, k, r = torch.unbind(quaternions, -1) + # pyre-fixme[58]: `/` is not supported for operand types `float` and `Tensor`. + two_s = 2.0 / (quaternions * quaternions).sum(-1) + + o = torch.stack( + ( + 1 - two_s * (j * j + k * k), + two_s * (i * j - k * r), + two_s * (i * k + j * r), + two_s * (i * j + k * r), + 1 - two_s * (i * i + k * k), + two_s * (j * k - i * r), + two_s * (i * k - j * r), + two_s * (j * k + i * r), + 1 - two_s * (i * i + j * j), + ), + -1, + ) + return o.reshape(quaternions.shape[:-1] + (3, 3)) + + +def mat_to_quat(matrix: torch.Tensor) -> torch.Tensor: + """ + Convert rotations given as rotation matrices to quaternions. + + Args: + matrix: Rotation matrices as tensor of shape (..., 3, 3). + + Returns: + quaternions with real part last, as tensor of shape (..., 4). + Quaternion Order: XYZW or say ijkr, scalar-last + """ + if matrix.size(-1) != 3 or matrix.size(-2) != 3: + raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.") + + batch_dim = matrix.shape[:-2] + m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind(matrix.reshape(batch_dim + (9,)), dim=-1) + + q_abs = _sqrt_positive_part( + torch.stack( + [ + 1.0 + m00 + m11 + m22, + 1.0 + m00 - m11 - m22, + 1.0 - m00 + m11 - m22, + 1.0 - m00 - m11 + m22, + ], + dim=-1, + ) + ) + + # we produce the desired quaternion multiplied by each of r, i, j, k + quat_by_rijk = torch.stack( + [ + # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and + # `int`. + torch.stack([q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1), + # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and + # `int`. + torch.stack([m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1), + # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and + # `int`. + torch.stack([m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1), + # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and + # `int`. + torch.stack([m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1), + ], + dim=-2, + ) + + # We floor here at 0.1 but the exact level is not important; if q_abs is small, + # the candidate won't be picked. + flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device) + quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr)) + + # if not for numerical problems, quat_candidates[i] should be same (up to a sign), + # forall i; we pick the best-conditioned one (with the largest denominator) + out = quat_candidates[F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, :].reshape(batch_dim + (4,)) + + # Convert from rijk to ijkr + out = out[..., [1, 2, 3, 0]] + + out = standardize_quaternion(out) + + return out + + +def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor: + """ + Returns torch.sqrt(torch.max(0, x)) + but with a zero subgradient where x is 0. + """ + ret = torch.zeros_like(x) + positive_mask = x > 0 + if torch.is_grad_enabled(): + ret[positive_mask] = torch.sqrt(x[positive_mask]) + else: + ret = torch.where(positive_mask, torch.sqrt(x), ret) + return ret + + +def standardize_quaternion(quaternions: torch.Tensor) -> torch.Tensor: + """ + Convert a unit quaternion to a standard form: one in which the real + part is non negative. + + Args: + quaternions: Quaternions with real part last, + as tensor of shape (..., 4). + + Returns: + Standardized quaternions as tensor of shape (..., 4). + """ + return torch.where(quaternions[..., 3:4] < 0, -quaternions, quaternions) diff --git a/wheels/vggt/vggt/utils/visual_track.py b/wheels/vggt/vggt/utils/visual_track.py new file mode 100644 index 0000000000000000000000000000000000000000..796c114ccba00b5f7850e04b9444a6cd5c44b154 --- /dev/null +++ b/wheels/vggt/vggt/utils/visual_track.py @@ -0,0 +1,239 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import cv2 +import torch +import numpy as np +import os + + +def color_from_xy(x, y, W, H, cmap_name="hsv"): + """ + Map (x, y) -> color in (R, G, B). + 1) Normalize x,y to [0,1]. + 2) Combine them into a single scalar c in [0,1]. + 3) Use matplotlib's colormap to convert c -> (R,G,B). + + You can customize step 2, e.g., c = (x + y)/2, or some function of (x, y). + """ + import matplotlib.cm + import matplotlib.colors + + x_norm = x / max(W - 1, 1) + y_norm = y / max(H - 1, 1) + # Simple combination: + c = (x_norm + y_norm) / 2.0 + + cmap = matplotlib.cm.get_cmap(cmap_name) + # cmap(c) -> (r,g,b,a) in [0,1] + rgba = cmap(c) + r, g, b = rgba[0], rgba[1], rgba[2] + return (r, g, b) # in [0,1], RGB order + + +def get_track_colors_by_position(tracks_b, vis_mask_b=None, image_width=None, image_height=None, cmap_name="hsv"): + """ + Given all tracks in one sample (b), compute a (N,3) array of RGB color values + in [0,255]. The color is determined by the (x,y) position in the first + visible frame for each track. + + Args: + tracks_b: Tensor of shape (S, N, 2). (x,y) for each track in each frame. + vis_mask_b: (S, N) boolean mask; if None, assume all are visible. + image_width, image_height: used for normalizing (x, y). + cmap_name: for matplotlib (e.g., 'hsv', 'rainbow', 'jet'). + + Returns: + track_colors: np.ndarray of shape (N, 3), each row is (R,G,B) in [0,255]. + """ + S, N, _ = tracks_b.shape + track_colors = np.zeros((N, 3), dtype=np.uint8) + + if vis_mask_b is None: + # treat all as visible + vis_mask_b = torch.ones(S, N, dtype=torch.bool, device=tracks_b.device) + + for i in range(N): + # Find first visible frame for track i + visible_frames = torch.where(vis_mask_b[:, i])[0] + if len(visible_frames) == 0: + # track is never visible; just assign black or something + track_colors[i] = (0, 0, 0) + continue + + first_s = int(visible_frames[0].item()) + # use that frame's (x,y) + x, y = tracks_b[first_s, i].tolist() + + # map (x,y) -> (R,G,B) in [0,1] + r, g, b = color_from_xy(x, y, W=image_width, H=image_height, cmap_name=cmap_name) + # scale to [0,255] + r, g, b = int(r * 255), int(g * 255), int(b * 255) + track_colors[i] = (r, g, b) + + return track_colors + + +def visualize_tracks_on_images( + images, + tracks, + track_vis_mask=None, + out_dir="track_visuals_concat_by_xy", + image_format="CHW", # "CHW" or "HWC" + normalize_mode="[0,1]", + cmap_name="hsv", # e.g. "hsv", "rainbow", "jet" + frames_per_row=4, # New parameter for grid layout + save_grid=True, # Flag to control whether to save the grid image +): + """ + Visualizes frames in a grid layout with specified frames per row. + Each track's color is determined by its (x,y) position + in the first visible frame (or frame 0 if always visible). + Finally convert the BGR result to RGB before saving. + Also saves each individual frame as a separate PNG file. + + Args: + images: torch.Tensor (S, 3, H, W) if CHW or (S, H, W, 3) if HWC. + tracks: torch.Tensor (S, N, 2), last dim = (x, y). + track_vis_mask: torch.Tensor (S, N) or None. + out_dir: folder to save visualizations. + image_format: "CHW" or "HWC". + normalize_mode: "[0,1]", "[-1,1]", or None for direct raw -> 0..255 + cmap_name: a matplotlib colormap name for color_from_xy. + frames_per_row: number of frames to display in each row of the grid. + save_grid: whether to save all frames in one grid image. + + Returns: + None (saves images in out_dir). + """ + + if len(tracks.shape) == 4: + tracks = tracks.squeeze(0) + images = images.squeeze(0) + if track_vis_mask is not None: + track_vis_mask = track_vis_mask.squeeze(0) + + import matplotlib + + matplotlib.use("Agg") # for non-interactive (optional) + + os.makedirs(out_dir, exist_ok=True) + + S = images.shape[0] + _, N, _ = tracks.shape # (S, N, 2) + + # Move to CPU + images = images.cpu().clone() + tracks = tracks.cpu().clone() + if track_vis_mask is not None: + track_vis_mask = track_vis_mask.cpu().clone() + + # Infer H, W from images shape + if image_format == "CHW": + # e.g. images[s].shape = (3, H, W) + H, W = images.shape[2], images.shape[3] + else: + # e.g. images[s].shape = (H, W, 3) + H, W = images.shape[1], images.shape[2] + + # Pre-compute the color for each track i based on first visible position + track_colors_rgb = get_track_colors_by_position( + tracks, # shape (S, N, 2) + vis_mask_b=track_vis_mask if track_vis_mask is not None else None, + image_width=W, + image_height=H, + cmap_name=cmap_name, + ) + + # We'll accumulate each frame's drawn image in a list + frame_images = [] + + for s in range(S): + # shape => either (3, H, W) or (H, W, 3) + img = images[s] + + # Convert to (H, W, 3) + if image_format == "CHW": + img = img.permute(1, 2, 0) # (H, W, 3) + # else "HWC", do nothing + + img = img.numpy().astype(np.float32) + + # Scale to [0,255] if needed + if normalize_mode == "[0,1]": + img = np.clip(img, 0, 1) * 255.0 + elif normalize_mode == "[-1,1]": + img = (img + 1.0) * 0.5 * 255.0 + img = np.clip(img, 0, 255.0) + # else no normalization + + # Convert to uint8 + img = img.astype(np.uint8) + + # For drawing in OpenCV, convert to BGR + img_bgr = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) + + # Draw each visible track + cur_tracks = tracks[s] # shape (N, 2) + if track_vis_mask is not None: + valid_indices = torch.where(track_vis_mask[s])[0] + else: + valid_indices = range(N) + + cur_tracks_np = cur_tracks.numpy() + for i in valid_indices: + x, y = cur_tracks_np[i] + pt = (int(round(x)), int(round(y))) + + # track_colors_rgb[i] is (R,G,B). For OpenCV circle, we need BGR + R, G, B = track_colors_rgb[i] + color_bgr = (int(B), int(G), int(R)) + cv2.circle(img_bgr, pt, radius=3, color=color_bgr, thickness=-1) + + # Convert back to RGB for consistent final saving: + img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB) + + # Save individual frame + frame_path = os.path.join(out_dir, f"frame_{s:04d}.png") + # Convert to BGR for OpenCV imwrite + frame_bgr = cv2.cvtColor(img_rgb, cv2.COLOR_RGB2BGR) + cv2.imwrite(frame_path, frame_bgr) + + frame_images.append(img_rgb) + + # Only create and save the grid image if save_grid is True + if save_grid: + # Calculate grid dimensions + num_rows = (S + frames_per_row - 1) // frames_per_row # Ceiling division + + # Create a grid of images + grid_img = None + for row in range(num_rows): + start_idx = row * frames_per_row + end_idx = min(start_idx + frames_per_row, S) + + # Concatenate this row horizontally + row_img = np.concatenate(frame_images[start_idx:end_idx], axis=1) + + # If this row has fewer than frames_per_row images, pad with black + if end_idx - start_idx < frames_per_row: + padding_width = (frames_per_row - (end_idx - start_idx)) * W + padding = np.zeros((H, padding_width, 3), dtype=np.uint8) + row_img = np.concatenate([row_img, padding], axis=1) + + # Add this row to the grid + if grid_img is None: + grid_img = row_img + else: + grid_img = np.concatenate([grid_img, row_img], axis=0) + + out_path = os.path.join(out_dir, "tracks_grid.png") + # Convert back to BGR for OpenCV imwrite + grid_img_bgr = cv2.cvtColor(grid_img, cv2.COLOR_RGB2BGR) + cv2.imwrite(out_path, grid_img_bgr) + print(f"[INFO] Saved color-by-XY track visualization grid -> {out_path}") + + print(f"[INFO] Saved {S} individual frames to {out_dir}/frame_*.png")